import os
import torch
import gradio as gr
import requests
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
from textwrap import wrap, fill

## using Falcon 7b Instruct
Falcon_API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
hf_token = os.getenv("HUGGINGFACE_TOKEN")
HEADERS = {"Authorization": "Bearer {hf_token}"}
def falcon_query(payload):
    response = requests.post(Falcon_API_URL, headers=HEADERS, json=payload)
    return response.json()
def falcon_inference(input_text):
    payload = {"inputs": input_text}
    return falcon_query(payload)

## using Mistral
Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
def mistral_query(payload):
    response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload)
    return response.json()
def mistral_inference(input_text):
    payload = {"inputs": input_text}
    return mistral_query(payload)

# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text

class ChatbotInterface():
    def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
        self.name = name
        self.system_prompt = system_prompt
        self.chatbot = gr.Chatbot()
        self.chat_history = []
        
        with gr.Row() as row:
            row.justify = "end"
            self.msg = gr.Textbox(scale=7)
            #self.msg.change(fn=, inputs=, outputs=)
            self.submit = gr.Button("Submit", scale=1)

        clear = gr.ClearButton([self.msg, self.chatbot])
        chat_history = []
        
        self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
    
    def respond(self, msg, chatbot):
        raise NotImplementedError

class GaiaMinimed(ChatbotInterface):
    def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
        super().__init__(name, system_prompt)
        
    def respond(self, msg, history):
            formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
            input_ids = tokenizer.encode(
                formatted_input, 
                return_tensors="pt", 
                add_special_tokens=False
            )
            response = peft_model.generate(
                input_ids=input_ids, 
                max_length=500, 
                use_cache=False,
                early_stopping=False,
                bos_token_id=peft_model.config.bos_token_id,
                eos_token_id=peft_model.config.eos_token_id,
                pad_token_id=peft_model.config.eos_token_id,
                temperature=0.4,
                do_sample=True
            )
            response_text = tokenizer.decode(response[0], skip_special_tokens=True)
            
            self.chat_history.append([formatted_input, response_text])

            return "", self.chat_history

class FalconBot(ChatbotInterface):
    def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
        super().__init__(name, system_prompt)
        
    def respond(self, msg, chatbot):
        falcon_response = falcon_inference(msg)
        falcon_output = falcon_response[0]["generated_text"]
        self.chat_history.append([msg, falcon_output])
        return "", falcon_output

class MistralBot(ChatbotInterface):
    def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
        super().__init__(name, system_prompt)
    
    def respond(self, msg, chatbot):
        mistral_response = mistral_inference(msg)
        mistral_output = mistral_response[0]["generated_text"]
        self.chat_history.append([msg, mistral_output])
        return "", mistral_output

if __name__ == "__main__":
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Use the base model's ID
    base_model_id = "tiiuae/falcon-7b-instruct"
    model_directory = "Tonic/GaiaMiniMed"

    # Instantiate the Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
    
    # Specify the configuration class for the model
    model_config = AutoConfig.from_pretrained(base_model_id)
    # Load the PEFT model with the specified configuration
    peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
    peft_model = PeftModel.from_pretrained(peft_model, model_directory)
    
    with gr.Blocks() as demo:        
        with gr.Row() as intro:
            gr.Markdown(
                """
                # MedChat: Your Medical Assistant Chatbot
            
                Welcome to MedChat, your friendly medical assistant chatbot! 🩺
            
                Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help!
            
                ## How it Works
                Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses. 
            
                ## Explore and Compare
                Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best.
            
                _Ready to get started? Type your question and let's begin!_
                """
            )
        with gr.Row() as row:
            with gr.Column() as col1:
                with gr.Tab("GaiaMinimed") as gaia:
                    gaia_bot = GaiaMinimed("GaiaMinimed")
            with gr.Column() as col2:
                with gr.Tab("MistralMed") as mistral:
                    mistral_bot = MistralBot("MistralMed") 
                with gr.Tab("Falcon-7B") as falcon7b:
                    falcon_bot = FalconBot("Falcon-7B")
        
        gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
        mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
        falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
                
    demo.launch()