import math
import torch
import torchvision.transforms as T
from os import path
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
from torchmetrics.functional import accuracy
from timm import create_model, list_models
from timm.models.vision_transformer import VisionTransformer
from torchvision.datasets import ImageFolder

from utils import AverageMeter
from lightning import LightningDataModule, LightningModule
from huggingface_hub import PyTorchModelHubMixin, login
import torch.nn as nn
from lora import LoRA_qkv


PRE_SIZE = (256, 256)
IMG_SIZE = (224, 224)

STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
DATASET_DIRECTORY = path.join(path.dirname(__file__), "datasets")
CHECKPOINT_DIRECTORY = path.join(path.dirname(__file__), "checkpoints")

TRANSFORMS = {
    "train": T.Compose([
        T.Resize(PRE_SIZE),
        T.RandomCrop(IMG_SIZE),
        T.ToTensor(),
        T.Normalize(**STATS)
    ]),
    "val": T.Compose([
        T.Resize(PRE_SIZE),
        T.CenterCrop(IMG_SIZE),
        T.ToTensor(),
        T.Normalize(**STATS)
    ])
}



class myDataModule(LightningDataModule):
    """
    Lightning DataModule for loading and preparing the image dataset.

    Args:
        ds_name (str): Name of the dataset directory.
        batch_size (int): Batch size for data loaders.
        num_workers (int): Number of workers for data loaders.
    """
    def __init__(self, ds_name: str = "deities", batch_size: int = 32, num_workers: int = 8):
        super(myDataModule, self).__init__()

        self.ds_path = path.join(DATASET_DIRECTORY, ds_name)
        assert path.exists(self.ds_path), f"Dataset {ds_name} not found in {DATASET_DIRECTORY}."

        self.ds_name = ds_name
        self.batch_size = batch_size
        self.num_workers = num_workers


    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_ds = ImageFolder(root=path.join(self.ds_path, 'train'), transform=TRANSFORMS['train'])
            self.val_ds = ImageFolder(root=path.join(self.ds_path, 'val'), transform=TRANSFORMS['val'])
            # Number of classes
            self.num_classes = len(self.train_ds.classes)          
    

    def train_dataloader(self) -> DataLoader:
        # Weighted Random sampler for imbalanced dataset
        class_samples = [0] * self.num_classes
        for _, (_, label) in enumerate(self.train_ds):
            class_samples[label] += 1
        weights = [1.0 / class_samples[label] for _, label in self.train_ds]
        self.sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        return DataLoader(dataset=self.train_ds, batch_size=self.batch_size, 
                          sampler=self.sampler, num_workers=self.num_workers, persistent_workers=True)


    def val_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.val_ds, batch_size=self.batch_size, 
                          shuffle=False, num_workers=self.num_workers, persistent_workers=True)
    



