Clement Vachet commited on
Commit
b4bdfee
·
1 Parent(s): fd5ccbe

Code refactoring

Browse files
Files changed (4) hide show
  1. app.py +259 -238
  2. indexing.py +83 -0
  3. prompt_template.json +5 -0
  4. retrieval.py +114 -0
app.py CHANGED
@@ -1,211 +1,93 @@
1
- import gradio as gr
2
- import os
 
3
 
4
- from langchain_community.document_loaders import PyPDFLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_chroma import Chroma
7
- from langchain.chains import ConversationalRetrievalChain
8
- from langchain_huggingface import HuggingFaceEmbeddings
9
- from langchain.chains import ConversationChain
10
- from langchain.memory import ConversationBufferMemory
11
- from langchain_huggingface import HuggingFaceEndpoint
12
- from langchain_core.prompts import PromptTemplate
13
-
14
- from pathlib import Path
15
- import chromadb
16
- from unidecode import unidecode
17
-
18
- from transformers import AutoTokenizer
19
- import transformers
20
- import torch
21
- import tqdm
22
- import accelerate
23
- import re
24
 
25
  from dotenv import load_dotenv
26
 
27
-
28
- # Load environment file - HuggingFace API key
29
- _ = load_dotenv()
30
- huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
31
-
32
-
33
- # Add system template for RAG application
34
- prompt_template = """
35
- You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end.
36
- If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
37
- Question: {question}
38
- Context: {context}
39
- Helpful Answer:
40
- """
41
 
42
 
43
  # default_persist_directory = './chroma_HF/'
44
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.3", "microsoft/Phi-3.5-mini-instruct", \
45
- "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", \
46
- "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \
47
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "google/gemma-2-2b-it", "google/gemma-2-9b-it", \
48
- "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-7B-Instruct",
 
 
 
 
 
 
 
 
 
49
  ]
50
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
51
 
52
 
53
- # Load PDF document and create doc splits
54
- def load_doc(list_file_path, chunk_size, chunk_overlap):
55
- """Load PDF document and create doc splits"""
56
-
57
- loaders = [PyPDFLoader(x) for x in list_file_path]
58
- pages = []
59
- for loader in loaders:
60
- pages.extend(loader.load())
61
- text_splitter = RecursiveCharacterTextSplitter(
62
- chunk_size = chunk_size,
63
- chunk_overlap = chunk_overlap)
64
- doc_splits = text_splitter.split_documents(pages)
65
- return doc_splits
66
-
67
-
68
- # Create vector database
69
- def create_db(splits, collection_name):
70
- """Create embeddings and vector database"""
71
-
72
- embedding = HuggingFaceEmbeddings(
73
- model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
74
- # model_name="sentence-transformers/all-MiniLM-L6-v2",
75
- model_kwargs={'device': 'cpu'},
76
- # encode_kwargs={'normalize_embeddings': False}
77
- )
78
- chromadb.api.client.SharedSystemClient.clear_system_cache()
79
- new_client = chromadb.EphemeralClient()
80
- vectordb = Chroma.from_documents(
81
- documents=splits,
82
- embedding=embedding,
83
- client=new_client,
84
- collection_name=collection_name,
85
- # persist_directory=default_persist_directory
86
- )
87
- return vectordb
88
-
89
-
90
-
91
- # Initialize langchain LLM chain
92
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
93
- """Initialize Langchain LLM chain"""
94
-
95
- progress(0.1, desc="Initializing HF tokenizer...")
96
- # HuggingFaceHub uses HF inference endpoints
97
- progress(0.5, desc="Initializing HF Hub...")
98
- # Use of trust_remote_code as model_kwargs
99
- # Warning: langchain issue
100
- # URL: https://github.com/langchain-ai/langchain/issues/6080
101
-
102
-
103
- llm = HuggingFaceEndpoint(
104
- repo_id=llm_model,
105
- task = "text-generation",
106
- temperature = temperature,
107
- max_new_tokens = max_tokens,
108
- top_k = top_k,
109
- huggingfacehub_api_token=huggingfacehub_api_token,
110
- )
111
-
112
- progress(0.75, desc="Defining buffer memory...")
113
- memory = ConversationBufferMemory(
114
- memory_key="chat_history",
115
- output_key='answer',
116
- return_messages=True
117
- )
118
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
119
- retriever=vector_db.as_retriever()
120
- progress(0.8, desc="Defining retrieval chain...")
121
- rag_prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
122
- qa_chain = ConversationalRetrievalChain.from_llm(
123
- llm,
124
- retriever=retriever,
125
- chain_type="stuff",
126
- memory=memory,
127
- combine_docs_chain_kwargs={"prompt": rag_prompt},
128
- return_source_documents=True,
129
- #return_generated_question=False,
130
- verbose=False,
131
- )
132
- progress(0.9, desc="Done!")
133
-
134
- return qa_chain
135
-
136
-
137
- # Generate collection name for vector database
138
- # - Use filepath as input, ensuring unicode text
139
- def create_collection_name(filepath):
140
- # Extract filename without extension
141
- collection_name = Path(filepath).stem
142
- # Fix potential issues from naming convention
143
- ## Remove space
144
- collection_name = collection_name.replace(" ","-")
145
- ## ASCII transliterations of Unicode text
146
- collection_name = unidecode(collection_name)
147
- ## Remove special characters
148
- #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
149
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
150
- ## Limit length to 50 characters
151
- collection_name = collection_name[:50]
152
- ## Minimum length of 3 characters
153
- if len(collection_name) < 3:
154
- collection_name = collection_name + 'xyz'
155
- ## Enforce start and end as alphanumeric character
156
- if not collection_name[0].isalnum():
157
- collection_name = 'A' + collection_name[1:]
158
- if not collection_name[-1].isalnum():
159
- collection_name = collection_name[:-1] + 'Z'
160
- print('\n\nFilepath: ', filepath)
161
- print('Collection name: ', collection_name)
162
- return collection_name
163
 
