dryouviavant commited on
Commit
410bc42
·
verified ·
1 Parent(s): a3de672

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -32
app.py CHANGED
@@ -1,25 +1,24 @@
1
  import gradio as gr
2
  import os
3
- from langchain_community.vectorstores import FAISS
4
- from langchain_community.document_loaders import PyPDFLoader
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain_community.llms import HuggingFaceEndpoint
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
  from dotenv import load_dotenv
11
  import torch
12
 
13
- # Load environment variables
14
  load_dotenv()
 
15
  api_token = os.getenv("HF_TOKEN")
16
 
17
- # List of available LLMs
18
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
19
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
20
 
21
  # Load and split PDF document
22
- def load_doc(list_file_path, chunk_size=1024, chunk_overlap=64):
23
  loaders = [PyPDFLoader(x) for x in list_file_path]
24
  pages = []
25
  for loader in loaders:
@@ -126,7 +125,7 @@ def conversation(qa_chain, message, history, persona_text):
126
  def upload_file(file_obj):
127
  list_file_path = []
128
  for idx, file in enumerate(file_obj):
129
- file_path = file.name
130
  list_file_path.append(file_path)
131
  return list_file_path
132
 
@@ -136,52 +135,54 @@ def demo():
136
  with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
137
  vector_db = gr.State()
138
  qa_chain = gr.State()
139
- gr.HTML("<center><h1>RAG PDF Chatbot</h1><center>")
140
- gr.Markdown("""<b>Interact with Your PDF Documents!</b> This AI agent performs retrieval-augmented generation (RAG) on PDF documents. Hosted on Hugging Face Hub for demonstration purposes. \
141
- <b>Do not upload confidential documents.</b>""")
142
 
143
  # Interface for static pre-selected documents
144
  gr.Markdown("<b>Pre-Selected Documents</b>")
145
- gr.Textbox(value="Document 1: Introduction to AI.pdf", show_label=False, interactive=False)
146
- gr.Textbox(value="Document 2: Advanced Machine Learning.pdf", show_label=False, interactive=False)
147
 
148
- gr.Markdown("<b>Upload Your PDF Documents</b>")
149
- document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
150
- db_btn = gr.Button("Create vector database")
151
- db_progress = gr.Textbox(value="Not initialized", show_label=False)
152
-
153
- gr.Markdown("<b>Select Large Language Model (LLM) and Configure Parameters</b>")
154
  llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
155
  slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
156
  slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
157
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-K", info="Number of tokens to select the next token from", interactive=True)
158
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
159
  llm_progress = gr.Textbox(value="Not initialized", show_label=False)
160
 
161
- gr.Markdown("<b>Chat with Your Document</b>")
162
  chatbot = gr.Chatbot(height=505)
163
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, interactive=False)
164
- source1_page = gr.Number(label="Page", interactive=False)
165
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, interactive=False)
166
- source2_page = gr.Number(label="Page", interactive=False)
167
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, interactive=False)
168
- source3_page = gr.Number(label="Page", interactive=False)
 
 
 
 
169
  msg = gr.Textbox(placeholder="Ask a question", container=True)
170
  submit_btn = gr.Button("Submit")
171
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
172
-
173
- # Bind the events
 
 
174
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
 
175
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0, "", 0],
176
  inputs=None,
177
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
178
  queue=False)
179
 
 
180
  msg.submit(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
181
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
182
- clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
183
-
184
  demo.queue().launch(debug=True)
185
 
186
  if __name__ == "__main__":
187
- demo()
 
1
  import gradio as gr
2
  import os
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.llms import HuggingFaceEndpoint
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
  from dotenv import load_dotenv
11
  import torch
12
 
 
13
  load_dotenv()
14
+
15
  api_token = os.getenv("HF_TOKEN")
16
 
 
17
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
18
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
19
 
20
  # Load and split PDF document
21
+ def load_doc(list_file_path, chunk_size=512, chunk_overlap=64):
22
  loaders = [PyPDFLoader(x) for x in list_file_path]
23
  pages = []
24
  for loader in loaders:
 
125
  def upload_file(file_obj):
126
  list_file_path = []
127
  for idx, file in enumerate(file_obj):
128
+ file_path = file_obj.name
129
  list_file_path.append(file_path)
130
  return list_file_path
131
 
 
135
  with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
136
  vector_db = gr.State()
137
  qa_chain = gr.State()
138
+ gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
139
+ gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. <b>Please do not upload confidential documents.</b>""")
 
140
 
141
  # Interface for static pre-selected documents
142
  gr.Markdown("<b>Pre-Selected Documents</b>")
143
+ gr.Textbox(value="Document 1: ...", show_label=False, interactive=False)
144
+ gr.Textbox(value="Document 2: ...", show_label=False, interactive=False)
145
 
146
+ gr.Markdown("<b>Select Large Language Model (LLM) and Input Parameters</b>")
 
 
 
 
 
147
  llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
148
  slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
149
  slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
150
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
151
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
152
  llm_progress = gr.Textbox(value="Not initialized", show_label=False)
153
 
154
+ gr.Markdown("<b>Chat with your Document</b>")
155
  chatbot = gr.Chatbot(height=505)
156
+ with gr.Accordion("Relevant context from the source document", open=False):
157
+ with gr.Row():
158
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
159
+ source1_page = gr.Number(label="Page", scale=1)
160
+ with gr.Row():
161
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
162
+ source2_page = gr.Number(label="Page", scale=1)
163
+ with gr.Row():
164
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
165
+ source3_page = gr.Number(label="Page", scale=1)
166
  msg = gr.Textbox(placeholder="Ask a question", container=True)
167
  submit_btn = gr.Button("Submit")
168
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
169
+
170
+ # Preprocessing events
171
+ db_btn = gr.Button("Create vector database")
172
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
173
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
174
+
175
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0, "", 0],
176
  inputs=None,
177
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
178
  queue=False)
179
 
180
+ # Chatbot events
181
  msg.submit(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
182
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
183
+ clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
184
+
185
  demo.queue().launch(debug=True)
186
 
187
  if __name__ == "__main__":
188
+ demo()