File size: 12,210 Bytes
e86b928
d4c27ab
 
 
 
 
cf496f0
d08a770
 
12e1b40
 
 
f5b6a0d
12e1b40
b2c1b30
d08a770
9699ac9
e86b928
d08a770
 
 
 
 
f5b6a0d
 
 
 
 
 
d08a770
12e1b40
a91f0db
12e1b40
 
 
 
 
 
 
 
d4c27ab
cf496f0
d4c27ab
12e1b40
d4c27ab
f5b6a0d
 
 
 
 
 
d4c27ab
cf496f0
12e1b40
 
 
 
 
 
 
d4c27ab
d08a770
cf496f0
12e1b40
cf496f0
12e1b40
 
 
 
 
f5b6a0d
12e1b40
 
 
cf496f0
d8a8174
 
f5b6a0d
12e1b40
 
 
 
 
 
 
 
 
 
f5b6a0d
12e1b40
 
43c1491
d4c27ab
12e1b40
 
 
f5b6a0d
 
 
12e1b40
f5b6a0d
 
 
 
 
12e1b40
 
 
d4c27ab
425d4bf
f5b6a0d
 
12e1b40
 
425d4bf
f5b6a0d
12e1b40
425d4bf
12e1b40
 
f5b6a0d
 
 
 
 
 
 
 
093fd7d
12e1b40
 
093fd7d
cf496f0
d8a8174
 
 
f5b6a0d
d8a8174
cf496f0
d8a8174
12e1b40
 
d8a8174
f5b6a0d
 
 
 
 
 
d8a8174
 
 
 
 
 
 
 
 
 
f5b6a0d
d8a8174
 
 
 
 
f5b6a0d
d8a8174
 
 
 
f5b6a0d
d8a8174
f5b6a0d
 
 
 
 
d8a8174
f5b6a0d
 
d8a8174
 
f5b6a0d
d8a8174
 
 
 
 
 
ad54e4d
f5b6a0d
d8a8174
425d4bf
12e1b40
 
d8a8174
12e1b40
cf496f0
 
 
 
d8a8174
 
 
cf496f0
d8a8174
cf496f0
d8a8174
cf496f0
d8a8174
 
 
 
 
 
 
 
 
 
 
425d4bf
d8a8174
 
 
 
 
 
 
 
 
 
cf496f0
d8a8174
 
 
 
cf496f0
 
 
d8a8174
cf496f0
093fd7d
 
d8a8174
 
 
 
f5b6a0d
 
 
 
 
 
d8a8174
cf496f0
d8a8174
 
 
 
425d4bf
 
 
d8a8174
7ef58b7
d8a8174
 
cf496f0
d8a8174
 
12e1b40
 
 
 
43c1491
cf496f0
 
 
f5b6a0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import google.generativeai as genai
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)

# Set cache directory for Hugging Face models (SciBERT only)
os.environ["HF_HOME"] = "/tmp/huggingface"

# Get Gemini API key from environment variable (stored in Spaces secrets)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.")
    raise ValueError("GEMINI_API_KEY is required for Gemini API access.")
try:
    genai.configure(api_key=GEMINI_API_KEY)
    logger.info("Gemini API configured")
except Exception as e:
    logger.error(f"Failed to configure Gemini API: {e}")
    raise

# Load dataset with error handling
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
try:
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
    df = pd.read_json(DATASET_PATH)
    logger.info("Dataset loaded successfully")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

# Clean text
def clean_text(text):
    return text.strip().lower() if isinstance(text, str) else ""

try:
    df["cleaned_abstract"] = df["abstract"].apply(clean_text)
    logger.info("Text cleaning completed")
except Exception as e:
    logger.error(f"Error during cleaning abstracts: {e}")
    raise

# Precompute BM25 Index
try:
    tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
    bm25 = BM25Okapi(tokenized_corpus)
    logger.info("BM25 index created")
except Exception as e:
    logger.error(f"BM25 index creation failed: {e}")
    raise

# Load SciBERT for embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

try:
    sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model.to(device)
    sci_bert_model.eval()
    logger.info("SciBERT loaded successfully")
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    raise

# Generate SciBERT embeddings
def generate_embeddings_sci_bert(texts, batch_size=32):
    all_embeddings = []
    try:
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = sci_bert_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            all_embeddings.append(embeddings.cpu().numpy())
            torch.cuda.empty_cache()
        return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 768))
    except Exception as e:
        logger.error(f"Embedding generation failed: {e}")
        return np.zeros((len(texts), 768))

try:
    abstracts = df["cleaned_abstract"].tolist()
    embeddings = generate_embeddings_sci_bert(abstracts)
    if embeddings.shape[0] != len(abstracts):
        logger.warning("Embeddings count does not match abstracts count")
    dimension = embeddings.shape[1] if embeddings.size else 768
    faiss_index = faiss.IndexFlatL2(dimension)
    if embeddings.size:
        faiss_index.add(embeddings.astype(np.float32))
        logger.info("FAISS index created")
    else:
        logger.warning("No embeddings to index")
except Exception as e:
    logger.error(f"FAISS index creation failed: {e}")
    raise

