ADKU commited on
Commit
d8a8174
·
verified ·
1 Parent(s): cf9155f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -200
app.py CHANGED
@@ -8,8 +8,6 @@ import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
  import google.generativeai as genai
10
  import logging
11
- from PyPDF2 import PdfReader
12
- import io
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
@@ -66,8 +64,8 @@ except Exception as e:
66
  logger.error(f"Model loading failed: {e}")
67
  raise
68
 
69
- # Generate SciBERT embeddings (optimized with larger batch size)
70
- def generate_embeddings_sci_bert(texts, batch_size=64):
71
  try:
72
  all_embeddings = []
73
  for i in range(0, len(texts), batch_size):
@@ -96,7 +94,7 @@ except Exception as e:
96
  logger.error(f"FAISS index creation failed: {e}")
97
  raise
98
 
99
- # Hybrid search function (unchanged from original)
100
  def get_relevant_papers(query):
101
  if not query.strip():
102
  return [], "Please enter a search query."
@@ -108,237 +106,147 @@ def get_relevant_papers(query):
108
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
109
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
110
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
 
111
  papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
112
  return papers, ranked_results[:5], "Search completed."
113
  except Exception as e:
114
  logger.error(f"Search failed: {e}")
115
  return [], [], "Search failed. Please try again."
116
 
117
- # Process uploaded PDF for RAG
118
- def process_uploaded_pdf(file):
119
- try:
120
- pdf_reader = PdfReader(file)
121
- text = ""
122
- for page in pdf_reader.pages:
123
- text += page.extract_text() or ""
124
- cleaned_text = clean_text(text)
125
- chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)]
126
- embeddings = generate_embeddings_sci_bert(chunks)
127
- faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
128
- faiss_index.add(embeddings.astype(np.float32))
129
- tokenized_chunks = [chunk.split() for chunk in chunks]
130
- bm25_rag = BM25Okapi(tokenized_chunks)
131
- return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
132
- except Exception as e:
133
- logger.error(f"PDF processing failed: {e}")
134
- return None, "Failed to process document"
135
-
136
- # Hybrid search for RAG
137
- def get_relevant_chunks(query, uploaded_doc):
138
- if not query.strip():
139
- return [], "Please enter a question."
140
- try:
141
- query_embedding = generate_embeddings_sci_bert([query])
142
- distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
143
- bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
144
- combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
145
- ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
146
- return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
147
- except Exception as e:
148
- logger.error(f"RAG retrieval failed: {e}")
149
- return [], "Retrieval failed."
150
-
151
- # Unified QA function (updated for messages format)
152
- def answer_question(mode, selected_index, question, history, uploaded_doc=None):
153
  if not question.strip():
154
- return history + [{"role": "user", "content": question}, {"role": "assistant", "content": "Please ask a question!"}], history
155
  if question.lower() in ["exit", "done"]:
156
- return history + [{"role": "user", "content": "Conversation ended."}, {"role": "assistant", "content": "Start a new conversation!"}], []
157
 
158
  try:
159
- if mode == "research":
160
- if selected_index is None:
161
- return history + [{"role": "user", "content": question}, {"role": "assistant", "content": "Please select a paper first!"}], history
162
- paper_data = df.iloc[selected_index]
163
- title = paper_data["title"]
164
- abstract = paper_data["abstract"]
165
- authors = ", ".join(paper_data["authors"])
166
- doi = paper_data["doi"]
167
- prompt = (
168
- "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
169
- "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
170
- "When asked about tech stacks or methods, follow these guidelines:\n"
171
- "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
172
- "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
173
- "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
174
- "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
175
- "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
176
- "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
177
- "Here’s the paper:\n"
178
- f"Title: {title}\n"
179
- f"Authors: {authors}\n"
180
- f"Abstract: {abstract}\n"
181
- f"DOI: {doi}\n\n"
182
- )
183
- if history:
184
- prompt += "Previous conversation (use for context):\n"
185
- for msg in history[-2:]:
186
- prompt += f"User: {msg['content']}\n" if msg["role"] == "user" else f"Assistant: {msg['content']}\n"
187
- prompt += f"Now, answer this question: {question}"
188
- model = genai.GenerativeModel("gemini-1.5-flash")
189
- response = model.generate_content(prompt)
190
- answer = response.text.strip()
191
- if not answer or len(answer) < 15:
192
- answer = (
193
- "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
194
- "- Python: Core language for ML/DL.\n"
195
- "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
196
- "- Scikit-learn: For traditional ML algorithms.\n"
197
- "- Pandas/NumPy: For data handling and preprocessing."
198
- )
199
-
200
- elif mode == "rag":
201
- if uploaded_doc is None:
202
- return history + [{"role": "user", "content": question}, {"role": "assistant", "content": "Please upload a document first!"}], history
203
- relevant_chunks, _ = get_relevant_chunks(question, uploaded_doc)
204
- context = "\n".join(relevant_chunks)
205
- prompt = (
206
- "You are an expert AI assistant specializing in answering questions based on uploaded documents. "
207
- "Provide concise, accurate answers based on the following document content:\n"
208
- f"Content: {context}\n\n"
209
  )
