majorSeaweed commited on
Commit
f1592a4
Β·
verified Β·
1 Parent(s): 7e145d6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +698 -0
app.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pdfplumber
4
+ import requests
5
+ import google.generativeai as genai
6
+ from bs4 import BeautifulSoup
7
+ from langchain.schema import Document
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_pinecone import PineconeVectorStore
10
+ from langchain_groq import ChatGroq
11
+ from langchain.chains import create_retrieval_chain
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.embeddings import Embeddings
15
+ from langchain_community.tools import DuckDuckGoSearchRun
16
+ from pinecone import Pinecone
17
+ from dotenv import load_dotenv
18
+ import numpy as np
19
+ import time
20
+ import random
21
+ from typing import List
22
+ import arxiv
23
+ import wikipedia
24
+ from selenium import webdriver
25
+ from selenium.webdriver.common.by import By
26
+ from selenium.webdriver.chrome.options import Options
27
+ from selenium.webdriver.common.action_chains import ActionChains
28
+ from lxml import html
29
+ import base64
30
+ import os
31
+ import streamlit as st
32
+ import pdfplumber
33
+ import requests
34
+ import google.generativeai as genai
35
+ # Load environment variables
36
+ load_dotenv()
37
+
38
+ # Get API keys from environment variables
39
+ groq_key = os.getenv("GROQ_API_KEY")
40
+ pinecone_key = os.getenv("PINECONE_API_KEY")
41
+ gemini_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
42
+ genai.configure(api_key=gemini_key)
43
+ # Check if all required API keys are available
44
+ if not gemini_key:
45
+ st.error("Gemini API key is missing. Please set either GEMINI_API_KEY or GOOGLE_API_KEY environment variable.")
46
+
47
+ st.set_page_config(
48
+ page_title="AI Research Assistant",
49
+ page_icon="πŸ”",
50
+ layout="wide",
51
+ initial_sidebar_state="expanded"
52
+ )
53
+
54
+ #-------------------------------------------------------------
55
+ # UTILITY FUNCTIONS
56
+ #-------------------------------------------------------------
57
+
58
+ # Gemini Embeddings class
59
+ class GeminiEmbeddings(Embeddings):
60
+ def __init__(self, api_key):
61
+ genai.configure(api_key=api_key)
62
+ self.model_name = "models/embedding-001"
63
+
64
+ def embed_documents(self, texts):
65
+ return [self._convert_to_float32(genai.embed_content(
66
+ model=self.model_name, content=text, task_type="retrieval_document"
67
+ )["embedding"]) for text in texts]
68
+
69
+ def embed_query(self, text):
70
+ response = genai.embed_content(
71
+ model=self.model_name, content=text, task_type="retrieval_query"
72
+ )
73
+ return self._convert_to_float32(response["embedding"])
74
+
75
+ @staticmethod
76
+ def _convert_to_float32(embedding):
77
+ return np.array(embedding, dtype=np.float32).tolist()
78
+
79
+ # PDF handling functions
80
+ def extract_text_from_pdf(pdf_path):
81
+ text = ""
82
+ try:
83
+ with pdfplumber.open(pdf_path) as pdf:
84
+ for page in pdf.pages:
85
+ extracted_text = page.extract_text()
86
+ if extracted_text:
87
+ text += extracted_text + "\n"
88
+ return text.strip()
89
+ except Exception as e:
90
+ st.error(f"Error extracting text from PDF: {e}")
91
+ return ""
92
+
93
+ def read_data_from_doc(uploaded_file):
94
+ docs = []
95
+ with pdfplumber.open(uploaded_file) as pdf:
96
+ for i, page in enumerate(pdf.pages):
97
+ text = page.extract_text() or ""
98
+ tables = page.extract_tables()
99
+ table_text = "\n".join([
100
+ "\n".join(["\t".join(cell if cell is not None else "" for cell in row) for row in table])
101
+ for table in tables if table
102
+ ]) if tables else ""
103
+ images = page.images
104
+ image_text = f"[{len(images)} image(s) detected]" if images else ""
105
+ content = f"{text}\n\n{table_text}\n\n{image_text}".strip()
106
+ if content:
107
+ docs.append(Document(page_content=content, metadata={"page": i + 1}))
108
+ return docs
109
+
110
+ def make_chunks(docs, chunk_len=1000, chunk_overlap=200):
111
+ text_splitter = RecursiveCharacterTextSplitter(
112
+ chunk_size=chunk_len, chunk_overlap=chunk_overlap
113
+ )
114
+ chunks = text_splitter.split_documents(docs)
115
+ return [Document(page_content=chunk.page_content, metadata=chunk.metadata) for chunk in chunks]
116
+
117
+ # Gemini model functions
118
+ def get_gemini_model(model_name="gemini-1.5-pro", temperature=0.4):
119
+ return genai.GenerativeModel(model_name)
120
+
121
+ def get_generation_config(temperature=0.4):
122
+ return {
123
+ "temperature": temperature,
124
+ "top_p": 1,
125
+ "top_k": 1,
126
+ "max_output_tokens": 2048,
127
+ }
128
+
129
+ def get_safety_settings():
130
+ return [
131
+ {"category": category, "threshold": "BLOCK_NONE"}
132
+ for category in [
133
+ "HARM_CATEGORY_HARASSMENT",
134
+ "HARM_CATEGORY_HATE_SPEECH",
135
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
136
+ "HARM_CATEGORY_DANGEROUS_CONTENT",
137
+ ]
138
+ ]
139
+
140
+ def generate_gemini_response(model, prompt):
141
+ response = model.generate_content(
142
+ prompt,
143
+ generation_config=get_generation_config(),
144
+ safety_settings=get_safety_settings()
145
+ )
146
+ if response.candidates and len(response.candidates) > 0:
147
+ return response.candidates[0].content.parts[0].text
148
+ return ''
149
+
150
+ def summarize_text(text):
151
+ model = get_gemini_model()
152
+ prompt_text = f"Summarize the following research paper very concisely:\n{text[:5000]}" # Truncate to 5000 chars
153
+ summary = generate_gemini_response(model, prompt_text)
154
+ return summary
155
+
156
+ #-------------------------------------------------------------
157
+ # RESEARCH ASSISTANT MODULE
158
+ #-------------------------------------------------------------
159
+
160
+ def download_pdf(pdf_url, save_path="temp_paper.pdf"):
161
+ try:
162
+ response = requests.get(pdf_url)
163
+ if response.status_code == 200:
164
+ with open(save_path, "wb") as file:
165
+ file.write(response.content)
166
+ return save_path
167
+ except Exception as e:
168
+ st.error(f"Error downloading PDF: {e}")
169
+ return None
170
+
171
+ def search_arxiv(query, max_results=2):
172
+ client = arxiv.Client()
173
+ search = arxiv.Search(query=query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance)
174
+
175
+ arxiv_docs = []
176
+
177
+ for result in client.results(search):
178
+ pdf_link = next((link.href for link in result.links if 'pdf' in link.href), None)
179
+
180
+ # Download, extract, and summarize PDF if link exists
181
+ if pdf_link:
182
+ with st.spinner(f"Processing arXiv paper: {result.title}"):
183
+ pdf_path = download_pdf(pdf_link)
184
+ if pdf_path:
185
+ text = extract_text_from_pdf(pdf_path)
186
+ summary = summarize_text(text)
187
+ # Clean up downloaded file
188
+ if os.path.exists(pdf_path):
189
+ os.remove(pdf_path)
190
+ else:
191
+ summary = "PDF could not be downloaded."
192
+ else:
193
+ summary = "No PDF available."
194
+
195
+ content = f"""
196
+ **Title:** {result.title}
197
+ **Authors:** {', '.join(author.name for author in result.authors)}
198
+ **Published:** {result.published.strftime('%Y-%m-%d')}
199
+ **Abstract:** {result.summary}
200
+ **PDF Summary:** {summary}
201
+ **PDF Link:** {pdf_link if pdf_link else 'Not available'}
202
+ """
203
+
204
+ arxiv_docs.append(Document(page_content=content, metadata={"source": "arXiv", "title": result.title}))
205
+
206
+ return arxiv_docs
207
+
208
+ def search_wikipedia(query, max_results=2):
209
+ try:
210
+ page_titles = wikipedia.search(query, results=max_results)
211
+ wiki_docs = []
212
+ for title in page_titles:
213
+ try:
214
+ with st.spinner(f"Processing Wikipedia article: {title}"):
215
+ page = wikipedia.page(title)
216
+ wiki_docs.append(Document(
217
+ page_content=page.content[:2000],
218
+ metadata={"source": "Wikipedia", "title": title}
219
+ ))
220
+ except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError) as e:
221
+ st.warning(f"Error retrieving Wikipedia page {title}: {e}")
222
+ return wiki_docs
223
+ except Exception as e:
224
+ st.error(f"Error searching Wikipedia: {e}")
225
+ return []
226
+
227
+ class ResearchAssistant:
228
+ def __init__(self):
229
+ # Initialize LLM
230
+ self.llm = ChatGroq(
231
+ api_key=groq_key,
232
+ model="llama3-70b-8192",
233
+ temperature=0.2
234
+ )
235
+
236
+ # Set up the prompt template
237
+ self.prompt = ChatPromptTemplate.from_template("""
238
+ You are an expert research assistant. Use the following context to answer the question.
239
+ If you don't know the answer, say so, but try your best to find relevant information
240
+ from the provided context and additional context.
241
+
242
+ Context from user documents:
243
+ {context}
244
+
245
+ Additional context from research sources:
246
+ {additional_context}
247
+
248
+ Question: {input}
249
+
250
+ Answer:
251
+ """)
252
+
253
+ # Set up the question-answer chain
254
+ self.question_answer_chain = create_stuff_documents_chain(
255
+ self.llm, self.prompt
256
+ )
257
+
258
+ def retrieve_documents(self, query):
259
+ user_context = []
260
+
261
+ # Get documents from arXiv and Wikipedia
262
+ arxiv_docs = search_arxiv(query)
263
+ wiki_docs = search_wikipedia(query)
264
+
265
+ summarized_context = []
266
+ for doc in arxiv_docs:
267
+ summarized_context.append(f"**ArXiv - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...")
268
+
269
+ for doc in wiki_docs:
270
+ summarized_context.append(f"**Wikipedia - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...")
271
+
272
+ return user_context, summarized_context
273
+
274
+ def chat(self, question):
275
+ user_context, summarized_context = self.retrieve_documents(question)
276
+
277
+ input_data = {
278
+ "input": question,
279
+ "context": "\n\n".join(user_context),
280
+ "additional_context": "\n\n".join(summarized_context)
281
+ }
282
+
283
+ with st.spinner("Generating answer..."):
284
+ # Use the LLM directly
285
+ prompt_text = f"""
286
+ Question: {question}
287
+
288
+ Additional context:
289
+ {input_data['additional_context']}
290
+
291
+ Please provide a comprehensive answer based on the above information.
292
+ """
293
+ response = self.llm.invoke(prompt_text)
294
+ return response.content, summarized_context
295
+
296
+ #-------------------------------------------------------------
297
+ # DOCUMENT QA MODULE
298
+ #-------------------------------------------------------------
299
+
300
+ # Initialize retrieval chain
301
+ @st.cache_resource(show_spinner=False)
302
+ def get_retrieval_chain(uploaded_file, model):
303
+ with st.spinner("Processing document... This may take a minute."):
304
+ # Configure embeddings
305
+ genai.configure(api_key=gemini_key)
306
+ embeddings = GeminiEmbeddings(api_key=gemini_key)
307
+
308
+ # Read and process document
309
+ docs = read_data_from_doc(uploaded_file)
310
+ splits = make_chunks(docs)
311
+
312
+ # Set up vector store
313
+ pc = Pinecone(api_key=pinecone_key)
314
+
315
+ # Check if index exists, create it if not
316
+ indexes = pc.list_indexes()
317
+ index_name = "research-rag"
318
+ if index_name not in [idx.name for idx in indexes]:
319
+ pc.create_index(
320
+ name=index_name,
321
+ dimension=768, # Dimension for embeddings
322
+ metric="cosine"
323
+ )
324
+
325
+ vectorstore = PineconeVectorStore.from_documents(
326
+ splits,
327
+ embeddings,
328
+ index_name=index_name,
329
+ )
330
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
331
+
332
+ # Set up LLM and chain
333
+ llm = ChatGroq(model_name=model, temperature=0.75, api_key=groq_key)
334
+
335
+ system_prompt = """
336
+ You are an AI assistant answering questions based on retrieved documents and additional context.
337
+ Use the provided context from both database retrieval and additional sources to answer the question.
338
+
339
+ - **Discard irrelevant context:** If one of the contexts (retrieved or additional) does not match the question, ignore it.
340
+ - **Highlight conflicting information:** If multiple sources provide conflicting information, explicitly mention it by saying:
341
+ - "According to the retrieved context, ... but as per internet sources, ..."
342
+ - "According to the retrieved context, ... but as per internet sources, ..."
343
+ - **Prioritize accuracy:** If neither context provides a relevant answer, say "I don't know" instead of guessing.
344
+
345
+ Provide concise yet informative answers, ensuring clarity and completeness.
346
+
347
+ Retrieved Context: {context}
348
+ Additional Context: {additional_context}
349
+ """
350
+
351
+ prompt = ChatPromptTemplate.from_messages([
352
+ ("system", system_prompt),
353
+ ("human", "{input}\n\nRetrieved Context: {context}\n\nAdditional Context: {additional_context}"),
354
+ ])
355
+
356
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
357
+ chain = create_retrieval_chain(retriever, question_answer_chain)
358
+
359
+ return chain
360
+
361
+ #-------------------------------------------------------------
362
+ # WEB SEARCH MODULE
363
+ #-------------------------------------------------------------
364
+
365
+ # Prompt creation functions
366
+ def create_search_prompt(query, context=""):
367
+ system_prompt = """You are a smart assistant designed to determine whether a query needs data from a web search or can be answered using a document database.
368
+ Consider the provided context if available.
369
+ If the query requires external information, No context is provided, Irrelevent context is present or latest information is required, then output the special token <SEARCH>
370
+ followed by relevant keywords extracted from the query to optimize for search engine results.
371
+ Ensure the keywords are concise and relevant. If document data is sufficient, simply return blank."""
372
+
373
+ if context:
374
+ return f"{system_prompt}\n\nContext: {context}\n\nQuery: {query}"
375
+
376
+ return f"{system_prompt}\n\nQuery: {query}"
377
+
378
+ def create_summary_prompt(content):
379
+ return f"""Please provide a comprehensive yet concise summary of the following content, highlighting the most important points and maintaining factual accuracy. Organize the information in a clear and coherent manner:
380
+
381
+ Content to summarize:
382
+ {content}
383
+
384
+ Summary:"""
385
+
386
+ # Web scraping functions
387
+ def init_selenium_driver():
388
+ chrome_options = Options()
389
+ chrome_options.add_argument("--headless")
390
+ chrome_options.add_argument("--disable-gpu")
391
+ chrome_options.add_argument("--no-sandbox")
392
+ chrome_options.add_argument("--disable-dev-shm-usage")
393
+
394
+ driver = webdriver.Chrome(options=chrome_options)
395
+ return driver
396
+
397
+ def extract_static_page(url):
398
+ try:
399
+ response = requests.get(url, timeout=5)
400
+ response.raise_for_status()
401
+ soup = BeautifulSoup(response.text, 'lxml')
402
+
403
+ text = soup.get_text(separator=" ", strip=True)
404
+ return text[:5000]
405
+
406
+ except requests.exceptions.RequestException as e:
407
+ st.error(f"Error fetching page: {e}")
408
+ return None
409
+
410
+ def extract_dynamic_page(url, driver):
411
+ try:
412
+ driver.get(url)
413
+ time.sleep(random.uniform(2, 5))
414
+
415
+ body = driver.find_element(By.TAG_NAME, "body")
416
+ ActionChains(driver).move_to_element(body).perform()
417
+ time.sleep(random.uniform(2, 5))
418
+
419
+ page_source = driver.page_source
420
+ tree = html.fromstring(page_source)
421
+
422
+ text = tree.xpath('//body//text()')
423
+ text_content = ' '.join(text).strip()
424
+ return text_content[:1000]
425
+
426
+ except Exception as e:
427
+ st.error(f"Error fetching dynamic page: {e}")
428
+ return None
429
+
430
+ def scrape_page(url):
431
+ if "javascript" in url or "dynamic" in url:
432
+ driver = init_selenium_driver()
433
+ text = extract_dynamic_page(url, driver)
434
+ driver.quit()
435
+ else:
436
+ text = extract_static_page(url)
437
+
438
+ return text
439
+
440
+ def scrape_web(urls, max_urls=5):
441
+ texts = []
442
+
443
+ for url in urls[:max_urls]:
444
+ text = scrape_page(url)
445
+
446
+ if text:
447
+ texts.append(text)
448
+ else:
449
+ st.warning(f"Failed to retrieve content from {url}")
450
+
451
+ return texts
452
+
453
+ # Main web search functions
454
+ def check_search_needed(model, query, context):
455
+ prompt = create_search_prompt(query, context)
456
+ response = generate_gemini_response(model, prompt)
457
+
458
+ if "<SEARCH>" in response:
459
+ search_terms = response.split("<SEARCH>")[1].strip()
460
+ return True, search_terms
461
+ return False, None
462
+
463
+ def summarize_content(model, content):
464
+ prompt = create_summary_prompt(content)
465
+ return generate_gemini_response(model, prompt)
466
+
467
+ def process_query(query, context=''):
468
+ with st.spinner("Processing query..."):
469
+ model = get_gemini_model()
470
+ search_tool = DuckDuckGoSearchRun()
471
+
472
+ needs_search, search_terms = check_search_needed(model, query, context)
473
+
474
+ result = {
475
+ "original_query": query,
476
+ "needs_search": needs_search,
477
+ "search_terms": search_terms,
478
+ "web_content": None,
479
+ "summary": None
480
+ }
481
+
482
+ if needs_search:
483
+ with st.spinner(f"Searching the web for: {search_terms}"):
484
+ search_results = search_tool.run(search_terms)
485
+ result["web_content"] = search_results
486
+
487
+ with st.spinner("Summarizing search results..."):
488
+ summary = summarize_content(model, search_results)
489
+ result["summary"] = summary
490
+
491
+ return result
492
+
493
+ #-------------------------------------------------------------
494
+ # MAIN APP
495
+ #-------------------------------------------------------------
496
+
497
+ def display_header():
498
+ st.title("πŸ” AI Research Assistant")
499
+ st.markdown("Your all-in-one tool for research, document analysis, and web search")
500
+
501
+ def main():
502
+ # App header
503
+ display_header()
504
+
505
+ # Sidebar navigation
506
+ with st.sidebar:
507
+ st.title("Navigation")
508
+ app_mode = st.radio("Choose a mode:",
509
+ ["Research Assistant", "Document Q&A", "Web Search"])
510
+
511
+ st.markdown("---")
512
+ st.subheader("About")
513
+ st.markdown("""
514
+ This AI Research Assistant helps you find and analyze information from various sources:
515
+ - arXiv papers
516
+ - Wikipedia articles
517
+ - Your own uploaded documents
518
+ - Web search results
519
+ """)
520
+
521
+ # API keys status
522
+ st.markdown("---")
523
+ st.subheader("API Status")
524
+
525
+ if groq_key:
526
+ st.success("βœ… Groq API connected")
527
+ else:
528
+ st.error("❌ Groq API key missing")
529
+
530
+ if gemini_key:
531
+ st.success("βœ… Gemini API connected")
532
+ else:
533
+ st.error("❌ Gemini API key missing")
534
+
535
+ if pinecone_key:
536
+ st.success("βœ… Pinecone API connected")
537
+ else:
538
+ st.error("❌ Pinecone API key missing")
539
+
540
+ # Research Assistant Mode
541
+ if app_mode == "Research Assistant":
542
+ st.header("Research Assistant")
543
+ st.markdown("Ask research questions and get answers from arXiv papers and Wikipedia.")
544
+
545
+ # Initialize session state for chat history
546
+ if "research_history" not in st.session_state:
547
+ st.session_state.research_history = []
548
+
549
+ # Initialize Research Assistant
550
+ if "research_assistant" not in st.session_state:
551
+ with st.spinner("Initializing Research Assistant..."):
552
+ st.session_state.research_assistant = ResearchAssistant()
553
+
554
+ # Input area
555
+ with st.form(key="research_form"):
556
+ question = st.text_input("Your research question:", key="research_question")
557
+ submit_button = st.form_submit_button("Search")
558
+
559
+ # Clear chat button
560
+ if st.button("Clear Chat"):
561
+ st.session_state.research_history = []
562
+ st.rerun()
563
+
564
+ # Process query when submitted
565
+ if submit_button and question:
566
+ # Add user query to chat history
567
+ st.session_state.research_history.append({"role": "user", "content": question})
568
+
569
+ # Get response from assistant
570
+ answer, sources = st.session_state.research_assistant.chat(question)
571
+
572
+ # Add assistant response to chat history
573
+ st.session_state.research_history.append({
574
+ "role": "assistant",
575
+ "content": answer,
576
+ "sources": sources
577
+ })
578
+
579
+ # Display chat history
580
+ for message in st.session_state.research_history:
581
+ if message["role"] == "user":
582
+ st.write(f"πŸ‘€ **You:** {message['content']}")
583
+ else:
584
+ st.write(f"πŸ€– **AI Assistant:**")
585
+ st.markdown(message["content"])
586
+
587
+ # Display sources in expandable section
588
+ if message.get("sources"):
589
+ with st.expander("View Sources"):
590
+ for i, source in enumerate(message["sources"], 1):
591
+ st.markdown(f"**Source {i}:**")
592
+ st.markdown(source)
593
+ st.markdown("---")
594
+
595
+ # Document Q&A Mode
596
+ elif app_mode == "Document Q&A":
597
+ st.header("Document Q&A")
598
+ st.markdown("Upload a PDF document and ask questions about it.")
599
+
600
+ # Model selection
601
+ model_name = st.selectbox(
602
+ "Select Groq Model",
603
+ [
604
+ "llama3-70b-8192",
605
+ "gemma2-9b-it",
606
+ "llama-3.3-70b-versatile",
607
+ "llama-3.1-8b-instant",
608
+ "llama-guard-3-8b",
609
+ "mixtral-8x7b-32768",
610
+ "deepseek-r1-distill-llama-70b",
611
+ "llama-3.2-1b-preview"
612
+ ],
613
+ index=0
614
+ )
615
+
616
+ # Initialize session state for conversation history
617
+ if 'document_conversation' not in st.session_state:
618
+ st.session_state.document_conversation = []
619
+
620
+ # File upload
621
+ uploaded_file = st.file_uploader("Upload a PDF document", type="pdf")
622
+
623
+ if uploaded_file:
624
+ try:
625
+ chain = get_retrieval_chain(
626
+ uploaded_file,
627
+ model_name
628
+ )
629
+
630
+ # Show success message
631
+ st.success("Document processed successfully! You can now ask questions.")
632
+
633
+ # Display conversation history
634
+ for q, a in st.session_state.document_conversation:
635
+ with st.chat_message("user"):
636
+ st.write(q)
637
+ with st.chat_message("assistant"):
638
+ st.write(a)
639
+
640
+ # Question input
641
+ question = st.chat_input("Ask a question about your document...")
642
+
643
+ if question:
644
+ with st.chat_message("user"):
645
+ st.write(question)
646
+
647
+ with st.chat_message("assistant"):
648
+ with st.spinner("Thinking..."):
649
+ additional_context = "" # Can be modified to add external context if needed
650
+ result = chain.invoke({
651
+ "input": question,
652
+ "additional_context": additional_context
653
+ })
654
+ answer = result['answer']
655
+ st.write(answer)
656
+
657
+ # Store in conversation history
658
+ st.session_state.document_conversation.append((question, answer))
659
+
660
+ except Exception as e:
661
+ st.error(f"An error occurred: {str(e)}")
662
+
663
+ elif not (groq_key and gemini_key and pinecone_key):
664
+ st.warning("Please make sure all API keys are properly configured.")
665
+
666
+ # Web Search Mode
667
+ else:
668
+ st.header("Web Search")
669
+ st.markdown("Search the web for answers to your questions.")
670
+
671
+ # Input area
672
+ with st.form("web_query_form"):
673
+ query = st.text_area("Enter your research question", height=100,
674
+ placeholder="E.g., What are the latest developments in quantum computing?")
675
+ context = st.text_area("Optional: Add any context", height=100,
676
+ placeholder="Add any additional context that might help with the research")
677
+ submit_button = st.form_submit_button("πŸ” Research")
678
+
679
+ if submit_button and query:
680
+ result = process_query(query, context)
681
+
682
+ if result["needs_search"]:
683
+ st.success("Research completed!")
684
+
685
+ with st.expander("Search Details", expanded=False):
686
+ st.subheader("Search Terms Used")
687
+ st.info(result["search_terms"])
688
+
689
+ st.subheader("Raw Web Content")
690
+ st.text_area("Web Content", result["web_content"], height=200)
691
+
692
+ st.subheader("Summary of Findings")
693
+ st.markdown(result["summary"])
694
+ else:
695
+ st.info("Based on the analysis, no web search was needed for this query.")
696
+
697
+ if __name__ == "__main__":
698
+ main()