Use pretrained model loaded on start
Browse files
app.py
CHANGED
@@ -9,7 +9,44 @@ from transformers import *
|
|
9 |
from blurr.text.data.all import *
|
10 |
from blurr.text.modeling.all import *
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def get_summary(text, sequences_num):
|
15 |
return learn.blurr_summarize(text, early_stopping=True, num_beams=4, num_return_sequences=sequences_num)[0]
|
|
|
9 |
from blurr.text.data.all import *
|
10 |
from blurr.text.modeling.all import *
|
11 |
|
12 |
+
import nltk
|
13 |
+
nltk.download('punkt', quiet=True)
|
14 |
+
|
15 |
+
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
|
16 |
+
df = pd.DataFrame(raw_data)
|
17 |
+
pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
|
18 |
+
hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
|
19 |
+
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization')
|
20 |
+
hf_batch_tfm = Seq2SeqBatchTokenizeTransform(
|
21 |
+
hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs
|
22 |
+
)
|
23 |
+
|
24 |
+
blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop)
|
25 |
+
dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
|
26 |
+
dls = dblock.dataloaders(df, bs=2)
|
27 |
+
seq2seq_metrics = {
|
28 |
+
'rouge': {
|
29 |
+
'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
|
30 |
+
'returns': ["rouge1", "rouge2", "rougeL"]
|
31 |
+
},
|
32 |
+
'bertscore': {
|
33 |
+
'compute_kwargs': { 'lang': 'en' },
|
34 |
+
'returns': ["precision", "recall", "f1"]
|
35 |
+
}
|
36 |
+
}
|
37 |
+
model = BaseModelWrapper(hf_model)
|
38 |
+
learn_cbs = [BaseModelCallback]
|
39 |
+
fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
|
40 |
+
|
41 |
+
learn = Learner(dls,
|
42 |
+
model,
|
43 |
+
opt_func=ranger,
|
44 |
+
loss_func=CrossEntropyLossFlat(),
|
45 |
+
cbs=learn_cbs,
|
46 |
+
splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16()
|
47 |
+
|
48 |
+
learn.create_opt()
|
49 |
+
learn.freeze()
|
50 |
|
51 |
def get_summary(text, sequences_num):
|
52 |
return learn.blurr_summarize(text, early_stopping=True, num_beams=4, num_return_sequences=sequences_num)[0]
|