164
 
165
  # Initialize database
166
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
 
 
 
 
167
  # Create list of documents (when valid)
168
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
169
  # Create collection_name for vector database
170
  progress(0.1, desc="Creating collection name...")
171
- collection_name = create_collection_name(list_file_path[0])
 
172
  progress(0.25, desc="Loading document...")
173
  # Load document and create splits
174
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
 
175
  # Create or load vector database
176
  progress(0.5, desc="Generating vector database...")
 
177
  # global vector_db
178
- vector_db = create_db(doc_splits, collection_name)
179
- progress(0.9, desc="Done!")
180
  return vector_db, collection_name, "Complete!"
181
 
182
 
183
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
 
 
184
  # print("llm_option",llm_option)
185
  llm_name = list_llm[llm_option]
186
- print("llm_name: ",llm_name)
187
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
 
 
188
  return qa_chain, "Complete!"
189
 
190
 
191
- def format_chat_history(message, chat_history):
192
- formatted_chat_history = []
193
- for user_message, bot_message in chat_history:
194
- formatted_chat_history.append(f"User: {user_message}")
195
- formatted_chat_history.append(f"Assistant: {bot_message}")
196
- return formatted_chat_history
197
-
198
-
199
  def conversation(qa_chain, message, history):
200
- formatted_chat_history = format_chat_history(message, history)
201
- #print("formatted_chat_history",formatted_chat_history)
202
-
203
- # Generate response using QA chain
204
- response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
205
- response_answer = response["answer"]
206
- if response_answer.find("Helpful Answer:") != -1:
207
- response_answer = response_answer.split("Helpful Answer:")[-1]
208
- response_sources = response["source_documents"]
209
  response_source1 = response_sources[0].page_content.strip()
210
  response_source2 = response_sources[1].page_content.strip()
211
  response_source3 = response_sources[2].page_content.strip()
@@ -213,61 +95,134 @@ def conversation(qa_chain, message, history):
213
  response_source1_page = response_sources[0].metadata["page"] + 1
214
  response_source2_page = response_sources[1].metadata["page"] + 1
215
  response_source3_page = response_sources[2].metadata["page"] + 1
216
- # print ('chat response: ', response_answer)
217
- # print('DB source', response_sources)
218
-
219
- # Append user message and response to chat history
220
- new_history = history + [(message, response_answer)]
221
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
222
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
223
-
224
-
225
- def demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  with gr.Blocks(theme="base") as demo:
227
  vector_db = gr.State()
228
  qa_chain = gr.State()
229
  collection_name = gr.State()
