import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from PIL import Image
import clip
import numpy as np
import cv2
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):

    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

class ImageEncoder(nn.Module):

    def __init__(self):
        super(ImageEncoder, self).__init__()

        self.encoder, _ = clip.load("ViT-B/16", device=device)   # loads already in eval mode

    def forward(self, x):
        """
        Expects a tensor of size (batch_size, 3, 224, 224)
        """
        with torch.no_grad():
            x = x.type(self.encoder.visual.conv1.weight.dtype)
            x = self.encoder.visual.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            x = x + self.encoder.visual.positional_embedding.to(x.dtype)
            x = self.encoder.visual.ln_pre(x)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.encoder.visual.transformer(x)
            grid_feats = x.permute(1, 0, 2)  # LND -> NLD    (N, 197, 768)
            grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])  

        return grid_feats.float()

def change_requires_grad(model, req_grad):
    for p in model.parameters():
        p.requires_grad = req_grad

def load_checkpoint(ckpt_path, epoch):
    
    model_name = 'nle_model_{}'.format(str(epoch))
    tokenizer_name = 'nle_gpt2_tokenizer_0'
    tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name)        # load tokenizer
    model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device)   # load model with config
    return tokenizer, model

def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
    
    SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    because_token = tokenizer.convert_tokens_to_ids('Ġbecause')
    max_len = 20
    current_output = []
    img_embeddings = image_encoder(img)
    always_exp = False
        
    with torch.no_grad():
        
        for step in range(max_len + 1):
            
            if step == max_len:
                break
            
            outputs = model(input_ids=input_ids, 
                            past_key_values=None, 
                            attention_mask=None, 
                            token_type_ids=segment_ids, 
                            position_ids=None, 
                            encoder_hidden_states=img_embeddings, 
                            encoder_attention_mask=None, 
                            labels=None, 
                            use_cache=False, 
                            output_attentions=True,
                            return_dict=True)
            
            lm_logits = outputs.logits 
            xa_maps = outputs.cross_attentions
            logits = lm_logits[0, -1, :] / temperature
            logits = top_filtering(logits, top_k=top_k, top_p=top_p)
            probs = F.softmax(logits, dim=-1)
            prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
            
            if prev.item() in special_tokens_ids:
                break
            
            # take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
            if not always_exp:
                
                if prev.item() != because_token:
                    new_segment = special_tokens_ids[-2]   # answer segment
                else:
                    new_segment = special_tokens_ids[-1]   # explanation segment
                    always_exp = True
            else:
                new_segment = special_tokens_ids[-1]   # explanation segment
                
            new_segment = torch.LongTensor([new_segment]).to(device)
            current_output.append(prev.item())
            input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
            segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
                
        decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()

    return decoded_sequences, xa_maps

def get_inputs(tokenizer):
    a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<answer>', '<explanation>'])
    tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is")
    segment_ids = [a_segment_id] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    segment_ids = torch.tensor(segment_ids, dtype=torch.long)

    return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
    

img_size = 224
ckpt_path = 'ACTX_p/'
max_seq_len = 30
load_from_epoch = 5
no_sample = True  
top_k =  0
top_p =  0.9
temperature = 1

image_encoder = ImageEncoder().to(device)
change_requires_grad(image_encoder, False)
tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
model.eval()

img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def inference(raw_image):

  oimg = raw_image.convert('RGB').resize((224,224))
  img = img_transform(oimg).unsqueeze(0).to(device)
  input_ids, segment_ids = get_inputs(tokenizer)
  seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
  last_am = xa_maps[-1].mean(1)[0]
  mask = last_am[0, :].reshape(14,14).cpu().numpy()
  mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
  attention_map = (mask * oimg).astype("uint8")
  splitted_seq = seq.split("because")
  return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)

inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")]
outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]

title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)"
gr.Interface(inference, inputs, outputs, title=title).launch()
#