jjuarez commited on
Commit
ecbf4e6
1 Parent(s): 6846b58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -28
app.py CHANGED
@@ -8,17 +8,17 @@ import nltk
8
  nltk.download("punkt")
9
  raw_dataset = load_dataset("scientific_papers", "pubmed")
10
  metric = evaluate.load("rouge")
11
- model_checkpoint = "t5-small"
12
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
 
14
- if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
15
  prefix = "summarize: "
16
  else:
17
  prefix = ""
18
 
19
  # preprocessing function
20
- max_input_length = 256
21
- max_target_length = 64
22
  def preprocess_function(examples):
23
  inputs = [prefix + doc for doc in examples["article"]]
24
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
@@ -31,23 +31,23 @@ def preprocess_function(examples):
31
  return model_inputs
32
 
33
  for split in ["train", "validation", "test"]:
34
- raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 200)])
35
-
36
  tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)
37
 
 
38
  model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
39
 
40
- batch_size = 4
41
 
42
  args = Seq2SeqTrainingArguments(
43
  f"{model_checkpoint}-scientific_papers",
44
  evaluation_strategy="epoch",
45
- learning_rate=3e-5,
46
  per_device_train_batch_size=batch_size,
47
  per_device_eval_batch_size=batch_size,
48
  weight_decay=0.01,
49
  save_total_limit=3,
50
- num_train_epochs=0.5,
51
  predict_with_generate=True,
52
  # fp16=True,
53
  push_to_hub=False,
@@ -69,40 +69,35 @@ def compute_metrics(eval_pred):
69
  result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
70
  # Extract a few results
71
  result = {key: value * 100 for key, value in result.items()}
72
- # Add mean generated length
73
  prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
74
  result["gen_len"] = np.mean(prediction_lens)
75
  return {k: round(v, 4) for k, v in result.items()}
76
 
77
-
78
- # Define the training and evaluation datasets
79
- train_dataset = tokenized_dataset["train"]
80
- eval_dataset = tokenized_dataset["validation"]
81
-
82
- # Create the trainer object
83
  trainer = Seq2SeqTrainer(
84
- model=model,
85
- args=args,
86
- train_dataset=train_dataset,
87
- eval_dataset=eval_dataset,
88
- data_collator=data_collator,
89
- compute_metrics=compute_metrics,
 
90
  )
91
-
92
- # Train the model
93
  trainer.train()
94
 
95
  # Define the input and output interface of the app
 
 
96
  def summarizer(input_text):
97
  inputs = [prefix + input_text]
98
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
99
  summary_ids = model.generate(
100
  input_ids=model_inputs["input_ids"],
101
  attention_mask=model_inputs["attention_mask"],
102
- num_beams=6,
103
- length_penalty=2.5,
104
  max_length=max_target_length + 2, # +2 from original because we start at step=1 and stop before max_length
105
- repetition_penalty=3.5,
106
  early_stopping=True,
107
  use_cache=True
108
  )
@@ -119,4 +114,3 @@ iface = gr.Interface(
119
  theme="gray"
120
  )
121
  iface.launch()
122
-
 
8
  nltk.download("punkt")
9
  raw_dataset = load_dataset("scientific_papers", "pubmed")
10
  metric = evaluate.load("rouge")
11
+ model_checkpoint = "google/flan-t5-small"
12
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
 
14
+ if model_checkpoint in ["google/flan-t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
15
  prefix = "summarize: "
16
  else:
17
  prefix = ""
18
 
19
  # preprocessing function
20
+ max_input_length = 512
21
+ max_target_length = 128
22
  def preprocess_function(examples):
23
  inputs = [prefix + doc for doc in examples["article"]]
24
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
 
31
  return model_inputs
32
 
33
  for split in ["train", "validation", "test"]:
34
+ raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 1_000)])
 
35
  tokenized_dataset = raw_dataset.map(preprocess_function, batched=True)
36
 
37
+
38
  model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
39
 
40
+ batch_size = 8
41
 
42
  args = Seq2SeqTrainingArguments(
43
  f"{model_checkpoint}-scientific_papers",
44
  evaluation_strategy="epoch",
45
+ learning_rate=2e-5,
46
  per_device_train_batch_size=batch_size,
47
  per_device_eval_batch_size=batch_size,
48
  weight_decay=0.01,
49
  save_total_limit=3,
50
+ num_train_epochs=1,
51
  predict_with_generate=True,
52
  # fp16=True,
53
  push_to_hub=False,
 
69
  result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
70
  # Extract a few results
71
  result = {key: value * 100 for key, value in result.items()}
72
+ # Add mean generated length
73
  prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
74
  result["gen_len"] = np.mean(prediction_lens)
75
  return {k: round(v, 4) for k, v in result.items()}
76
 
 
 
 
 
 
 
77
  trainer = Seq2SeqTrainer(
78
+ model,
79
+ args,
80
+ train_dataset=tokenized_dataset["train"],
81
+ eval_dataset=tokenized_dataset["validation"],
82
+ data_collator=data_collator,
83
+ tokenizer=tokenizer,
84
+ compute_metrics=compute_metrics
85
  )
 
 
86
  trainer.train()
87
 
88
  # Define the input and output interface of the app
89
+ import gradio as gr
90
+
91
  def summarizer(input_text):
92
  inputs = [prefix + input_text]
93
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
94
  summary_ids = model.generate(
95
  input_ids=model_inputs["input_ids"],
96
  attention_mask=model_inputs["attention_mask"],
97
+ num_beams=4,
98
+ length_penalty=2.0,
99
  max_length=max_target_length + 2, # +2 from original because we start at step=1 and stop before max_length
100
+ repetition_penalty=2.0,
101
  early_stopping=True,
102
  use_cache=True
103
  )
 
114
  theme="gray"
115
  )
116
  iface.launch()