import os import faiss import numpy as np from rank_bm25 import BM25Okapi import torch import pandas as pd import gradio as gr from transformers import AutoTokenizer, AutoModel import google.generativeai as genai import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set cache directory for Hugging Face models (SciBERT only) os.environ["HF_HOME"] = "/tmp/huggingface" # Get Gemini API key from environment variable (stored in Spaces secrets) GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") if not GEMINI_API_KEY: logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.") raise ValueError("GEMINI_API_KEY is required for Gemini API access.") genai.configure(api_key=GEMINI_API_KEY) logger.info("Gemini API configured") # Load dataset with error handling DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json") try: if not os.path.exists(DATASET_PATH): raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}") df = pd.read_json(DATASET_PATH) logger.info("Dataset loaded successfully") except Exception as e: logger.error(f"Failed to load dataset: {e}") raise # Clean text def clean_text(text): return text.strip().lower() if isinstance(text, str) else "" df["cleaned_abstract"] = df["abstract"].apply(clean_text) # Precompute BM25 Index try: tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]] bm25 = BM25Okapi(tokenized_corpus) logger.info("BM25 index created") except Exception as e: logger.error(f"BM25 index creation failed: {e}") raise # Load SciBERT for embeddings device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") try: sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface") sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface") sci_bert_model.to(device) sci_bert_model.eval() logger.info("SciBERT loaded") except Exception as e: logger.error(f"Model loading failed: {e}") raise # Generate SciBERT embeddings def generate_embeddings_sci_bert(texts, batch_size=32): try: all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512) inputs = {key: val.to(device) for key, val in inputs.items()} with torch.no_grad(): outputs = sci_bert_model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) all_embeddings.append(embeddings.cpu().numpy()) torch.cuda.empty_cache() return np.concatenate(all_embeddings, axis=0) except Exception as e: logger.error(f"Embedding generation failed: {e}") return np.zeros((len(texts), 768)) # Precompute embeddings and FAISS index try: abstracts = df["cleaned_abstract"].tolist() embeddings = generate_embeddings_sci_bert(abstracts) dimension = embeddings.shape[1] faiss_index = faiss.IndexFlatL2(dimension) faiss_index.add(embeddings.astype(np.float32)) logger.info("FAISS index created") except Exception as e: logger.error(f"FAISS index creation failed: {e}") raise # Hybrid search function (return indices instead of truncated strings) def get_relevant_papers(query): if not query.strip(): return [], "Please enter a search query." try: query_embedding = generate_embeddings_sci_bert([query]) distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5) tokenized_query = query.split() bm25_scores = bm25.get_scores(tokenized_query) bm25_top_indices = np.argsort(bm25_scores)[::-1][:5] combined_indices = list(set(indices[0]) | set(bm25_top_indices)) ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx]) # Return formatted strings for dropdown and indices for full data papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])] return papers, ranked_results[:5], "Search completed." except Exception as e: logger.error(f"Search failed: {e}") return [], [], "Search failed. Please try again." # Gemini API QA function with full context def answer_question(selected_index, question, history): if selected_index is None: return [(question, "Please select a paper first!")], history if not question.strip(): return [(question, "Please ask a question!")], history if question.lower() in ["exit", "done"]: return [("Conversation ended.", "Select a new paper or search again!")], [] try: # Get full paper data from DataFrame using index paper_data = df.iloc[selected_index] title = paper_data["title"] abstract = paper_data["abstract"] # Full abstract, not truncated authors = ", ".join(paper_data["authors"]) doi = paper_data["doi"] # Build prompt with all fields prompt = ( "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. " "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. " "When asked about tech stacks or methods, follow these guidelines:\n" "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n" "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n" "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n" "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n" "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n" "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n" "Here’s the paper:\n" f"Title: {title}\n" f"Authors: {authors}\n" f"Abstract: {abstract}\n" f"DOI: {doi}\n\n" ) # Add history if present if history: prompt += "Previous conversation (use for context):\n" for user_q, bot_a in history[-2:]: prompt += f"User: {user_q}\nAssistant: {bot_a}\n" prompt += f"Now, answer this question: {question}" logger.info(f"Prompt sent to Gemini API: {prompt[:200]}...") # Call Gemini API (Gemini 1.5 Flash) model = genai.GenerativeModel("gemini-1.5-flash") response = model.generate_content(prompt) answer = response.text.strip() # Fallback for poor responses if not answer or len(answer) < 15: answer = ( "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n" "- Python: Core language for ML/DL.\n" "- TensorFlow or PyTorch: Frameworks for deep learning models.\n" "- Scikit-learn: For traditional ML algorithms.\n" "- Pandas/NumPy: For data handling and preprocessing." ) history.append((question, answer)) return history, history except Exception as e: logger.error(f"QA failed: {e}") history.append((question, "Sorry, I couldn’t process that. Try again!")) return history, history # Gradio UI with gr.Blocks( css=""" .chatbot {height: 600px; overflow-y: auto;} .sidebar {width: 300px;} #main {display: flex; flex-direction: row;} """, theme=gr.themes.Default(primary_hue="blue") ) as demo: gr.Markdown("# ResearchGPT - Paper Search & Chat") with gr.Row(elem_id="main"): # Sidebar for search with gr.Column(scale=1, min_width=300, elem_classes="sidebar"): gr.Markdown("### Search Papers") query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare") search_btn = gr.Button("Search") paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True) search_status = gr.Textbox(label="Search Status", interactive=False) # States to store paper choices and indices paper_choices_state = gr.State([]) paper_indices_state = gr.State([]) search_btn.click( fn=get_relevant_papers, inputs=query_input, outputs=[paper_choices_state, paper_indices_state, search_status] ).then( fn=lambda choices: gr.update(choices=choices, value=None), inputs=paper_choices_state, outputs=paper_dropdown ) # Main chat area with gr.Column(scale=3): gr.Markdown("### Chat with Selected Paper") selected_paper = gr.Textbox(label="Selected Paper", interactive=False) chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot") question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?") chat_btn = gr.Button("Send") # State to store conversation history and selected index history_state = gr.State([]) selected_index_state = gr.State(None) # Update selected paper and index def update_selected_paper(choice, indices): if choice is None: return "", None index = int(choice.split(".")[0]) - 1 # Extract rank (e.g., "1." -> 0) selected_idx = indices[index] return choice, selected_idx paper_dropdown.change( fn=update_selected_paper, inputs=[paper_dropdown, paper_indices_state], outputs=[selected_paper, selected_index_state] ).then( fn=lambda: [], inputs=None, outputs=chatbot ) # Handle chat chat_btn.click( fn=answer_question, inputs=[selected_index_state, question_input, history_state], outputs=[chatbot, history_state] ).then( fn=lambda: "", inputs=None, outputs=question_input ) # Launch the app demo.launch(server_name="0.0.0.0", server_port=7860)