import random
import time
import torch
import gradio as gr
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
from textwrap import wrap, fill

# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text

def multimodal_prompt(user_input, system_prompt):
    """
    Generates text using a large language model, given a user input and a system prompt.
    Args:
        user_input: The user's input text to generate a response for.
        system_prompt: Optional system prompt.
    Returns:
        A string containing the generated text in the Falcon-like format.
    """
    # Combine user input and system prompt
    formatted_input = f"{{{{ {system_prompt} }}}}\nUser: {user_input}\nFalcon:"

    # Encode the input text
    encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)

    # Generate a response using the model
    output = peft_model.generate(
        **model_inputs,
        max_length=500,
        use_cache=True,
        early_stopping=False,
        bos_token_id=peft_model.config.bos_token_id,
        eos_token_id=peft_model.config.eos_token_id,
        pad_token_id=peft_model.config.eos_token_id,
        temperature=0.4,
        do_sample=True
    )

    # Decode the response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text

class ChatbotInterface():
    def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
        self.name = name
        self.system_prompt = system_prompt
        self.chatbot = gr.Chatbot()
        self.chat_history = []
        
        with gr.Row() as row:
            row.justify = "end"
            self.msg = gr.Textbox(scale=7)
            #self.msg.change(fn=, inputs=, outputs=)
            self.submit = gr.Button("Submit", scale=1)

        clear = gr.ClearButton([self.msg, self.chatbot])
        chat_history = []
        
        self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
    
    def respond(self, msg, history):
            #bot_message = random.choice(["Hello, I'm MedChat! How can I help you?", "Hello there! I'm Medchat, a medical assistant! How can I help you?"])
            formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
            input_ids = tokenizer.encode(
                formatted_input, 
                return_tensors="pt", 
                add_special_tokens=False
            )
            response = peft_model.generate(
                input_ids=input_ids, 
                max_length=900, 
                use_cache=False,
                early_stopping=False,
                bos_token_id=peft_model.config.bos_token_id,
                eos_token_id=peft_model.config.eos_token_id,
                pad_token_id=peft_model.config.eos_token_id,
                temperature=0.4,
                do_sample=True
            )
            response_text = tokenizer.decode(response[0], skip_special_tokens=True)
            
            self.chat_history.append([formatted_input, response_text])

            return "", self.chat_history

if __name__ == "__main__":
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Use the base model's ID
    base_model_id = "tiiuae/falcon-7b-instruct"
    model_directory = "Tonic/GaiaMiniMed"

    # Instantiate the Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
    
    # Specify the configuration class for the model
    model_config = AutoConfig.from_pretrained(base_model_id)
    # Load the PEFT model with the specified configuration
    peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
    peft_model = PeftModel.from_pretrained(peft_model, model_directory)
    
    with gr.Blocks() as demo:        
        with gr.Row() as intro:
            gr.Markdown(
                """
                # MedChat: Your Medical Assistant Chatbot
            
                Welcome to MedChat, your friendly medical assistant chatbot! 🩺
            
                Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help!
            
                ## How it Works
                Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses. 
            
                ## Explore and Compare
                Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best.
            
                _Ready to get started? Type your question and let's begin!_
                """
            )
        with gr.Row() as row:
            with gr.Column() as col1:
                with gr.Tab("GaiaMinimed") as gaia:
                    gaia_bot = ChatbotInterface("GaiaMinimed")
            with gr.Column() as col2:
                with gr.Tab("MistralMed") as mistral:
                    mistral_bot = ChatbotInterface("MistralMed") 
                with gr.Tab("Falcon-7B") as falcon7b:
                    falcon_bot = ChatbotInterface("Falcon-7B")
        
        gaia_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
        mistral_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
        falcon_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
                
    demo.launch()