|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import hypertune |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
EvalPrediction, |
|
Trainer, |
|
TrainingArguments, |
|
default_data_collator, |
|
TrainerCallback |
|
) |
|
|
|
from trainer import model, metadata, utils |
|
|
|
|
|
class HPTuneCallback(TrainerCallback): |
|
""" |
|
A custom callback class that reports a metric to hypertuner |
|
at the end of each epoch. |
|
""" |
|
|
|
def __init__(self, metric_tag, metric_value): |
|
super(HPTuneCallback, self).__init__() |
|
self.metric_tag = metric_tag |
|
self.metric_value = metric_value |
|
self.hpt = hypertune.HyperTune() |
|
|
|
def on_evaluate(self, args, state, control, **kwargs): |
|
print(f"HP metric {self.metric_tag}={kwargs['metrics'][self.metric_value]}") |
|
self.hpt.report_hyperparameter_tuning_metric( |
|
hyperparameter_metric_tag=self.metric_tag, |
|
metric_value=kwargs['metrics'][self.metric_value], |
|
global_step=state.epoch) |
|
|
|
|
|
def compute_metrics(p: EvalPrediction): |
|
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions |
|
preds = np.argmax(preds, axis=1) |
|
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} |
|
|
|
|
|
def train(args, model, train_dataset, test_dataset): |
|
"""Create the training loop to load pretrained model and tokenizer and |
|
start the training process |
|
|
|
Args: |
|
args: read arguments from the runner to set training hyperparameters |
|
model: The neural network that you are training |
|
train_dataset: The training dataset |
|
test_dataset: The test dataset for evaluation |
|
""" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
metadata.PRETRAINED_MODEL_NAME, |
|
use_fast=True, |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
evaluation_strategy="epoch", |
|
learning_rate=args.learning_rate, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
num_train_epochs=args.num_epochs, |
|
weight_decay=args.weight_decay, |
|
output_dir=os.path.join("/tmp", args.model_name) |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model, |
|
training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
data_collator=default_data_collator, |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics |
|
) |
|
|
|
|
|
if args.hp_tune == "y": |
|
trainer.add_callback(HPTuneCallback("accuracy", "eval_accuracy")) |
|
|
|
|
|
trainer.train() |
|
|
|
return trainer |
|
|
|
|
|
def run(args): |
|
"""Load the data, train, evaluate, and export the model for serving and |
|
evaluating. |
|
|
|
Args: |
|
args: experiment parameters. |
|
""" |
|
|
|
train_dataset, test_dataset = utils.load_data(args) |
|
|
|
label_list = train_dataset.unique("label") |
|
num_labels = len(label_list) |
|
|
|
|
|
text_classifier = model.create(num_labels=num_labels) |
|
|
|
|
|
trainer = train(args, text_classifier, train_dataset, test_dataset) |
|
|
|
metrics = trainer.evaluate(eval_dataset=test_dataset) |
|
trainer.save_metrics("all", metrics) |
|
|
|
|
|
trainer.save_model(os.path.join("/tmp", args.model_name)) |
|
|
|
|
|
if args.job_dir: |
|
utils.save_model(args) |
|
else: |
|
print(f"Saved model files at {os.path.join('/tmp', args.model_name)}") |
|
print(f"To save model files in GCS bucket, please specify job_dir starting with gs://") |
|
|
|
|