dappyx commited on
Commit
4634935
·
verified ·
1 Parent(s): 31a60dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -21,12 +21,10 @@ def qa_pipeline(text, question):
21
  }
22
 
23
  # Выполнение предсказания
24
- with torch.no_grad():
25
- outputs = model(**batch)
26
-
27
- # Извлечение логитов начала и конца ответа
28
- start_logits = outputs.start_logits
29
- end_logits = outputs.end_logits
30
 
31
  # Нахождение индексов начала и конца ответа
32
  start_index = torch.argmax(start_logits, dim=-1).item()
 
21
  }
22
 
23
  # Выполнение предсказания
24
+ start_logits, end_logits, loss = model(batch)
25
+
26
+ start_index = torch.argmax(start_logits, dim=-1).item()
27
+ end_index = torch.argmax(end_logits, dim=-1).item()
 
 
28
 
29
  # Нахождение индексов начала и конца ответа
30
  start_index = torch.argmax(start_logits, dim=-1).item()