210
- if history:
211
- prompt += "Previous conversation (use for context):\n"
212
- for msg in history[-2:]:
213
- prompt += f"User: {msg['content']}\n" if msg["role"] == "user" else f"Assistant: {msg['content']}\n"
214
- prompt += f"Now, answer this question: {question}"
215
- model = genai.GenerativeModel("gemini-1.5-flash")
216
- response = model.generate_content(prompt)
217
- answer = response.text.strip()
218
-
219
- else: # general mode
220
- prompt = (
221
- "You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
222
- )
223
- if history:
224
- prompt += "Previous conversation (use for context):\n"
225
- for msg in history[-2:]:
226
- prompt += f"User: {msg['content']}\n" if msg["role"] == "user" else f"Assistant: {msg['content']}\n"
227
- prompt += f"Question: {question}"
228
- model = genai.GenerativeModel("gemini-1.5-flash")
229
- response = model.generate_content(prompt)
230
- answer = response.text.strip()
231
-
232
- history.append({"role": "user", "content": question})
233
- history.append({"role": "assistant", "content": answer})
234
  return history, history
235
  except Exception as e:
236
  logger.error(f"QA failed: {e}")
237
- history.append({"role": "user", "content": question})
238
- history.append({"role": "assistant", "content": "Sorry, I couldn’t process that. Try again!"})
239
  return history, history
240
 
241
  # Gradio UI
242
  with gr.Blocks(
243
  css="""
244
- .chatbot {height: 500px; overflow-y: auto; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
245
- .sidebar {width: 350px; padding: 15px; background: #f8f9fa; border-radius: 10px;}
246
- #main {display: flex; flex-direction: row; gap: 20px; padding: 20px;}
247
- .tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
248
- .gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
249
- .gr-button:hover {background: #0056b3;}
250
- h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
251
  """,
252
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
253
  ) as demo:
254
- gr.Markdown("# Triad: ResearchGPT, RAG, & General Chat")
255
-
256
  with gr.Row(elem_id="main"):
257
- # Sidebar
258
- with gr.Column(scale=1, min_width=350, elem_classes="sidebar"):
259
- mode_tabs = gr.Tabs()
260
- with mode_tabs:
261
- # Research Mode
262
- with gr.TabItem("Research Mode"):
263
- gr.Markdown("### Search Papers")
264
- query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
265
- search_btn = gr.Button("Search")
266
- paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
267
- search_status = gr.Textbox(label="Search Status", interactive=False)
268
- paper_choices_state = gr.State([])
269
- paper_indices_state = gr.State([])
270
-
271
- search_btn.click(
272
- fn=get_relevant_papers,
273
- inputs=query_input,
274
- outputs=[paper_choices_state, paper_indices_state, search_status]
275
- ).then(
276
- fn=lambda choices: gr.update(choices=choices, value=None),
277
- inputs=paper_choices_state,
278
- outputs=paper_dropdown
279
- )
280
-
281
- # RAG Mode
282
- with gr.TabItem("RAG Mode"):
283
- gr.Markdown("### Upload Document")
284
- file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
285
- upload_status = gr.Textbox(label="Upload Status", interactive=False)
286
- uploaded_doc_state = gr.State(None)
287
- file_upload.change(
288
- fn=process_uploaded_pdf,
289
- inputs=file_upload,
290
- outputs=[uploaded_doc_state, upload_status]
291
- )
292
-
293
- # General Mode
294
- with gr.TabItem("General Chat"):
295
- gr.Markdown("Ask anything, powered by Gemini!")
296
 
 
 
 
 
 
 
 
 
 
 
