haoqi7's picture
Upload 47 files
16188ba
def train(
push_to_hub:bool,
num_epoch: int,
train_batch_size: int,
eval_batch_size: int,
):
import torch
import numpy as np
# 1. Dataset
from datasets import load_dataset
dataset = load_dataset("Adapting/abstract-keyphrases")
# 2. Model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from lrt.clustering.models import KeyBartAdapter
tokenizer = AutoTokenizer.from_pretrained("Adapting/KeyBartAdapter")
'''
Or you can just use the initial model weights from Huggingface:
model = AutoModelForSeq2SeqLM.from_pretrained("Adapting/KeyBartAdapter",
revision='9c3ed39c6ed5c7e141363e892d77cf8f589d5999')
'''
model = KeyBartAdapter(256)
# 3. preprocess dataset
dataset = dataset.shuffle()
def preprocess_function(examples):
inputs = examples['Abstract']
targets = examples['Keywords']
model_inputs = tokenizer(inputs, truncation=True)
# Set up the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset["train"].column_names,
)
# 4. evaluation metrics
def compute_metrics(eval_preds):
preds = eval_preds.predictions
labels = eval_preds.label_ids
if isinstance(preds, tuple):
preds = preds[0]
print(preds.shape)
if len(preds.shape) == 3:
preds = preds.argmax(axis=-1)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [a.strip().split(';') for a in decoded_preds]
decoded_labels = [a.strip().split(';') for a in decoded_labels]
precs, recalls, f_scores = [], [], []
num_match, num_pred, num_gold = [], [], []
for pred, label in zip(decoded_preds, decoded_labels):
pred_set = set(pred)
label_set = set(label)
match_set = label_set.intersection(pred_set)
p = float(len(match_set)) / float(len(pred_set)) if len(pred_set) > 0 else 0.0
r = float(len(match_set)) / float(len(label_set)) if len(label_set) > 0 else 0.0
f1 = float(2 * (p * r)) / (p + r) if (p + r) > 0 else 0.0
precs.append(p)
recalls.append(r)
f_scores.append(f1)
num_match.append(len(match_set))
num_pred.append(len(pred_set))
num_gold.append(len(label_set))
# print(f'raw_PRED: {raw_pred}')
print(f'PRED: num={len(pred_set)} - {pred_set}')
print(f'GT: num={len(label_set)} - {label_set}')
print(f'p={p}, r={r}, f1={f1}')
print('-' * 20)
result = {
'precision@M': np.mean(precs) * 100.0,
'recall@M': np.mean(recalls) * 100.0,
'fscore@M': np.mean(f_scores) * 100.0,
'num_match': np.mean(num_match),
'num_pred': np.mean(num_pred),
'num_gold': np.mean(num_gold),
}
result = {k: round(v, 2) for k, v in result.items()}
return result
# 5. train
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
model_name = 'KeyBartAdapter'
args = Seq2SeqTrainingArguments(
model_name,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=num_epoch,
logging_steps=4,
load_best_model_at_end=True,
metric_for_best_model='fscore@M',
predict_with_generate=True,
fp16=torch.cuda.is_available(), # speeds up training on modern GPUs.
# eval_accumulation_steps=10,
)
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["train"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
# 6. push
if push_to_hub:
commit_msg = f'{model_name}_{num_epoch}'
tokenizer.push_to_hub(commit_message=commit_msg, repo_id=model_name)
model.push_to_hub(commit_message=commit_msg, repo_id=model_name)
return model, tokenizer
if __name__ == '__main__':
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(project_root.__str__())
# code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", help="number of epochs", default=30)
parser.add_argument("--train_batch_size", help="training batch size", default=16)
parser.add_argument("--eval_batch_size", help="evaluation batch size", default=16)
parser.add_argument("--push", help="whether push the model to hub", action='store_true')
args = parser.parse_args()
print(args)
model, tokenizer = train(
push_to_hub= bool(args.push),
num_epoch= int(args.epoch),
train_batch_size= int(args.train_batch_size),
eval_batch_size= int(args.eval_batch_size)
)