# from langchain.chains import ConversationalRetrievalChain
# from langchain.chains.question_answering import load_qa_chain
# from langchain.chains import RetrievalQA
# from langchain.memory import ConversationBufferMemory
# from langchain.memory import ConversationTokenBufferMemory
# from langchain.llms import HuggingFacePipeline
# # from langchain import PromptTemplate
# from langchain.prompts import PromptTemplate
# from langchain.embeddings import HuggingFaceEmbeddings
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# from langchain.vectorstores import Chroma
# from chromadb.utils import embedding_functions
# from langchain.embeddings import SentenceTransformerEmbeddings
# from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.document_loaders import (
    CSVLoader,
    DirectoryLoader,
    GitLoader,
    NotebookLoader,
    OnlinePDFLoader,
    PythonLoader,
    TextLoader,
    UnstructuredFileLoader,
    UnstructuredHTMLLoader,
    UnstructuredPDFLoader,
    UnstructuredWordDocumentLoader,
    WebBaseLoader,
    PyPDFLoader,
    UnstructuredMarkdownLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredPowerPointLoader,
    UnstructuredODTLoader,
    NotebookLoader,
    UnstructuredFileLoader
)
# from transformers import (
#     AutoModelForCausalLM,
#     AutoTokenizer,
#     StoppingCriteria,
#     StoppingCriteriaList,
#     pipeline,
#     GenerationConfig,
#     TextStreamer,
#     pipeline
# )
# from langchain.llms import HuggingFaceHub
import torch
# from transformers import BitsAndBytesConfig
import os
# from langchain.llms import CTransformers
import streamlit as st
# from langchain.document_loaders.base import BaseLoader
# from langchain.schema import Document
# import gradio as gr
import tempfile
import timeit
import textwrap
# from chromadb.utils import embedding_functions
# from tqdm import tqdm
# tqdm(disable=True, total=0)  # initialise internal lock

# tqdm.write("test")

from langchain import PromptTemplate, LLMChain
from langchain.llms import CTransformers
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceBgeEmbeddings
from io import BytesIO
from langchain.document_loaders import PyPDFLoader




# def load_model():
#     config = {'max_new_tokens': 1024,
#               'repetition_penalty': 1.1,
#               'temperature': 0.1,
#               'top_k': 50,
#               'top_p': 0.9,
#               'stream': True,
#               'threads': int(os.cpu_count() / 2)
#             }
    
#     llm = CTransformers(
#         model = "TheBloke/zephyr-7B-beta-GGUF",
#         model_file = "zephyr-7b-beta.Q4_0.gguf",
#         callbacks=[StreamingStdOutCallbackHandler()],
#         lib="avx2", #for CPU use
#         **config
#         # model_type=model_type,
#         # max_new_tokens=max_new_tokens,  # type: ignore
#         # temperature=temperature,  # type: ignore
#     )
#     return llm

# def create_vector_database(loaded_documents):
#     # DB_DIR: str = os.path.join(ABS_PATH, "db")
#     """
#     Creates a vector database using document loaders and embeddings.
#     This function loads data from PDF, markdown and text files in the 'data/' directory,
#     splits the loaded documents into chunks, transforms them into embeddings using HuggingFace,
#     and finally persists the embeddings into a Chroma vector database.
#     """
#     # Split loaded documents into chunks
#     text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len)
#     chunked_documents = text_splitter.split_documents(loaded_documents)

#     # embeddings = HuggingFaceEmbeddings(
#     #     model_name="sentence-transformers/all-MiniLM-L6-v2"
#     #     # model_name = "sentence-transformers/all-mpnet-base-v2"
#     # )
#     embeddings = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")

#     # embeddings = HuggingFaceBgeEmbeddings(
#     #     model_name = "BAAI/bge-large-en"
#     # )
    
#     # model_name = "BAAI/bge-large-en"
#     # model_kwargs = {'device': 'cpu'}
#     # encode_kwargs = {'normalize_embeddings': False}
#     # embeddings = HuggingFaceBgeEmbeddings(
#     # model_name=model_name,
#     # model_kwargs=model_kwargs,
#     # encode_kwargs=encode_kwargs
#     # )
    
#     persist_directory = 'db'
#     # Create and persist a Chroma vector database from the chunked documents
#     db = Chroma.from_documents(
#         documents=chunked_documents,
#         embedding=embeddings,
#         persist_directory=persist_directory
#         # persist_directory=DB_DIR,
#     )
#     db.persist()
#     # db = Chroma(persist_directory=persist_directory, 
#     #               embedding_function=embedding)
#     return db
    