class myModule(LightningModule, PyTorchModelHubMixin):
    """
    Lightning Module for training and evaluating the Image classification model.

    Args:
        model_name (str): Name of the Vision Transformer model.
        num_classes (int): Number of classes in the dataset.
        freeze_flag (bool): Flag to freeze the base model parameters.
        use_lora (bool): Flag to use LoRA (Local Rank Adaptation) for fine-tuning.
        rank (int): Rank for LoRA if use_lora is True.
        learning_rate (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay for the optimizer.
        push_to_hf (bool): Flag to push model to Huggingface Hub.
        commit_message (str): Commit message
        repo_id (str): Huggingface repo id
    """
    def __init__(self, 
                 model_name: str = "vit_tiny_patch16_224", 
                 num_classes: int = 25,
                 freeze_flag: bool = True,
                 use_lora: bool = False, 
                 rank: int = None, 
                 learning_rate: float = 3e-4, 
                 weight_decay: float = 2e-5,
                 push_to_hf: bool = True,
                 commit_message: str = "my model",
                 repo_id: str = "Yegiiii/ideityfy"
        ):
    
        super(myModule, self).__init__()
        self.save_hyperparameters()
        self.model_name = model_name
        self.num_classes = num_classes
        self.freeze_flag = freeze_flag
        self.rank = rank
        self.use_lora = use_lora
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.push_to_hf = push_to_hf
        self.commit_message = commit_message
        self.repo_id = repo_id
        
        assert model_name in list_models(), f"Timm model name {model_name} not available."
        timm_model = create_model(model_name, pretrained=True)
        assert isinstance(timm_model, VisionTransformer), f"{model_name} not a Vision Transformer."
        self.model = timm_model

        if freeze_flag:
            # Freeze the Timm model parameters
            self.freeze()

        if use_lora:
            # Add LoRA matrices to the Timm model
            assert freeze_flag, "Set freeze_flag to True for using LoRA fine-tuning."
            assert rank, "Rank can't be None."
            # self.model = LoRA_VisionTransformer(self.model, rank)
            self.add_lora()

        self.model.reset_classifier(num_classes)

        # Loss function
        self.criterion = CrossEntropyLoss()

        # Validation metrics
        self.top1_acc = AverageMeter()
        self.top3_acc = AverageMeter()
        self.top5_acc = AverageMeter()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
    
    def on_fit_start(self) -> None:
        num_classes = self.trainer.datamodule.num_classes
        assert num_classes == self.num_classes, \
        f"Number of classes provided in the argument ({self.num_classes}) is not matching \
         the number of classes in the dataset ({num_classes})."


    def on_fit_end(self) -> None:
        if self.push_to_hf:
            login()
            self.push_to_hub(repo_id=self.repo_id, commit_message=self.commit_message)


    def configure_optimizers(self):
        optimizer = AdamW(params=filter(lambda param: param.requires_grad, self.model.parameters()), 
                          lr=self.learning_rate, weight_decay=self.weight_decay)
        
        scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs, 1e-6)
        return ([optimizer], [scheduler])


    def shared_step(self, x: torch.Tensor, y: torch.Tensor):
        logits = self(x)
        loss = self.criterion(logits, y)   
        return logits, loss 


    def training_step(self, batch, batch_idx) -> torch.Tensor:
        x, y = batch
        _, loss = self.shared_step(x, y)

        self.log("train_loss", loss, prog_bar=True, logger=True, on_epoch=True)
        return loss


    def validation_step(self, batch, batch_idx) -> dict:
        x, y = batch
        logits, loss = self.shared_step(x, y)

        self.top1_acc(
            val=accuracy(logits, y, average="weighted", top_k=1, num_classes=self.num_classes))
        self.top3_acc(
            val=accuracy(logits, y, average="weighted", top_k=3, num_classes=self.num_classes))
        self.top5_acc(
            val=accuracy(logits, y, average="weighted", top_k=5, num_classes=self.num_classes))

        metric_dict = {
            "val_loss": loss, 
            "top1_acc": self.top1_acc.avg, 
            "top3_acc": self.top3_acc.avg, 
            "top5_acc": self.top5_acc.avg
        }
        
        self.log_dict(metric_dict, prog_bar=True, logger=True, on_epoch=True)
        return  metric_dict

    
    def on_validation_epoch_end(self) -> None:
        self.top1_acc.reset()
        self.top3_acc.reset()
        self.top5_acc.reset()


    def add_lora(self):
        self.w_As = []
        self.w_Bs = []

        for _, blk in enumerate(self.model.blocks):
            w_qkv_linear = blk.attn.qkv
            self.dim = w_qkv_linear.in_features
            lora_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
            lora_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
            lora_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
            lora_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
            self.w_As.append(lora_a_linear_q)
            self.w_Bs.append(lora_b_linear_q)
            self.w_As.append(lora_a_linear_v)
            self.w_Bs.append(lora_b_linear_v)
            blk.attn.qkv = LoRA_qkv(w_qkv_linear, lora_a_linear_q, 
                                    lora_b_linear_q, lora_a_linear_v, lora_b_linear_v)

        for w_A in self.w_As:
            nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
        for w_B in self.w_Bs:
            nn.init.zeros_(w_B.weight)



if __name__ == "__main__":
    # from torchinfo import summary
    
    # module = myModule(freeze_flag=False)
    # summary(module, (1, 3, 224, 224))

    from datasets import load_dataset

    dataset = load_dataset("Yegiiii/deities")
    print(dataset)