def get_relevant_papers(query):
    if not isinstance(query, str) or not query.strip():
        return [], [], "Please enter a valid search query."
    try:
        query_embedding = generate_embeddings_sci_bert([query])
        distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5)
        tokenized_query = query.lower().split()
        bm25_scores = bm25.get_scores(tokenized_query)
        bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
        combined_indices = list(set(indices[0]) | set(bm25_top_indices))
        ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
        papers = []
        for i, idx in enumerate(ranked_results[:5]):
            try:
                title = df.iloc[idx]['title']
                abstract_snip = df.iloc[idx]['abstract'][:200] + "..." if len(df.iloc[idx]['abstract']) > 200 else df.iloc[idx]['abstract']
                papers.append(f"{i+1}. {title} - Abstract: {abstract_snip}")
            except Exception as e:
                logger.error(f"Error accessing paper at index {idx}: {e}")
        return papers, ranked_results[:5], "Search completed."
    except Exception as e:
        logger.error(f"Search failed: {e}")
        return [], [], "Search failed. Please try again."

def answer_question(selected_index, question, history):
    if selected_index is None:
        return [(question, "Please select a paper first!")], history
    if not isinstance(question, str) or not question.strip():
        return [(question, "Please ask a question!")], history
    if question.lower() in ["exit", "done"]:
        return [("Conversation ended.", "Select a new paper or search again!")], []

    try:
        paper_data = df.iloc[selected_index]
        title = paper_data.get("title", "Unknown Title")
        abstract = paper_data.get("abstract", "Abstract not available.")
        authors_list = paper_data.get("authors", [])
        authors = ", ".join(authors_list) if isinstance(authors_list, list) else str(authors_list)
        doi = paper_data.get("doi", "No DOI")

        prompt = (
            "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
            "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
            "When asked about tech stacks or methods, follow these guidelines:\n"
            "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
            "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
            "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
            "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
            "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
            "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
            f"Here’s the paper:\n"
            f"Title: {title}\n"
            f"Authors: {authors}\n"
            f"Abstract: {abstract}\n"
            f"DOI: {doi}\n\n"
        )

        if history:
            prompt += "Previous conversation (use for context):\n"
            for user_q, bot_a in history[-2:]:
                prompt += f"User: {user_q}\nAssistant: {bot_a}\n"

        prompt += f"Now, answer this question: {question}"

        logger.info(f"Prompt sent to Gemini API (truncated): {prompt[:500]}...")

        # Updated to use valid model name
        model = genai.GenerativeModel("models/gemini-2.5-flash")
        response = model.generate_content(prompt)

        answer = getattr(response, 'text', '').strip() if response else ""
        
        if not answer or len(answer) < 15:
            logger.warning("Received short or empty answer from Gemini API, applying fallback.")
            answer = (
                "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
                "- Python: Core language for ML/DL.\n"
                "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
                "- Scikit-learn: For traditional ML algorithms.\n"
                "- Pandas/NumPy: For data handling and preprocessing."
            )

        history.append((question, answer))
        return history, history
    except Exception as e:
        logger.error(f"QA failed: {e}")
        history.append((question, "Sorry, I couldn’t process that. Try again!"))
        return history, history

# Gradio UI
with gr.Blocks(
    css="""
    .chatbot {height: 600px; overflow-y: auto;}
    .sidebar {width: 300px;}
    #main {display: flex; flex-direction: row;}
    """,
    theme=gr.themes.Default(primary_hue="blue")
) as demo:
    gr.Markdown("# ResearchGPT - Paper Search & Chat")
    with gr.Row(elem_id="main"):
        # Sidebar for search
        with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
            gr.Markdown("### Search Papers")
            query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
            search_btn = gr.Button("Search")
            paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
            search_status = gr.Textbox(label="Search Status", interactive=False)
            
            # States to store paper choices and indices
            paper_choices_state = gr.State([])
            paper_indices_state = gr.State([])

            search_btn.click(
                fn=get_relevant_papers,
                inputs=query_input,
                outputs=[paper_choices_state, paper_indices_state, search_status]
            ).then(
                fn=lambda choices: gr.update(choices=choices, value=None),
                inputs=paper_choices_state,
                outputs=paper_dropdown
            )
        
        # Main chat area
        with gr.Column(scale=3):
            gr.Markdown("### Chat with Selected Paper")
            selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
            chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
            question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
            chat_btn = gr.Button("Send")
            
            # State to store conversation history and selected index
            history_state = gr.State([])
            selected_index_state = gr.State(None)

            # Update selected paper and index
            def update_selected_paper(choice, indices):
                if choice is None:
                    return "", None
                try:
                    index = int(choice.split(".")[0]) - 1  # Extract rank (e.g., "1." -> 0)
                    selected_idx = indices[index]
                except Exception as e:
                    logger.error(f"Error updating selected paper: {e}")
                    return "", None
                return choice, selected_idx

            paper_dropdown.change(
                fn=update_selected_paper,
                inputs=[paper_dropdown, paper_indices_state],
                outputs=[selected_paper, selected_index_state]
            ).then(
                fn=lambda: [],
                inputs=None,
                outputs=chatbot
            )
            
            # Handle chat
            chat_btn.click(
                fn=answer_question,
                inputs=[selected_index_state, question_input, history_state],
                outputs=[chatbot, history_state]
            ).then(
                fn=lambda: "",
                inputs=None,
                outputs=question_input
            )

# Launch the app
demo.launch(server_name="0.0.0.0", server_port=7860)