Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- README.md +6 -8
- app.py +199 -48
- gitattributes +35 -0
- requirements.txt +9 -1
README.md
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
-
short_description: This is me...
|
12 |
---
|
13 |
|
14 |
-
|
|
|
1 |
---
|
2 |
+
title: RAG PDF Chatbot
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.31.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,64 +1,215 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
temperature=temperature,
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
demo = gr.ChatInterface(
|
47 |
-
respond,
|
48 |
-
additional_inputs=[
|
49 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
52 |
-
gr.Slider(
|
53 |
-
minimum=0.1,
|
54 |
-
maximum=1.0,
|
55 |
-
value=0.95,
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
-
],
|
60 |
-
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
3 |
from huggingface_hub import InferenceClient
|
4 |
|
5 |
+
# --- LangChain / RAG Imports ---
|
6 |
+
from langchain_community.vectorstores import FAISS
|
7 |
+
from langchain_community.document_loaders import PyPDFLoader
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
+
from langchain.chains import ConversationalRetrievalChain
|
10 |
+
from langchain.memory import ConversationBufferMemory
|
11 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
12 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
13 |
|
14 |
+
# Global InferenceClient for plain chat (streaming)
|
15 |
+
client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
|
16 |
|
17 |
+
# ============================================================================
|
18 |
+
# PDF Processing & RAG Chain Functions
|
19 |
+
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def load_doc(list_file_path):
|
22 |
+
"""Load and split PDF documents into chunks."""
|
23 |
+
loaders = [PyPDFLoader(x) for x in list_file_path]
|
24 |
+
pages = []
|
25 |
+
for loader in loaders:
|
26 |
+
pages.extend(loader.load())
|
27 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
28 |
+
chunk_size=1024,
|
29 |
+
chunk_overlap=64
|
30 |
+
)
|
31 |
+
doc_splits = text_splitter.split_documents(pages)
|
32 |
+
return doc_splits
|
33 |
|
34 |
+
def create_db(splits):
|
35 |
+
"""Create a vector database from document splits."""
|
36 |
+
# Note: HuggingFaceEmbeddings is deprecated. You may consider using the new package.
|
37 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
38 |
+
vectordb = FAISS.from_documents(splits, embeddings)
|
39 |
+
return vectordb
|
40 |
|
41 |
+
def initialize_database(file_objs):
|
42 |
+
"""
|
43 |
+
Process uploaded PDF files, create document splits and a vector database.
|
44 |
+
Expects file objects from gr.Files.
|
45 |
+
"""
|
46 |
+
# Each file object's .name attribute holds the file path.
|
47 |
+
list_file_path = [file_obj.name for file_obj in file_objs if file_obj is not None]
|
48 |
+
if not list_file_path:
|
49 |
+
return None, "No files uploaded."
|
50 |
+
doc_splits = load_doc(list_file_path)
|
51 |
+
vector_db = create_db(doc_splits)
|
52 |
+
return vector_db, "Database created!"
|
53 |
|
54 |
+
def initialize_qa_chain(temperature, max_tokens, top_k, vector_db):
|
55 |
+
"""
|
56 |
+
Initialize the retrieval-augmented QA chain using your chat model.
|
57 |
+
An explicit task parameter is passed to avoid the "Task unknown" error.
|
58 |
+
"""
|
59 |
+
if vector_db is None:
|
60 |
+
return None, "No vector database available. Please create one first."
|
61 |
+
|
62 |
+
# Explicitly set the task to "text-generation" to avoid the error.
|
63 |
+
llm = HuggingFaceEndpoint(
|
64 |
+
repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
65 |
+
huggingfacehub_api_token=os.getenv("HF_TOKEN"), # ensure HF_TOKEN is set
|
66 |
temperature=temperature,
|
67 |
+
max_new_tokens=max_tokens,
|
68 |
+
top_k=top_k,
|
69 |
+
task="text-generation"
|
70 |
+
)
|
71 |
+
|
72 |
+
memory = ConversationBufferMemory(
|
73 |
+
memory_key="chat_history",
|
74 |
+
output_key='answer',
|
75 |
+
return_messages=True
|
76 |
+
)
|
77 |
+
|
78 |
+
retriever = vector_db.as_retriever()
|
79 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
80 |
+
llm,
|
81 |
+
retriever=retriever,
|
82 |
+
chain_type="stuff",
|
83 |
+
memory=memory,
|
84 |
+
return_source_documents=True,
|
85 |
+
verbose=False,
|
86 |
+
)
|
87 |
+
return qa_chain, "QA chain initialized. Chatbot is ready!"
|
88 |
+
|
89 |
+
def format_chat_history(history):
|
90 |
+
"""Format chat history into a list of strings for the QA chain."""
|
91 |
+
formatted = []
|
92 |
+
for user_msg, bot_msg in history:
|
93 |
+
formatted.append(f"User: {user_msg}")
|
94 |
+
formatted.append(f"Assistant: {bot_msg}")
|
95 |
+
return formatted
|
96 |
+
|
97 |
+
# ============================================================================
|
98 |
+
# Chat Function that switches between plain chat and RAG mode
|
99 |
+
# ============================================================================
|
100 |
|
101 |
+
def chat_respond(message, history, system_message, max_tokens, temperature, top_p, qa_chain):
|
102 |
+
"""
|
103 |
+
If a QA chain (i.e. RAG mode) is initialized, use it to generate a response
|
104 |
+
(including source references). Otherwise, fall back to a plain chat response
|
105 |
+
using a streaming InferenceClient.
|
106 |
+
"""
|
107 |
+
# --- QA Chain (RAG) Mode ---
|
108 |
+
if qa_chain is not None:
|
109 |
+
formatted_history = format_chat_history(history)
|
110 |
+
response = qa_chain.invoke({"question": message, "chat_history": formatted_history})
|
111 |
+
answer = response.get("answer", "")
|
112 |
+
# Optionally include up to 3 source references
|
113 |
+
sources = response.get("source_documents", [])
|
114 |
+
ref_text = ""
|
115 |
+
for i, doc in enumerate(sources[:3]):
|
116 |
+
page = doc.metadata.get("page", "?")
|
117 |
+
ref_text += f"\n\nReference {i+1} (Page {int(page)+1 if page != '?' else '?'}):\n{doc.page_content.strip()}"
|
118 |
+
full_answer = answer + ref_text if ref_text else answer
|
119 |
+
history = history + [(message, full_answer)]
|
120 |
+
return history, qa_chain
|
121 |
|
122 |
+
# --- Plain Chat Mode (fallback) ---
|
123 |
+
else:
|
124 |
+
messages = [{"role": "system", "content": system_message}]
|
125 |
+
for user_msg, bot_msg in history:
|
126 |
+
if user_msg:
|
127 |
+
messages.append({"role": "user", "content": user_msg})
|
128 |
+
if bot_msg:
|
129 |
+
messages.append({"role": "assistant", "content": bot_msg})
|
130 |
+
messages.append({"role": "user", "content": message})
|
131 |
+
|
132 |
+
response = ""
|
133 |
+
result = client.chat_completion(
|
134 |
+
messages,
|
135 |
+
max_tokens=max_tokens,
|
136 |
+
stream=False,
|
137 |
+
temperature=temperature,
|
138 |
+
top_p=top_p,
|
139 |
+
)
|
140 |
+
for token_message in result:
|
141 |
+
token = token_message.choices[0].delta.content
|
142 |
+
response += token
|
143 |
+
|
144 |
+
history = history + [(message, response)]
|
145 |
+
return history, qa_chain
|
146 |
|
147 |
+
# ============================================================================
|
148 |
+
# Gradio Interface Layout
|
149 |
+
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
|
152 |
+
|
153 |
+
# States to hold the vector DB and QA chain
|
154 |
+
vector_db_state = gr.State()
|
155 |
+
qa_chain_state = gr.State(None)
|
156 |
+
|
157 |
+
gr.Markdown("<h1 align='center'>Chat with RAG-enabled PDFs</h1>")
|
158 |
+
gr.Markdown(
|
159 |
+
"Upload PDF files to allow your chatbot to answer questions using information from those documents. "
|
160 |
+
"If no PDFs are uploaded (or the QA chain isn’t initialized), the bot will use plain chat mode."
|
161 |
+
)
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=4):
|
165 |
+
gr.Markdown("### Step 1: Document Upload & RAG Setup")
|
166 |
+
pdf_files = gr.Files(file_types=[".pdf"], label="Upload PDF documents")
|
167 |
+
db_status = gr.Textbox(label="Database status", interactive=False)
|
168 |
+
qa_status = gr.Textbox(label="QA Chain status", interactive=False)
|
169 |
+
with gr.Row():
|
170 |
+
create_db_btn = gr.Button("Create Vector DB")
|
171 |
+
init_qa_btn = gr.Button("Initialize QA Chain")
|
172 |
+
top_k_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k (for RAG)")
|
173 |
+
|
174 |
+
with gr.Column(scale=6):
|
175 |
+
gr.Markdown("### Step 2: Chat Settings & Conversation")
|
176 |
+
system_message_input = gr.Textbox(value="You are a friendly Chatbot.", label="System Message")
|
177 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
|
178 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
|
179 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
|
180 |
+
chatbot = gr.Chatbot(label="Chat", elem_id="chatbot", height=400)
|
181 |
+
with gr.Row():
|
182 |
+
user_input = gr.Textbox(placeholder="Enter your message", label="Your Message")
|
183 |
+
send_btn = gr.Button("Send")
|
184 |
+
|
185 |
+
# -------------------------
|
186 |
+
# Set up button events
|
187 |
+
# -------------------------
|
188 |
+
|
189 |
+
# Create the vector database from uploaded PDFs.
|
190 |
+
create_db_btn.click(
|
191 |
+
fn=initialize_database,
|
192 |
+
inputs=[pdf_files],
|
193 |
+
outputs=[vector_db_state, db_status]
|
194 |
+
)
|
195 |
+
|
196 |
+
# Initialize the QA chain (RAG mode) using the vector DB.
|
197 |
+
init_qa_btn.click(
|
198 |
+
fn=initialize_qa_chain,
|
199 |
+
inputs=[temperature_slider, max_tokens_slider, top_k_slider, vector_db_state],
|
200 |
+
outputs=[qa_chain_state, qa_status]
|
201 |
+
)
|
202 |
+
|
203 |
+
# Chat button: process user input. This function checks if qa_chain is set.
|
204 |
+
send_btn.click(
|
205 |
+
fn=chat_respond,
|
206 |
+
inputs=[user_input, chatbot, system_message_input, max_tokens_slider, temperature_slider, top_p_slider, qa_chain_state],
|
207 |
+
outputs=[chatbot, qa_chain_state]
|
208 |
+
).then(
|
209 |
+
lambda: "", # clear the user input box after sending
|
210 |
+
None,
|
211 |
+
user_input
|
212 |
+
)
|
213 |
|
214 |
if __name__ == "__main__":
|
215 |
+
demo.queue().launch()
|
gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
requirements.txt
CHANGED
@@ -1 +1,9 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
sentence-transformers
|
4 |
+
langchain
|
5 |
+
langchain-community
|
6 |
+
tqdm
|
7 |
+
accelerate
|
8 |
+
pypdf
|
9 |
+
faiss-gpu
|