297
  # Main chat area
298
- with gr.Column(scale=3, elem_classes="tab-content"):
299
- gr.Markdown("### Chat Area")
300
- selected_display = gr.Markdown(label="Selected Context", value="Select a mode to begin!")
301
- chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot", type="messages") # Updated to messages format
302
  question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
303
  chat_btn = gr.Button("Send")
304
 
 
305
  history_state = gr.State([])
306
  selected_index_state = gr.State(None)
307
 
308
- def update_display(selected_tab, choice, indices, uploaded_doc):
309
- if selected_tab == "Research Mode" and choice:
310
- index = int(choice.split(".")[0]) - 1
311
- selected_idx = indices[index]
312
- paper = df.iloc[selected_idx]
313
- return f"**{paper['title']}**<br>DOI: [{paper['doi']}](https://doi.org/{paper['doi']})", selected_idx
314
- elif selected_tab == "RAG Mode" and uploaded_doc:
315
- return "Uploaded Document Ready", None
316
- elif selected_tab == "General Chat":
317
- return "General Chat Mode", None
318
- return "Select a mode to begin!", None
319
 
320
- mode_tabs.select(
321
- fn=lambda selected_tab: update_display(selected_tab, paper_dropdown.value, paper_indices_state.value, uploaded_doc_state.value),
322
- inputs=[mode_tabs],
323
- outputs=[selected_display, selected_index_state]
324
  ).then(
325
  fn=lambda: [],
326
  inputs=None,
327
- outputs=[chatbot, history_state]
328
- )
329
-
330
- paper_dropdown.change(
331
- fn=update_display,
332
- inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
333
- outputs=[selected_display, selected_index_state]
334
  )
335
-
 
