awitecki commited on
Commit
03ed2aa
·
1 Parent(s): 8037565

Another fix approach

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -9,6 +9,29 @@ from fastai.callback.all import *
9
  from fastai.learner import *
10
  from fastai.optimizer import *
11
  from transformers import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  learn = load_learner(fname='cnn_model.pkl')
14
 
 
9
  from fastai.learner import *
10
  from fastai.optimizer import *
11
  from transformers import *
12
+ import nltk
13
+
14
+ nltk.download('punkt', quiet=True)
15
+
16
+ pre_trained_model_name = "sshleifer/distilbart-cnn-6-6"
17
+ hf_arch, hf_config, hf_tokenizer, hf_model = BlurrText().get_hf_objects(pre_trained_model_name, model_cls=BartForConditionalGeneration)
18
+ summarization_metrics = {
19
+ "rouge": {
20
+ "compute_kwargs": {"rouge_types": ["rouge1", "rouge2", "rougeL", "rougeLsum"], "use_stemmer": True}, "returns": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
21
+ },
22
+ "bertscore": {
23
+ "compute_kwargs": {"lang": "en"}, "returns": ["precision", "recall", "f1"]
24
+ }
25
+ }
26
+ translation_metrics = {
27
+ "bleu": {"returns": "bleu"},
28
+ "meteor": {"returns": "maeteor"},
29
+ "sacrebleu": {"returns": "score"},
30
+ }
31
+
32
+ model = BaseModelWrapper(hf_model)
33
+ learn_cbs = [BaseModelCallback]
34
+ fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=summarization_metrics, calc_every="last_epoch")]
35
 
36
  learn = load_learner(fname='cnn_model.pkl')
37