# def set_custom_prompt():
#     """
#     Prompt template for retrieval for each vectorstore
#     """
#     prompt_template = """Use the following pieces of information to answer the user's question.
#     If you don't know the answer, just say that you don't know, don't try to make up an answer.
#     Context: {context}
#     Question: {question}
#     Only return the helpful answer below and nothing else.
#     Helpful answer:
#     """

#     prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
#     return prompt
    
# def create_chain(llm, prompt, db):
#     """
#     Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database.
#     This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations,
#     and returns this  chain. The retriever is set up to return the top 3 results (k=3).
#     Args:
#         llm (any): The language model to be used in the RetrievalQA.
#         prompt (str): The prompt to be used in the chain type.
#         db (any): The database to be used as the 
#         retriever.
#     Returns:
#         ConversationalRetrievalChain: The initialized conversational chain.
#     """
#     memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', output_key='answer')
#     # chain = ConversationalRetrievalChain.from_llm(
#     #     llm=llm,
#     #     chain_type="stuff",
#     #     retriever=db.as_retriever(search_kwargs={"k": 3}),
#     #     return_source_documents=True,
#     #     max_tokens_limit=256,
#     #     combine_docs_chain_kwargs={"prompt": prompt},
#     #     condense_question_prompt=CONDENSE_QUESTION_PROMPT,
#     #     memory=memory,
#     # )
#     # chain = RetrievalQA.from_chain_type(llm=llm,
#     #                                    chain_type='stuff',
#     #                                    retriever=db.as_retriever(search_kwargs={'k': 3}),
#     #                                    return_source_documents=True,
#     #                                    chain_type_kwargs={'prompt': prompt}
#     #                                    )
#     chain = RetrievalQA.from_chain_type(llm=llm,
#                                        chain_type='stuff',
#                                        retriever=db.as_retriever(search_kwargs={'k': 3}),
#                                        return_source_documents=True
#                                        )
#     return chain

# def create_retrieval_qa_bot(loaded_documents):
#     # if not os.path.exists(persist_dir):
#     #       raise FileNotFoundError(f"No directory found at {persist_dir}")

#     try:
#         llm = load_model()  # Assuming this function exists and works as expected
#     except Exception as e:
#         raise Exception(f"Failed to load model: {str(e)}")

#     try:
#         prompt = set_custom_prompt()  # Assuming this function exists and works as expected
#     except Exception as e:
#         raise Exception(f"Failed to get prompt: {str(e)}")

#     # try:
#     #     CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense()  # Assuming this function exists and works as expected
#     # except Exception as e:
#     #     raise Exception(f"Failed to get condense prompt: {str(e)}")

#     try:
#         db = create_vector_database(loaded_documents)  # Assuming this function exists and works as expected
#     except Exception as e:
#         raise Exception(f"Failed to get database: {str(e)}")

#     try:
#         # qa = create_chain(
#         #     llm=llm, prompt=prompt,CONDENSE_QUESTION_PROMPT=CONDENSE_QUESTION_PROMPT, db=db
#         # )  # Assuming this function exists and works as expected
#         qa = create_chain(
#             llm=llm, prompt=prompt, db=db
#         )  # Assuming this function exists and works as expected
#     except Exception as e:
#         raise Exception(f"Failed to create retrieval QA chain: {str(e)}")

#     return qa

# def wrap_text_preserve_newlines(text, width=110):
#     # Split the input text into lines based on newline characters
#     lines = text.split('\n')

#     # Wrap each line individually
#     wrapped_lines = [textwrap.fill(line, width=width) for line in lines]

#     # Join the wrapped lines back together using newline characters
#     wrapped_text = '\n'.join(wrapped_lines)

#     return wrapped_text

# def retrieve_bot_answer(query, loaded_documents):
#     """
#     Retrieves the answer to a given query using a QA bot.
#     This function creates an instance of a QA bot, passes the query to it,
#     and returns the bot's response.
#     Args:
#         query (str): The question to be answered by the QA bot.
#     Returns:
#         dict: The QA bot's response, typically a dictionary with response details.
#     """
#     qa_bot_instance = create_retrieval_qa_bot(loaded_documents)
#     # bot_response = qa_bot_instance({"question": query})
#     bot_response = qa_bot_instance({"query": query})
#     # Check if the 'answer' key exists in the bot_response dictionary
#     # if 'answer' in bot_response:
#     #     # answer = bot_response['answer']
#     #     return bot_response
#     # else:
#     #     raise KeyError("Expected 'answer' key in bot_response, but it was not found.")
#     # result = bot_response['answer']
    
#     # result = bot_response['result']
#     # sources = []
#     # for source in bot_response["source_documents"]:
#     #     sources.append(source.metadata['source'])
#     # return result, sources

