from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from unsloth.chat_templates import train_on_responses_only

def finetune_model(model, tokenizer, dataset, learning_rate, batch_size, num_epochs):
    """
    Fine-tune a model on a given dataset, using CUDA if available.
    This version supports fine-tuning of quantized models using PEFT and Unsloth optimizations.
    
    Args:
    model: The pre-trained model to fine-tune.
    tokenizer: The tokenizer associated with the model.
    dataset: The dataset to use for fine-tuning.
    learning_rate (float): Learning rate for optimization.
    batch_size (int): Number of training examples used in one iteration.
    num_epochs (int): Number of complete passes through the dataset.
    
    Returns:
    SFTTrainer: The trained model wrapped in an SFTTrainer object.
    """
    # Prepare the model for training
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
    )

    # Set up the trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        dataset_text_field="text",
        max_seq_length=model.config.max_position_embeddings,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
        dataset_num_proc=2,
        packing=False,
        args=TrainingArguments(
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=4,
            warmup_steps=5,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=1,
            optim="adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3407,
            output_dir="outputs",
        ),
    )

    # Apply train_on_responses_only
    trainer = train_on_responses_only(
        trainer,
        instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
        response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
    )

    # Train the model
    trainer.train()

    return trainer