import gradio as gr
import os
import time
from PIL import Image
import torch
import whisperx


from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
from models.vision_projector_model import VisionProjector
from config import VisionProjectorConfig, app_config as cfg

device = 'cuda' if torch.cuda.is_available() else 'cpu'

clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

vision_projector = VisionProjector(VisionProjectorConfig())
ckpt = torch.load(cfg['vision_projector_file'], map_location=torch.device(device))
vision_projector.load_state_dict(ckpt['model_state_dict'])

phi_base_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/phi-2',
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float32,
    trust_remote_code=True
    # device_map=device_map,
)

from peft import PeftModel
phi_new_model = "models/phi_adapter"
phi_model = PeftModel.from_pretrained(phi_base_model, phi_new_model)
phi_model = phi_model.merge_and_unload().to(device)

'''compute_type = 'float32'
if device != 'cpu':
    compute_type = 'float16'''

audi_model = whisperx.load_model("small", device, compute_type='float16')

tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
tokenizer.pad_token = tokenizer.unk_token


### app functions ##

context_added = False
query_added = False
context = None
context_type = ''
query = ''
bot_active = False

def print_like_dislike(x: gr.LikeData):
    print(x.index, x.value, x.liked)


def add_text(history, text):
    global context, context_type, context_added, query, query_added
    context_added = False
    if not context_type and '</context>' not in text:
        context = "**Please add context (upload image/audio or enter text followed by \</context\>"
        context_type = 'error'
        context_added = True
        query_added = False
        
    elif '</context>' in text:
            context_type = 'text'
            context_added = True
            text = text.replace('</context>', ' ')
            context = text
            query_added = False
    elif context_type in ['[text]', '[image]', '[audio]']:
        query = 'Human### ' + text + '\n' + 'AI### '
        query_added = True
        context_added = False
    else:
        query_added = False
        context_added = True
        context = 'error'
        context = "**Please provide a valid context**"

    history = history + [(text, None)]

    return history, gr.Textbox(value="", interactive=False)


def add_file(history, file):
    global context_added, context, context_type, query_added
    
    context = file
    context_type = 'image'
    context_added = True
    query_added = False

    history = history + [((file.name,), None)]

    return history


def audio_upload(history, audio_file):
    global context, context_type, context_added, query, query_added

    if audio_file:
        context_added = True
        context_type = 'audio'
        context = audio_file
        query_added = False
        history = history + [((audio_file,), None)]
      
    else:
        pass

    return history


def preprocess_fn(history):
    global context, context_added, query, context_type, query_added
  
    if context_added:
        if context_type == 'image':
            image = Image.open(context)
            inputs = clip_processor(images=image, return_tensors="pt")

            x = clip_model(**inputs, output_hidden_states=True)
            image_features = x.hidden_states[-2]

            context = vision_projector(image_features)
            
        elif context_type == 'audio':
            audio_file = context
            audio = whisperx.load_audio(audio_file)
            result = audi_model.transcribe(audio, batch_size=1)

            error = False
            if result.get('language', None) and result.get('segments', None):
                try:
                    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
                    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
                except Exception as e:
                    error = True

                print(result.get('language', None))
                if not error and result.get('segments', []) and len(result["segments"]) > 0 and result["segments"][0].get('text', None):
                    text = result["segments"][0].get('text', '')
                    print(text)
                    context_type = 'audio'
                    context_added = True
                    context = text
                    query_added = False
                    print(context)
                else:
                    error = True
            else:
                error = True
            
            if error:
                context_type = 'error'
                context_added = True
                context = "**Please provide a valid audio file / context**"
                query_added = False
    
    print("Here")
    return history

def bot(history):
    global context, context_added, query, context_type, query_added, bot_active

    response = ''
    if context_added:
        context_added = False
        if context_type == 'error':
            response = context
            query = ''
            
        elif context_type in ['image', 'audio', 'text']:
            response = ''
            if context_type == 'audio':
                response = 'Context: \nšŸ—£ ' + '"_' + context.strip() + '_"\n\n'
            
            response += "**Please proceed with your queries**"
            query = ''
            context_type = '[' + context_type + ']'
    elif query_added:
        query_added = False
        if context_type == '[image]':
            query_ids = tokenizer.encode(query)
            query_ids = torch.tensor(query_ids, dtype=torch.int32).unsqueeze(0).to(device)
            query_embeds = phi_model.get_input_embeddings()(query_ids)
            inputs_embeds = torch.cat([context.to(device), query_embeds], dim=1)
            out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
                                     bos_token_id=tokenizer.bos_token_id)
            response = tokenizer.decode(out[0], skip_special_tokens=True)
        elif context_type in ['[text]', '[audio]']:
            input_text = context + query

            input_tokens = tokenizer.encode(input_text)
            input_ids = torch.tensor(input_tokens, dtype=torch.int32).unsqueeze(0).to(device)
            inputs_embeds = phi_model.get_input_embeddings()(input_ids)
            out = phi_model.generate(inputs_embeds=inputs_embeds, min_new_tokens=10, max_new_tokens=50,
                                     bos_token_id=tokenizer.bos_token_id)
            response = tokenizer.decode(out[0], skip_special_tokens=True)
        else:
            query = ''
            response = "**Please provide a valid context**"
      
    if response:
        bot_active = True
        if history and len(history[-1]) > 1:
            history[-1][1] = ""
            for character in response:
                history[-1][1] += character
                time.sleep(0.05)
                yield history

            time.sleep(0.5)
            bot_active = False
    


def clear_fn():
    global context_added, context_type, context, query, query_added
    context_added = False
    context_type = ''
    context = None
    query = ''
    query_added = False

    return {
        chatbot: None
    }


with gr.Blocks() as app:
    gr.Markdown(
        """
        # ContextGPT - A Multimodal chatbot
        ### Upload image or audio to add a context. And then ask questions.
        ### You can also enter text followed by \</context\> to set the context.
        """
    )

    chatbot = gr.Chatbot(
        [],
        elem_id="chatbot",
        bubble_full_width=False
    )

    with gr.Row():
        txt = gr.Textbox(
            scale=4,
            show_label=False,
            placeholder="Press enter to send ",
            container=False,
        )

    with gr.Row():
        aud = gr.Audio(sources=['microphone', 'upload'], type='filepath', max_length=100, show_download_button=True,
                       show_share_button=True)
        btn = gr.UploadButton("šŸ“·", file_types=["image"])

    with gr.Row():
        clear = gr.Button("Clear")

    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
        preprocess_fn, chatbot, chatbot
    ).then(
        bot, chatbot, chatbot, api_name="bot_response"
    )

    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
    
    file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
        preprocess_fn, chatbot, chatbot
    ).then(
        bot, chatbot, chatbot, api_name="bot_response"
    )

    chatbot.like(print_like_dislike, None, None)
    clear.click(clear_fn, None, chatbot, queue=False)

    aud.stop_recording(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
        preprocess_fn, chatbot, chatbot
    ).then(
        bot, chatbot, chatbot, api_name="bot_response"
    )

    aud.upload(audio_upload, [chatbot, aud], [chatbot], queue=False).then(
        preprocess_fn, chatbot, chatbot
    ).then(
        bot, chatbot, chatbot, api_name="bot_response"
    )

app.queue()
app.launch()