Another fix approach
Browse files
app.py
CHANGED
@@ -9,6 +9,29 @@ from fastai.callback.all import *
|
|
9 |
from fastai.learner import *
|
10 |
from fastai.optimizer import *
|
11 |
from transformers import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
learn = load_learner(fname='cnn_model.pkl')
|
14 |
|
|
|
9 |
from fastai.learner import *
|
10 |
from fastai.optimizer import *
|
11 |
from transformers import *
|
12 |
+
import nltk
|
13 |
+
|
14 |
+
nltk.download('punkt', quiet=True)
|
15 |
+
|
16 |
+
pre_trained_model_name = "sshleifer/distilbart-cnn-6-6"
|
17 |
+
hf_arch, hf_config, hf_tokenizer, hf_model = BlurrText().get_hf_objects(pre_trained_model_name, model_cls=BartForConditionalGeneration)
|
18 |
+
summarization_metrics = {
|
19 |
+
"rouge": {
|
20 |
+
"compute_kwargs": {"rouge_types": ["rouge1", "rouge2", "rougeL", "rougeLsum"], "use_stemmer": True}, "returns": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
|
21 |
+
},
|
22 |
+
"bertscore": {
|
23 |
+
"compute_kwargs": {"lang": "en"}, "returns": ["precision", "recall", "f1"]
|
24 |
+
}
|
25 |
+
}
|
26 |
+
translation_metrics = {
|
27 |
+
"bleu": {"returns": "bleu"},
|
28 |
+
"meteor": {"returns": "maeteor"},
|
29 |
+
"sacrebleu": {"returns": "score"},
|
30 |
+
}
|
31 |
+
|
32 |
+
model = BaseModelWrapper(hf_model)
|
33 |
+
learn_cbs = [BaseModelCallback]
|
34 |
+
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=summarization_metrics, calc_every="last_epoch")]
|
35 |
|
36 |
learn = load_learner(fname='cnn_model.pkl')
|
37 |
|