|
import os |
|
import streamlit as st |
|
import pdfplumber |
|
import requests |
|
import google.generativeai as genai |
|
from bs4 import BeautifulSoup |
|
from langchain.schema import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_pinecone import PineconeVectorStore |
|
from langchain_groq import ChatGroq |
|
from langchain.chains import create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.embeddings import Embeddings |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
from pinecone import Pinecone |
|
from dotenv import load_dotenv |
|
import numpy as np |
|
import time |
|
import random |
|
from typing import List |
|
import arxiv |
|
import wikipedia |
|
from selenium import webdriver |
|
from selenium.webdriver.common.by import By |
|
from selenium.webdriver.chrome.options import Options |
|
from selenium.webdriver.common.action_chains import ActionChains |
|
from lxml import html |
|
import base64 |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
groq_key = os.getenv("GROQ_API_KEY") |
|
pinecone_key = os.getenv("PINECONE_API_KEY") |
|
gemini_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
|
genai.configure(api_key=gemini_key) |
|
|
|
|
|
if not gemini_key: |
|
st.error("Gemini API key is missing. Please set either GEMINI_API_KEY or GOOGLE_API_KEY environment variable.") |
|
|
|
|
|
if 'theme' not in st.session_state: |
|
st.session_state.theme = 'light' |
|
|
|
|
|
st.set_page_config( |
|
page_title="AI Research Assistant", |
|
page_icon="π", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
if st.session_state.theme == 'dark': |
|
|
|
theme_bg_color = "#0E1117" |
|
theme_secondary_bg_color = "#262730" |
|
theme_text_color = "#FAFAFA" |
|
theme_primary_color = "#FF4B4B" |
|
else: |
|
|
|
theme_bg_color = "#FFFFFF" |
|
theme_secondary_bg_color = "#F0F2F6" |
|
theme_text_color = "#31333F" |
|
theme_primary_color = "#FF4B4B" |
|
|
|
|
|
st.markdown(f""" |
|
<style> |
|
/* Main container styling */ |
|
.main {{ |
|
padding: 1.5rem; |
|
background-color: {theme_bg_color}; |
|
color: {theme_text_color}; |
|
}} |
|
|
|
/* Header styling */ |
|
h1, h2, h3 {{ |
|
color: {theme_text_color}; |
|
font-weight: 600; |
|
margin-bottom: 1rem; |
|
}} |
|
|
|
/* Card-like containers */ |
|
.stExpander, div[data-testid="stForm"] {{ |
|
border-radius: 10px; |
|
border: 1px solid {theme_secondary_bg_color}; |
|
padding: 1rem; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
margin-bottom: 1rem; |
|
background-color: {theme_secondary_bg_color}; |
|
}} |
|
|
|
/* Button styling */ |
|
button[kind="primaryFormSubmit"] {{ |
|
border-radius: 8px; |
|
background-color: {theme_primary_color}; |
|
transition: all 0.3s ease; |
|
}} |
|
button[kind="primaryFormSubmit"]:hover {{ |
|
background-color: {theme_primary_color}; |
|
opacity: 0.8; |
|
box-shadow: 0 4px 8px rgba(0,0,0,0.1); |
|
}} |
|
|
|
/* Chat message styling */ |
|
[data-testid="stChatMessage"] {{ |
|
border-radius: 10px; |
|
margin-bottom: 0.5rem; |
|
padding: 0.5rem; |
|
}} |
|
|
|
/* Sidebar styling */ |
|
[data-testid="stSidebar"] {{ |
|
background-color: {theme_secondary_bg_color}; |
|
border-right: 1px solid {theme_secondary_bg_color}; |
|
}} |
|
|
|
/* Success/info/error message styling */ |
|
[data-testid="stSuccessMessage"], [data-testid="stInfoMessage"], [data-testid="stErrorMessage"] {{ |
|
border-radius: 8px; |
|
}} |
|
|
|
/* Input field styling */ |
|
[data-testid="stTextInput"], [data-testid="stTextArea"] {{ |
|
border-radius: 8px; |
|
}} |
|
|
|
/* File uploader styling */ |
|
[data-testid="stFileUploader"] {{ |
|
border-radius: 8px; |
|
border: 2px dashed {theme_secondary_bg_color}; |
|
padding: 1rem; |
|
}} |
|
|
|
/* Tabs styling */ |
|
.stTabs [data-baseweb="tab-list"] {{ |
|
gap: 8px; |
|
}} |
|
.stTabs [data-baseweb="tab"] {{ |
|
border-radius: 8px 8px 0 0; |
|
padding: 10px 16px; |
|
background-color: {theme_secondary_bg_color}; |
|
}} |
|
.stTabs [aria-selected="true"] {{ |
|
background-color: {theme_bg_color}; |
|
border-bottom: 2px solid {theme_primary_color}; |
|
}} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeminiEmbeddings(Embeddings): |
|
def __init__(self, api_key): |
|
genai.configure(api_key=api_key) |
|
self.model_name = "models/embedding-001" |
|
|
|
def embed_documents(self, texts): |
|
return [self._convert_to_float32(genai.embed_content( |
|
model=self.model_name, content=text, task_type="retrieval_document" |
|
)["embedding"]) for text in texts] |
|
|
|
def embed_query(self, text): |
|
response = genai.embed_content( |
|
model=self.model_name, content=text, task_type="retrieval_query" |
|
) |
|
return self._convert_to_float32(response["embedding"]) |
|
|
|
@staticmethod |
|
def _convert_to_float32(embedding): |
|
return np.array(embedding, dtype=np.float32).tolist() |
|
|
|
|
|
def extract_text_from_pdf(pdf_path): |
|
text = "" |
|
try: |
|
with pdfplumber.open(pdf_path) as pdf: |
|
for page in pdf.pages: |
|
extracted_text = page.extract_text() |
|
if extracted_text: |
|
text += extracted_text + "\n" |
|
return text.strip() |
|
except Exception as e: |
|
st.error(f"Error extracting text from PDF: {e}") |
|
return "" |
|
|
|
def read_data_from_doc(uploaded_file): |
|
docs = [] |
|
with pdfplumber.open(uploaded_file) as pdf: |
|
for i, page in enumerate(pdf.pages): |
|
text = page.extract_text() or "" |
|
tables = page.extract_tables() |
|
table_text = "\n".join([ |
|
"\n".join(["\t".join(cell if cell is not None else "" for cell in row) for row in table]) |
|
for table in tables if table |
|
]) if tables else "" |
|
images = page.images |
|
image_text = f"[{len(images)} image(s) detected]" if images else "" |
|
content = f"{text}\n\n{table_text}\n\n{image_text}".strip() |
|
if content: |
|
docs.append(Document(page_content=content, metadata={"page": i + 1})) |
|
return docs |
|
|
|
def make_chunks(docs, chunk_len=1000, chunk_overlap=200): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_len, chunk_overlap=chunk_overlap |
|
) |
|
chunks = text_splitter.split_documents(docs) |
|
return [Document(page_content=chunk.page_content, metadata=chunk.metadata) for chunk in chunks] |
|
|
|
|
|
def get_gemini_model(model_name="gemini-1.5-pro", temperature=0.4): |
|
return genai.GenerativeModel(model_name) |
|
|
|
def get_generation_config(temperature=0.4): |
|
return { |
|
"temperature": temperature, |
|
"top_p": 1, |
|
"top_k": 1, |
|
"max_output_tokens": 2048, |
|
} |
|
|
|
def get_safety_settings(): |
|
return [ |
|
{"category": category, "threshold": "BLOCK_NONE"} |
|
for category in [ |
|
"HARM_CATEGORY_HARASSMENT", |
|
"HARM_CATEGORY_HATE_SPEECH", |
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"HARM_CATEGORY_DANGEROUS_CONTENT", |
|
] |
|
] |
|
|
|
def generate_gemini_response(model, prompt): |
|
response = model.generate_content( |
|
prompt, |
|
generation_config=get_generation_config(), |
|
safety_settings=get_safety_settings() |
|
) |
|
if response.candidates and len(response.candidates) > 0: |
|
return response.candidates[0].content.parts[0].text |
|
return '' |
|
|
|
def summarize_text(text): |
|
model = get_gemini_model() |
|
prompt_text = f"Summarize the following research paper very concisely:\n{text[:5000]}" |
|
summary = generate_gemini_response(model, prompt_text) |
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
def download_pdf(pdf_url, save_path="temp_paper.pdf"): |
|
try: |
|
response = requests.get(pdf_url) |
|
if response.status_code == 200: |
|
with open(save_path, "wb") as file: |
|
file.write(response.content) |
|
return save_path |
|
except Exception as e: |
|
st.error(f"Error downloading PDF: {e}") |
|
return None |
|
|
|
def search_arxiv(query, max_results=2): |
|
client = arxiv.Client() |
|
search = arxiv.Search(query=query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance) |
|
|
|
arxiv_docs = [] |
|
|
|
for result in client.results(search): |
|
pdf_link = next((link.href for link in result.links if 'pdf' in link.href), None) |
|
|
|
|
|
if pdf_link: |
|
with st.spinner(f"Processing arXiv paper: {result.title}"): |
|
pdf_path = download_pdf(pdf_link) |
|
if pdf_path: |
|
text = extract_text_from_pdf(pdf_path) |
|
summary = summarize_text(text) |
|
|
|
if os.path.exists(pdf_path): |
|
os.remove(pdf_path) |
|
else: |
|
summary = "PDF could not be downloaded." |
|
else: |
|
summary = "No PDF available." |
|
|
|
content = f""" |
|
**Title:** {result.title} |
|
**Authors:** {', '.join(author.name for author in result.authors)} |
|
**Published:** {result.published.strftime('%Y-%m-%d')} |
|
**Abstract:** {result.summary} |
|
**PDF Summary:** {summary} |
|
**PDF Link:** {pdf_link if pdf_link else 'Not available'} |
|
""" |
|
|
|
arxiv_docs.append(Document(page_content=content, metadata={"source": "arXiv", "title": result.title})) |
|
|
|
return arxiv_docs |
|
|
|
def search_wikipedia(query, max_results=2): |
|
try: |
|
page_titles = wikipedia.search(query, results=max_results) |
|
wiki_docs = [] |
|
for title in page_titles: |
|
try: |
|
with st.spinner(f"Processing Wikipedia article: {title}"): |
|
page = wikipedia.page(title) |
|
wiki_docs.append(Document( |
|
page_content=page.content[:2000], |
|
metadata={"source": "Wikipedia", "title": title} |
|
)) |
|
except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError) as e: |
|
st.warning(f"Error retrieving Wikipedia page {title}: {e}") |
|
return wiki_docs |
|
except Exception as e: |
|
st.error(f"Error searching Wikipedia: {e}") |
|
return [] |
|
|
|
class ResearchAssistant: |
|
def __init__(self): |
|
|
|
self.llm = ChatGroq( |
|
api_key=groq_key, |
|
|
|
model = 'llama-3.3-70b-versatile', |
|
temperature=0.4 |
|
) |
|
|
|
|
|
self.prompt = ChatPromptTemplate.from_template(""" |
|
You are an expert research assistant. Use the following context to answer the question. |
|
If you don't know the answer, say so, but try your best to find relevant information |
|
from the provided context and additional context. |
|
|
|
Context from user documents: |
|
{context} |
|
|
|
Additional context from research sources: |
|
{additional_context} |
|
|
|
Question: {input} |
|
|
|
Answer: |
|
""") |
|
|
|
|
|
self.question_answer_chain = create_stuff_documents_chain( |
|
self.llm, self.prompt |
|
) |
|
|
|
def retrieve_documents(self, query): |
|
user_context = [] |
|
|
|
|
|
arxiv_docs = search_arxiv(query) |
|
wiki_docs = search_wikipedia(query) |
|
|
|
summarized_context = [] |
|
for doc in arxiv_docs: |
|
summarized_context.append(f"**ArXiv - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...") |
|
|
|
for doc in wiki_docs: |
|
summarized_context.append(f"**Wikipedia - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...") |
|
|
|
return user_context, summarized_context |
|
|
|
def chat(self, question): |
|
user_context, summarized_context = self.retrieve_documents(question) |
|
|
|
input_data = { |
|
"input": question, |
|
"context": "\n\n".join(user_context), |
|
"additional_context": "\n\n".join(summarized_context) |
|
} |
|
|
|
with st.spinner("Generating answer..."): |
|
|
|
prompt_text = f""" |
|
Question: {question} |
|
|
|
Additional context: |
|
{input_data['additional_context']} |
|
|
|
Please provide a comprehensive answer based on the above information. |
|
""" |
|
response = self.llm.invoke(prompt_text) |
|
return response.content, summarized_context |
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def get_retrieval_chain(uploaded_file, model): |
|
with st.spinner("Processing document... This may take a minute."): |
|
|
|
genai.configure(api_key=gemini_key) |
|
embeddings = GeminiEmbeddings(api_key=gemini_key) |
|
|
|
|
|
docs = read_data_from_doc(uploaded_file) |
|
splits = make_chunks(docs) |
|
|
|
|
|
pc = Pinecone(api_key=pinecone_key) |
|
|
|
|
|
indexes = pc.list_indexes() |
|
index_name = "research-rag" |
|
if index_name not in [idx.name for idx in indexes]: |
|
pc.create_index( |
|
name=index_name, |
|
dimension=768, |
|
metric="cosine" |
|
) |
|
|
|
vectorstore = PineconeVectorStore.from_documents( |
|
splits, |
|
embeddings, |
|
index_name=index_name, |
|
) |
|
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4}) |
|
|
|
|
|
llm = ChatGroq(model_name=model, temperature=0.75, api_key=groq_key) |
|
|
|
system_prompt = """ |
|
You are an AI assistant answering questions based on retrieved documents and additional context. |
|
Use the provided context from both database retrieval and additional sources to answer the question. |
|
|
|
- **Discard irrelevant context:** If one of the contexts (retrieved or additional) does not match the question, ignore it. |
|
- **Highlight conflicting information:** If multiple sources provide conflicting information, explicitly mention it by saying: |
|
- "According to the retrieved context, ... but as per internet sources, ..." |
|
- "According to the retrieved context, ... but as per internet sources, ..." |
|
- **Prioritize accuracy:** If neither context provides a relevant answer, say "I don't know" instead of guessing. |
|
|
|
Provide concise yet informative answers, ensuring clarity and completeness. |
|
|
|
Retrieved Context: {context} |
|
Additional Context: {additional_context} |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_prompt), |
|
("human", "{input}\n\nRetrieved Context: {context}\n\nAdditional Context: {additional_context}"), |
|
]) |
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, prompt) |
|
chain = create_retrieval_chain(retriever, question_answer_chain) |
|
|
|
return chain |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_search_prompt(query, context=""): |
|
system_prompt = """You are a smart assistant designed to determine whether a query needs data from a web search or can be answered using a document database. |
|
Consider the provided context if available. |
|
If the query requires external information, No context is provided, Irrelevent context is present or latest information is required, then output the special token <SEARCH> |
|
followed by relevant keywords extracted from the query to optimize for search engine results. |
|
Ensure the keywords are concise and relevant. If document data is sufficient, simply return blank.""" |
|
|
|
if context: |
|
return f"{system_prompt}\n\nContext: {context}\n\nQuery: {query}" |
|
|
|
return f"{system_prompt}\n\nQuery: {query}" |
|
|
|
def create_summary_prompt(content): |
|
return f"""Please provide a comprehensive yet concise summary of the following content, highlighting the most important points and maintaining factual accuracy. Organize the information in a clear and coherent manner: |
|
|
|
Content to summarize: |
|
{content} |
|
|
|
Summary:""" |
|
|
|
|
|
def init_selenium_driver(): |
|
chrome_options = Options() |
|
chrome_options.add_argument("--headless") |
|
chrome_options.add_argument("--disable-gpu") |
|
chrome_options.add_argument("--no-sandbox") |
|
chrome_options.add_argument("--disable-dev-shm-usage") |
|
|
|
driver = webdriver.Chrome(options=chrome_options) |
|
return driver |
|
|
|
def extract_static_page(url): |
|
try: |
|
response = requests.get(url, timeout=5) |
|
response.raise_for_status() |
|
soup = BeautifulSoup(response.text, 'lxml') |
|
|
|
text = soup.get_text(separator=" ", strip=True) |
|
return text[:5000] |
|
|
|
except requests.exceptions.RequestException as e: |
|
st.error(f"Error fetching page: {e}") |
|
return None |
|
|
|
def extract_dynamic_page(url, driver): |
|
try: |
|
driver.get(url) |
|
time.sleep(random.uniform(2, 5)) |
|
|
|
body = driver.find_element(By.TAG_NAME, "body") |
|
ActionChains(driver).move_to_element(body).perform() |
|
time.sleep(random.uniform(2, 5)) |
|
|
|
page_source = driver.page_source |
|
tree = html.fromstring(page_source) |
|
|
|
text = tree.xpath('//body//text()') |
|
text_content = ' '.join(text).strip() |
|
return text_content[:1000] |
|
|
|
except Exception as e: |
|
st.error(f"Error fetching dynamic page: {e}") |
|
return None |
|
|
|
def scrape_page(url): |
|
if "javascript" in url or "dynamic" in url: |
|
driver = init_selenium_driver() |
|
text = extract_dynamic_page(url, driver) |
|
driver.quit() |
|
else: |
|
text = extract_static_page(url) |
|
|
|
return text |
|
|
|
def scrape_web(urls, max_urls=5): |
|
texts = [] |
|
|
|
for url in urls[:max_urls]: |
|
text = scrape_page(url) |
|
|
|
if text: |
|
texts.append(text) |
|
else: |
|
st.warning(f"Failed to retrieve content from {url}") |
|
|
|
return texts |
|
|
|
|
|
def check_search_needed(model, query, context): |
|
prompt = create_search_prompt(query, context) |
|
response = generate_gemini_response(model, prompt) |
|
|
|
if "<SEARCH>" in response: |
|
search_terms = response.split("<SEARCH>")[1].strip() |
|
return True, search_terms |
|
return False, None |
|
|
|
def summarize_content(model, content): |
|
prompt = create_summary_prompt(content) |
|
return generate_gemini_response(model, prompt) |
|
|
|
def process_query(query, context=''): |
|
with st.spinner("Processing query..."): |
|
model = get_gemini_model() |
|
search_tool = DuckDuckGoSearchRun() |
|
|
|
needs_search, search_terms = check_search_needed(model, query, context) |
|
|
|
result = { |
|
"original_query": query, |
|
"needs_search": needs_search, |
|
"search_terms": search_terms, |
|
"web_content": None, |
|
"summary": None |
|
} |
|
|
|
if needs_search: |
|
with st.spinner(f"Searching the web for: {search_terms}"): |
|
search_results = search_tool.run(search_terms) |
|
result["web_content"] = search_results |
|
|
|
with st.spinner("Summarizing search results..."): |
|
summary = summarize_content(model, search_results) |
|
result["summary"] = summary |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_progress_bar(message="Processing..."): |
|
progress_container = st.empty() |
|
progress_bar = progress_container.progress(0) |
|
|
|
for i in range(100): |
|
time.sleep(0.01) |
|
progress_bar.progress(i + 1) |
|
|
|
progress_container.empty() |
|
|
|
|
|
def display_search_history(history_key, input_key): |
|
if history_key in st.session_state and st.session_state[history_key]: |
|
with st.expander("π Search History", expanded=False): |
|
for i, query in enumerate(st.session_state[history_key]): |
|
col1, col2 = st.columns([4, 1]) |
|
with col1: |
|
st.write(f"**{i+1}.** {query}") |
|
with col2: |
|
if st.button("Use", key=f"history_{i}", help="Use this query again"): |
|
st.session_state[input_key] = query |
|
st.experimental_rerun() |
|
st.divider() |
|
|
|
|
|
def main(): |
|
|
|
if "research_history_queries" not in st.session_state: |
|
st.session_state.research_history_queries = [] |
|
|
|
if "web_search_history" not in st.session_state: |
|
st.session_state.web_search_history = [] |
|
|
|
|
|
with st.sidebar: |
|
|
|
st.image("https://img.icons8.com/fluency/96/000000/artificial-intelligence.png", width=80) |
|
st.title("AI Research Hub") |
|
st.markdown("---") |
|
|
|
|
|
st.subheader("π Navigation") |
|
app_mode = st.radio( |
|
"", |
|
[ |
|
"π¬ Research Assistant", |
|
"π Document Q&A", |
|
"π Web Search" |
|
] |
|
) |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("π¨ Appearance") |
|
|
|
|
|
current_theme = st.session_state.theme |
|
theme_icon = "π" if current_theme == "light" else "βοΈ" |
|
theme_label = f"{theme_icon} Toggle {current_theme.capitalize()} Mode" |
|
|
|
if st.button(theme_label): |
|
|
|
if st.session_state.theme == 'light': |
|
st.session_state.theme = 'dark' |
|
|
|
st._config.set_option('theme.base', 'dark') |
|
st._config.set_option('theme.backgroundColor', '#0E1117') |
|
st._config.set_option('theme.secondaryBackgroundColor', '#262730') |
|
st._config.set_option('theme.textColor', '#FAFAFA') |
|
else: |
|
st.session_state.theme = 'light' |
|
|
|
st._config.set_option('theme.base', 'light') |
|
st._config.set_option('theme.backgroundColor', '#FFFFFF') |
|
st._config.set_option('theme.secondaryBackgroundColor', '#F0F2F6') |
|
st._config.set_option('theme.textColor', '#31333F') |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("π API Status") |
|
|
|
api_col1, api_col2 = st.columns(2) |
|
|
|
with api_col1: |
|
st.markdown("**Groq API**") |
|
st.markdown("**Gemini API**") |
|
st.markdown("**Pinecone API**") |
|
|
|
with api_col2: |
|
if groq_key: |
|
st.markdown("β
Connected") |
|
else: |
|
st.markdown("β Missing") |
|
|
|
if gemini_key: |
|
st.markdown("β
Connected") |
|
else: |
|
st.markdown("β Missing") |
|
|
|
if pinecone_key: |
|
st.markdown("β
Connected") |
|
else: |
|
st.markdown("β Missing") |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("βΉοΈ About") |
|
st.markdown(""" |
|
This AI Research Assistant helps you find and analyze information from various sources including arXiv papers, Wikipedia articles, your documents, and web search results. |
|
""") |
|
|
|
|
|
st.markdown("---") |
|
st.caption("Version 2.0 | Updated April 2025") |
|
|
|
|
|
if "Research Assistant" in app_mode: |
|
display_research_assistant() |
|
elif "Document Q&A" in app_mode: |
|
display_document_qa() |
|
else: |
|
display_web_search() |
|
|
|
def display_research_assistant(): |
|
|
|
st.markdown(""" |
|
<div style="display: flex; align-items: center; margin-bottom: 1rem;"> |
|
<img src="https://img.icons8.com/fluency/48/000000/microscope.png" style="margin-right: 1rem;"> |
|
<div> |
|
<h1 style="margin: 0;">Research Assistant</h1> |
|
<p style="margin: 0; color: #6c757d;">Get insights from arXiv papers and Wikipedia articles</p> |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if "research_history" not in st.session_state: |
|
st.session_state.research_history = [] |
|
|
|
|
|
if "research_assistant" not in st.session_state: |
|
with st.spinner("Initializing Research Assistant..."): |
|
st.session_state.research_assistant = ResearchAssistant() |
|
|
|
|
|
display_search_history("research_history_queries", "research_question") |
|
|
|
|
|
st.markdown(""" |
|
<div style="background-color: white; padding: 1.5rem; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); margin-bottom: 2rem;"> |
|
<h3 style="margin-top: 0; margin-bottom: 1rem;">Ask a Research Question</h3> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with st.form(key="research_form", clear_on_submit=False): |
|
question = st.text_area( |
|
"Your research question:", |
|
key="research_question", |
|
height=100, |
|
placeholder="E.g., What are the latest developments in quantum computing?" |
|
) |
|
|
|
col1, col2 = st.columns([1, 4]) |
|
with col1: |
|
submit_button = st.form_submit_button("π Research") |
|
with col2: |
|
if st.form_submit_button("ποΈ Clear Chat"): |
|
st.session_state.research_history = [] |
|
st.session_state.research_history_queries = [] |
|
st.experimental_rerun() |
|
|
|
|
|
if submit_button and question: |
|
|
|
if question not in st.session_state.research_history_queries: |
|
st.session_state.research_history_queries.insert(0, question) |
|
if len(st.session_state.research_history_queries) > 10: |
|
st.session_state.research_history_queries.pop() |
|
|
|
|
|
st.session_state.research_history.append({"role": "user", "content": question}) |
|
|
|
|
|
answer, sources = st.session_state.research_assistant.chat(question) |
|
|
|
|
|
st.session_state.research_history.append({ |
|
"role": "assistant", |
|
"content": answer, |
|
"sources": sources |
|
}) |
|
|
|
|
|
if st.session_state.research_history: |
|
st.markdown("### Conversation") |
|
|
|
for i, message in enumerate(st.session_state.research_history): |
|
if message["role"] == "user": |
|
with st.chat_message("user", avatar="π€"): |
|
st.write(message['content']) |
|
else: |
|
with st.chat_message("assistant", avatar="π€"): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if message.get("sources"): |
|
with st.expander("π View Sources"): |
|
tabs = st.tabs([f"Source {i+1}" for i in range(len(message["sources"]))]) |
|
for i, (tab, source) in enumerate(zip(tabs, message["sources"])): |
|
with tab: |
|
st.markdown(source) |
|
|
|
def display_document_qa(): |
|
|
|
st.markdown(""" |
|
<div style="display: flex; align-items: center; margin-bottom: 1rem;"> |
|
<img src="https://img.icons8.com/fluency/48/000000/document.png" style="margin-right: 1rem;"> |
|
<div> |
|
<h1 style="margin: 0;">Document Q&A</h1> |
|
<p style="margin: 0; color: #6c757d;">Upload a PDF and ask questions about its content</p> |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if 'document_conversation' not in st.session_state: |
|
st.session_state.document_conversation = [] |
|
|
|
|
|
st.markdown(""" |
|
<div style="background-color: white; padding: 1.5rem; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); margin-bottom: 2rem;"> |
|
<h3 style="margin-top: 0; margin-bottom: 1rem;">Upload Document</h3> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
|
|
st.markdown("#### Model Selection") |
|
model_name = st.selectbox( |
|
"Select AI Model", |
|
[ |
|
"llama3-70b-8192", |
|
"gemma2-9b-it", |
|
"llama-3.3-70b-versatile", |
|
"llama-3.1-8b-instant", |
|
"llama-guard-3-8b", |
|
"mixtral-8x7b-32768", |
|
"deepseek-r1-distill-llama-70b", |
|
"llama-3.2-1b-preview" |
|
], |
|
index=0 |
|
) |
|
|
|
with col2: |
|
|
|
st.markdown("#### Document Upload") |
|
uploaded_file = st.file_uploader( |
|
"Drag and drop your PDF here", |
|
type="pdf", |
|
help="Upload a PDF document to analyze" |
|
) |
|
|
|
|
|
if uploaded_file: |
|
try: |
|
|
|
with st.spinner("Processing your document..."): |
|
chain = get_retrieval_chain( |
|
uploaded_file, |
|
model_name |
|
) |
|
|
|
|
|
st.success(f"β
Document '{uploaded_file.name}' processed successfully!") |
|
|
|
|
|
st.markdown(""" |
|
<div style="background-color: #f8f9fa; padding: 1rem; border-radius: 8px; border-left: 4px solid #1E3A8A; margin-bottom: 1rem;"> |
|
<h4 style="margin-top: 0;">Document Ready for Questions</h4> |
|
<p>You can now ask questions about the content of your document.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("### Chat with your Document") |
|
|
|
|
|
for q, a in st.session_state.document_conversation: |
|
with st.chat_message("user", avatar="π€"): |
|
st.write(q) |
|
with st.chat_message("assistant", avatar="π€"): |
|
st.write(a) |
|
|
|
|
|
question = st.chat_input("Ask a question about your document...") |
|
|
|
if question: |
|
with st.chat_message("user", avatar="π€"): |
|
st.write(question) |
|
|
|
with st.chat_message("assistant", avatar="π€"): |
|
with st.spinner("Analyzing document..."): |
|
additional_context = "" |
|
result = chain.invoke({ |
|
"input": question, |
|
"additional_context": additional_context |
|
}) |
|
answer = result['answer'] |
|
st.write(answer) |
|
|
|
|
|
st.session_state.document_conversation.append((question, answer)) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
elif not (groq_key and gemini_key and pinecone_key): |
|
|
|
st.warning("β οΈ Please make sure all API keys are properly configured in your environment variables.") |
|
|
|
def display_web_search(): |
|
|
|
st.markdown(""" |
|
<div style="display: flex; align-items: center; margin-bottom: 1rem;"> |
|
<img src="https://img.icons8.com/fluency/48/000000/internet.png" style="margin-right: 1rem;"> |
|
<div> |
|
<h1 style="margin: 0;">Web Search</h1> |
|
<p style="margin: 0; color: #6c757d;">Search the web for answers to your questions</p> |
|
</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
display_search_history("web_search_history", "web_query") |
|
|
|
|
|
st.markdown(""" |
|
<div style="background-color: white; padding: 1.5rem; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); margin-bottom: 2rem;"> |
|
<h3 style="margin-top: 0; margin-bottom: 1rem;">Web Research</h3> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with st.form("web_query_form"): |
|
query = st.text_area( |
|
"Enter your research question", |
|
key="web_query", |
|
height=100, |
|
placeholder="E.g., What are the latest developments in quantum computing?" |
|
) |
|
|
|
|
|
with st.expander("Advanced Options", expanded=False): |
|
context = st.text_area( |
|
"Additional context (optional)", |
|
height=100, |
|
placeholder="Add any additional context that might help with the research" |
|
) |
|
|
|
|
|
submit_col1, submit_col2 = st.columns([1, 4]) |
|
with submit_col1: |
|
submit_button = st.form_submit_button("π Research") |
|
with submit_col2: |
|
st.write("") |
|
|
|
|
|
if submit_button and query: |
|
|
|
if query not in st.session_state.web_search_history: |
|
st.session_state.web_search_history.insert(0, query) |
|
if len(st.session_state.web_search_history) > 10: |
|
st.session_state.web_search_history.pop() |
|
|
|
|
|
result = process_query(query, context) |
|
|
|
if result["needs_search"]: |
|
|
|
st.markdown(""" |
|
<div style="background-color: #f0f8ff; padding: 1rem; border-radius: 8px; border-left: 4px solid #4CAF50; margin-bottom: 1rem;"> |
|
<h4 style="margin-top: 0; color: #4CAF50;">β
Research Complete</h4> |
|
<p>Web search completed successfully. Results are shown below.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
search_tab, summary_tab = st.tabs(["π Search Details", "π Summary"]) |
|
|
|
with search_tab: |
|
st.subheader("Search Terms Used") |
|
st.info(result["search_terms"]) |
|
|
|
st.subheader("Raw Web Content") |
|
st.text_area("Web Content", result["web_content"], height=200) |
|
|
|
with summary_tab: |
|
st.subheader("Summary of Findings") |
|
st.markdown(result["summary"]) |
|
else: |
|
|
|
st.info("Based on the analysis, no web search was needed for this query.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|