mariagrandury commited on
Commit
9124976
·
1 Parent(s): e15765c
Files changed (2) hide show
  1. app.py +532 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
+
14
+ from pathlib import Path
15
+ import chromadb
16
+ from unidecode import unidecode
17
+
18
+ from transformers import AutoTokenizer
19
+ import transformers
20
+ import torch
21
+ import tqdm
22
+ import accelerate
23
+ import re
24
+
25
+
26
+ # default_persist_directory = './chroma_HF/'
27
+ list_llm = [
28
+ "mistralai/Mistral-7B-Instruct-v0.2",
29
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
30
+ "mistralai/Mistral-7B-Instruct-v0.1",
31
+ "google/gemma-7b-it",
32
+ "google/gemma-2b-it",
33
+ "HuggingFaceH4/zephyr-7b-beta",
34
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
35
+ "meta-llama/Llama-2-7b-chat-hf",
36
+ "microsoft/phi-2",
37
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
38
+ "mosaicml/mpt-7b-instruct",
39
+ "tiiuae/falcon-7b-instruct",
40
+ "google/flan-t5-xxl",
41
+ ]
42
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
43
+
44
+
45
+ # Load PDF document and create doc splits
46
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
47
+ # Processing for one document only
48
+ # loader = PyPDFLoader(file_path)
49
+ # pages = loader.load()
50
+ loaders = [PyPDFLoader(x) for x in list_file_path]
51
+ pages = []
52
+ for loader in loaders:
53
+ pages.extend(loader.load())
54
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
55
+ text_splitter = RecursiveCharacterTextSplitter(
56
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
57
+ )
58
+ doc_splits = text_splitter.split_documents(pages)
59
+ return doc_splits
60
+
61
+
62
+ # Create vector database
63
+ def create_db(splits, collection_name):
64
+ embedding = HuggingFaceEmbeddings()
65
+ new_client = chromadb.EphemeralClient()
66
+ vectordb = Chroma.from_documents(
67
+ documents=splits,
68
+ embedding=embedding,
69
+ client=new_client,
70
+ collection_name=collection_name,
71
+ # persist_directory=default_persist_directory
72
+ )
73
+ return vectordb
74
+
75
+
76
+ # Load vector database
77
+ def load_db():
78
+ embedding = HuggingFaceEmbeddings()
79
+ vectordb = Chroma(
80
+ # persist_directory=default_persist_directory,
81
+ embedding_function=embedding
82
+ )
83
+ return vectordb
84
+
85
+
86
+ # Initialize langchain LLM chain
87
+ def initialize_llmchain(
88
+ llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
89
+ ):
90
+ progress(0.1, desc="Initializing HF tokenizer...")
91
+ # HuggingFacePipeline uses local model
92
+ # Note: it will download model locally...
93
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
94
+ # progress(0.5, desc="Initializing HF pipeline...")
95
+ # pipeline=transformers.pipeline(
96
+ # "text-generation",
97
+ # model=llm_model,
98
+ # tokenizer=tokenizer,
99
+ # torch_dtype=torch.bfloat16,
100
+ # trust_remote_code=True,
101
+ # device_map="auto",
102
+ # # max_length=1024,
103
+ # max_new_tokens=max_tokens,
104
+ # do_sample=True,
105
+ # top_k=top_k,
106
+ # num_return_sequences=1,
107
+ # eos_token_id=tokenizer.eos_token_id
108
+ # )
109
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
110
+
111
+ # HuggingFaceHub uses HF inference endpoints
112
+ progress(0.5, desc="Initializing HF Hub...")
113
+ # Use of trust_remote_code as model_kwargs
114
+ # Warning: langchain issue
115
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
116
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
117
+ llm = HuggingFaceEndpoint(
118
+ repo_id=llm_model,
119
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
120
+ temperature=temperature,
121
+ max_new_tokens=max_tokens,
122
+ top_k=top_k,
123
+ load_in_8bit=True,
124
+ )
125
+ elif llm_model in [
126
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
127
+ "mosaicml/mpt-7b-instruct",
128
+ ]:
129
+ raise gr.Error(
130
+ "LLM model is too large to be loaded automatically on free inference endpoint"
131
+ )
132
+ llm = HuggingFaceEndpoint(
133
+ repo_id=llm_model,
134
+ temperature=temperature,
135
+ max_new_tokens=max_tokens,
136
+ top_k=top_k,
137
+ )
138
+ elif llm_model == "microsoft/phi-2":
139
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
140
+ llm = HuggingFaceEndpoint(
141
+ repo_id=llm_model,
142
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
143
+ temperature=temperature,
144
+ max_new_tokens=max_tokens,
145
+ top_k=top_k,
146
+ trust_remote_code=True,
147
+ torch_dtype="auto",
148
+ )
149
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
150
+ llm = HuggingFaceEndpoint(
151
+ repo_id=llm_model,
152
+ # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
153
+ temperature=temperature,
154
+ max_new_tokens=250,
155
+ top_k=top_k,
156
+ )
157
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
158
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
159
+ llm = HuggingFaceEndpoint(
160
+ repo_id=llm_model,
161
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
162
+ temperature=temperature,
163
+ max_new_tokens=max_tokens,
164
+ top_k=top_k,
165
+ )
166
+ else:
167
+ llm = HuggingFaceEndpoint(
168
+ repo_id=llm_model,
169
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
170
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
171
+ temperature=temperature,
172
+ max_new_tokens=max_tokens,
173
+ top_k=top_k,
174
+ )
175
+
176
+ progress(0.75, desc="Defining buffer memory...")
177
+ memory = ConversationBufferMemory(
178
+ memory_key="chat_history", output_key="answer", return_messages=True
179
+ )
180
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
181
+ retriever = vector_db.as_retriever()
182
+ progress(0.8, desc="Defining retrieval chain...")
183
+ qa_chain = ConversationalRetrievalChain.from_llm(
184
+ llm,
185
+ retriever=retriever,
186
+ chain_type="stuff",
187
+ memory=memory,
188
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
189
+ return_source_documents=True,
190
+ # return_generated_question=False,
191
+ verbose=False,
192
+ )
193
+ progress(0.9, desc="Done!")
194
+ return qa_chain
195
+
196
+
197
+ # Generate collection name for vector database
198
+ # - Use filepath as input, ensuring unicode text
199
+ def create_collection_name(filepath):
200
+ # Extract filename without extension
201
+ collection_name = Path(filepath).stem
202
+ # Fix potential issues from naming convention
203
+ ## Remove space
204
+ collection_name = collection_name.replace(" ", "-")
205
+ ## ASCII transliterations of Unicode text
206
+ collection_name = unidecode(collection_name)
207
+ ## Remove special characters
208
+ # collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
209
+ collection_name = re.sub("[^A-Za-z0-9]+", "-", collection_name)
210
+ ## Limit length to 50 characters
211
+ collection_name = collection_name[:50]
212
+ ## Minimum length of 3 characters
213
+ if len(collection_name) < 3:
214
+ collection_name = collection_name + "xyz"
215
+ ## Enforce start and end as alphanumeric character
216
+ if not collection_name[0].isalnum():
217
+ collection_name = "A" + collection_name[1:]
218
+ if not collection_name[-1].isalnum():
219
+ collection_name = collection_name[:-1] + "Z"
220
+ print("Filepath: ", filepath)
221
+ print("Collection name: ", collection_name)
222
+ return collection_name
223
+
224
+
225
+ # Initialize database
226
+ def initialize_database(
227
+ list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
228
+ ):
229
+ # Create list of documents (when valid)
230
+ list_file_path = [x.name for x in list_file_obj if x is not None]
231
+ # Create collection_name for vector database
232
+ progress(0.1, desc="Creating collection name...")
233
+ collection_name = create_collection_name(list_file_path[0])
234
+ progress(0.25, desc="Loading document...")
235
+ # Load document and create splits
236
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
237
+ # Create or load vector database
238
+ progress(0.5, desc="Generating vector database...")
239
+ # global vector_db
240
+ vector_db = create_db(doc_splits, collection_name)
241
+ progress(0.9, desc="Done!")
242
+ return vector_db, collection_name, "Complete!"
243
+
244
+
245
+ def initialize_LLM(
246
+ llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
247
+ ):
248
+ # print("llm_option",llm_option)
249
+ llm_name = list_llm[llm_option]
250
+ print("llm_name: ", llm_name)
251
+ qa_chain = initialize_llmchain(
252
+ llm_name, llm_temperature, max_tokens, top_k, vector_db, progress
253
+ )
254
+ return qa_chain, "Complete!"
255
+
256
+
257
+ def format_chat_history(message, chat_history):
258
+ formatted_chat_history = []
259
+ for user_message, bot_message in chat_history:
260
+ formatted_chat_history.append(f"User: {user_message}")
261
+ formatted_chat_history.append(f"Assistant: {bot_message}")
262
+ return formatted_chat_history
263
+
264
+
265
+ def conversation(qa_chain, message, history):
266
+ formatted_chat_history = format_chat_history(message, history)
267
+ # print("formatted_chat_history",formatted_chat_history)
268
+
269
+ # Generate response using QA chain
270
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
271
+ response_answer = response["answer"]
272
+ if response_answer.find("Helpful Answer:") != -1:
273
+ response_answer = response_answer.split("Helpful Answer:")[-1]
274
+ response_sources = response["source_documents"]
275
+ response_source1 = response_sources[0].page_content.strip()
276
+ response_source2 = response_sources[1].page_content.strip()
277
+ response_source3 = response_sources[2].page_content.strip()
278
+ # Langchain sources are zero-based
279
+ response_source1_page = response_sources[0].metadata["page"] + 1
280
+ response_source2_page = response_sources[1].metadata["page"] + 1
281
+ response_source3_page = response_sources[2].metadata["page"] + 1
282
+ # print ('chat response: ', response_answer)
283
+ # print('DB source', response_sources)
284
+
285
+ # Append user message and response to chat history
286
+ new_history = history + [(message, response_answer)]
287
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
288
+ return (
289
+ qa_chain,
290
+ gr.update(value=""),
291
+ new_history,
292
+ response_source1,
293
+ response_source1_page,
294
+ response_source2,
295
+ response_source2_page,
296
+ response_source3,
297
+ response_source3_page,
298
+ )
299
+
300
+
301
+ def upload_file(file_obj):
302
+ list_file_path = []
303
+ for idx, file in enumerate(file_obj):
304
+ file_path = file_obj.name
305
+ list_file_path.append(file_path)
306
+ # print(file_path)
307
+ # initialize_database(file_path, progress)
308
+ return list_file_path
309
+
310
+
311
+ def demo():
312
+ with gr.Blocks(theme="base") as demo:
313
+ vector_db = gr.State()
314
+ qa_chain = gr.State()
315
+ collection_name = gr.State()
316
+
317
+ gr.Markdown(
318
+ """<center><h2>PDF-based chatbot</center></h2>
319
+ <h3>Ask any questions about your PDF documents</h3>"""
320
+ )
321
+ gr.Markdown(
322
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
323
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
324
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
325
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
326
+ """
327
+ )
328
+
329
+ with gr.Tab("Step 1 - Upload PDF"):
330
+ with gr.Row():
331
+ document = gr.Files(
332
+ height=100,
333
+ file_count="multiple",
334
+ file_types=["pdf"],
335
+ interactive=True,
336
+ label="Upload your PDF documents (single or multiple)",
337
+ )
338
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
339
+
340
+ with gr.Tab("Step 2 - Process document"):
341
+ with gr.Row():
342
+ db_btn = gr.Radio(
343
+ ["ChromaDB"],
344
+ label="Vector database type",
345
+ value="ChromaDB",
346
+ type="index",
347
+ info="Choose your vector database",
348
+ )
349
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
350
+ with gr.Row():
351
+ slider_chunk_size = gr.Slider(
352
+ minimum=100,
353
+ maximum=1000,
354
+ value=600,
355
+ step=20,
356
+ label="Chunk size",
357
+ info="Chunk size",
358
+ interactive=True,
359
+ )
360
+ with gr.Row():
361
+ slider_chunk_overlap = gr.Slider(
362
+ minimum=10,
363
+ maximum=200,
364
+ value=40,
365
+ step=10,
366
+ label="Chunk overlap",
367
+ info="Chunk overlap",
368
+ interactive=True,
369
+ )
370
+ with gr.Row():
371
+ db_progress = gr.Textbox(
372
+ label="Vector database initialization", value="None"
373
+ )
374
+ with gr.Row():
375
+ db_btn = gr.Button("Generate vector database")
376
+
377
+ with gr.Tab("Step 3 - Initialize QA chain"):
378
+ with gr.Row():
379
+ llm_btn = gr.Radio(
380
+ list_llm_simple,
381
+ label="LLM models",
382
+ value=list_llm_simple[0],
383
+ type="index",
384
+ info="Choose your LLM model",
385
+ )
386
+ with gr.Accordion("Advanced options - LLM model", open=False):
387
+ with gr.Row():
388
+ slider_temperature = gr.Slider(
389
+ minimum=0.01,
390
+ maximum=1.0,
391
+ value=0.7,
392
+ step=0.1,
393
+ label="Temperature",
394
+ info="Model temperature",
395
+ interactive=True,
396
+ )
397
+ with gr.Row():
398
+ slider_maxtokens = gr.Slider(
399
+ minimum=224,
400
+ maximum=4096,
401
+ value=1024,
402
+ step=32,
403
+ label="Max Tokens",
404
+ info="Model max tokens",
405
+ interactive=True,
406
+ )
407
+ with gr.Row():
408
+ slider_topk = gr.Slider(
409
+ minimum=1,
410
+ maximum=10,
411
+ value=3,
412
+ step=1,
413
+ label="top-k samples",
414
+ info="Model top-k samples",
415
+ interactive=True,
416
+ )
417
+ with gr.Row():
418
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
419
+ with gr.Row():
420
+ qachain_btn = gr.Button("Initialize Question Answering chain")
421
+
422
+ with gr.Tab("Step 4 - Chatbot"):
423
+ chatbot = gr.Chatbot(height=300)
424
+ with gr.Accordion("Advanced - Document references", open=False):
425
+ with gr.Row():
426
+ doc_source1 = gr.Textbox(
427
+ label="Reference 1", lines=2, container=True, scale=20
428
+ )
429
+ source1_page = gr.Number(label="Page", scale=1)
430
+ with gr.Row():
431
+ doc_source2 = gr.Textbox(
432
+ label="Reference 2", lines=2, container=True, scale=20
433
+ )
434
+ source2_page = gr.Number(label="Page", scale=1)
435
+ with gr.Row():
436
+ doc_source3 = gr.Textbox(
437
+ label="Reference 3", lines=2, container=True, scale=20
438
+ )
439
+ source3_page = gr.Number(label="Page", scale=1)
440
+ with gr.Row():
441
+ msg = gr.Textbox(
442
+ placeholder="Type message (e.g. 'What is this document about?')",
443
+ container=True,
444
+ )
445
+ with gr.Row():
446
+ submit_btn = gr.Button("Submit message")
447
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
448
+
449
+ # Preprocessing events
450
+ # upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
451
+ db_btn.click(
452
+ initialize_database,
453
+ inputs=[document, slider_chunk_size, slider_chunk_overlap],
454
+ outputs=[vector_db, collection_name, db_progress],
455
+ )
456
+ qachain_btn.click(
457
+ initialize_LLM,
458
+ inputs=[
459
+ llm_btn,
460
+ slider_temperature,
461
+ slider_maxtokens,
462
+ slider_topk,
463
+ vector_db,
464
+ ],
465
+ outputs=[qa_chain, llm_progress],
466
+ ).then(
467
+ lambda: [None, "", 0, "", 0, "", 0],
468
+ inputs=None,
469
+ outputs=[
470
+ chatbot,
471
+ doc_source1,
472
+ source1_page,
473
+ doc_source2,
474
+ source2_page,
475
+ doc_source3,
476
+ source3_page,
477
+ ],
478
+ queue=False,
479
+ )
480
+
481
+ # Chatbot events
482
+ msg.submit(
483
+ conversation,
484
+ inputs=[qa_chain, msg, chatbot],
485
+ outputs=[
486
+ qa_chain,
487
+ msg,
488
+ chatbot,
489
+ doc_source1,
490
+ source1_page,
491
+ doc_source2,
492
+ source2_page,
493
+ doc_source3,
494
+ source3_page,
495
+ ],
496
+ queue=False,
497
+ )
498
+ submit_btn.click(
499
+ conversation,
500
+ inputs=[qa_chain, msg, chatbot],
501
+ outputs=[
502
+ qa_chain,
503
+ msg,
504
+ chatbot,
505
+ doc_source1,
506
+ source1_page,
507
+ doc_source2,
508
+ source2_page,
509
+ doc_source3,
510
+ source3_page,
511
+ ],
512
+ queue=False,
513
+ )
514
+ clear_btn.click(
515
+ lambda: [None, "", 0, "", 0, "", 0],
516
+ inputs=None,
517
+ outputs=[
518
+ chatbot,
519
+ doc_source1,
520
+ source1_page,
521
+ doc_source2,
522
+ source2_page,
523
+ doc_source3,
524
+ source3_page,
525
+ ],
526
+ queue=False,
527
+ )
528
+ demo.queue().launch(debug=True)
529
+
530
+
531
+ if __name__ == "__main__":
532
+ demo()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ tqdm
6
+ accelerate
7
+ pypdf
8
+ chromadb
9
+ unidecode