dryouviavant commited on
Commit
0d951c1
·
verified ·
1 Parent(s): 410bc42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
  import os
3
- from langchain.vectorstores import FAISS
4
- from langchain.document_loaders import PyPDFLoader
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.llms import HuggingFaceEndpoint
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
@@ -75,6 +75,13 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
75
  )
76
  return qa_chain
77
 
 
 
 
 
 
 
 
78
  # Initialize database
79
  def initialize_database(list_file_obj, progress=gr.Progress()):
80
  list_file_path = [x.name for x in list_file_obj if x is not None]
@@ -120,12 +127,12 @@ def conversation(qa_chain, message, history, persona_text):
120
  response_source2_page = response_sources[1].metadata["page"] + 1
121
  response_source3_page = response_sources[2].metadata["page"] + 1
122
  new_history = history + [(message, response_answer)]
123
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
124
 
125
  def upload_file(file_obj):
126
  list_file_path = []
127
  for idx, file in enumerate(file_obj):
128
- file_path = file_obj.name
129
  list_file_path.append(file_path)
130
  return list_file_path
131
 
@@ -133,21 +140,27 @@ def demo():
133
  persona_text = load_persona('persona.md')
134
 
135
  with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
136
- vector_db = gr.State()
137
  qa_chain = gr.State()
138
- gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
139
- gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. <b>Please do not upload confidential documents.</b>""")
 
140
 
141
  # Interface for static pre-selected documents
142
  gr.Markdown("<b>Pre-Selected Documents</b>")
143
- gr.Textbox(value="Document 1: ...", show_label=False, interactive=False)
144
- gr.Textbox(value="Document 2: ...", show_label=False, interactive=False)
145
 
146
- gr.Markdown("<b>Select Large Language Model (LLM) and Input Parameters</b>")
 
 
 
 
 
147
  llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
148
  slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
149
  slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
150
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
151
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
152
  llm_progress = gr.Textbox(value="Not initialized", show_label=False)
153
 
@@ -168,8 +181,6 @@ def demo():
168
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
169
 
170
  # Preprocessing events
171
- db_btn = gr.Button("Create vector database")
172
- db_progress = gr.Textbox(value="Not initialized", show_label=False)
173
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
174
 
175
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0, "", 0],
@@ -178,8 +189,8 @@ def demo():
178
  queue=False)
179
 
180
  # Chatbot events
181
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
182
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, persona_text], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
183
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
184
 
185
  demo.queue().launch(debug=True)
 
1
  import gradio as gr
2
  import os
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.llms import HuggingFaceEndpoint
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
 
75
  )
76
  return qa_chain
77
 
78
+ # Pre-process and vectorize local PDFs
79
+ def pre_process_pdfs(directory="pdfs"):
80
+ file_paths = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pdf')]
81
+ doc_splits = load_doc(file_paths)
82
+ vector_db = create_db(doc_splits)
83
+ return vector_db
84
+
85
  # Initialize database
86
  def initialize_database(list_file_obj, progress=gr.Progress()):
87
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
127
  response_source2_page = response_sources[1].metadata["page"] + 1
128
  response_source3_page = response_sources[2].metadata["page"] + 1
129
  new_history = history + [(message, response_answer)]
130
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response2_page, response_source3, source3_page
131
 
132
  def upload_file(file_obj):
133
  list_file_path = []
134
  for idx, file in enumerate(file_obj):
135
+ file_path = file.name
136
  list_file_path.append(file_path)
137
  return list_file_path
138
 
 
140
  persona_text = load_persona('persona.md')
141
 
142
  with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
143
+ vector_db = gr.State(pre_process_pdfs("ILYA/pdfs")) # Pre-process PDFs on initialization with correct path
144
  qa_chain = gr.State()
145
+ gr.HTML("<center><h1>RAG PDF Chatbot</h1><center>")
146
+ gr.Markdown("""<b>Interact with Your PDF Documents!</b> This AI agent performs retrieval-augmented generation (RAG) on PDF documents. Hosted on Hugging Face Hub for demonstration purposes. \
147
+ <b>Do not upload confidential documents.</b>""")
148
 
149
  # Interface for static pre-selected documents
150
  gr.Markdown("<b>Pre-Selected Documents</b>")
151
+ gr.Textbox(value="Document 1: Introduction to AI.pdf", show_label=False, interactive=False)
152
+ gr.Textbox(value="Document 2: Advanced Machine Learning.pdf", show_label=False, interactive=False)
153
 
154
+ gr.Markdown("<b>Upload Your PDF Documents</b>")
155
+ document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
156
+ db_btn = gr.Button("Create vector database")
157
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
158
+
159
+ gr.Markdown("<b>Select Large Language Model (LLM) and Configure Parameters</b>")
160
  llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
161
  slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
162
  slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
163
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-K", info="Number of tokens to select the next token from", interactive=True)
164
  qachain_btn = gr.Button("Initialize Question Answering Chatbot")
165
  llm_progress = gr.Textbox(value="Not initialized", show_label=False)
166
 
 
181
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
182
 
183
  # Preprocessing events
 
 
184
  db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
185
 
186
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda: [None, "", 0, "", 0, "", 0],
 
189
  queue=False)
190
 
191
  # Chatbot events
192
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot, gr.State(value=persona_text)], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
193
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, gr.State(value=persona_text)], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
194
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
195
 
196
  demo.queue().launch(debug=True)