File size: 16,549 Bytes
e86b928
d4c27ab
 
 
 
 
cf496f0
d08a770
 
12e1b40
7ef58b7
 
12e1b40
 
 
 
b2c1b30
d08a770
9699ac9
e86b928
d08a770
 
 
 
 
 
 
 
12e1b40
a91f0db
12e1b40
 
 
 
 
 
 
 
d4c27ab
cf496f0
d4c27ab
12e1b40
d4c27ab
9699ac9
d4c27ab
cf496f0
12e1b40
 
 
 
 
 
 
d4c27ab
d08a770
cf496f0
12e1b40
cf496f0
12e1b40
 
 
 
 
 
 
 
 
cf496f0
7ef58b7
 
12e1b40
 
 
 
 
 
 
 
 
 
 
 
 
 
43c1491
d4c27ab
cf496f0
12e1b40
 
 
 
 
 
 
 
 
 
d4c27ab
7ef58b7
425d4bf
d4c27ab
425d4bf
12e1b40
 
425d4bf
12e1b40
 
425d4bf
12e1b40
 
093fd7d
 
12e1b40
 
093fd7d
cf496f0
7ef58b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf496f0
12e1b40
cf496f0
7ef58b7
12e1b40
 
7ef58b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad54e4d
7ef58b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d08a770
425d4bf
12e1b40
 
 
 
cf496f0
 
 
 
7ef58b7
 
 
 
 
 
 
cf496f0
7ef58b7
cf496f0
7ef58b7
 
cf496f0
7ef58b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425d4bf
cf496f0
7ef58b7
 
 
cf496f0
 
 
 
 
093fd7d
 
7ef58b7
 
 
 
 
 
 
 
 
 
 
cf496f0
7ef58b7
 
 
 
 
 
 
 
 
425d4bf
 
 
7ef58b7
cf496f0
7ef58b7
 
 
 
 
 
 
cf496f0
 
7ef58b7
12e1b40
 
 
 
43c1491
cf496f0
 
 
12e1b40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import google.generativeai as genai
import logging
from PyPDF2 import PdfReader
import io

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set cache directory for Hugging Face models (SciBERT only)
os.environ["HF_HOME"] = "/tmp/huggingface"

# Get Gemini API key from environment variable (stored in Spaces secrets)
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.")
    raise ValueError("GEMINI_API_KEY is required for Gemini API access.")
genai.configure(api_key=GEMINI_API_KEY)
logger.info("Gemini API configured")

# Load dataset with error handling
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
try:
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
    df = pd.read_json(DATASET_PATH)
    logger.info("Dataset loaded successfully")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

# Clean text
def clean_text(text):
    return text.strip().lower() if isinstance(text, str) else ""

df["cleaned_abstract"] = df["abstract"].apply(clean_text)

# Precompute BM25 Index
try:
    tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
    bm25 = BM25Okapi(tokenized_corpus)
    logger.info("BM25 index created")
except Exception as e:
    logger.error(f"BM25 index creation failed: {e}")
    raise

# Load SciBERT for embeddings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

try:
    sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
    sci_bert_model.to(device)
    sci_bert_model.eval()
    logger.info("SciBERT loaded")
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    raise

# Generate SciBERT embeddings (optimized with larger batch size)
def generate_embeddings_sci_bert(texts, batch_size=64):  # Increased batch size for efficiency
    try:
        all_embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = sci_bert_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            all_embeddings.append(embeddings.cpu().numpy())
            torch.cuda.empty_cache()
        return np.concatenate(all_embeddings, axis=0)
    except Exception as e:
        logger.error(f"Embedding generation failed: {e}")
        return np.zeros((len(texts), 768))

# Precompute embeddings and FAISS index
try:
    abstracts = df["cleaned_abstract"].tolist()
    embeddings = generate_embeddings_sci_bert(abstracts)
    dimension = embeddings.shape[1]
    faiss_index = faiss.IndexFlatL2(dimension)
    faiss_index.add(embeddings.astype(np.float32))
    logger.info("FAISS index created")
except Exception as e:
    logger.error(f"FAISS index creation failed: {e}")
    raise

