minjibi commited on
Commit
c9b5ac1
·
1 Parent(s): 68e81e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -16,8 +16,6 @@ tokenizer = MT5TokenizerFast.from_pretrained(
16
  )
17
 
18
  def predict(text):
19
-
20
- # with torch.no_grad():
21
  input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
22
  generated_ids = model.generate(
23
  input_ids=input_ids,
@@ -30,7 +28,7 @@ def predict(text):
30
  top_k=20, #default 20
31
  num_return_sequences=3,
32
  )
33
-
34
  preds = [
35
  tokenizer.decode(
36
  g,
@@ -39,10 +37,12 @@ def predict(text):
39
  )
40
  for g in generated_ids
41
  ]
42
-
43
  output = ['Q: ' + text for text in preds]
44
- final_str = '\n'.join([f"{i+1}. Question: {s.split('Answer')[0].strip()}\n Answer{s.split('Answer')[1].strip()}" for i, s in enumerate(output)])
45
- return final_str
 
 
46
 
47
  # text_to_predict = predict(text)
48
  # predicted = ['Q: ' + text for text in predict(text_to_predict)]
 
16
  )
17
 
18
  def predict(text):
 
 
19
  input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
20
  generated_ids = model.generate(
21
  input_ids=input_ids,
 
28
  top_k=20, #default 20
29
  num_return_sequences=3,
30
  )
31
+
32
  preds = [
33
  tokenizer.decode(
34
  g,
 
37
  )
38
  for g in generated_ids
39
  ]
40
+
41
  output = ['Q: ' + text for text in preds]
42
+ final_str = '\n'.join([f"{i+1}. {s}" for i, s in enumerate(output)])
43
+ #final_str = '\n'.join([f"{i+1}. Question: {s.split('Answer')[0].strip()}\n Answer{s.split('Answer')[1].strip()}" for i, s in enumerate(output)])
44
+
45
+ return final_str
46
 
47
  # text_to_predict = predict(text)
48
  # predicted = ['Q: ' + text for text in predict(text_to_predict)]