from unsloth import FastLanguageModel
import torch
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer

# Load the model
model_name = "deepseek-ai/deepseek-coder-6.7b-instruct"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name, 
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# Load the dataset
dataset = load_dataset("tatsu-lab/alpaca")

# Format the dataset correctly
dataset = dataset.map(lambda x: {
    "input": x["instruction"],
    "output": x["output"],
}, remove_columns=["instruction", "output"])

# Apply LoRA to the model
model = FastLanguageModel.get_peft_model(
    model,
    r=16,  
    target_modules=["q_proj", "v_proj"],  
    lora_alpha=32,  
    lora_dropout=0.1,  
    bias="none",
)

# Function to check bfloat16 support
def is_bfloat16_supported():
    return torch.cuda.is_bf16_supported()

# Define training arguments
training_args = TrainingArguments(
    output_dir="outputs",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir="logs",
    logging_steps=100,
    fp16=not is_bfloat16_supported(),  # Use fp16 if bf16 is not supported
    bf16=is_bfloat16_supported(),      # Use bf16 if supported
    report_to="none",  
    num_train_epochs=3,
)

# Set up the SFT Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    args=training_args,
    tokenizer=tokenizer,
)

# Start fine-tuning
trainer.train()

# Save the fine-tuned model
model.save_pretrained("fine_tuned_llama_3_2")
tokenizer.save_pretrained("fine_tuned_llama_3_2")