Update README.md
Browse files
README.md
CHANGED
@@ -104,6 +104,7 @@ Predicting all summaries:
|
|
104 |
import json
|
105 |
import torch
|
106 |
from transformers import MBartTokenizer, MBartForConditionalGeneration
|
|
|
107 |
|
108 |
|
109 |
def gen_batch(inputs, batch_size):
|
@@ -115,26 +116,19 @@ def gen_batch(inputs, batch_size):
|
|
115 |
|
116 |
def predict(
|
117 |
model_name,
|
118 |
-
|
119 |
-
|
120 |
-
targets_file,
|
121 |
max_source_tokens_count=600,
|
122 |
-
use_cuda=True,
|
123 |
batch_size=4
|
124 |
):
|
125 |
-
|
126 |
-
|
127 |
-
with open(test_file, "r") as r:
|
128 |
-
for line in r:
|
129 |
-
record = json.loads(line)
|
130 |
-
inputs.append(record["text"])
|
131 |
-
targets.append(record["summary"].replace("\n", " "))
|
132 |
-
|
133 |
tokenizer = MBartTokenizer.from_pretrained(model_name)
|
134 |
-
device = torch.device("cuda:0") if use_cuda else torch.device("cpu")
|
135 |
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
|
|
|
136 |
predictions = []
|
137 |
for batch in gen_batch(inputs, batch_size):
|
|
|
138 |
input_ids = tokenizer(
|
139 |
batch,
|
140 |
return_tensors="pt",
|
@@ -142,22 +136,21 @@ def predict(
|
|
142 |
truncation=True,
|
143 |
max_length=max_source_tokens_count
|
144 |
)["input_ids"].to(device)
|
|
|
145 |
output_ids = model.generate(
|
146 |
input_ids=input_ids,
|
147 |
-
|
148 |
)
|
149 |
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
150 |
for s in summaries:
|
151 |
print(s)
|
152 |
predictions.extend(summaries)
|
153 |
-
with open(
|
154 |
for p in predictions:
|
155 |
w.write(p.strip().replace("\n", " ") + "\n")
|
156 |
-
with open(targets_file, "w") as w:
|
157 |
-
for t in targets:
|
158 |
-
w.write(t.strip().replace("\n", " ") + "\n")
|
159 |
|
160 |
-
|
|
|
161 |
```
|
162 |
|
163 |
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
|
|
|
104 |
import json
|
105 |
import torch
|
106 |
from transformers import MBartTokenizer, MBartForConditionalGeneration
|
107 |
+
from datasets import load_dataset
|
108 |
|
109 |
|
110 |
def gen_batch(inputs, batch_size):
|
|
|
116 |
|
117 |
def predict(
|
118 |
model_name,
|
119 |
+
input_records,
|
120 |
+
output_file,
|
|
|
121 |
max_source_tokens_count=600,
|
|
|
122 |
batch_size=4
|
123 |
):
|
124 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
125 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
tokenizer = MBartTokenizer.from_pretrained(model_name)
|
|
|
127 |
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
|
128 |
+
|
129 |
predictions = []
|
130 |
for batch in gen_batch(inputs, batch_size):
|
131 |
+
texts = [r["text"] for r in batch]
|
132 |
input_ids = tokenizer(
|
133 |
batch,
|
134 |
return_tensors="pt",
|
|
|
136 |
truncation=True,
|
137 |
max_length=max_source_tokens_count
|
138 |
)["input_ids"].to(device)
|
139 |
+
|
140 |
output_ids = model.generate(
|
141 |
input_ids=input_ids,
|
142 |
+
no_repeat_ngram_size=4
|
143 |
)
|
144 |
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
145 |
for s in summaries:
|
146 |
print(s)
|
147 |
predictions.extend(summaries)
|
148 |
+
with open(output_file, "w") as w:
|
149 |
for p in predictions:
|
150 |
w.write(p.strip().replace("\n", " ") + "\n")
|
|
|
|
|
|
|
151 |
|
152 |
+
gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"]
|
153 |
+
predict("IlyaGusev/mbart_ru_sum_gazeta", list(gazeta_test), "mbart_predictions.txt")
|
154 |
```
|
155 |
|
156 |
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
|