ChristopherMarais commited on
Commit
3858e0e
·
verified ·
1 Parent(s): 8ceef7b

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +6 -8
  2. app.py +199 -48
  3. gitattributes +35 -0
  4. requirements.txt +9 -1
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: MemoMe
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: This is me...
12
  ---
13
 
14
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
+ title: RAG PDF Chatbot
3
+ emoji: 📚
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.31.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,64 +1,215 @@
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
 
 
 
 
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
 
 
 
 
 
 
 
 
34
  temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
 
5
+ # --- LangChain / RAG Imports ---
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
 
14
+ # Global InferenceClient for plain chat (streaming)
15
+ client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
16
 
17
+ # ============================================================================
18
+ # PDF Processing & RAG Chain Functions
19
+ # ============================================================================
 
 
 
 
 
 
20
 
21
+ def load_doc(list_file_path):
22
+ """Load and split PDF documents into chunks."""
23
+ loaders = [PyPDFLoader(x) for x in list_file_path]
24
+ pages = []
25
+ for loader in loaders:
26
+ pages.extend(loader.load())
27
+ text_splitter = RecursiveCharacterTextSplitter(
28
+ chunk_size=1024,
29
+ chunk_overlap=64
30
+ )
31
+ doc_splits = text_splitter.split_documents(pages)
32
+ return doc_splits
33
 
34
+ def create_db(splits):
35
+ """Create a vector database from document splits."""
36
+ # Note: HuggingFaceEmbeddings is deprecated. You may consider using the new package.
37
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
38
+ vectordb = FAISS.from_documents(splits, embeddings)
39
+ return vectordb
40
 
41
+ def initialize_database(file_objs):
42
+ """
43
+ Process uploaded PDF files, create document splits and a vector database.
44
+ Expects file objects from gr.Files.
45
+ """
46
+ # Each file object's .name attribute holds the file path.
47
+ list_file_path = [file_obj.name for file_obj in file_objs if file_obj is not None]
48
+ if not list_file_path:
49
+ return None, "No files uploaded."
50
+ doc_splits = load_doc(list_file_path)
51
+ vector_db = create_db(doc_splits)
52
+ return vector_db, "Database created!"
53
 
54
+ def initialize_qa_chain(temperature, max_tokens, top_k, vector_db):
55
+ """
56
+ Initialize the retrieval-augmented QA chain using your chat model.
57
+ An explicit task parameter is passed to avoid the "Task unknown" error.
58
+ """
59
+ if vector_db is None:
60
+ return None, "No vector database available. Please create one first."
61
+
62
+ # Explicitly set the task to "text-generation" to avoid the error.
63
+ llm = HuggingFaceEndpoint(
64
+ repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
65
+ huggingfacehub_api_token=os.getenv("HF_TOKEN"), # ensure HF_TOKEN is set
66
  temperature=temperature,
67
+ max_new_tokens=max_tokens,
68
+ top_k=top_k,
69
+ task="text-generation"
70
+ )
71
+
72
+ memory = ConversationBufferMemory(
73
+ memory_key="chat_history",
74
+ output_key='answer',
75
+ return_messages=True
76
+ )
77
+
78
+ retriever = vector_db.as_retriever()
79
+ qa_chain = ConversationalRetrievalChain.from_llm(
80
+ llm,
81
+ retriever=retriever,
82
+ chain_type="stuff",
83
+ memory=memory,
84
+ return_source_documents=True,
85
+ verbose=False,
86
+ )
87
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
88
+
89
+ def format_chat_history(history):
90
+ """Format chat history into a list of strings for the QA chain."""
91
+ formatted = []
92
+ for user_msg, bot_msg in history:
93
+ formatted.append(f"User: {user_msg}")
94
+ formatted.append(f"Assistant: {bot_msg}")
95
+ return formatted
96
+
97
+ # ============================================================================
98
+ # Chat Function that switches between plain chat and RAG mode
99
+ # ============================================================================
100
 
101
+ def chat_respond(message, history, system_message, max_tokens, temperature, top_p, qa_chain):
102
+ """
103
+ If a QA chain (i.e. RAG mode) is initialized, use it to generate a response
104
+ (including source references). Otherwise, fall back to a plain chat response
105
+ using a streaming InferenceClient.
106
+ """
107
+ # --- QA Chain (RAG) Mode ---
108
+ if qa_chain is not None:
109
+ formatted_history = format_chat_history(history)
110
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_history})
111
+ answer = response.get("answer", "")
112
+ # Optionally include up to 3 source references
113
+ sources = response.get("source_documents", [])
114
+ ref_text = ""
115
+ for i, doc in enumerate(sources[:3]):
116
+ page = doc.metadata.get("page", "?")
117
+ ref_text += f"\n\nReference {i+1} (Page {int(page)+1 if page != '?' else '?'}):\n{doc.page_content.strip()}"
118
+ full_answer = answer + ref_text if ref_text else answer
119
+ history = history + [(message, full_answer)]
120
+ return history, qa_chain
121
 
