Spaces:
Sleeping
Sleeping
File size: 16,549 Bytes
e86b928 d4c27ab cf496f0 d08a770 12e1b40 7ef58b7 12e1b40 b2c1b30 d08a770 9699ac9 e86b928 d08a770 12e1b40 a91f0db 12e1b40 d4c27ab cf496f0 d4c27ab 12e1b40 d4c27ab 9699ac9 d4c27ab cf496f0 12e1b40 d4c27ab d08a770 cf496f0 12e1b40 cf496f0 12e1b40 cf496f0 7ef58b7 12e1b40 43c1491 d4c27ab cf496f0 12e1b40 d4c27ab 7ef58b7 425d4bf d4c27ab 425d4bf 12e1b40 425d4bf 12e1b40 425d4bf 12e1b40 093fd7d 12e1b40 093fd7d cf496f0 7ef58b7 cf496f0 12e1b40 cf496f0 7ef58b7 12e1b40 7ef58b7 ad54e4d 7ef58b7 d08a770 425d4bf 12e1b40 cf496f0 7ef58b7 cf496f0 7ef58b7 cf496f0 7ef58b7 cf496f0 7ef58b7 425d4bf cf496f0 7ef58b7 cf496f0 093fd7d 7ef58b7 cf496f0 7ef58b7 425d4bf 7ef58b7 cf496f0 7ef58b7 cf496f0 7ef58b7 12e1b40 43c1491 cf496f0 12e1b40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import google.generativeai as genai
import logging
from PyPDF2 import PdfReader
import io
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory for Hugging Face models (SciBERT only)
os.environ["HF_HOME"] = "/tmp/huggingface"
# Get Gemini API key from environment variable (stored in Spaces secrets)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.")
raise ValueError("GEMINI_API_KEY is required for Gemini API access.")
genai.configure(api_key=GEMINI_API_KEY)
logger.info("Gemini API configured")
# Load dataset with error handling
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
try:
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
df = pd.read_json(DATASET_PATH)
logger.info("Dataset loaded successfully")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
# Clean text
def clean_text(text):
return text.strip().lower() if isinstance(text, str) else ""
df["cleaned_abstract"] = df["abstract"].apply(clean_text)
# Precompute BM25 Index
try:
tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
bm25 = BM25Okapi(tokenized_corpus)
logger.info("BM25 index created")
except Exception as e:
logger.error(f"BM25 index creation failed: {e}")
raise
# Load SciBERT for embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
try:
sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model.to(device)
sci_bert_model.eval()
logger.info("SciBERT loaded")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise
# Generate SciBERT embeddings (optimized with larger batch size)
def generate_embeddings_sci_bert(texts, batch_size=64): # Increased batch size for efficiency
try:
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = sci_bert_model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.append(embeddings.cpu().numpy())
torch.cuda.empty_cache()
return np.concatenate(all_embeddings, axis=0)
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return np.zeros((len(texts), 768))
# Precompute embeddings and FAISS index
try:
abstracts = df["cleaned_abstract"].tolist()
embeddings = generate_embeddings_sci_bert(abstracts)
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings.astype(np.float32))
logger.info("FAISS index created")
except Exception as e:
logger.error(f"FAISS index creation failed: {e}")
raise
# Hybrid search function (unchanged from original)
def get_relevant_papers(query):
if not query.strip():
return [], "Please enter a search query."
try:
query_embedding = generate_embeddings_sci_bert([query])
distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5)
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
return papers, ranked_results[:5], "Search completed."
except Exception as e:
logger.error(f"Search failed: {e}")
return [], [], "Search failed. Please try again."
# Process uploaded PDF for RAG
def process_uploaded_pdf(file):
try:
pdf_reader = PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() or ""
cleaned_text = clean_text(text)
chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)] # Chunk for efficiency
embeddings = generate_embeddings_sci_bert(chunks)
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
faiss_index.add(embeddings.astype(np.float32))
tokenized_chunks = [chunk.split() for chunk in chunks]
bm25_rag = BM25Okapi(tokenized_chunks)
return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
except Exception as e:
logger.error(f"PDF processing failed: {e}")
return None, "Failed to process document"
# Hybrid search for RAG
def get_relevant_chunks(query, uploaded_doc):
if not query.strip():
return [], "Please enter a question."
try:
query_embedding = generate_embeddings_sci_bert([query])
distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
except Exception as e:
logger.error(f"RAG retrieval failed: {e}")
return [], "Retrieval failed."
# Unified QA function
def answer_question(mode, selected_index, question, history, uploaded_doc=None):
if not question.strip():
return [(question, "Please ask a question!")], history
if question.lower() in ["exit", "done"]:
return [("Conversation ended.", "Start a new conversation!")], []
try:
if mode == "research":
if selected_index is None:
return [(question, "Please select a paper first!")], history
paper_data = df.iloc[selected_index]
title = paper_data["title"]
abstract = paper_data["abstract"]
authors = ", ".join(paper_data["authors"])
doi = paper_data["doi"]
prompt = (
"You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
"When asked about tech stacks or methods, follow these guidelines:\n"
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
"2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
"3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
"4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
"5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
"6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
"Here’s the paper:\n"
f"Title: {title}\n"
f"Authors: {authors}\n"
f"Abstract: {abstract}\n"
f"DOI: {doi}\n\n"
)
if history:
prompt += "Previous conversation (use for context):\n"
for user_q, bot_a in history[-2:]:
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
prompt += f"Now, answer this question: {question}"
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(prompt)
answer = response.text.strip()
if not answer or len(answer) < 15:
answer = (
"The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
"- Python: Core language for ML/DL.\n"
"- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
"- Scikit-learn: For traditional ML algorithms.\n"
"- Pandas/NumPy: For data handling and preprocessing."
)
elif mode == "rag":
if uploaded_doc is None:
return [(question, "Please upload a document first!")], history
relevant_chunks, _ = get_relevant_chunks(question, uploaded_doc)
context = "\n".join(relevant_chunks)
prompt = (
"You are an expert AI assistant specializing in answering questions based on uploaded documents. "
"Provide concise, accurate answers based on the following document content:\n"
f"Content: {context}\n\n"
)
if history:
prompt += "Previous conversation (use for context):\n"
for user_q, bot_a in history[-2:]:
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
prompt += f"Now, answer this question: {question}"
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(prompt)
answer = response.text.strip()
else: # general mode
prompt = (
"You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
)
if history:
prompt += "Previous conversation (use for context):\n"
for user_q, bot_a in history[-2:]:
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
prompt += f"Question: {question}"
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(prompt)
answer = response.text.strip()
history.append((question, answer))
return history, history
except Exception as e:
logger.error(f"QA failed: {e}")
history.append((question, "Sorry, I couldn’t process that. Try again!"))
return history, history
# Gradio UI
with gr.Blocks(
css="""
.chatbot {height: 500px; overflow-y: auto; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
.sidebar {width: 350px; padding: 15px; background: #f8f9fa; border-radius: 10px;}
#main {display: flex; flex-direction: row; gap: 20px; padding: 20px;}
.tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
.gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
.gr-button:hover {background: #0056b3;}
h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
""",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
) as demo:
gr.Markdown("# Triad: ResearchGPT, RAG, & General Chat")
with gr.Row(elem_id="main"):
# Sidebar
with gr.Column(scale=1, min_width=350, elem_classes="sidebar"):
mode_tabs = gr.Tabs()
with mode_tabs:
# Research Mode (unchanged backend)
with gr.TabItem("Research Mode"):
gr.Markdown("### Search Papers")
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
search_btn = gr.Button("Search")
paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
search_status = gr.Textbox(label="Search Status", interactive=False)
paper_choices_state = gr.State([])
paper_indices_state = gr.State([])
search_btn.click(
fn=get_relevant_papers,
inputs=query_input,
outputs=[paper_choices_state, paper_indices_state, search_status]
).then(
fn=lambda choices: gr.update(choices=choices, value=None),
inputs=paper_choices_state,
outputs=paper_dropdown
)
# RAG Mode
with gr.TabItem("RAG Mode"):
gr.Markdown("### Upload Document")
file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
upload_status = gr.Textbox(label="Upload Status", interactive=False)
uploaded_doc_state = gr.State(None)
file_upload.change(
fn=process_uploaded_pdf,
inputs=file_upload,
outputs=[uploaded_doc_state, upload_status]
)
# General Mode
with gr.TabItem("General Chat"):
gr.Markdown("Ask anything, powered by Gemini!")
# Main chat area
with gr.Column(scale=3, elem_classes="tab-content"):
gr.Markdown("### Chat Area")
selected_display = gr.Markdown(label="Selected Context", value="Select a mode to begin!")
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
chat_btn = gr.Button("Send")
history_state = gr.State([])
selected_index_state = gr.State(None)
def update_display(mode, choice, indices, uploaded_doc):
if mode == "research" and choice:
index = int(choice.split(".")[0]) - 1
selected_idx = indices[index]
paper = df.iloc[selected_idx]
return f"**{paper['title']}**<br>DOI: [{paper['doi']}](https://doi.org/{paper['doi']})", selected_idx
elif mode == "rag" and uploaded_doc:
return "Uploaded Document Ready", None
elif mode == "general":
return "General Chat Mode", None
return "Select a mode to begin!", None
mode_tabs.select(
fn=lambda tab: ("research" if tab == "Research Mode" else "rag" if tab == "RAG Mode" else "general"),
inputs=None,
outputs=None,
_js="tab => tab"
).then(
fn=update_display,
inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
outputs=[selected_display, selected_index_state]
).then(
fn=lambda: [],
inputs=None,
outputs=[chatbot, history_state]
)
paper_dropdown.change(
fn=update_display,
inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
outputs=[selected_display, selected_index_state]
)
chat_btn.click(
fn=answer_question,
inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: "",
inputs=None,
outputs=question_input
)
# Launch the app
demo.launch(server_name="0.0.0.0", server_port=7860) |