336
  chat_btn.click(
337
- fn=lambda mode, idx, q, hist, doc: answer_question(
338
- "research" if mode == "Research Mode" else "rag" if mode == "RAG Mode" else "general",
339
- idx, q, hist, doc
340
- ),
341
- inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
342
  outputs=[chatbot, history_state]
343
  ).then(
344
  fn=lambda: "",
 
8
  from transformers import AutoTokenizer, AutoModel
9
  import google.generativeai as genai
10
  import logging
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
 
64
  logger.error(f"Model loading failed: {e}")
65
  raise
66
 
67
+ # Generate SciBERT embeddings
68
+ def generate_embeddings_sci_bert(texts, batch_size=32):
69
  try:
70
  all_embeddings = []
71
  for i in range(0, len(texts), batch_size):
 
94
  logger.error(f"FAISS index creation failed: {e}")
95
  raise
96
 
97
+ # Hybrid search function (return indices instead of truncated strings)
98
  def get_relevant_papers(query):
99
  if not query.strip():
100
  return [], "Please enter a search query."
 
106
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
107
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
108
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
109
+ # Return formatted strings for dropdown and indices for full data
110
  papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
111
  return papers, ranked_results[:5], "Search completed."
112
  except Exception as e:
113
  logger.error(f"Search failed: {e}")
114
  return [], [], "Search failed. Please try again."
115
 
116
+ # Gemini API QA function with full context
117
+ def answer_question(selected_index, question, history):
118
+ if selected_index is None:
119
+ return [(question, "Please select a paper first!")], history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if not question.strip():
121
+ return [(question, "Please ask a question!")], history
122
  if question.lower() in ["exit", "done"]:
123
+ return [("Conversation ended.", "Select a new paper or search again!")], []
124
 
125
  try:
126
+ # Get full paper data from DataFrame using index
127
+ paper_data = df.iloc[selected_index]
128
+ title = paper_data["title"]
129
+ abstract = paper_data["abstract"] # Full abstract, not truncated
130
+ authors = ", ".join(paper_data["authors"])
131
+ doi = paper_data["doi"]
132
+
133
+ # Build prompt with all fields
134
+ prompt = (
135
+ "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
136
+ "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
137
+ "When asked about tech stacks or methods, follow these guidelines:\n"
138
+ "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
139
+ "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
140
+ "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
141
+ "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
142
+ "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
143
+ "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
144
+ "Here’s the paper:\n"
145
+ f"Title: {title}\n"
146
+ f"Authors: {authors}\n"
147
+ f"Abstract: {abstract}\n"
148
+ f"DOI: {doi}\n\n"
149
+ )
150
+
151
+ # Add history if present
152
+ if history:
153
+ prompt += "Previous conversation (use for context):\n"
154
+ for user_q, bot_a in history[-2:]:
155
+ prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
156
+
157
+ prompt += f"Now, answer this question: {question}"
158
+
159
+ logger.info(f"Prompt sent to Gemini API: {prompt[:200]}...")
160
+
161
+ # Call Gemini API (Gemini 1.5 Flash)
162
+ model = genai.GenerativeModel("gemini-1.5-flash")
163
+ response = model.generate_content(prompt)
164
+ answer = response.text.strip()
165
+
166
+ # Fallback for poor responses
167
+ if not answer or len(answer) < 15:
168
+ answer = (
169
+ "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
170
+ "- Python: Core language for ML/DL.\n"
171
+ "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
172
+ "- Scikit-learn: For traditional ML algorithms.\n"
173
+ "- Pandas/NumPy: For data handling and preprocessing."
 
 
174
  )
175
+
176
+ history.append((question, answer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  return history, history
178
  except Exception as e:
179
  logger.error(f"QA failed: {e}")
180
+ history.append((question, "Sorry, I couldn’t process that. Try again!"))
 
181
  return history, history
182
 
183
  # Gradio UI
184
  with gr.Blocks(
185
  css="""
186
+ .chatbot {height: 600px; overflow-y: auto;}
187
+ .sidebar {width: 300px;}
188
+ #main {display: flex; flex-direction: row;}
 
 
 
 
189
  """,
190
+ theme=gr.themes.Default(primary_hue="blue")
191
  ) as demo:
192
+ gr.Markdown("# ResearchGPT - Paper Search & Chat")
 
193
  with gr.Row(elem_id="main"):
194
+ # Sidebar for search
195
+ with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
196
+ gr.Markdown("### Search Papers")
197
+ query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
198
+ search_btn = gr.Button("Search")
199
+ paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
200
+ search_status = gr.Textbox(label="Search Status", interactive=False)
201
+
202
+ # States to store paper choices and indices
203
+ paper_choices_state = gr.State([])
204
+ paper_indices_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ search_btn.click(
207
+ fn=get_relevant_papers,
208
+ inputs=query_input,
209
+ outputs=[paper_choices_state, paper_indices_state, search_status]
210
+ ).then(
211
+ fn=lambda choices: gr.update(choices=choices, value=None),
212
+ inputs=paper_choices_state,
213
+ outputs=paper_dropdown
214
+ )
215
+
216
  # Main chat area
217
+ with gr.Column(scale=3):
218
+ gr.Markdown("### Chat with Selected Paper")
219
+ selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
220
+ chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
221
  question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
222
  chat_btn = gr.Button("Send")
223
 
224
+ # State to store conversation history and selected index
225
  history_state = gr.State([])
226
  selected_index_state = gr.State(None)
227
 
228
+ # Update selected paper and index
229
+ def update_selected_paper(choice, indices):
230
+ if choice is None:
231
+ return "", None
232
+ index = int(choice.split(".")[0]) - 1 # Extract rank (e.g., "1." -> 0)
233
+ selected_idx = indices[index]
234
+ return choice, selected_idx
 
 
 
 
235
 
236
+ paper_dropdown.change(
237
+ fn=update_selected_paper,
238
+ inputs=[paper_dropdown, paper_indices_state],
239
+ outputs=[selected_paper, selected_index_state]
240
  ).then(
241
  fn=lambda: [],
242
  inputs=None,
243
+ outputs=chatbot
 
 
 
 
 
 
244
  )
245
+
246
+ # Handle chat
247
  chat_btn.click(
248
+ fn=answer_question,
249
+ inputs=[selected_index_state, question_input, history_state],
 
 
 
250
  outputs=[chatbot, history_state]
251
  ).then(
252
  fn=lambda: "",