122
+ # --- Plain Chat Mode (fallback) ---
123
+ else:
124
+ messages = [{"role": "system", "content": system_message}]
125
+ for user_msg, bot_msg in history:
126
+ if user_msg:
127
+ messages.append({"role": "user", "content": user_msg})
128
+ if bot_msg:
129
+ messages.append({"role": "assistant", "content": bot_msg})
130
+ messages.append({"role": "user", "content": message})
131
+
132
+ response = ""
133
+ result = client.chat_completion(
134
+ messages,
135
+ max_tokens=max_tokens,
136
+ stream=False,
137
+ temperature=temperature,
138
+ top_p=top_p,
139
+ )
140
+ for token_message in result:
141
+ token = token_message.choices[0].delta.content
142
+ response += token
143
+
144
+ history = history + [(message, response)]
145
+ return history, qa_chain
146
 
147
+ # ============================================================================
148
+ # Gradio Interface Layout
149
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
152
+
153
+ # States to hold the vector DB and QA chain
154
+ vector_db_state = gr.State()
155
+ qa_chain_state = gr.State(None)
156
+
157
+ gr.Markdown("<h1 align='center'>Chat with RAG-enabled PDFs</h1>")
158
+ gr.Markdown(
159
+ "Upload PDF files to allow your chatbot to answer questions using information from those documents. "
160
+ "If no PDFs are uploaded (or the QA chain isn’t initialized), the bot will use plain chat mode."
161
+ )
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=4):
165
+ gr.Markdown("### Step 1: Document Upload & RAG Setup")
166
+ pdf_files = gr.Files(file_types=[".pdf"], label="Upload PDF documents")
167
+ db_status = gr.Textbox(label="Database status", interactive=False)
168
+ qa_status = gr.Textbox(label="QA Chain status", interactive=False)
169
+ with gr.Row():
170
+ create_db_btn = gr.Button("Create Vector DB")
171
+ init_qa_btn = gr.Button("Initialize QA Chain")
172
+ top_k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k (for RAG)")
173
+
174
+ with gr.Column(scale=6):
175
+ gr.Markdown("### Step 2: Chat Settings & Conversation")
176
+ system_message_input = gr.Textbox(value="You are a friendly Chatbot.", label="System Message")
177
+ max_tokens_slider = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
178
+ temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
179
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
180
+ chatbot = gr.Chatbot(label="Chat", elem_id="chatbot", height=400)
181
+ with gr.Row():
182
+ user_input = gr.Textbox(placeholder="Enter your message", label="Your Message")
183
+ send_btn = gr.Button("Send")
184
+
185
+ # -------------------------
186
+ # Set up button events
187
+ # -------------------------
188
+
189
+ # Create the vector database from uploaded PDFs.
190
+ create_db_btn.click(
191
+ fn=initialize_database,
192
+ inputs=[pdf_files],
193
+ outputs=[vector_db_state, db_status]
194
+ )
195
+
196
+ # Initialize the QA chain (RAG mode) using the vector DB.
197
+ init_qa_btn.click(
198
+ fn=initialize_qa_chain,
199
+ inputs=[temperature_slider, max_tokens_slider, top_k_slider, vector_db_state],
200
+ outputs=[qa_chain_state, qa_status]
201
+ )
202
+
203
+ # Chat button: process user input. This function checks if qa_chain is set.
204
+ send_btn.click(
205
+ fn=chat_respond,
206
+ inputs=[user_input, chatbot, system_message_input, max_tokens_slider, temperature_slider, top_p_slider, qa_chain_state],
207
+ outputs=[chatbot, qa_chain_state]
208
+ ).then(
209
+ lambda: "", # clear the user input box after sending
210
+ None,
211
+ user_input
212
+ )
213
 
214
  if __name__ == "__main__":
215
+ demo.queue().launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt CHANGED
@@ -1 +1,9 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ langchain-community
6
+ tqdm
7
+ accelerate
8
+ pypdf
9
+ faiss-gpu