Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ import gradio as gr
|
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import google.generativeai as genai
|
10 |
import logging
|
|
|
|
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -64,8 +66,8 @@ except Exception as e:
|
|
64 |
logger.error(f"Model loading failed: {e}")
|
65 |
raise
|
66 |
|
67 |
-
# Generate SciBERT embeddings
|
68 |
-
def generate_embeddings_sci_bert(texts, batch_size=
|
69 |
try:
|
70 |
all_embeddings = []
|
71 |
for i in range(0, len(texts), batch_size):
|
@@ -94,7 +96,7 @@ except Exception as e:
|
|
94 |
logger.error(f"FAISS index creation failed: {e}")
|
95 |
raise
|
96 |
|
97 |
-
# Hybrid search function (
|
98 |
def get_relevant_papers(query):
|
99 |
if not query.strip():
|
100 |
return [], "Please enter a search query."
|
@@ -106,73 +108,127 @@ def get_relevant_papers(query):
|
|
106 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
|
107 |
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
|
108 |
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
109 |
-
# Return formatted strings for dropdown and indices for full data
|
110 |
papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
|
111 |
return papers, ranked_results[:5], "Search completed."
|
112 |
except Exception as e:
|
113 |
logger.error(f"Search failed: {e}")
|
114 |
return [], [], "Search failed. Please try again."
|
115 |
|
116 |
-
#
|
117 |
-
def
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
if not question.strip():
|
121 |
return [(question, "Please ask a question!")], history
|
122 |
if question.lower() in ["exit", "done"]:
|
123 |
-
return [("Conversation ended.", "
|
124 |
|
125 |
try:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
"
|
|
|
|
|
174 |
)
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
history.append((question, answer))
|
177 |
return history, history
|
178 |
except Exception as e:
|
@@ -183,70 +239,106 @@ def answer_question(selected_index, question, history):
|
|
183 |
# Gradio UI
|
184 |
with gr.Blocks(
|
185 |
css="""
|
186 |
-
.chatbot {height:
|
187 |
-
.sidebar {width:
|
188 |
-
#main {display: flex; flex-direction: row;}
|
|
|
|
|
|
|
|
|
189 |
""",
|
190 |
-
theme=gr.themes.
|
191 |
) as demo:
|
192 |
-
gr.Markdown("# ResearchGPT
|
|
|
193 |
with gr.Row(elem_id="main"):
|
194 |
-
# Sidebar
|
195 |
-
with gr.Column(scale=1, min_width=
|
196 |
-
gr.
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
search_btn.click(
|
207 |
-
fn=get_relevant_papers,
|
208 |
-
inputs=query_input,
|
209 |
-
outputs=[paper_choices_state, paper_indices_state, search_status]
|
210 |
-
).then(
|
211 |
-
fn=lambda choices: gr.update(choices=choices, value=None),
|
212 |
-
inputs=paper_choices_state,
|
213 |
-
outputs=paper_dropdown
|
214 |
-
)
|
215 |
-
|
216 |
# Main chat area
|
217 |
-
with gr.Column(scale=3):
|
218 |
-
gr.Markdown("### Chat
|
219 |
-
|
220 |
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
|
221 |
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
|
222 |
chat_btn = gr.Button("Send")
|
223 |
|
224 |
-
# State to store conversation history and selected index
|
225 |
history_state = gr.State([])
|
226 |
selected_index_state = gr.State(None)
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
|
237 |
-
fn=
|
238 |
-
inputs=
|
239 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
240 |
).then(
|
241 |
fn=lambda: [],
|
242 |
inputs=None,
|
243 |
-
outputs=chatbot
|
244 |
)
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
247 |
chat_btn.click(
|
248 |
fn=answer_question,
|
249 |
-
inputs=[selected_index_state, question_input, history_state],
|
250 |
outputs=[chatbot, history_state]
|
251 |
).then(
|
252 |
fn=lambda: "",
|
|
|
8 |
from transformers import AutoTokenizer, AutoModel
|
9 |
import google.generativeai as genai
|
10 |
import logging
|
11 |
+
from PyPDF2 import PdfReader
|
12 |
+
import io
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
|
|
66 |
logger.error(f"Model loading failed: {e}")
|
67 |
raise
|
68 |
|
69 |
+
# Generate SciBERT embeddings (optimized with larger batch size)
|
70 |
+
def generate_embeddings_sci_bert(texts, batch_size=64): # Increased batch size for efficiency
|
71 |
try:
|
72 |
all_embeddings = []
|
73 |
for i in range(0, len(texts), batch_size):
|
|
|
96 |
logger.error(f"FAISS index creation failed: {e}")
|
97 |
raise
|
98 |
|
99 |
+
# Hybrid search function (unchanged from original)
|
100 |
def get_relevant_papers(query):
|
101 |
if not query.strip():
|
102 |
return [], "Please enter a search query."
|
|
|
108 |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
|
109 |
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
|
110 |
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
|
|
111 |
papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
|
112 |
return papers, ranked_results[:5], "Search completed."
|
113 |
except Exception as e:
|
114 |
logger.error(f"Search failed: {e}")
|
115 |
return [], [], "Search failed. Please try again."
|
116 |
|
117 |
+
# Process uploaded PDF for RAG
|
118 |
+
def process_uploaded_pdf(file):
|
119 |
+
try:
|
120 |
+
pdf_reader = PdfReader(file)
|
121 |
+
text = ""
|
122 |
+
for page in pdf_reader.pages:
|
123 |
+
text += page.extract_text() or ""
|
124 |
+
cleaned_text = clean_text(text)
|
125 |
+
chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)] # Chunk for efficiency
|
126 |
+
embeddings = generate_embeddings_sci_bert(chunks)
|
127 |
+
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
|
128 |
+
faiss_index.add(embeddings.astype(np.float32))
|
129 |
+
tokenized_chunks = [chunk.split() for chunk in chunks]
|
130 |
+
bm25_rag = BM25Okapi(tokenized_chunks)
|
131 |
+
return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"PDF processing failed: {e}")
|
134 |
+
return None, "Failed to process document"
|
135 |
+
|
136 |
+
# Hybrid search for RAG
|
137 |
+
def get_relevant_chunks(query, uploaded_doc):
|
138 |
+
if not query.strip():
|
139 |
+
return [], "Please enter a question."
|
140 |
+
try:
|
141 |
+
query_embedding = generate_embeddings_sci_bert([query])
|
142 |
+
distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
|
143 |
+
bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
|
144 |
+
combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
|
145 |
+
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
|
146 |
+
return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
|
147 |
+
except Exception as e:
|
148 |
+
logger.error(f"RAG retrieval failed: {e}")
|
149 |
+
return [], "Retrieval failed."
|
150 |
+
|
151 |
+
# Unified QA function
|
152 |
+
def answer_question(mode, selected_index, question, history, uploaded_doc=None):
|
153 |
if not question.strip():
|
154 |
return [(question, "Please ask a question!")], history
|
155 |
if question.lower() in ["exit", "done"]:
|
156 |
+
return [("Conversation ended.", "Start a new conversation!")], []
|
157 |
|
158 |
try:
|
159 |
+
if mode == "research":
|
160 |
+
if selected_index is None:
|
161 |
+
return [(question, "Please select a paper first!")], history
|
162 |
+
paper_data = df.iloc[selected_index]
|
163 |
+
title = paper_data["title"]
|
164 |
+
abstract = paper_data["abstract"]
|
165 |
+
authors = ", ".join(paper_data["authors"])
|
166 |
+
doi = paper_data["doi"]
|
167 |
+
prompt = (
|
168 |
+
"You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
|
169 |
+
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
|
170 |
+
"When asked about tech stacks or methods, follow these guidelines:\n"
|
171 |
+
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
|
172 |
+
"2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
|
173 |
+
"3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
|
174 |
+
"4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
|
175 |
+
"5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
|
176 |
+
"6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
|
177 |
+
"Here’s the paper:\n"
|
178 |
+
f"Title: {title}\n"
|
179 |
+
f"Authors: {authors}\n"
|
180 |
+
f"Abstract: {abstract}\n"
|
181 |
+
f"DOI: {doi}\n\n"
|
182 |
+
)
|
183 |
+
if history:
|
184 |
+
prompt += "Previous conversation (use for context):\n"
|
185 |
+
for user_q, bot_a in history[-2:]:
|
186 |
+
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
187 |
+
prompt += f"Now, answer this question: {question}"
|
188 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
189 |
+
response = model.generate_content(prompt)
|
190 |
+
answer = response.text.strip()
|
191 |
+
if not answer or len(answer) < 15:
|
192 |
+
answer = (
|
193 |
+
"The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
|
194 |
+
"- Python: Core language for ML/DL.\n"
|
195 |
+
"- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
|
196 |
+
"- Scikit-learn: For traditional ML algorithms.\n"
|
197 |
+
"- Pandas/NumPy: For data handling and preprocessing."
|
198 |
+
)
|
199 |
+
|
200 |
+
elif mode == "rag":
|
201 |
+
if uploaded_doc is None:
|
202 |
+
return [(question, "Please upload a document first!")], history
|
203 |
+
relevant_chunks, _ = get_relevant_chunks(question, uploaded_doc)
|
204 |
+
context = "\n".join(relevant_chunks)
|
205 |
+
prompt = (
|
206 |
+
"You are an expert AI assistant specializing in answering questions based on uploaded documents. "
|
207 |
+
"Provide concise, accurate answers based on the following document content:\n"
|
208 |
+
f"Content: {context}\n\n"
|
209 |
)
|
210 |
+
if history:
|
211 |
+
prompt += "Previous conversation (use for context):\n"
|
212 |
+
for user_q, bot_a in history[-2:]:
|
213 |
+
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
214 |
+
prompt += f"Now, answer this question: {question}"
|
215 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
216 |
+
response = model.generate_content(prompt)
|
217 |
+
answer = response.text.strip()
|
218 |
+
|
219 |
+
else: # general mode
|
220 |
+
prompt = (
|
221 |
+
"You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
|
222 |
+
)
|
223 |
+
if history:
|
224 |
+
prompt += "Previous conversation (use for context):\n"
|
225 |
+
for user_q, bot_a in history[-2:]:
|
226 |
+
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
227 |
+
prompt += f"Question: {question}"
|
228 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
229 |
+
response = model.generate_content(prompt)
|
230 |
+
answer = response.text.strip()
|
231 |
+
|
232 |
history.append((question, answer))
|
233 |
return history, history
|
234 |
except Exception as e:
|
|
|
239 |
# Gradio UI
|
240 |
with gr.Blocks(
|
241 |
css="""
|
242 |
+
.chatbot {height: 500px; overflow-y: auto; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
243 |
+
.sidebar {width: 350px; padding: 15px; background: #f8f9fa; border-radius: 10px;}
|
244 |
+
#main {display: flex; flex-direction: row; gap: 20px; padding: 20px;}
|
245 |
+
.tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
246 |
+
.gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
|
247 |
+
.gr-button:hover {background: #0056b3;}
|
248 |
+
h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
|
249 |
""",
|
250 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
|
251 |
) as demo:
|
252 |
+
gr.Markdown("# Triad: ResearchGPT, RAG, & General Chat")
|
253 |
+
|
254 |
with gr.Row(elem_id="main"):
|
255 |
+
# Sidebar
|
256 |
+
with gr.Column(scale=1, min_width=350, elem_classes="sidebar"):
|
257 |
+
mode_tabs = gr.Tabs()
|
258 |
+
with mode_tabs:
|
259 |
+
# Research Mode (unchanged backend)
|
260 |
+
with gr.TabItem("Research Mode"):
|
261 |
+
gr.Markdown("### Search Papers")
|
262 |
+
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
|
263 |
+
search_btn = gr.Button("Search")
|
264 |
+
paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
|
265 |
+
search_status = gr.Textbox(label="Search Status", interactive=False)
|
266 |
+
paper_choices_state = gr.State([])
|
267 |
+
paper_indices_state = gr.State([])
|
268 |
+
|
269 |
+
search_btn.click(
|
270 |
+
fn=get_relevant_papers,
|
271 |
+
inputs=query_input,
|
272 |
+
outputs=[paper_choices_state, paper_indices_state, search_status]
|
273 |
+
).then(
|
274 |
+
fn=lambda choices: gr.update(choices=choices, value=None),
|
275 |
+
inputs=paper_choices_state,
|
276 |
+
outputs=paper_dropdown
|
277 |
+
)
|
278 |
+
|
279 |
+
# RAG Mode
|
280 |
+
with gr.TabItem("RAG Mode"):
|
281 |
+
gr.Markdown("### Upload Document")
|
282 |
+
file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
|
283 |
+
upload_status = gr.Textbox(label="Upload Status", interactive=False)
|
284 |
+
uploaded_doc_state = gr.State(None)
|
285 |
+
file_upload.change(
|
286 |
+
fn=process_uploaded_pdf,
|
287 |
+
inputs=file_upload,
|
288 |
+
outputs=[uploaded_doc_state, upload_status]
|
289 |
+
)
|
290 |
+
|
291 |
+
# General Mode
|
292 |
+
with gr.TabItem("General Chat"):
|
293 |
+
gr.Markdown("Ask anything, powered by Gemini!")
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
# Main chat area
|
296 |
+
with gr.Column(scale=3, elem_classes="tab-content"):
|
297 |
+
gr.Markdown("### Chat Area")
|
298 |
+
selected_display = gr.Markdown(label="Selected Context", value="Select a mode to begin!")
|
299 |
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
|
300 |
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
|
301 |
chat_btn = gr.Button("Send")
|
302 |
|
|
|
303 |
history_state = gr.State([])
|
304 |
selected_index_state = gr.State(None)
|
305 |
|
306 |
+
def update_display(mode, choice, indices, uploaded_doc):
|
307 |
+
if mode == "research" and choice:
|
308 |
+
index = int(choice.split(".")[0]) - 1
|
309 |
+
selected_idx = indices[index]
|
310 |
+
paper = df.iloc[selected_idx]
|
311 |
+
return f"**{paper['title']}**<br>DOI: [{paper['doi']}](https://doi.org/{paper['doi']})", selected_idx
|
312 |
+
elif mode == "rag" and uploaded_doc:
|
313 |
+
return "Uploaded Document Ready", None
|
314 |
+
elif mode == "general":
|
315 |
+
return "General Chat Mode", None
|
316 |
+
return "Select a mode to begin!", None
|
317 |
|
318 |
+
mode_tabs.select(
|
319 |
+
fn=lambda tab: ("research" if tab == "Research Mode" else "rag" if tab == "RAG Mode" else "general"),
|
320 |
+
inputs=None,
|
321 |
+
outputs=None,
|
322 |
+
_js="tab => tab"
|
323 |
+
).then(
|
324 |
+
fn=update_display,
|
325 |
+
inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
|
326 |
+
outputs=[selected_display, selected_index_state]
|
327 |
).then(
|
328 |
fn=lambda: [],
|
329 |
inputs=None,
|
330 |
+
outputs=[chatbot, history_state]
|
331 |
)
|
332 |
+
|
333 |
+
paper_dropdown.change(
|
334 |
+
fn=update_display,
|
335 |
+
inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
|
336 |
+
outputs=[selected_display, selected_index_state]
|
337 |
+
)
|
338 |
+
|
339 |
chat_btn.click(
|
340 |
fn=answer_question,
|
341 |
+
inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
|
342 |
outputs=[chatbot, history_state]
|
343 |
).then(
|
344 |
fn=lambda: "",
|