Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,6 @@ import gradio as gr
|
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import google.generativeai as genai
|
10 |
import logging
|
11 |
-
from PyPDF2 import PdfReader
|
12 |
-
import io
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
@@ -66,8 +64,8 @@ except Exception as e:
|
|
66 |
logger.error(f"Model loading failed: {e}")
|
67 |
raise
|
68 |
|
69 |
-
# Generate SciBERT embeddings
|
70 |
-
def generate_embeddings_sci_bert(texts, batch_size=
|
71 |
try:
|
72 |
all_embeddings = []
|
73 |
for i in range(0, len(texts), batch_size):
|
@@ -96,7 +94,7 @@ except Exception as e:
|
|
96 |
logger.error(f"FAISS index creation failed: {e}")
|
97 |
raise
|
98 |
|
99 |
-
# Hybrid search function (
|
100 |
def get_relevant_papers(query):
|
101 |
if not query.strip():
|
102 |
return [], "Please enter a search query."
|
@@ -108,237 +106,147 @@ def get_relevant_papers(query):
|
|
108 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
|
109 |
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
|
110 |
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
|
|
111 |
papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
|
112 |
return papers, ranked_results[:5], "Search completed."
|
113 |
except Exception as e:
|
114 |
logger.error(f"Search failed: {e}")
|
115 |
return [], [], "Search failed. Please try again."
|
116 |
|
117 |
-
#
|
118 |
-
def
|
119 |
-
|
120 |
-
|
121 |
-
text = ""
|
122 |
-
for page in pdf_reader.pages:
|
123 |
-
text += page.extract_text() or ""
|
124 |
-
cleaned_text = clean_text(text)
|
125 |
-
chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)]
|
126 |
-
embeddings = generate_embeddings_sci_bert(chunks)
|
127 |
-
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
|
128 |
-
faiss_index.add(embeddings.astype(np.float32))
|
129 |
-
tokenized_chunks = [chunk.split() for chunk in chunks]
|
130 |
-
bm25_rag = BM25Okapi(tokenized_chunks)
|
131 |
-
return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
|
132 |
-
except Exception as e:
|
133 |
-
logger.error(f"PDF processing failed: {e}")
|
134 |
-
return None, "Failed to process document"
|
135 |
-
|
136 |
-
# Hybrid search for RAG
|
137 |
-
def get_relevant_chunks(query, uploaded_doc):
|
138 |
-
if not query.strip():
|
139 |
-
return [], "Please enter a question."
|
140 |
-
try:
|
141 |
-
query_embedding = generate_embeddings_sci_bert([query])
|
142 |
-
distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
|
143 |
-
bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
|
144 |
-
combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
|
145 |
-
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
146 |
-
return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
|
147 |
-
except Exception as e:
|
148 |
-
logger.error(f"RAG retrieval failed: {e}")
|
149 |
-
return [], "Retrieval failed."
|
150 |
-
|
151 |
-
# Unified QA function (updated for messages format)
|
152 |
-
def answer_question(mode, selected_index, question, history, uploaded_doc=None):
|
153 |
if not question.strip():
|
154 |
-
return
|
155 |
if question.lower() in ["exit", "done"]:
|
156 |
-
return
|
157 |
|
158 |
try:
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
"
|
207 |
-
"Provide concise, accurate answers based on the following document content:\n"
|
208 |
-
f"Content: {context}\n\n"
|
209 |
)
|
210 |
-
|
211 |
-
|
212 |
-
for msg in history[-2:]:
|
213 |
-
prompt += f"User: {msg['content']}\n" if msg["role"] == "user" else f"Assistant: {msg['content']}\n"
|
214 |
-
prompt += f"Now, answer this question: {question}"
|
215 |
-
model = genai.GenerativeModel("gemini-1.5-flash")
|
216 |
-
response = model.generate_content(prompt)
|
217 |
-
answer = response.text.strip()
|
218 |
-
|
219 |
-
else: # general mode
|
220 |
-
prompt = (
|
221 |
-
"You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
|
222 |
-
)
|
223 |
-
if history:
|
224 |
-
prompt += "Previous conversation (use for context):\n"
|
225 |
-
for msg in history[-2:]:
|
226 |
-
prompt += f"User: {msg['content']}\n" if msg["role"] == "user" else f"Assistant: {msg['content']}\n"
|
227 |
-
prompt += f"Question: {question}"
|
228 |
-
model = genai.GenerativeModel("gemini-1.5-flash")
|
229 |
-
response = model.generate_content(prompt)
|
230 |
-
answer = response.text.strip()
|
231 |
-
|
232 |
-
history.append({"role": "user", "content": question})
|
233 |
-
history.append({"role": "assistant", "content": answer})
|
234 |
return history, history
|
235 |
except Exception as e:
|
236 |
logger.error(f"QA failed: {e}")
|
237 |
-
history.append(
|
238 |
-
history.append({"role": "assistant", "content": "Sorry, I couldn’t process that. Try again!"})
|
239 |
return history, history
|
240 |
|
241 |
# Gradio UI
|
242 |
with gr.Blocks(
|
243 |
css="""
|
244 |
-
.chatbot {height:
|
245 |
-
.sidebar {width:
|
246 |
-
#main {display: flex; flex-direction: row;
|
247 |
-
.tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
248 |
-
.gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
|
249 |
-
.gr-button:hover {background: #0056b3;}
|
250 |
-
h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
|
251 |
""",
|
252 |
-
theme=gr.themes.
|
253 |
) as demo:
|
254 |
-
gr.Markdown("#
|
255 |
-
|
256 |
with gr.Row(elem_id="main"):
|
257 |
-
# Sidebar
|
258 |
-
with gr.Column(scale=1, min_width=
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
paper_choices_state = gr.State([])
|
269 |
-
paper_indices_state = gr.State([])
|
270 |
-
|
271 |
-
search_btn.click(
|
272 |
-
fn=get_relevant_papers,
|
273 |
-
inputs=query_input,
|
274 |
-
outputs=[paper_choices_state, paper_indices_state, search_status]
|
275 |
-
).then(
|
276 |
-
fn=lambda choices: gr.update(choices=choices, value=None),
|
277 |
-
inputs=paper_choices_state,
|
278 |
-
outputs=paper_dropdown
|
279 |
-
)
|
280 |
-
|
281 |
-
# RAG Mode
|
282 |
-
with gr.TabItem("RAG Mode"):
|
283 |
-
gr.Markdown("### Upload Document")
|
284 |
-
file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
|
285 |
-
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
286 |
-
uploaded_doc_state = gr.State(None)
|
287 |
-
file_upload.change(
|
288 |
-
fn=process_uploaded_pdf,
|
289 |
-
inputs=file_upload,
|
290 |
-
outputs=[uploaded_doc_state, upload_status]
|
291 |
-
)
|
292 |
-
|
293 |
-
# General Mode
|
294 |
-
with gr.TabItem("General Chat"):
|
295 |
-
gr.Markdown("Ask anything, powered by Gemini!")
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
# Main chat area
|
298 |
-
with gr.Column(scale=3
|
299 |
-
gr.Markdown("### Chat
|
300 |
-
|
301 |
-
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot"
|
302 |
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
|
303 |
chat_btn = gr.Button("Send")
|
304 |
|
|
|
305 |
history_state = gr.State([])
|
306 |
selected_index_state = gr.State(None)
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
return "Uploaded Document Ready", None
|
316 |
-
elif selected_tab == "General Chat":
|
317 |
-
return "General Chat Mode", None
|
318 |
-
return "Select a mode to begin!", None
|
319 |
|
320 |
-
|
321 |
-
fn=
|
322 |
-
inputs=[
|
323 |
-
outputs=[
|
324 |
).then(
|
325 |
fn=lambda: [],
|
326 |
inputs=None,
|
327 |
-
outputs=
|
328 |
-
)
|
329 |
-
|
330 |
-
paper_dropdown.change(
|
331 |
-
fn=update_display,
|
332 |
-
inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
|
333 |
-
outputs=[selected_display, selected_index_state]
|
334 |
)
|
335 |
-
|
|
|
336 |
chat_btn.click(
|
337 |
-
fn=
|
338 |
-
|
339 |
-
idx, q, hist, doc
|
340 |
-
),
|
341 |
-
inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
|
342 |
outputs=[chatbot, history_state]
|
343 |
).then(
|
344 |
fn=lambda: "",
|
|
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import google.generativeai as genai
|
10 |
import logging
|
|
|
|
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
|
|
64 |
logger.error(f"Model loading failed: {e}")
|
65 |
raise
|
66 |
|
67 |
+
# Generate SciBERT embeddings
|
68 |
+
def generate_embeddings_sci_bert(texts, batch_size=32):
|
69 |
try:
|
70 |
all_embeddings = []
|
71 |
for i in range(0, len(texts), batch_size):
|
|
|
94 |
logger.error(f"FAISS index creation failed: {e}")
|
95 |
raise
|
96 |
|
97 |
+
# Hybrid search function (return indices instead of truncated strings)
|
98 |
def get_relevant_papers(query):
|
99 |
if not query.strip():
|
100 |
return [], "Please enter a search query."
|
|
|
106 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
|
107 |
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
|
108 |
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
109 |
+
# Return formatted strings for dropdown and indices for full data
|
110 |
papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
|
111 |
return papers, ranked_results[:5], "Search completed."
|
112 |
except Exception as e:
|
113 |
logger.error(f"Search failed: {e}")
|
114 |
return [], [], "Search failed. Please try again."
|
115 |
|
116 |
+
# Gemini API QA function with full context
|
117 |
+
def answer_question(selected_index, question, history):
|
118 |
+
if selected_index is None:
|
119 |
+
return [(question, "Please select a paper first!")], history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
if not question.strip():
|
121 |
+
return [(question, "Please ask a question!")], history
|
122 |
if question.lower() in ["exit", "done"]:
|
123 |
+
return [("Conversation ended.", "Select a new paper or search again!")], []
|
124 |
|
125 |
try:
|
126 |
+
# Get full paper data from DataFrame using index
|
127 |
+
paper_data = df.iloc[selected_index]
|
128 |
+
title = paper_data["title"]
|
129 |
+
abstract = paper_data["abstract"] # Full abstract, not truncated
|
130 |
+
authors = ", ".join(paper_data["authors"])
|
131 |
+
doi = paper_data["doi"]
|
132 |
+
|
133 |
+
# Build prompt with all fields
|
134 |
+
prompt = (
|
135 |
+
"You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
|
136 |
+
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
|
137 |
+
"When asked about tech stacks or methods, follow these guidelines:\n"
|
138 |
+
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
|
139 |
+
"2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
|
140 |
+
"3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
|
141 |
+
"4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
|
142 |
+
"5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
|
143 |
+
"6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
|
144 |
+
"Here’s the paper:\n"
|
145 |
+
f"Title: {title}\n"
|
146 |
+
f"Authors: {authors}\n"
|
147 |
+
f"Abstract: {abstract}\n"
|
148 |
+
f"DOI: {doi}\n\n"
|
149 |
+
)
|
150 |
+
|
151 |
+
# Add history if present
|
152 |
+
if history:
|
153 |
+
prompt += "Previous conversation (use for context):\n"
|
154 |
+
for user_q, bot_a in history[-2:]:
|
155 |
+
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
156 |
+
|
157 |
+
prompt += f"Now, answer this question: {question}"
|
158 |
+
|
159 |
+
logger.info(f"Prompt sent to Gemini API: {prompt[:200]}...")
|
160 |
+
|
161 |
+
# Call Gemini API (Gemini 1.5 Flash)
|
162 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
163 |
+
response = model.generate_content(prompt)
|
164 |
+
answer = response.text.strip()
|
165 |
+
|
166 |
+
# Fallback for poor responses
|
167 |
+
if not answer or len(answer) < 15:
|
168 |
+
answer = (
|
169 |
+
"The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
|
170 |
+
"- Python: Core language for ML/DL.\n"
|
171 |
+
"- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
|
172 |
+
"- Scikit-learn: For traditional ML algorithms.\n"
|
173 |
+
"- Pandas/NumPy: For data handling and preprocessing."
|
|
|
|
|
174 |
)
|
175 |
+
|
176 |
+
history.append((question, answer))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
return history, history
|
178 |
except Exception as e:
|
179 |
logger.error(f"QA failed: {e}")
|
180 |
+
history.append((question, "Sorry, I couldn’t process that. Try again!"))
|
|
|
181 |
return history, history
|
182 |
|
183 |
# Gradio UI
|
184 |
with gr.Blocks(
|
185 |
css="""
|
186 |
+
.chatbot {height: 600px; overflow-y: auto;}
|
187 |
+
.sidebar {width: 300px;}
|
188 |
+
#main {display: flex; flex-direction: row;}
|
|
|
|
|
|
|
|
|
189 |
""",
|
190 |
+
theme=gr.themes.Default(primary_hue="blue")
|
191 |
) as demo:
|
192 |
+
gr.Markdown("# ResearchGPT - Paper Search & Chat")
|
|
|
193 |
with gr.Row(elem_id="main"):
|
194 |
+
# Sidebar for search
|
195 |
+
with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
|
196 |
+
gr.Markdown("### Search Papers")
|
197 |
+
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
|
198 |
+
search_btn = gr.Button("Search")
|
199 |
+
paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
|
200 |
+
search_status = gr.Textbox(label="Search Status", interactive=False)
|
201 |
+
|
202 |
+
# States to store paper choices and indices
|
203 |
+
paper_choices_state = gr.State([])
|
204 |
+
paper_indices_state = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
+
search_btn.click(
|
207 |
+
fn=get_relevant_papers,
|
208 |
+
inputs=query_input,
|
209 |
+
outputs=[paper_choices_state, paper_indices_state, search_status]
|
210 |
+
).then(
|
211 |
+
fn=lambda choices: gr.update(choices=choices, value=None),
|
212 |
+
inputs=paper_choices_state,
|
213 |
+
outputs=paper_dropdown
|
214 |
+
)
|
215 |
+
|
216 |
# Main chat area
|
217 |
+
with gr.Column(scale=3):
|
218 |
+
gr.Markdown("### Chat with Selected Paper")
|
219 |
+
selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
|
220 |
+
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
|
221 |
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
|
222 |
chat_btn = gr.Button("Send")
|
223 |
|
224 |
+
# State to store conversation history and selected index
|
225 |
history_state = gr.State([])
|
226 |
selected_index_state = gr.State(None)
|
227 |
|
228 |
+
# Update selected paper and index
|
229 |
+
def update_selected_paper(choice, indices):
|
230 |
+
if choice is None:
|
231 |
+
return "", None
|
232 |
+
index = int(choice.split(".")[0]) - 1 # Extract rank (e.g., "1." -> 0)
|
233 |
+
selected_idx = indices[index]
|
234 |
+
return choice, selected_idx
|
|
|
|
|
|
|
|
|
235 |
|
236 |
+
paper_dropdown.change(
|
237 |
+
fn=update_selected_paper,
|
238 |
+
inputs=[paper_dropdown, paper_indices_state],
|
239 |
+
outputs=[selected_paper, selected_index_state]
|
240 |
).then(
|
241 |
fn=lambda: [],
|
242 |
inputs=None,
|
243 |
+
outputs=chatbot
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
)
|
245 |
+
|
246 |
+
# Handle chat
|
247 |
chat_btn.click(
|
248 |
+
fn=answer_question,
|
249 |
+
inputs=[selected_index_state, question_input, history_state],
|
|
|
|
|
|
|
250 |
outputs=[chatbot, history_state]
|
251 |
).then(
|
252 |
fn=lambda: "",
|