aa_book / train.py
lambchop11's picture
Fix dataset processing and add proper data collator
7b026c9
import yaml
import os
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
def load_config():
try:
with open('training_config.yml', 'r') as file:
config = yaml.safe_load(file)
if not config:
raise ValueError("Empty configuration file")
return config
except FileNotFoundError:
raise FileNotFoundError("training_config.yml not found")
except yaml.YAMLError as e:
raise ValueError(f"Error parsing training_config.yml: {e}")
def validate_config(config):
training = config.get('training', {})
try:
# Explicitly convert all numeric values to their proper types
validated = {
'learning_rate': float(str(training.get('learning_rate', 5e-5)).strip()),
'num_train_epochs': int(str(training.get('num_train_epochs', 3)).strip()),
'per_device_train_batch_size': int(str(training.get('per_device_train_batch_size', 2)).strip()),
'gradient_accumulation_steps': int(str(training.get('gradient_accumulation_steps', 4)).strip()),
'save_steps': int(str(training.get('save_steps', 1000)).strip()),
'eval_steps': int(str(training.get('eval_steps', 500)).strip()),
'max_length': int(str(training.get('max_length', 512)).strip())
}
# Print values for debugging
print("Validated config values:", validated)
return validated
except (ValueError, TypeError) as e:
raise ValueError(f"Error converting config values: {str(e)}")
def train():
try:
config = load_config()
training_params = validate_config(config)
# Login to Hugging Face
token = os.environ.get('HUGGINGFACE_TOKEN')
if not token:
raise ValueError("HUGGINGFACE_TOKEN not set")
login(token)
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config['base_model'])
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
print("Loading dataset...")
raw_dataset = load_dataset('json', data_files=config['dataset'])
print(f"Dataset loaded: {raw_dataset}")
# Prepare the texts first
def prepare_texts(examples):
texts = []
for prompt, response in zip(examples['prompt'], examples['response']):
text = f"### Question: {str(prompt)}\n\n### Answer: {str(response)}"
texts.append(text)
return {'text': texts}
# Convert to the format we want
processed_dataset = raw_dataset.map(
prepare_texts,
batched=True,
remove_columns=raw_dataset['train'].column_names
)
print(f"Processed dataset: {processed_dataset}")
# Now tokenize
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=512
)
tokenized_dataset = processed_dataset.map(
tokenize_function,
batched=True,
remove_columns=['text']
)
print(f"Tokenized dataset: {tokenized_dataset}")
# Load model
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
config['base_model'],
load_in_8bit=True,
device_map="auto"
)
model = prepare_model_for_kbit_training(model)
# LoRA configuration
print("Configuring LoRA...")
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Training arguments
print("Setting up training arguments...")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=training_params['num_train_epochs'],
per_device_train_batch_size=training_params['per_device_train_batch_size'],
gradient_accumulation_steps=training_params['gradient_accumulation_steps'],
learning_rate=training_params['learning_rate'],
save_steps=training_params['save_steps'],
eval_steps=training_params['eval_steps'],
logging_steps=100,
remove_unused_columns=False,
push_to_hub=True,
hub_model_id=config['model_id']
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator,
tokenizer=tokenizer
)
print("Starting training...")
trainer.train()
print("Pushing model to hub...")
trainer.push_to_hub()
return "Training completed successfully!"
except Exception as e:
print(f"Error details: {str(e)}")
return f"Error during training: {str(e)}"
if __name__ == "__main__":
train()