# Hybrid search function (unchanged from original)
def get_relevant_papers(query):
    if not query.strip():
        return [], "Please enter a search query."
    try:
        query_embedding = generate_embeddings_sci_bert([query])
        distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5)
        tokenized_query = query.split()
        bm25_scores = bm25.get_scores(tokenized_query)
        bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
        combined_indices = list(set(indices[0]) | set(bm25_top_indices))
        ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
        papers = [f"{i+1}. {df.iloc[idx]['title']} - Abstract: {df.iloc[idx]['abstract'][:200]}..." for i, idx in enumerate(ranked_results[:5])]
        return papers, ranked_results[:5], "Search completed."
    except Exception as e:
        logger.error(f"Search failed: {e}")
        return [], [], "Search failed. Please try again."

# Process uploaded PDF for RAG
def process_uploaded_pdf(file):
    try:
        pdf_reader = PdfReader(file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text() or ""
        cleaned_text = clean_text(text)
        chunks = [cleaned_text[i:i+1000] for i in range(0, len(cleaned_text), 1000)]  # Chunk for efficiency
        embeddings = generate_embeddings_sci_bert(chunks)
        faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
        faiss_index.add(embeddings.astype(np.float32))
        tokenized_chunks = [chunk.split() for chunk in chunks]
        bm25_rag = BM25Okapi(tokenized_chunks)
        return {"chunks": chunks, "embeddings": embeddings, "faiss_index": faiss_index, "bm25": bm25_rag}, "Document processed successfully"
    except Exception as e:
        logger.error(f"PDF processing failed: {e}")
        return None, "Failed to process document"

# Hybrid search for RAG
def get_relevant_chunks(query, uploaded_doc):
    if not query.strip():
        return [], "Please enter a question."
    try:
        query_embedding = generate_embeddings_sci_bert([query])
        distances, indices = uploaded_doc["faiss_index"].search(query_embedding.astype(np.float32), 3)
        bm25_scores = uploaded_doc["bm25"].get_scores(query.split())
        combined_indices = list(set(indices[0]) | set(np.argsort(bm25_scores)[::-1][:3]))
        ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
        return [uploaded_doc["chunks"][idx] for idx in ranked_results[:3]], "Retrieval completed."
    except Exception as e:
        logger.error(f"RAG retrieval failed: {e}")
        return [], "Retrieval failed."

# Unified QA function
def answer_question(mode, selected_index, question, history, uploaded_doc=None):
    if not question.strip():
        return [(question, "Please ask a question!")], history
    if question.lower() in ["exit", "done"]:
        return [("Conversation ended.", "Start a new conversation!")], []

    try:
        if mode == "research":
            if selected_index is None:
                return [(question, "Please select a paper first!")], history
            paper_data = df.iloc[selected_index]
            title = paper_data["title"]
            abstract = paper_data["abstract"]
            authors = ", ".join(paper_data["authors"])
            doi = paper_data["doi"]
            prompt = (
                "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
                "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. "
                "When asked about tech stacks or methods, follow these guidelines:\n"
                "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
                "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
                "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
                "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
                "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
                "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
                "Here’s the paper:\n"
                f"Title: {title}\n"
                f"Authors: {authors}\n"
                f"Abstract: {abstract}\n"
                f"DOI: {doi}\n\n"
            )
            if history:
                prompt += "Previous conversation (use for context):\n"
                for user_q, bot_a in history[-2:]:
                    prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
            prompt += f"Now, answer this question: {question}"
            model = genai.GenerativeModel("gemini-1.5-flash")
            response = model.generate_content(prompt)
            answer = response.text.strip()
            if not answer or len(answer) < 15:
                answer = (
                    "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
                    "- Python: Core language for ML/DL.\n"
                    "- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
                    "- Scikit-learn: For traditional ML algorithms.\n"
                    "- Pandas/NumPy: For data handling and preprocessing."
                )

        elif mode == "rag":
            if uploaded_doc is None:
                return [(question, "Please upload a document first!")], history
            relevant_chunks, _ = get_relevant_chunks(question, uploaded_doc)
            context = "\n".join(relevant_chunks)
            prompt = (
                "You are an expert AI assistant specializing in answering questions based on uploaded documents. "
                "Provide concise, accurate answers based on the following document content:\n"
                f"Content: {context}\n\n"
            )
            if history:
                prompt += "Previous conversation (use for context):\n"
                for user_q, bot_a in history[-2:]:
                    prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
            prompt += f"Now, answer this question: {question}"
            model = genai.GenerativeModel("gemini-1.5-flash")
            response = model.generate_content(prompt)
            answer = response.text.strip()

        else:  # general mode
            prompt = (
                "You are a highly knowledgeable AI assistant. Answer the following question concisely and accurately:\n"
            )
            if history:
                prompt += "Previous conversation (use for context):\n"
                for user_q, bot_a in history[-2:]:
                    prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
            prompt += f"Question: {question}"
            model = genai.GenerativeModel("gemini-1.5-flash")
            response = model.generate_content(prompt)
            answer = response.text.strip()

        history.append((question, answer))
        return history, history
    except Exception as e:
        logger.error(f"QA failed: {e}")
        history.append((question, "Sorry, I couldn’t process that. Try again!"))
        return history, history

# Gradio UI
with gr.Blocks(
    css="""
    .chatbot {height: 500px; overflow-y: auto; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
    .sidebar {width: 350px; padding: 15px; background: #f8f9fa; border-radius: 10px;}
    #main {display: flex; flex-direction: row; gap: 20px; padding: 20px;}
    .tab-content {padding: 20px; background: #ffffff; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
    .gr-button {background: #007bff; color: white; border-radius: 5px; transition: background 0.3s;}
    .gr-button:hover {background: #0056b3;}
    h1 {color: #007bff; text-align: center; margin-bottom: 20px;}
    """,
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
) as demo:
    gr.Markdown("# Triad: ResearchGPT, RAG, & General Chat")
    
    with gr.Row(elem_id="main"):
        # Sidebar
        with gr.Column(scale=1, min_width=350, elem_classes="sidebar"):
            mode_tabs = gr.Tabs()
            with mode_tabs:
                # Research Mode (unchanged backend)
                with gr.TabItem("Research Mode"):
                    gr.Markdown("### Search Papers")
                    query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
                    search_btn = gr.Button("Search")
                    paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
                    search_status = gr.Textbox(label="Search Status", interactive=False)
                    paper_choices_state = gr.State([])
                    paper_indices_state = gr.State([])

                    search_btn.click(
                        fn=get_relevant_papers,
                        inputs=query_input,
                        outputs=[paper_choices_state, paper_indices_state, search_status]
                    ).then(
                        fn=lambda choices: gr.update(choices=choices, value=None),
                        inputs=paper_choices_state,
                        outputs=paper_dropdown
                    )

                # RAG Mode
                with gr.TabItem("RAG Mode"):
                    gr.Markdown("### Upload Document")
                    file_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
                    upload_status = gr.Textbox(label="Upload Status", interactive=False)
                    uploaded_doc_state = gr.State(None)
                    file_upload.change(
                        fn=process_uploaded_pdf,
                        inputs=file_upload,
                        outputs=[uploaded_doc_state, upload_status]
                    )

                # General Mode
                with gr.TabItem("General Chat"):
                    gr.Markdown("Ask anything, powered by Gemini!")

        # Main chat area
        with gr.Column(scale=3, elem_classes="tab-content"):
            gr.Markdown("### Chat Area")
            selected_display = gr.Markdown(label="Selected Context", value="Select a mode to begin!")
            chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
            question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
            chat_btn = gr.Button("Send")
            
            history_state = gr.State([])
            selected_index_state = gr.State(None)

            def update_display(mode, choice, indices, uploaded_doc):
                if mode == "research" and choice:
                    index = int(choice.split(".")[0]) - 1
                    selected_idx = indices[index]
                    paper = df.iloc[selected_idx]
                    return f"**{paper['title']}**<br>DOI: [{paper['doi']}](https://doi.org/{paper['doi']})", selected_idx
                elif mode == "rag" and uploaded_doc:
                    return "Uploaded Document Ready", None
                elif mode == "general":
                    return "General Chat Mode", None
                return "Select a mode to begin!", None

            mode_tabs.select(
                fn=lambda tab: ("research" if tab == "Research Mode" else "rag" if tab == "RAG Mode" else "general"),
                inputs=None,
                outputs=None,
                _js="tab => tab"
            ).then(
                fn=update_display,
                inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
                outputs=[selected_display, selected_index_state]
            ).then(
                fn=lambda: [],
                inputs=None,
                outputs=[chatbot, history_state]
            )

            paper_dropdown.change(
                fn=update_display,
                inputs=[mode_tabs, paper_dropdown, paper_indices_state, uploaded_doc_state],
                outputs=[selected_display, selected_index_state]
            )

            chat_btn.click(
                fn=answer_question,
                inputs=[mode_tabs, selected_index_state, question_input, history_state, uploaded_doc_state],
                outputs=[chatbot, history_state]
            ).then(
                fn=lambda: "",
                inputs=None,
                outputs=question_input
            )

# Launch the app
demo.launch(server_name="0.0.0.0", server_port=7860)