Spaces:
Runtime error
Runtime error
def train( | |
push_to_hub:bool, | |
num_epoch: int, | |
train_batch_size: int, | |
eval_batch_size: int, | |
): | |
import torch | |
import numpy as np | |
# 1. Dataset | |
from datasets import load_dataset | |
dataset = load_dataset("Adapting/abstract-keyphrases") | |
# 2. Model | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from lrt.clustering.models import KeyBartAdapter | |
tokenizer = AutoTokenizer.from_pretrained("Adapting/KeyBartAdapter") | |
''' | |
Or you can just use the initial model weights from Huggingface: | |
model = AutoModelForSeq2SeqLM.from_pretrained("Adapting/KeyBartAdapter", | |
revision='9c3ed39c6ed5c7e141363e892d77cf8f589d5999') | |
''' | |
model = KeyBartAdapter(256) | |
# 3. preprocess dataset | |
dataset = dataset.shuffle() | |
def preprocess_function(examples): | |
inputs = examples['Abstract'] | |
targets = examples['Keywords'] | |
model_inputs = tokenizer(inputs, truncation=True) | |
# Set up the tokenizer for targets | |
with tokenizer.as_target_tokenizer(): | |
labels = tokenizer(targets, truncation=True) | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
tokenized_dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
remove_columns=dataset["train"].column_names, | |
) | |
# 4. evaluation metrics | |
def compute_metrics(eval_preds): | |
preds = eval_preds.predictions | |
labels = eval_preds.label_ids | |
if isinstance(preds, tuple): | |
preds = preds[0] | |
print(preds.shape) | |
if len(preds.shape) == 3: | |
preds = preds.argmax(axis=-1) | |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
# Replace -100 in the labels as we can't decode them. | |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Some simple post-processing | |
decoded_preds = [a.strip().split(';') for a in decoded_preds] | |
decoded_labels = [a.strip().split(';') for a in decoded_labels] | |
precs, recalls, f_scores = [], [], [] | |
num_match, num_pred, num_gold = [], [], [] | |
for pred, label in zip(decoded_preds, decoded_labels): | |
pred_set = set(pred) | |
label_set = set(label) | |
match_set = label_set.intersection(pred_set) | |
p = float(len(match_set)) / float(len(pred_set)) if len(pred_set) > 0 else 0.0 | |
r = float(len(match_set)) / float(len(label_set)) if len(label_set) > 0 else 0.0 | |
f1 = float(2 * (p * r)) / (p + r) if (p + r) > 0 else 0.0 | |
precs.append(p) | |
recalls.append(r) | |
f_scores.append(f1) | |
num_match.append(len(match_set)) | |
num_pred.append(len(pred_set)) | |
num_gold.append(len(label_set)) | |
# print(f'raw_PRED: {raw_pred}') | |
print(f'PRED: num={len(pred_set)} - {pred_set}') | |
print(f'GT: num={len(label_set)} - {label_set}') | |
print(f'p={p}, r={r}, f1={f1}') | |
print('-' * 20) | |
result = { | |
'precision@M': np.mean(precs) * 100.0, | |
'recall@M': np.mean(recalls) * 100.0, | |
'fscore@M': np.mean(f_scores) * 100.0, | |
'num_match': np.mean(num_match), | |
'num_pred': np.mean(num_pred), | |
'num_gold': np.mean(num_gold), | |
} | |
result = {k: round(v, 2) for k, v in result.items()} | |
return result | |
# 5. train | |
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
model_name = 'KeyBartAdapter' | |
args = Seq2SeqTrainingArguments( | |
model_name, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=train_batch_size, | |
per_device_eval_batch_size=eval_batch_size, | |
weight_decay=0.01, | |
save_total_limit=3, | |
num_train_epochs=num_epoch, | |
logging_steps=4, | |
load_best_model_at_end=True, | |
metric_for_best_model='fscore@M', | |
predict_with_generate=True, | |
fp16=torch.cuda.is_available(), # speeds up training on modern GPUs. | |
# eval_accumulation_steps=10, | |
) | |
trainer = Seq2SeqTrainer( | |
model, | |
args, | |
train_dataset=tokenized_dataset["train"], | |
eval_dataset=tokenized_dataset["train"], | |
data_collator=data_collator, | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics | |
) | |
trainer.train() | |
# 6. push | |
if push_to_hub: | |
commit_msg = f'{model_name}_{num_epoch}' | |
tokenizer.push_to_hub(commit_message=commit_msg, repo_id=model_name) | |
model.push_to_hub(commit_message=commit_msg, repo_id=model_name) | |
return model, tokenizer | |
if __name__ == '__main__': | |
import sys | |
from pathlib import Path | |
project_root = Path(__file__).parent.parent.parent.absolute() | |
sys.path.append(project_root.__str__()) | |
# code | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--epoch", help="number of epochs", default=30) | |
parser.add_argument("--train_batch_size", help="training batch size", default=16) | |
parser.add_argument("--eval_batch_size", help="evaluation batch size", default=16) | |
parser.add_argument("--push", help="whether push the model to hub", action='store_true') | |
args = parser.parse_args() | |
print(args) | |
model, tokenizer = train( | |
push_to_hub= bool(args.push), | |
num_epoch= int(args.epoch), | |
train_batch_size= int(args.train_batch_size), | |
eval_batch_size= int(args.eval_batch_size) | |
) | |