File size: 11,207 Bytes
e86b928
d4c27ab
 
 
 
 
cf496f0
d08a770
 
12e1b40
 
 
 
 
b2c1b30
d08a770
9699ac9
e86b928
d08a770
 
 
 
 
 
 
 
12e1b40
a91f0db
12e1b40
 
 
 
 
 
 
 
d4c27ab
cf496f0
d4c27ab
12e1b40
d4c27ab
9699ac9
d4c27ab
cf496f0
12e1b40
 
 
 
 
 
 
d4c27ab
d08a770
cf496f0
12e1b40
cf496f0
12e1b40
 
 
 
 
 
 
 
 
cf496f0
d8a8174
 
12e1b40
 
 
 
 
 
 
 
 
 
 
 
 
 
43c1491
d4c27ab
cf496f0
12e1b40
 
 
 
 
 
 
 
 
 
d4c27ab
d8a8174
425d4bf
d4c27ab
425d4bf
12e1b40
 
425d4bf
12e1b40
 
425d4bf
12e1b40
 
d8a8174
093fd7d
 
12e1b40
 
093fd7d
cf496f0
d8a8174
 
 
 
cf496f0
d8a8174
cf496f0
d8a8174
12e1b40
 
d8a8174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad54e4d
d8a8174
 
425d4bf
12e1b40
 
d8a8174
12e1b40
cf496f0
 
 
 
d8a8174
 
 
cf496f0
d8a8174
cf496f0
d8a8174
cf496f0
d8a8174
 
 
 
 
 
 
 
 
 
 
425d4bf
d8a8174
 
 
 
 
 
 
 
 
 
cf496f0
d8a8174
 
 
 
cf496f0
 
 
d8a8174
cf496f0
093fd7d
 
d8a8174
 
 
 
 
 
 
cf496f0
d8a8174
 
 
 
425d4bf
 
 
d8a8174
7ef58b7
d8a8174
 
cf496f0
d8a8174
 
12e1b40
 
 
 
43c1491
cf496f0
 
 
12e1b40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import google.generativeai as genai
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set cache directory for Hugging Face models (SciBERT only)
os.environ["HF_HOME"] = "/tmp/huggingface"

# Get Gemini API key from environment variable (stored in Spaces secrets)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.")
    raise ValueError("GEMINI_API_KEY is required for Gemini API access.")
genai.configure(api_key=GEMINI_API_KEY)
logger.info("Gemini API configured")

# Load dataset with error handling
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
try:
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
    df = pd.read_json(DATASET_PATH)
    logger.info("Dataset loaded successfully")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

# Clean text
def clean_text(text):
    return text.strip().lower() if isinstance(text, str) else ""

df["cleaned_abstract"] = df["abstract"].apply(clean_text)

# Precompute BM25 Index
try:
    tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
    bm25 = BM25Okapi(tokenized_corpus)
    logger.info("BM25 index created")
except Exception as e:
    logger.error(f"BM25 index creation failed: {e}")
    raise

# Load SciBERT for embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

try:
    sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model.to(device)
    sci_bert_model.eval()
    logger.info("SciBERT loaded")
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    raise

# Generate SciBERT embeddings
def generate_embeddings_sci_bert(texts, batch_size=32):
    try:
        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = sci_bert_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            all_embeddings.append(embeddings.cpu().numpy())
            torch.cuda.empty_cache()
        return np.concatenate(all_embeddings, axis=0)
    except Exception as e:
        logger.error(f"Embedding generation failed: {e}")
        return np.zeros((len(texts), 768))

# Precompute embeddings and FAISS index
try:
    abstracts = df["cleaned_abstract"].tolist()
    embeddings = generate_embeddings_sci_bert(abstracts)
    dimension = embeddings.shape[1]
    faiss_index = faiss.IndexFlatL2(dimension)
    faiss_index.add(embeddings.astype(np.float32))
    logger.info("FAISS index created")
except Exception as e:
    logger.error(f"FAISS index creation failed: {e}")
    raise

# Hybrid search function (return indices instead of truncated strings)
def get_relevant_papers(query):
    if not query.strip():
        return [], "Please enter a search query."
    try:
        query_embedding = generate_embeddings_sci_bert([query])
        distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5)
        tokenized_query = query.split()
        bm25_scores = bm25.get_scores(tokenized_query)
        bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
        combined_indices = list(set(indices[0]) | set(bm25_top_indices))
        ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
        # Return formatted strings for dropdown and indices for full data
        papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
        return papers, ranked_results[:5], "Search completed."
    except Exception as e:
        logger.error(f"Search failed: {e}")
        return [], [], "Search failed. Please try again."

