import os
import streamlit as st
from langchain.llms import HuggingFaceHub
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

class UserInterface():

    def __init__(self, ):
        st.warning("Warning: Some models may not work and some models may require GPU to run")
        st.text("An Open Source Chat Application")
        st.header("Open LLMs")

        # self.API_KEY = st.sidebar.text_input(
        #     'API Key',
        #     type='password',
        #     help="Type in your HuggingFace API key to use this app"
        # )

        models_name = (
          
            "HuggingFaceH4/zephyr-7b-beta",
            "Sharathhebbar24/chat_gpt2_dpo",
            "Sharathhebbar24/chat_gpt2",
            "Sharathhebbar24/math_gpt2_sft",
            "Sharathhebbar24/math_gpt2",
            "Sharathhebbar24/convo_bot_gpt_v1",
            "Sharathhebbar24/Instruct_GPT",
            "Sharathhebbar24/Mistral-7B-v0.1-sharded",
            "Sharathhebbar24/llama_chat_small_7b",
            "Deci/DeciCoder-6B",
            "Deci/DeciLM-7B-instruct",
            "Deci/DeciCoder-1b",
            "Deci/DeciLM-7B-instruct-GGUF",
            "Open-Orca/Mistral-7B-OpenOrca",
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            "Sharathhebbar24/llama_7b_chat",
            "CultriX/MistralTrix-v1",
            "ahxt/LiteLlama-460M-1T",
            "gorilla-llm/gorilla-7b-hf-delta-v0",
            "codeparrot/codeparrot"
        )
        self.models = st.sidebar.selectbox(
            label="Choose your models",
            options=models_name,
            help="Choose your model",
        )

        self.temperature = st.sidebar.slider(
            label='Temperature',
            min_value=0.1,
            max_value=1.0,
            step=0.1,
            value=0.5,
            help="Set the temperature to get accurate or random result"
        )

        self.max_token_length = st.sidebar.slider(
            label="Token Length",
            min_value=32,
            max_value=2048,
            step=16,
            value=64,
            help="Set max tokens to generate maximum amount of text output"
        )


        self.model_kwargs = {
            "temperature": self.temperature,
            "max_new_tokens": self.max_token_length
        }

        os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_KEY")

    
    def form_data(self):

        try:
            # if not self.API_KEY.startswith('hf_'):
            #     st.warning('Please enter your API key!', icon='⚠')
            #     text_input_visibility = True
            # else:
            #     text_input_visibility = False
            text_input_visibility = False
            

            if "messages" not in st.session_state:
                    st.session_state.messages = []

            st.write(f"You are using {self.models} model")

            for message in st.session_state.messages:
                with st.chat_message(message.get('role')):
                    st.write(message.get("content"))
            
            context = st.sidebar.text_input(
                 label="Context",
                 help="Context lets you know on what the answer should be generated"
                 )

        
            question = st.chat_input(
                 key="question",
                 disabled=text_input_visibility
            )

            template = f"<|system|>\nYou are a intelligent chatbot and expertise in {context}.</s>\n<|user|>\n{question}.\n<|assistant|>"

            # template = """
            # Answer the question based on the context, if you don't know then output "Out of Context"
            # Context: {context}
            # Question: {question}

            # Answer: 
            # """
            prompt = PromptTemplate(
                template=template,
                input_variables=[
                    'question',
                    'context'
                ]
            )
            llm = HuggingFaceHub(
                repo_id = self.models,
                model_kwargs = self.model_kwargs
            )

            if question:
                llm_chain = LLMChain(
                    prompt=prompt,
                    llm=llm,
                )

                result = llm_chain.run({
                    "question": question,
                    "context": context
                })

                if "Out of Context" in result:
                    result = "Out of Context"
                st.session_state.messages.append(
                    {
                        "role":"user",
                        "content": f"Context: {context}\n\nQuestion: {question}"
                    }
                )
                with st.chat_message("user"):
                    st.write(f"Context: {context}\n\nQuestion: {question}")
                
                if question.lower() == "clear":
                    del st.session_state.messages
                    return
                
                st.session_state.messages.append(
                    {
                        "role": "assistant",
                        "content": result
                    }
                )
                with st.chat_message('assistant'):
                    st.markdown(result)

        except Exception as e:
            st.error(e, icon="🚨")

model = UserInterface()
model.form_data()