sentiment / goodies /train.py
dejanseo's picture
Upload 4 files
de55574 verified
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
from transformers import DataCollatorWithPadding
from datasets import load_metric, Dataset
import torch
import wandb
# Set tweakable parameters
model_name = 'albert-base-v2'
num_labels = 7 # Number of sentiment labels
output_dir = './albert_sentiment_model'
data_file = 'data.csv'
wandb_entity = 'dejan'
batch_size = 8
num_train_epochs = 30
learning_rate = 5e-5
# Initialize wandb
wandb.init(entity=wandb_entity, project="sentiment_classification")
# Load and preprocess the dataset
df = pd.read_csv(data_file, header=None, names=['text', 'label'])
# Remove leading instructions and prompts (assuming we know the prompt structure)
df['text'] = df['text'].apply(lambda x: x.split('Write nothing but the article text. Do not include the sentiment in the text of the article.')[-1].strip())
# Display the cleaned data
print(df.head())
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['text'].tolist(), df['label'].tolist(), test_size=0.2, random_state=42
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
train_dataset = Dataset.from_dict({
'input_ids': train_encodings['input_ids'],
'attention_mask': train_encodings['attention_mask'],
'labels': train_labels
})
val_dataset = Dataset.from_dict({
'input_ids': val_encodings['input_ids'],
'attention_mask': val_encodings['attention_mask'],
'labels': val_labels
})
# Define data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Define metrics
accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")
f1_metric = load_metric("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = torch.argmax(torch.tensor(logits), dim=-1)
accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
precision = precision_metric.compute(predictions=predictions, references=labels, average='weighted')
recall = recall_metric.compute(predictions=predictions, references=labels, average='weighted')
f1 = f1_metric.compute(predictions=predictions, references=labels, average='weighted')
wandb.log({
"eval_accuracy": accuracy["accuracy"],
"eval_precision": precision["precision"],
"eval_recall": recall["recall"],
"eval_f1": f1["f1"],
})
return {
"accuracy": accuracy["accuracy"],
"precision": precision["precision"],
"recall": recall["recall"],
"f1": f1["f1"],
}
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
learning_rate=learning_rate,
report_to="wandb",
lr_scheduler_type="linear",
logging_strategy="steps",
)
# Early stopping callback
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, patience=2):
self.patience = patience
self.best_metric = None
self.best_model_checkpoint = None
self.epochs_no_improve = 0
def on_evaluate(self, args, state, control, **kwargs):
eval_metric = kwargs['metrics'][training_args.metric_for_best_model]
if self.best_metric is None or eval_metric < self.best_metric:
self.best_metric = eval_metric
self.best_model_checkpoint = state.global_step
self.epochs_no_improve = 0
else:
self.epochs_no_improve += 1
if self.epochs_no_improve >= self.patience:
print(f"Stopping early after {self.epochs_no_improve} evaluations with no improvement.")
control.should_training_stop = True
# Trainer
trainer = Trainer(
model=AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels),
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(patience=2)]
)
# Train and save the final model
trainer.train()
trainer.save_model(output_dir)
# Finalize wandb
wandb.finish()
print(f"Training completed. Model saved to {output_dir}")