|
import gradio as gr |
|
import os |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.llms import HuggingFaceHub |
|
from langchain.memory import ConversationBufferMemory |
|
import chromadb |
|
from transformers import AutoTokenizer |
|
import transformers |
|
import torch |
|
|
|
|
|
list_llm = [ |
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", |
|
"mistralai/Mistral-7B-Instruct-v0.1", "HuggingFaceH4/zephyr-7b-beta", |
|
"meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", |
|
"tiiuae/falcon-7b-instruct", "google/flan-t5-xxl" |
|
] |
|
list_llm_simple = [os.path.basename(llm) for llm in list_llm] |
|
|
|
|
|
def load_doc(list_file_path, chunk_size, chunk_overlap): |
|
loaders = [PyPDFLoader(x) for x in list_file_path] |
|
pages = [] |
|
for loader in loaders: |
|
pages.extend(loader.load()) |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size = chunk_size, |
|
chunk_overlap = chunk_overlap) |
|
doc_splits = text_splitter.split_documents(pages) |
|
return doc_splits |
|
|
|
def create_db(splits, collection_name): |
|
embedding = HuggingFaceEmbeddings() |
|
new_client = chromadb.EphemeralClient() |
|
vectordb = Chroma.from_documents( |
|
documents=splits, |
|
embedding=embedding, |
|
client=new_client, |
|
collection_name=collection_name, |
|
) |
|
return vectordb |
|
|
|
def load_db(): |
|
embedding = HuggingFaceEmbeddings() |
|
vectordb = Chroma(embedding_function=embedding) |
|
return vectordb |
|
|
|
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): |
|
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": |
|
llm = HuggingFaceHub( |
|
repo_id=llm_model, |
|
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True} |
|
) |
|
else: |
|
llm = HuggingFaceHub( |
|
repo_id=llm_model, |
|
model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} |
|
) |
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key='answer', |
|
return_messages=True |
|
) |
|
retriever = vector_db.as_retriever() |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever=retriever, |
|
chain_type="stuff", |
|
memory=memory, |
|
return_source_documents=True, |
|
return_generated_question=False, |
|
) |
|
return qa_chain |
|
|
|
def initialize_database(list_file_obj, chunk_size, chunk_overlap): |
|
list_file_path = [x.name for x in list_file_obj if x is not None] |
|
collection_name = os.path.basename(list_file_path[0]).replace(" ","-")[:50] |
|
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) |
|
vector_db = create_db(doc_splits, collection_name) |
|
return vector_db, collection_name, "Complete!" |
|
|
|
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db): |
|
llm_name = list_llm[llm_option] |
|
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db) |
|
return qa_chain, "Complete!" |
|
|
|
def format_chat_history(message, chat_history): |
|
formatted_chat_history = [] |
|
for user_message, bot_message in chat_history: |
|
formatted_chat_history.append(f"User: {user_message}") |
|
formatted_chat_history.append(f"Assistant: {bot_message}") |
|
return formatted_chat_history |
|
|
|
def conversation(qa_chain, message, history): |
|
formatted_chat_history = format_chat_history(message, history) |
|
response = qa_chain({"question": message, "chat_history": formatted_chat_history}) |
|
response_answer = response["answer"] |
|
response_sources = response["source_documents"] |
|
response_source1 = response_sources[0].page_content.strip() |
|
response_source2 = response_sources[1].page_content.strip() |
|
response_source1_page = response_sources[0].metadata["page"] + 1 |
|
response_source2_page = response_sources[1].metadata["page"] + 1 |
|
|
|
new_history = history + [(message, response_answer)] |
|
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page |
|
|
|
def upload_file(file_obj): |
|
list_file_path = [file.name for file in file_obj] |
|
return list_file_path |
|
|
|
|
|
def gradio_ui(): |
|
with gr.Blocks(theme="base") as demo: |
|
|
|
vector_db, qa_chain, collection_name = gr.State(), gr.State(), gr.State() |
|
db_progress, llm_progress = gr.Textbox(), gr.Textbox() |
|
chatbot, doc_source1, source1_page, doc_source2, source2_page = gr.Chatbot(), gr.Textbox(), gr.Number(), gr.Textbox(), gr.Number() |
|
msg = gr.Textbox(placeholder="Type message") |
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("Step 1 - Document Pre-processing"): |
|
with gr.Row(): |
|
document = gr.File(label="Upload your PDF document", file_types=["pdf"]) |
|
with gr.Row(): |
|
chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=50, label="Chunk size", interactive=True) |
|
chunk_overlap = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Chunk overlap", interactive=True) |
|
with gr.Row(): |
|
db_init_btn = gr.Button("Initialize Vector Database") |
|
|
|
|
|
with gr.Tab("Step 2 - QA Chain Initialization"): |
|
with gr.Row(): |
|
llm_selection = gr.Radio(list_llm_simple, label="Choose LLM Model", value=list_llm_simple[0]) |
|
with gr.Row(): |
|
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature", interactive=True) |
|
max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=64, label="Max Tokens", interactive=True) |
|
top_k = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K", interactive=True) |
|
with gr.Row(): |
|
qa_init_btn = gr.Button("Initialize QA Chain") |
|
|
|
|
|
with gr.Tab("Step 3 - Conversation with Chatbot"): |
|
chat_history = gr.State() |
|
with gr.Row(): |
|
chatbot |
|
with gr.Row(): |
|
msg |
|
submit_btn = gr.Button("Submit") |
|
|
|
|
|
db_init_btn.click(initialize_database, inputs=[document, chunk_size, chunk_overlap], outputs=[vector_db, collection_name, db_progress]) |
|
qa_init_btn.click(initialize_LLM, inputs=[llm_selection, temperature, max_tokens, top_k, vector_db], outputs=[qa_chain, llm_progress]) |
|
submit_btn.click(conversation, inputs=[qa_chain, msg, chat_history], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page]) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
gradio_ui().launch() |