awitecki commited on
Commit
d65b86c
·
1 Parent(s): 07c8d3a

Use pretrained model loaded on start

Browse files
Files changed (1) hide show
  1. app.py +38 -1
app.py CHANGED
@@ -9,7 +9,44 @@ from transformers import *
9
  from blurr.text.data.all import *
10
  from blurr.text.modeling.all import *
11
 
12
- learn = load_learner(fname='model.pkl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_summary(text, sequences_num):
15
  return learn.blurr_summarize(text, early_stopping=True, num_beams=4, num_return_sequences=sequences_num)[0]
 
9
  from blurr.text.data.all import *
10
  from blurr.text.modeling.all import *
11
 
12
+ import nltk
13
+ nltk.download('punkt', quiet=True)
14
+
15
+ raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')
16
+ df = pd.DataFrame(raw_data)
17
+ pretrained_model_name = "sshleifer/distilbart-cnn-6-6"
18
+ hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name, model_cls=BartForConditionalGeneration)
19
+ text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization')
20
+ hf_batch_tfm = Seq2SeqBatchTokenizeTransform(
21
+ hf_arch, hf_config, hf_tokenizer, hf_model, max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs
22
+ )
23
+
24
+ blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=hf_batch_tfm), noop)
25
+ dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
26
+ dls = dblock.dataloaders(df, bs=2)
27
+ seq2seq_metrics = {
28
+ 'rouge': {
29
+ 'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
30
+ 'returns': ["rouge1", "rouge2", "rougeL"]
31
+ },
32
+ 'bertscore': {
33
+ 'compute_kwargs': { 'lang': 'en' },
34
+ 'returns': ["precision", "recall", "f1"]
35
+ }
36
+ }
37
+ model = BaseModelWrapper(hf_model)
38
+ learn_cbs = [BaseModelCallback]
39
+ fit_cbs = [Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]
40
+
41
+ learn = Learner(dls,
42
+ model,
43
+ opt_func=ranger,
44
+ loss_func=CrossEntropyLossFlat(),
45
+ cbs=learn_cbs,
46
+ splitter=partial(blurr_seq2seq_splitter, arch=hf_arch)).to_fp16()
47
+
48
+ learn.create_opt()
49
+ learn.freeze()
50
 
51
  def get_summary(text, sequences_num):
52
  return learn.blurr_summarize(text, early_stopping=True, num_beams=4, num_return_sequences=sequences_num)[0]