quick-spin / app.py
DrishtiSharma's picture
Update app.py
5c59c17 verified
raw
history blame
8.89 kB
import sys
import os
import re
import shutil
import time
import streamlit as st
import nltk
import tempfile
import subprocess
import base64 # For embedding PDF content
# Pin NLTK to version 3.9.1
REQUIRED_NLTK_VERSION = "3.9.1"
subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"])
# Set up temporary directory for NLTK resources
nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
os.makedirs(nltk_data_path, exist_ok=True)
nltk.data.path.append(nltk_data_path)
# Download 'punkt_tab' for compatibility
try:
print("Ensuring NLTK 'punkt_tab' resource is downloaded...")
nltk.download("punkt_tab", download_dir=nltk_data_path)
except Exception as e:
print(f"Error downloading NLTK 'punkt_tab': {e}")
raise e
sys.path.append(os.path.abspath("."))
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import NLTKTextSplitter
from patent_downloader import PatentDownloader
PERSISTED_DIRECTORY = tempfile.mkdtemp()
# Fetch API key securely from the environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
st.stop()
def check_poppler_installed():
if not shutil.which("pdfinfo"):
raise EnvironmentError(
"Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
)
check_poppler_installed()
def load_docs(document_path):
try:
loader = UnstructuredPDFLoader(
document_path,
mode="elements",
strategy="fast",
ocr_languages=None
)
documents = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=1000)
split_docs = text_splitter.split_documents(documents)
# Filter metadata to only include str, int, float, or bool
for doc in split_docs:
if hasattr(doc, "metadata") and isinstance(doc.metadata, dict):
doc.metadata = {
k: v for k, v in doc.metadata.items()
if isinstance(v, (str, int, float, bool))
}
return split_docs
except Exception as e:
st.error(f"Failed to load and process PDF: {e}")
st.stop()
def already_indexed(vectordb, file_name):
indexed_sources = set(
x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
)
return file_name in indexed_sources
def load_chain(file_name=None):
loaded_patent = st.session_state.get("LOADED_PATENT")
vectordb = Chroma(
persist_directory=PERSISTED_DIRECTORY,
embedding_function=HuggingFaceEmbeddings(),
)
if loaded_patent == file_name or already_indexed(vectordb, file_name):
st.write("✅ Already indexed.")
else:
vectordb.delete_collection()
docs = load_docs(file_name)
st.write("🔍 Number of Documents: ", len(docs))
vectordb = Chroma.from_documents(
docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
)
vectordb.persist()
st.session_state["LOADED_PATENT"] = file_name
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
input_key="question",
output_key="answer",
)
return ConversationalRetrievalChain.from_llm(
OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
vectordb.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False,
memory=memory,
)
def extract_patent_number(url):
pattern = r"/patent/([A-Z]{2}\d+)"
match = re.search(pattern, url)
return match.group(1) if match else None
def download_pdf(patent_number):
try:
patent_downloader = PatentDownloader(verbose=True)
output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
return output_path[0]
except Exception as e:
st.error(f"Failed to download patent PDF: {e}")
st.stop()
def embed_pdf(file_path):
"""Convert PDF file to base64 and embed it in an iframe."""
with open(file_path, "rb") as f:
base64_pdf = base64.b64encode(f.read()).decode("utf-8")
pdf_display = f"""
<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" style="border: none;"></iframe>
"""
return pdf_display
if __name__ == "__main__":
st.set_page_config(
page_title="Patent Chat: Google Patents Chat Demo",
page_icon="📖",
layout="wide",
initial_sidebar_state="expanded",
)
st.header("📖 Patent Chat: Google Patents Chat Demo")
# Fetch query parameters safely
query_params = st.query_params
default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en")
# Input for Google Patent Link
patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100)
# Button to start processing
if st.button("Load and Process Patent"):
if not patent_link:
st.warning("Please enter a Google patent link to proceed.")
st.stop()
patent_number = extract_patent_number(patent_link)
if not patent_number:
st.error("Invalid patent link format. Please provide a valid Google patent link.")
st.stop()
st.write(f"Patent number: **{patent_number}**")
# Define PDF path in temp directory
pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
if os.path.isfile(pdf_path):
st.write("✅ File already downloaded.")
else:
st.write("📥 Downloading patent file...")
pdf_path = download_pdf(patent_number)
st.write(f"✅ File downloaded: {pdf_path}")
# Display a preview of the downloaded PDF
st.write("📄 Preview of the downloaded patent PDF:")
if os.path.isfile(pdf_path):
with open(pdf_path, "rb") as pdf_file:
st.download_button(
label="Download PDF",
data=pdf_file,
file_name=f"{patent_number}.pdf",
mime="application/pdf"
)
# Embed PDF content using base64
st.write("📋 PDF Content:")
pdf_view = embed_pdf(pdf_path)
st.components.v1.html(pdf_view, height=1000)
st.write("🔄 Loading document into the system...")
# Persist the chain in session state to prevent reloading
if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path:
st.session_state.chain = load_chain(pdf_path)
st.session_state.loaded_file = pdf_path
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
st.success("🚀 Document successfully loaded! You can now start asking questions.")
# Initialize messages if not already done
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
# Display previous chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat Input Section
if "chain" in st.session_state:
if user_input := st.chat_input("What is your question?"):
# Append user message
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
# Generate assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("Generating response..."):
try:
assistant_response = st.session_state.chain({"question": user_input})
full_response = assistant_response["answer"]
except Exception as e:
full_response = f"An error occurred: {e}"
# Display assistant response
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
else:
st.info("Press the 'Load and Process Patent' button to start processing.")