230
-
231
- gr.Markdown(
232
- """<center><h2>PDF-based chatbot</center></h2>
233
- <h3>Ask any questions about your PDF documents</h3>""")
234
- gr.Markdown(
235
- """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
236
- The user interface explicitely shows multiple steps to help understand the RAG workflow.
237
- This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
238
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
239
- """)
240
-
241
  with gr.Tab("Step 1 - Upload PDF"):
242
  with gr.Row():
243
- document = gr.File(height=200, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
244
-
 
 
 
 
 
 
245
  with gr.Tab("Step 2 - Process document"):
246
  with gr.Row():
247
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
 
 
 
 
 
 
248
  with gr.Accordion("Advanced options - Document text splitter", open=False):
249
  with gr.Row():
250
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
 
 
 
 
 
 
 
 
251
  with gr.Row():
252
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
 
 
 
 
 
 
 
 
253
  with gr.Row():
254
- db_progress = gr.Textbox(label="Vector database initialization", value="None")
 
 
255
  with gr.Row():
256
  db_btn = gr.Button("Generate vector database")
257
-
258
  with gr.Tab("Step 3 - Initialize QA chain"):
259
  with gr.Row():
260
- llm_btn = gr.Radio(list_llm_simple, \
261
- label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
 
 
 
 
 
262
  with gr.Accordion("Advanced options - LLM model", open=False):
263
  with gr.Row():
264
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
 
 
 
 
 
 
 
 
265
  with gr.Row():
266
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
 
 
 
 
 
 
 
 
267
  with gr.Row():
268
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
 
 
 
 
 
 
 
 
269
  with gr.Row():
270
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
271
  with gr.Row():
272
  qachain_btn = gr.Button("Initialize Question Answering chain")
273
 
@@ -275,46 +230,112 @@ def demo():
275
  chatbot = gr.Chatbot(height=300)
276
  with gr.Accordion("Advanced - Document references", open=False):
277
  with gr.Row():
278
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
 
 
279
  source1_page = gr.Number(label="Page", scale=1)
280
  with gr.Row():
281
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
 
 
282
  source2_page = gr.Number(label="Page", scale=1)
283
  with gr.Row():
284
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
 
 
285
  source3_page = gr.Number(label="Page", scale=1)
286
  with gr.Row():
287
- msg = gr.Textbox(placeholder="Type message (e.g. 'Can you summarize this document in one paragraph?')", container=True)
 
 
 
288
  with gr.Row():
289
  submit_btn = gr.Button("Submit message")
290
- clear_btn = gr.ClearButton(components=[msg, chatbot], value="Clear conversation")
291
-
 
 
292
  # Preprocessing events
293
- db_btn.click(initialize_database, \
294
- inputs=[document, slider_chunk_size, slider_chunk_overlap], \
295
- outputs=[vector_db, collection_name, db_progress])
296
- qachain_btn.click(initialize_LLM, \
297
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
298
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
299
- inputs=None, \
300
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
301
- queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # Chatbot events
304
- msg.submit(conversation, \
305
- inputs=[qa_chain, msg, chatbot], \
306
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
307
- queue=False)
308
- submit_btn.click(conversation, \
309
- inputs=[qa_chain, msg, chatbot], \
310
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
311
- queue=False)
312
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
313
- inputs=None, \
314
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
315
- queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  demo.queue().launch(debug=True)
317
 
318
 
319
  if __name__ == "__main__":
320
- demo()
 
 
1
+ """
2
+ PDF-based chatbot with Retrieval-Augmented Generation
3
+ """
4
 
5
+ import os
6
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  from dotenv import load_dotenv
9
 
10
+ import indexing
11
+ import retrieval
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  # default_persist_directory = './chroma_HF/'
15
+ list_llm = [
16
+ "mistralai/Mistral-7B-Instruct-v0.3",
17
+ "microsoft/Phi-3.5-mini-instruct",
18
+ "meta-llama/Llama-3.2-3B-Instruct",
19
+ "meta-llama/Llama-3.2-1B-Instruct",
20
+ "meta-llama/Meta-Llama-3-8B-Instruct",
21
+ "HuggingFaceH4/zephyr-7b-beta",
22
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
23
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
24
+ "google/gemma-2-2b-it",
25
+ "google/gemma-2-9b-it",
26
+ "Qwen/Qwen2.5-1.5B-Instruct",
27
+ "Qwen/Qwen2.5-3B-Instruct",
28
+ "Qwen/Qwen2.5-7B-Instruct",
29
  ]
30
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
31
 
32
 
33
+ # Load environment file - HuggingFace API key
34
+ def retrieve_api():
35
+ """Retrieve HuggingFace API Key"""
36
+ _ = load_dotenv()
37
+ global huggingfacehub_api_token
38
+ huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  # Initialize database
42
+ def initialize_database(
43
+ list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
44
+ ):
45
+ """Initialize database"""
46
+
47
  # Create list of documents (when valid)
48
  list_file_path = [x.name for x in list_file_obj if x is not None]
49
+
50
  # Create collection_name for vector database
51
  progress(0.1, desc="Creating collection name...")
52
+ collection_name = indexing.create_collection_name(list_file_path[0])
53
+
54
  progress(0.25, desc="Loading document...")
55
  # Load document and create splits
56
+ doc_splits = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
57
+
58
  # Create or load vector database
59
  progress(0.5, desc="Generating vector database...")
60
+
61
  # global vector_db
62
+ vector_db = indexing.create_db(doc_splits, collection_name)
63
+
64
  return vector_db, collection_name, "Complete!"
65
 
66
 
67
+ # Initialize LLM
68
+ def initialize_llm(
69
+ llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
70
+ ):
71
+ """Initialize LLM"""
72
+
73
  # print("llm_option",llm_option)
74
  llm_name = list_llm[llm_option]
75
+ print("llm_name: ", llm_name)
76
+ qa_chain = retrieval.initialize_llmchain(
77
+ llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
78
+ )
79
  return qa_chain, "Complete!"
80
 
81
 
82
+ # Chatbot conversation
 
 
 
 
 
 
 
83
  def conversation(qa_chain, message, history):
84
+ """Chatbot conversation"""
85
+
86
+ qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(
87
+ qa_chain, message, history
88
+ )
89
+
90
+ # Format output gradio components
 
 
91
  response_source1 = response_sources[0].page_content.strip()
92
  response_source2 = response_sources[1].page_content.strip()
93
  response_source3 = response_sources[2].page_content.strip()
 
95
  response_source1_page = response_sources[0].metadata["page"] + 1
96
  response_source2_page = response_sources[1].metadata["page"] + 1
97
  response_source3_page = response_sources[2].metadata["page"] + 1
98
+
99
+ return (
100
+ qa_chain,
101
+ gr.update(value=""),
102
+ new_history,
103
+ response_source1,
104
+ response_source1_page,
105
+ response_source2,
106
+ response_source2_page,
107
+ response_source3,
108
+ response_source3_page,
109
+ )
110
+
111
+
112
+ SPACE_TITLE = """
113
+ <center><h2>PDF-based chatbot</center></h2>
114
+ <h3>Ask any questions about your PDF documents</h3>
115
+ """
116
+
117
+ SPACE_INFO = """
118
+ <b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
119
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
120
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
121
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
122
+ """
123
+
124
+
125
+ # Gradio User Interface
126
+ def gradio_ui():
127
+ """Gradio User Interface"""
128
+
129
  with gr.Blocks(theme="base") as demo:
130
  vector_db = gr.State()
131
  qa_chain = gr.State()
132
  collection_name = gr.State()
133
+
134
+ gr.Markdown(SPACE_TITLE)
135
+ gr.Markdown(SPACE_INFO)
136
+
 
 
 
 
 
 
 
137
  with gr.Tab("Step 1 - Upload PDF"):
138
  with gr.Row():
139
+ document = gr.File(
140
+ height=200,
141
+ file_count="multiple",
142
+ file_types=[".pdf"],
143
+ interactive=True,
144
+ label="Upload your PDF documents (single or multiple)",
145
+ )
146
+
147
  with gr.Tab("Step 2 - Process document"):
148
  with gr.Row():
149
+ db_btn = gr.Radio(
150
+ ["ChromaDB"],
151
+ label="Vector database type",
152
+ value="ChromaDB",
153
+ type="index",
154
+ info="Choose your vector database",
155
+ )
156
  with gr.Accordion("Advanced options - Document text splitter", open=False):
157
  with gr.Row():
158
+ slider_chunk_size = gr.Slider(
159
+ minimum=100,
160
+ maximum=1000,
161
+ value=600,
162
+ step=20,
163
+ label="Chunk size",
164
+ info="Chunk size",
165
+ interactive=True,
166
+ )
167
  with gr.Row():
168
+ slider_chunk_overlap = gr.Slider(
169
+ minimum=10,
170
+ maximum=200,
171
+ value=40,
172
+ step=10,
173
+ label="Chunk overlap",
174
+ info="Chunk overlap",
175
+ interactive=True,
176
+ )
177
  with gr.Row():
178
+ db_progress = gr.Textbox(
179
+ label="Vector database initialization", value="None"
180
+ )
181
  with gr.Row():
182
  db_btn = gr.Button("Generate vector database")
183
+
184
  with gr.Tab("Step 3 - Initialize QA chain"):
185
  with gr.Row():
186
+ llm_btn = gr.Radio(
187
+ list_llm_simple,
188
+ label="LLM models",
189
+ value=list_llm_simple[0],
190
+ type="index",
191
+ info="Choose your LLM model",
192
+ )
193
  with gr.Accordion("Advanced options - LLM model", open=False):
194
  with gr.Row():
195
+ slider_temperature = gr.Slider(
196
+ minimum=0.01,
197
+ maximum=1.0,
198
+ value=0.7,
199
+ step=0.1,
200
+ label="Temperature",
201
+ info="Model temperature",
202
+ interactive=True,
203
+ )
204
  with gr.Row():
205
+ slider_maxtokens = gr.Slider(
206
+ minimum=224,
207
+ maximum=4096,
208
+ value=1024,
209
+ step=32,
210
+ label="Max Tokens",
211
+ info="Model max tokens",
212
+ interactive=True,
213
+ )
214
  with gr.Row():
215
+ slider_topk = gr.Slider(
216
+ minimum=1,
217
+ maximum=10,
218
+ value=3,
219
+ step=1,
220
+ label="top-k samples",
221
+ info="Model top-k samples",
222
+ interactive=True,
223
+ )
224
  with gr.Row():
225
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
226
  with gr.Row():
227
  qachain_btn = gr.Button("Initialize Question Answering chain")
228
 
 
230
  chatbot = gr.Chatbot(height=300)
231
  with gr.Accordion("Advanced - Document references", open=False):
232
  with gr.Row():
233
+ doc_source1 = gr.Textbox(
234
+ label="Reference 1", lines=2, container=True, scale=20
235
+ )
236
  source1_page = gr.Number(label="Page", scale=1)
237
  with gr.Row():
238
+ doc_source2 = gr.Textbox(
239
+ label="Reference 2", lines=2, container=True, scale=20
240
+ )
241
  source2_page = gr.Number(label="Page", scale=1)
242
  with gr.Row():
243
+ doc_source3 = gr.Textbox(
244
+ label="Reference 3", lines=2, container=True, scale=20
245
+ )
246
  source3_page = gr.Number(label="Page", scale=1)
247
  with gr.Row():
248
+ msg = gr.Textbox(
249
+ placeholder="Type message (e.g. 'Can you summarize this document in one paragraph?')",
250
+ container=True,
251
+ )
252
  with gr.Row():
253
  submit_btn = gr.Button("Submit message")
254
+ clear_btn = gr.ClearButton(
255
+ components=[msg, chatbot], value="Clear conversation"
256
+ )
257
+
258
  # Preprocessing events
259
+ db_btn.click(
260
+ initialize_database,
261
+ inputs=[document, slider_chunk_size, slider_chunk_overlap],
262
+ outputs=[vector_db, collection_name, db_progress],
263
+ )
264
+ qachain_btn.click(
265
+ initialize_llm,
266
+ inputs=[
267
+ llm_btn,
268
+ slider_temperature,
269
+ slider_maxtokens,
270
+ slider_topk,
271
+ vector_db,
272
+ ],
273
+ outputs=[qa_chain, llm_progress],
274
+ ).then(
275
+ lambda: [None, "", 0, "", 0, "", 0],
276
+ inputs=None,
277
+ outputs=[
278
+ chatbot,
279
+ doc_source1,
280
+ source1_page,
281
+ doc_source2,
282
+ source2_page,
283
+ doc_source3,
284
+ source3_page,
285
+ ],
286
+ queue=False,
287
+ )
288
 
289
  # Chatbot events
290
+ msg.submit(
291
+ conversation,
292
+ inputs=[qa_chain, msg, chatbot],
293
+ outputs=[
294
+ qa_chain,
295
+ msg,
296
+ chatbot,
297
+ doc_source1,
298
+ source1_page,
299
+ doc_source2,
300
+ source2_page,
301
+ doc_source3,
302
+ source3_page,
303
+ ],
304
+ queue=False,
305
+ )
306
+ submit_btn.click(
307
+ conversation,
308
+ inputs=[qa_chain, msg, chatbot],
309
+ outputs=[
310
+ qa_chain,
311
+ msg,
312
+ chatbot,
313
+ doc_source1,
314
+ source1_page,
315
+ doc_source2,
316
+ source2_page,
317
+ doc_source3,
318
+ source3_page,
319
+ ],
320
+ queue=False,
321
+ )
322
+ clear_btn.click(
323
+ lambda: [None, "", 0, "", 0, "", 0],
324
+ inputs=None,
325
+ outputs=[
326
+ chatbot,
327
+ doc_source1,
328
+ source1_page,
329
+ doc_source2,
330
+ source2_page,
331
+ doc_source3,
332
+ source3_page,
333
+ ],
334
+ queue=False,
335
+ )
336
  demo.queue().launch(debug=True)
337
 
338
 
339
  if __name__ == "__main__":
340
+ retrieve_api()
341
+ gradio_ui()
indexing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Indexing with vector database
3
+ """
4
+
5
+ from pathlib import Path
6
+ import re
7
+
8
+ import chromadb
9
+
10
+ from unidecode import unidecode
11
+
12
+ from langchain_community.document_loaders import PyPDFLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_chroma import Chroma
15
+ from langchain_huggingface import HuggingFaceEmbeddings
16
+
17
+
18
+
19
+ # Load PDF document and create doc splits
20
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
21
+ """Load PDF document and create doc splits"""
22
+
23
+ loaders = [PyPDFLoader(x) for x in list_file_path]
24
+ pages = []
25
+ for loader in loaders:
26
+ pages.extend(loader.load())
27
+ text_splitter = RecursiveCharacterTextSplitter(
28
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
29
+ )
30
+ doc_splits = text_splitter.split_documents(pages)
31
+ return doc_splits
32
+
33
+
34
+ # Generate collection name for vector database
35
+ # - Use filepath as input, ensuring unicode text
36
+ # - Handle multiple languages (arabic, chinese)
37
+ def create_collection_name(filepath):
38
+ """Create collection name for vector database"""
39
+
40
+ # Extract filename without extension
41
+ collection_name = Path(filepath).stem
42
+ # Fix potential issues from naming convention
43
+ ## Remove space
44
+ collection_name = collection_name.replace(" ", "-")
45
+ ## ASCII transliterations of Unicode text
46
+ collection_name = unidecode(collection_name)
47
+ ## Remove special characters
48
+ collection_name = re.sub("[^A-Za-z0-9]+", "-", collection_name)
49
+ ## Limit length to 50 characters
50
+ collection_name = collection_name[:50]
51
+ ## Minimum length of 3 characters
52
+ if len(collection_name) < 3:
53
+ collection_name = collection_name + "xyz"
54
+ ## Enforce start and end as alphanumeric character
55
+ if not collection_name[0].isalnum():
56
+ collection_name = "A" + collection_name[1:]
57
+ if not collection_name[-1].isalnum():
58
+ collection_name = collection_name[:-1] + "Z"
59
+ print("\n\nFilepath: ", filepath)
60
+ print("Collection name: ", collection_name)
61
+ return collection_name
62
+
63
+
64
+ # Create vector database
65
+ def create_db(splits, collection_name):
66
+ """Create embeddings and vector database"""
67
+
68
+ embedding = HuggingFaceEmbeddings(
69
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
70
+ # model_name="sentence-transformers/all-MiniLM-L6-v2",
71
+ model_kwargs={"device": "cpu"},
72
+ # encode_kwargs={'normalize_embeddings': False}
73
+ )
74
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
75
+ new_client = chromadb.EphemeralClient()
76
+ vectordb = Chroma.from_documents(
77
+ documents=splits,
78
+ embedding=embedding,
79
+ client=new_client,
80
+ collection_name=collection_name,
81
+ # persist_directory=default_persist_directory
82
+ )
83
+ return vectordb
prompt_template.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "title": "System prompt",
3
+ "prompt": "You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise. Question: {question} \\n Context: {context} \\n Helpful Answer:"
4
+ }
5
+
retrieval.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM chain retrieval
3
+ """
4
+
5
+ import json
6
+ import gradio as gr
7
+
8
+ from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain_huggingface import HuggingFaceEndpoint
11
+ from langchain_core.prompts import PromptTemplate
12
+
13
+
14
+ # Add system template for RAG application
15
+ PROMPT_TEMPLATE = """
16
+ You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end.
17
+ If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
18
+ Question: {question}
19
+ Context: {context}
20
+ Helpful Answer:
21
+ """
22
+
23
+
24
+ # Initialize langchain LLM chain
25
+ def initialize_llmchain(
26
+ llm_model,
27
+ huggingfacehub_api_token,
28
+ temperature,
29
+ max_tokens,
30
+ top_k,
31
+ vector_db,
32
+ progress=gr.Progress(),
33
+ ):
34
+ """Initialize Langchain LLM chain"""
35
+
36
+ progress(0.1, desc="Initializing HF tokenizer...")
37
+ # HuggingFaceHub uses HF inference endpoints
38
+ progress(0.5, desc="Initializing HF Hub...")
39
+ # Use of trust_remote_code as model_kwargs
40
+ # Warning: langchain issue
41
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
42
+
43
+ llm = HuggingFaceEndpoint(
44
+ repo_id=llm_model,
45
+ task="text-generation",
46
+ temperature=temperature,
47
+ max_new_tokens=max_tokens,
48
+ top_k=top_k,
49
+ huggingfacehub_api_token=huggingfacehub_api_token,
50
+ )
51
+
52
+ progress(0.75, desc="Defining buffer memory...")
53
+ memory = ConversationBufferMemory(
54
+ memory_key="chat_history", output_key="answer", return_messages=True
55
+ )
56
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
57
+ retriever = vector_db.as_retriever()
58
+
59
+ progress(0.8, desc="Defining retrieval chain...")
60
+ with open('prompt_template.json', 'r') as file:
61
+ system_prompt = json.load(file)
62
+ prompt_template = system_prompt["prompt"]
63
+ rag_prompt = PromptTemplate(
64
+ template=prompt_template, input_variables=["context", "question"]
65
+ )
66
+ qa_chain = ConversationalRetrievalChain.from_llm(
67
+ llm,
68
+ retriever=retriever,
69
+ chain_type="stuff",
70
+ memory=memory,
71
+ combine_docs_chain_kwargs={"prompt": rag_prompt},
72
+ return_source_documents=True,
73
+ # return_generated_question=False,
74
+ verbose=False,
75
+ )
76
+ progress(0.9, desc="Done!")
77
+
78
+ return qa_chain
79
+
80
+
81
+ def format_chat_history(message, chat_history):
82
+ """Format chat history for llm chain"""
83
+
84
+ formatted_chat_history = []
85
+ for user_message, bot_message in chat_history:
86
+ formatted_chat_history.append(f"User: {user_message}")
87
+ formatted_chat_history.append(f"Assistant: {bot_message}")
88
+ return formatted_chat_history
89
+
90
+
91
+ def invoke_qa_chain(qa_chain, message, history):
92
+ """Invoke question-answering chain"""
93
+
94
+ formatted_chat_history = format_chat_history(message, history)
95
+ # print("formatted_chat_history",formatted_chat_history)
96
+
97
+ # Generate response using QA chain
98
+ response = qa_chain.invoke(
99
+ {"question": message, "chat_history": formatted_chat_history}
100
+ )
101
+
102
+ response_sources = response["source_documents"]
103
+
104
+ response_answer = response["answer"]
105
+ if response_answer.find("Helpful Answer:") != -1:
106
+ response_answer = response_answer.split("Helpful Answer:")[-1]
107
+
108
+ # Append user message and response to chat history
109
+ new_history = history + [(message, response_answer)]
110
+
111
+ # print ('chat response: ', response_answer)
112
+ # print('DB source', response_sources)
113
+
114
+ return qa_chain, new_history, response_sources