# Gemini API QA function with full context
def answer_question(selected_index, question, history):
    if selected_index is None:
        return [(question, "Please select a paper first!")], history
    if not question.strip():
        return [(question, "Please ask a question!")], history
    if question.lower() in ["exit", "done"]:
        return [("Conversation ended.", "Select a new paper or search again!")], []

    try:
        # Get full paper data from DataFrame using index
        paper_data = df.iloc[selected_index]
        title = paper_data["title"]
        abstract = paper_data["abstract"]  # Full abstract, not truncated
        authors = ", ".join(paper_data["authors"])
        doi = paper_data["doi"]
        
        # Build prompt with all fields
        prompt = (
            "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
            "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
            "When asked about tech stacks or methods, follow these guidelines:\n"
            "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
            "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
            "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
            "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
            "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
            "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
            "Here’s the paper:\n"
            f"Title: {title}\n"
            f"Authors: {authors}\n"
            f"Abstract: {abstract}\n"
            f"DOI: {doi}\n\n"
        )
        
        # Add history if present
        if history:
            prompt += "Previous conversation (use for context):\n"
            for user_q, bot_a in history[-2:]:
                prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
        
        prompt += f"Now, answer this question: {question}"
        
        logger.info(f"Prompt sent to Gemini API: {prompt[:200]}...")
        
        # Call Gemini API (Gemini 1.5 Flash)
        model = genai.GenerativeModel("gemini-1.5-flash")
        response = model.generate_content(prompt)
        answer = response.text.strip()
        
        # Fallback for poor responses
        if not answer or len(answer) < 15:
            answer = (
                "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
                "- Python: Core language for ML/DL.\n"
                "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
                "- Scikit-learn: For traditional ML algorithms.\n"
                "- Pandas/NumPy: For data handling and preprocessing."
            )
        
        history.append((question, answer))
        return history, history
    except Exception as e:
        logger.error(f"QA failed: {e}")
        history.append((question, "Sorry, I couldn’t process that. Try again!"))
        return history, history

# Gradio UI
with gr.Blocks(
    css="""
    .chatbot {height: 600px; overflow-y: auto;}
    .sidebar {width: 300px;}
    #main {display: flex; flex-direction: row;}
    """,
    theme=gr.themes.Default(primary_hue="blue")
) as demo:
    gr.Markdown("# ResearchGPT - Paper Search & Chat")
    with gr.Row(elem_id="main"):
        # Sidebar for search
        with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
            gr.Markdown("### Search Papers")
            query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
            search_btn = gr.Button("Search")
            paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
            search_status = gr.Textbox(label="Search Status", interactive=False)
            
            # States to store paper choices and indices
            paper_choices_state = gr.State([])
            paper_indices_state = gr.State([])

            search_btn.click(
                fn=get_relevant_papers,
                inputs=query_input,
                outputs=[paper_choices_state, paper_indices_state, search_status]
            ).then(
                fn=lambda choices: gr.update(choices=choices, value=None),
                inputs=paper_choices_state,
                outputs=paper_dropdown
            )
        
        # Main chat area
        with gr.Column(scale=3):
            gr.Markdown("### Chat with Selected Paper")
            selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
            chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
            question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
            chat_btn = gr.Button("Send")
            
            # State to store conversation history and selected index
            history_state = gr.State([])
            selected_index_state = gr.State(None)

            # Update selected paper and index
            def update_selected_paper(choice, indices):
                if choice is None:
                    return "", None
                index = int(choice.split(".")[0]) - 1  # Extract rank (e.g., "1." -> 0)
                selected_idx = indices[index]
                return choice, selected_idx

            paper_dropdown.change(
                fn=update_selected_paper,
                inputs=[paper_dropdown, paper_indices_state],
                outputs=[selected_paper, selected_index_state]
            ).then(
                fn=lambda: [],
                inputs=None,
                outputs=chatbot
            )
            
            # Handle chat
            chat_btn.click(
                fn=answer_question,
                inputs=[selected_index_state, question_input, history_state],
                outputs=[chatbot, history_state]
            ).then(
                fn=lambda: "",
                inputs=None,
                outputs=question_input
            )

# Launch the app
demo.launch(server_name="0.0.0.0", server_port=7860)