import logging
from typing import Any, Union, List, Optional, Tuple, Dict
import open_clip
from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD

import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import cv2 as cv2

from .gem_wrapper import GEMWrapper


_MODELS = {
    # B/32
    "ViT-B/32": [
        "openai",
        "laion400m_e31",
        "laion400m_e32",
        "laion2b_e16",
        "laion2b_s34b_b79k",
    ],

    "ViT-B/32-quickgelu": [
        "metaclip_400m",
        "metaclip_fullcc"
    ],
    # B/16
    "ViT-B/16": [
        "openai",
        "laion400m_e31",
        "laion400m_e32",
        "laion2b_s34b_b88k",
    ],
    "ViT-B/16-quickgelu": [
        "metaclip_400m",
        "metaclip_fullcc",
    ],
    "ViT-B/16-plus-240": [
        "laion400m_e31",
        "laion400m_e32"
    ],
    # L/14
    "ViT-L/14": [
        "openai",
        "laion400m_e31",
        "laion400m_e32",
        "laion2b_s32b_b82k",
    ],
    "ViT-L/14-quickgelu": [
        "metaclip_400m",
    "metaclip_fullcc"
    ],
    "ViT-L/14-336": [
        "openai",
    ]
}

def available_models() -> List[str]:
    """Returns the names of available GEM-VL models"""
    # _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values])
    _str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values])
    return _str

def get_tokenizer(
        model_name: str = '',
        context_length: Optional[int] = None,
        **kwargs,
):
    """ Wrapper around openclip get_tokenizer function """
    return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs)


def get_gem_img_transform(
        img_size:  Union[int, Tuple[int, int]] = (448, 448),
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
):
    mean = mean or OPENAI_DATASET_MEAN
    std = std or OPENAI_DATASET_STD
    transform = transforms.Compose([
        transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    return transform


def create_gem_model(
        model_name: str,
        pretrained: Optional[str] = None,
        gem_depth: int = 7,
        ss_attn_iter: int = 1,
        ss_attn_temp: Optional[float] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        force_preprocess_cfg: Optional[Dict[str, Any]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
        **model_kwargs,
):
    model_name = model_name.replace("/", "-")
    logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...')
    open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text,
                                  force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
                                  pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
    tokenizer = open_clip.get_tokenizer(model_name=model_name)

    gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth,
                           ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp)
    logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!')
    return gem_model

def create_model_and_transforms(
        model_name: str,
        pretrained: Optional[str] = None,
        gem_depth: int = 7,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        force_preprocess_cfg: Optional[Dict[str, Any]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
        **model_kwargs,
):
    gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text,
                                 force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
                                 pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)

    transform = get_gem_img_transform(**model_kwargs)
    return gem_model, transform

def visualize(image, text, logits, alpha=0.6, save_path=None):
    W, H = logits.shape[-2:]
    if isinstance(image, Image.Image):
        image = image.resize((W, H))
    elif isinstance(image, torch.Tensor):
        if image.ndim > 3:
            image = image.squeeze(0)
        image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \
                        + torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None]  # undo the normalization
        image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8'))  # convert to PIL
    else:
        raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}'

    # plot image
    plt.imshow(image)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    if logits.ndim > 3:
        logits = logits.squeeze(0)
    logits = logits.detach().cpu().numpy()


    img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    logits = (logits * 255).astype('uint8')
    heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits]

    vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
    for viz, cls_name in zip(vizs, text):

        viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
        plt.imshow(viz)
        plt.title(cls_name)
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        if save_path is not None:
            plt.savefig(f'heatmap_{cls_name}.png')