import gradio as gr import os from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.llms import HuggingFaceHub from langchain.memory import ConversationBufferMemory import chromadb from transformers import AutoTokenizer import transformers import torch # Constants and configuration list_llm = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Function placeholders (actual function implementations from the original script) def load_doc(list_file_path, chunk_size, chunk_overlap): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size = chunk_size, chunk_overlap = chunk_overlap) doc_splits = text_splitter.split_documents(pages) return doc_splits def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, ) return vectordb def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma(embedding_function=embedding) return vectordb def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True} ) else: llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, return_generated_question=False, ) return qa_chain def initialize_database(list_file_obj, chunk_size, chunk_overlap): list_file_path = [x.name for x in list_file_obj if x is not None] collection_name = os.path.basename(list_file_path[0]).replace(" ","-")[:50] doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) vector_db = create_db(doc_splits, collection_name) return vector_db, collection_name, "Complete!" def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db): llm_name = list_llm[llm_option] qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db) return qa_chain, "Complete!" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history): formatted_chat_history = format_chat_history(message, history) response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 new_history = history + [(message, response_answer)] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page def upload_file(file_obj): list_file_path = [file.name for file in file_obj] return list_file_path def gradio_ui(): with gr.Blocks(theme="base") as demo: # States vector_db, qa_chain, collection_name = gr.State(), gr.State(), gr.State() db_progress, llm_progress = gr.Textbox(), gr.Textbox() chatbot, doc_source1, source1_page, doc_source2, source2_page = gr.Chatbot(), gr.Textbox(), gr.Number(), gr.Textbox(), gr.Number() msg = gr.Textbox(placeholder="Type message") with gr.Tabs(): # Tab 1: Document Pre-processing with gr.Tab("Step 1 - Document Pre-processing"): with gr.Row(): document = gr.File(label="Upload your PDF document", file_types=["pdf"]) with gr.Row(): chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=50, label="Chunk size", interactive=True) chunk_overlap = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Chunk overlap", interactive=True) with gr.Row(): db_init_btn = gr.Button("Initialize Vector Database") # Tab 2: QA Chain Initialization with gr.Tab("Step 2 - QA Chain Initialization"): with gr.Row(): llm_selection = gr.Radio(list_llm_simple, label="Choose LLM Model", value=list_llm_simple[0]) with gr.Row(): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", interactive=True) max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=64, label="Max Tokens", interactive=True) top_k = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K", interactive=True) with gr.Row(): qa_init_btn = gr.Button("Initialize QA Chain") # Tab 3: Conversation with Chatbot with gr.Tab("Step 3 - Conversation with Chatbot"): chat_history = gr.State() with gr.Row(): chatbot with gr.Row(): msg submit_btn = gr.Button("Submit") # Handlers db_init_btn.click(initialize_database, inputs=[document, chunk_size, chunk_overlap], outputs=[vector_db, collection_name, db_progress]) qa_init_btn.click(initialize_LLM, inputs=[llm_selection, temperature, max_tokens, top_k, vector_db], outputs=[qa_chain, llm_progress]) submit_btn.click(conversation, inputs=[qa_chain, msg, chat_history], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page]) return demo if __name__ == "__main__": gradio_ui().launch()