Spaces:
Running
Running
# generic libraries | |
import gradio as gr | |
import os | |
import re | |
# for embeddings and indexing | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
# for data retrieval | |
from langchain.chains import RetrievalQA | |
# for huggingface llms | |
from langchain_community.llms import HuggingFaceHub | |
# define constants | |
# Embedding models | |
#EMB_MODEL_bge_base = 'BAAI/bge-base-en-v1.5' | |
#EMB_MODEL_bge_large = 'BAAI/bge-large-en-v1.5' | |
#EMB_MODEL_gtr_t5_base = 'sentence-transformers/gtr-t5-base' | |
EMB_MODEL_gtr_t5_large = 'sentence-transformers/gtr-t5-large' | |
#EMB_MODEL_e5_base = 'intfloat/e5-large-v2' | |
# Chat app model | |
MISTRAL_MODEL1 = 'mistralai/Mixtral-8x7B-Instruct-v0.1' | |
HF_MODEL1 = 'HuggingFaceH4/zephyr-7b-beta' | |
# define paths | |
#vector_path_bge_base = 'vectorDB/faiss_index_bge_base' | |
#vector_path_bge_large = 'vectorDB/faiss_index_bge_large' | |
#vector_path_gtr_t5_base = 'vectorDB/faiss_index_gtr_t5_base' | |
vector_path_gtr_t5_large = 'vectorDB/faiss_index_gtr_t5_large' | |
#vector_path_e5_base = 'vectorDB/faiss_index_e5_base' | |
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
def respond(message, history): | |
# Initialize your embedding model | |
#embedding_model_bge = HuggingFaceEmbeddings(model_name=EMB_MODEL_bge_large) | |
embedding_model_gtr_t5 = HuggingFaceEmbeddings(model_name=EMB_MODEL_gtr_t5_large) | |
#embedding_model_e5 = HuggingFaceEmbeddings(model_name=EMB_MODEL_e5_base) | |
# Load FAISS from relative path | |
#vectordb_bge = FAISS.load_local(vector_path_bge_large, embedding_model_bge, allow_dangerous_deserialization=True) | |
vectordb_gtr_t5 = FAISS.load_local(vector_path_gtr_t5_large, embedding_model_gtr_t5, allow_dangerous_deserialization=True) | |
#vectordb_e5 = FAISS.load_local(vector_path_e5_base, embedding_model_e5, allow_dangerous_deserialization=True) | |
# define retriever object | |
#retriever_bge = vectordb_bge.as_retriever(search_type="similarity", search_kwargs={"k": 5}) | |
retriever_gtr_t5 = vectordb_gtr_t5.as_retriever(search_type="similarity", search_kwargs={"k": 5}) | |
#retriever_e5 = vectordb_e5.as_retriever(search_type="similarity", search_kwargs={"k": 5}) | |
# initialse chatbot llm | |
llm = HuggingFaceHub( | |
repo_id=MISTRAL_MODEL1, | |
huggingfacehub_api_token=hf_token, | |
model_kwargs={"temperature": 0.7, "max_new_tokens": 512} | |
) | |
# create a RAG pipeline | |
#qa_chain_bge = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_bge) | |
qa_chain_gtr_t5 = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_gtr_t5) | |
#qa_chain_e5 = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_e5) | |
#generate results | |
#responce_bge = qa_chain_bge.invoke(message)['result'] | |
responce_gtr_t5 = qa_chain_gtr_t5.invoke(message)['result'] | |
#responce_e5 = qa_chain_e5.invoke(message)['result'] | |
# remove the top instructions | |
instruction_prefix = ( | |
"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." | |
) | |
#if responce_bge.strip().startswith(instruction_prefix): | |
# responce_bge = responce_bge.strip()[len(instruction_prefix):].strip() | |
if responce_gtr_t5.strip().startswith(instruction_prefix): | |
responce_gtr_t5 = responce_gtr_t5.strip()[len(instruction_prefix):].strip() | |
#if responce_e5.strip().startswith(instruction_prefix): | |
# responce_e5 = responce_e5.strip()[len(instruction_prefix):].strip() | |
# Split question, Helpful Answer and Reason | |
#match_bge = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_bge, re.DOTALL) | |
match_gtr_t5 = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_gtr_t5, re.DOTALL) | |
#match_e5 = re.search(r"^(.*?)(?:\n+)?Question:\s*(.*?)(?:\n+)?Helpful Answer:\s*(.*)", responce_e5, re.DOTALL) | |
#if match_bge: | |
# #original_text_bge = match_bge.group(1).strip() | |
# question_bge = match_bge.group(2).strip() | |
# answer_bge = match_bge.group(3).strip() | |
if match_gtr_t5: | |
original_text_gtr_t5 = match_gtr_t5.group(1).strip() | |
#question_gtr_t5 = match_gtr_t5.group(2).strip() | |
answer_gtr_t5 = match_gtr_t5.group(3).strip() | |
#if match_e5: | |
# #original_text_e5 = match_e5.group(1).strip() | |
# #question_e5 = match_e5.group(2).strip() | |
# answer_e5 = match_e5.group(3).strip() | |
# | |
#formatted_responce = f'Question:{question_bge}\nHelpful Answer Type 1:\n{answer_bge}\nHelpful Answer Type 2:\n{answer_gtr_t5}\nHelpful Answer Type 3:\n{answer_e5}' | |
#formatted_responce = f'\n************* BAAI/bge-large-en-v1.5 ****************\n{responce_bge}\n************** sentence-transformers/gtr-t5-large ***************\n{responce_gtr_t5}\n************ intfloat/e5-large-v2 **************\n{responce_e5}' | |
#formatted_responce = f'\n************* BAAI/bge-large-en-v1.5 ****************\n{responce_bge}\n************** sentence-transformers/gtr-t5-large ***************\n{responce_gtr_t5}' | |
formatted_responce = f'\n************* sentence-transformers/gtr-t5-large ****************\n Helpful Answer:{answer_gtr_t5}\n Reasoning:\n{original_text_gtr_t5}' | |
yield formatted_responce | |
# Read the content of the README.md file | |
with open("about.md", "r") as file: | |
about_lines = file.read() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Intelligent Financial Document Q&A App") | |
# About the App | |
with gr.Tab("About the App"): | |
gr.Markdown(about_lines) | |
# Document Chatbot | |
with gr.Tab("Market Prediction"): | |
with gr.Column(variant="panel", scale=2): | |
gr.ChatInterface( | |
respond, | |
fill_width=True, | |
fill_height=True, | |
type="messages", | |
autofocus=False #, | |
#additional_inputs=[ | |
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
# gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="Max new tokens"), | |
# gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
# gr.Slider( | |
# minimum=0.1, | |
# maximum=1.0, | |
# value=0.95, | |
# step=0.05, | |
# label="Top-p (nucleus sampling)", | |
# ), | |
#], | |
) | |
if __name__ == "__main__": | |
demo.launch() |