#     result = wrap_text_preserve_newlines(bot_response['result'])
#     for source in bot_response["source_documents"]:
#         sources.append(source.metadata['source'])
#     return result, sources

def main():
    FILE_LOADER_MAPPING = {
    "csv": (CSVLoader, {"encoding": "utf-8"}),
    "doc": (UnstructuredWordDocumentLoader, {}),
    "docx": (UnstructuredWordDocumentLoader, {}),
    "epub": (UnstructuredEPubLoader, {}),
    "html": (UnstructuredHTMLLoader, {}),
    "md": (UnstructuredMarkdownLoader, {}),
    "odt": (UnstructuredODTLoader, {}),
    "pdf": (PyPDFLoader, {}),
    "ppt": (UnstructuredPowerPointLoader, {}),
    "pptx": (UnstructuredPowerPointLoader, {}),
    "txt": (TextLoader, {"encoding": "utf8"}),
    "ipynb": (NotebookLoader, {}),
    "py": (PythonLoader, {}),
    # Add more mappings for other file extensions and loaders as needed
    }
    
    
   
    st.title("Docuverse")

    # Upload files
    uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True)
    loaded_documents = []

    if uploaded_files:
        # Create a temporary directory
        with tempfile.TemporaryDirectory() as td:
            # Move the uploaded files to the temporary directory and process them
            for uploaded_file in uploaded_files:
                st.write(f"Uploaded: {uploaded_file.name}")
                ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
                st.write(f"Uploaded: {ext}")

                # Check if the extension is in FILE_LOADER_MAPPING
                if ext in FILE_LOADER_MAPPING:
                    loader_class, loader_args = FILE_LOADER_MAPPING[ext]
                    # st.write(f"loader_class: {loader_class}")

                    # Save the uploaded file to the temporary directory
                    file_path = os.path.join(td, uploaded_file.name)
                    with open(file_path, 'wb') as temp_file:
                        temp_file.write(uploaded_file.read())

                    # Use Langchain loader to process the file
                    loader = loader_class(file_path, **loader_args)
                    loaded_documents.extend(loader.load())
                else:
                    st.warning(f"Unsupported file extension: {ext}")

        # st.write(f"loaded_documents: {loaded_documents}")  
        st.write("Chat with the Document:")
        query = st.text_input("Ask a question:")

        if st.button("Get Answer"):
            if query:
                # Load model, set prompts, create vector database, and retrieve answer
                try:
                    start = timeit.default_timer()
                    config = {
                    'max_new_tokens': 1024,
                    'repetition_penalty': 1.1,
                    'temperature': 0.1,
                    'top_k': 50,
                    'top_p': 0.9,
                    'stream': True,
                    'threads': int(os.cpu_count() / 2)
                    }
                    
                    llm = CTransformers(
                        model = "TheBloke/zephyr-7B-beta-GGUF",
                        model_file = "zephyr-7b-beta.Q4_0.gguf",
                        model_type="mistral",
                        lib="avx2", #for CPU use
                        **config
                    )
                    st.write("LLM Initialized:")

                    model_name = "BAAI/bge-large-en"
                    model_kwargs = {'device': 'cpu'}
                    encode_kwargs = {'normalize_embeddings': False}
                    embeddings = HuggingFaceBgeEmbeddings(
                        model_name=model_name,
                        model_kwargs=model_kwargs,
                        encode_kwargs=encode_kwargs
                    )
                    
                    # llm = load_model()
                    # prompt = set_custom_prompt()
                    # CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense()
                    # db = create_vector_database(loaded_documents)
                    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len)
                    chunked_documents = text_splitter.split_documents(loaded_documents)
                    persist_directory = 'db'
                    # Create and persist a Chroma vector database from the chunked documents
                    db = Chroma.from_documents(documents=chunked_documents,embedding=embeddings,persist_directory=persist_directory)
                    db.persist()

                    retriever = db.as_retriever(search_kwargs={"k":1})

                    qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, verbose=True)
                    bot_response = qa(query)
                    lines = bot_response['result'].split('\n')
                    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
                    wrapped_text = '\n'.join(wrapped_lines)

                    for source in bot_response["source_documents"]:
                        sources = source.metadata['source']
                    
                    end = timeit.default_timer()
                    st.write("Elapsed time:")
                    st.write(end - start)
                    # st.write(f"response: {response}") 
                    # Display bot response
                    st.write("Bot Response:")
                    st.write(wrapped_text)
                    
                    st.write(sources)
                except Exception as e:
                    st.error(f"An error occurred: {str(e)}")
            else:
                st.warning("Please enter a question.")

if __name__ == "__main__":
    main()