import numpy as np # this should come first to mitigate mlk-service bug
from src.models.utils import get_image_arr, load_model
from src.data import TAIMGANTokenizer
from torchvision import transforms
from src.config import config_dict
from pathlib import Path
from enum import IntEnum, auto
from PIL import Image
import gradio as gr
import torch
from src.models.modules import (
    VGGEncoder,
    InceptionEncoder,
    TextEncoder,
    Generator
)

##########
# PARAMS #
##########

IMG_CHANS = 3  # RGB channels for image
IMG_HW = 256  # height and width of images
HIDDEN_DIM = 128  # hidden dimensions of lstm cell in one direction
C = 2 * HIDDEN_DIM  # length of embeddings

Ng = config_dict["Ng"]
cond_dim = config_dict["condition_dim"]
z_dim = config_dict["noise_dim"]


###############
# LOAD MODELS #
###############

models = {
    "COCO": {
        "dir": "weights/coco"
    },
    "Bird": {
        "dir": "weights/bird"
    },
    "UTKFace": {
        "dir": "weights/utkface"
    }
}

for model_name in models:
    # create tokenizer
    models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
    vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
    # instantiate models
    models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
    models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
    models[model_name]["vgg"] = VGGEncoder().eval()
    models[model_name]["inception"] = InceptionEncoder(D=C).eval()
    # load models
    load_model(
        generator=models[model_name]["generator"],
        discriminator=None,
        image_encoder=models[model_name]["inception"],
        text_encoder=models[model_name]["lstm"],
        output_dir=Path(models[model_name]["dir"]),
        device=torch.device("cpu")
    )


def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
    """
    Create an image modified by text from the original image
    and save it with _modified postfix

    :param gr.Image image: Path to the image
    :param str text: Desired caption
    """
    global models
    tokenizer = models[model_name]["tokenizer"]
    G = models[model_name]["generator"]
    lstm = models[model_name]["lstm"]
    inception = models[model_name]["inception"]
    vgg = models[model_name]["vgg"]
    # generate some noise
    noise = torch.rand(z_dim).unsqueeze(0)
    # transform input text and get masks with embeddings
    tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
    mask = (tokens == tokenizer.pad_token_id)
    word_embs, sent_embs = lstm(tokens)
    # open the image and transform it to the tensor
    image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((IMG_HW, IMG_HW)),
        transforms.Normalize(
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5)
        )
    ])(image).unsqueeze(0)
    # obtain visual features of the image
    vgg_features = vgg(image)
    local_features, global_features = inception(image)
    # generate new image from the old one
    fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
                         local_features, vgg_features, mask)
    # denormalize the image
    fake_image = Image.fromarray(get_image_arr(fake_image)[0])
    # return image in gradio format
    return fake_image


##########
# GRADIO #
##########
demo = gr.Interface(
    fn=change_image_with_text,
    inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
    outputs=gr.Image(type="pil"),
    examples=[
        ["src/data/stubs/bird.jpg", "black bird with blue wings", "Bird"],
        ["src/data/stubs/lady.jpg", "lady with blue eyes", "UTKFace"],
        ["src/data/stubs/bird.jpg", "white bird with black wings", "Bird"]
    ]
)
demo.launch(debug=True)