import pandas as pd from transformers import TrainingArguments from trl import SFTTrainer from peft import LoraConfig, PeftModel, get_peft_model from model import load_model import config # Make sure you have a valid config module from dataset import CustomDataset, template_dataset from datasets import Dataset, Features, Value, Sequence if __name__ == '__main__': model, tokenizer, peft_config = load_model(config.model_name) df = pd.read_csv('../data/trainv2.csv') #data data_list = df.to_dict(orient='records') # custom dataset object custom_dataset = CustomDataset(data_list) df = pd.DataFrame(custom_dataset.data, columns=["instruction", "context", "response"]) # Dataset features features = Features({ "instruction": Value("string"), "context": Value("string"), "response": Value("string"), }) # Create a Hugging Face Dataset from the Pandas DataFrame hugging_face_dataset = Dataset.from_pandas(df, features=features) dataset = hugging_face_dataset.map(lambda x: template_dataset(x, tokenizer), remove_columns=list(hugging_face_dataset.features)) print("----training data structure----",dataset) # Training Arguments training_arguments = TrainingArguments( output_dir=config.output_dir, per_device_train_batch_size=config.per_device_train_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, optim=config.optim, save_steps=config.save_steps, logging_steps=config.logging_steps, learning_rate=config.learning_rate, fp16=config.fp16, bf16=config.bf16, max_grad_norm=config.max_grad_norm, max_steps=config.max_steps, warmup_ratio=config.warmup_ratio, group_by_length=config.group_by_length, lr_scheduler_type=config.lr_scheduler_type, report_to="tensorboard" ) # SFTTrainer trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, dataset_text_field="text", max_seq_length=config.max_seq_length, tokenizer=tokenizer, args=training_arguments, packing=config.packing, ) print("**************** TRAINING STARTED ****************") trainer.train() trainer.model.save_pretrained(config.output_dir) print("**************** TRAINING OVER ****************")