Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import re | |
import shutil | |
import time | |
import fitz | |
import streamlit as st | |
import nltk | |
import tempfile | |
import subprocess | |
# Pin NLTK to version 3.9.1 | |
REQUIRED_NLTK_VERSION = "3.9.1" | |
subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"]) | |
# Set up temporary directory for NLTK resources | |
nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data") | |
os.makedirs(nltk_data_path, exist_ok=True) | |
nltk.data.path.append(nltk_data_path) | |
# Download 'punkt_tab' for compatibility | |
try: | |
print("Ensuring NLTK 'punkt_tab' resource is downloaded...") | |
nltk.download("punkt_tab", download_dir=nltk_data_path) | |
except Exception as e: | |
print(f"Error downloading NLTK 'punkt_tab': {e}") | |
raise e | |
sys.path.append(os.path.abspath(".")) | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.llms import OpenAI | |
from langchain.document_loaders import UnstructuredPDFLoader | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from patent_downloader import PatentDownloader | |
PERSISTED_DIRECTORY = tempfile.mkdtemp() | |
# Fetch API key securely from the environment | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.") | |
st.stop() | |
def check_poppler_installed(): | |
if not shutil.which("pdfinfo"): | |
raise EnvironmentError( | |
"Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing." | |
) | |
check_poppler_installed() | |
def extract_patent_number(url): | |
pattern = r"/patent/([A-Z]{2}\d+)" | |
match = re.search(pattern, url) | |
return match.group(1) if match else None | |
def download_pdf(patent_number): | |
try: | |
patent_downloader = PatentDownloader(verbose=True) | |
output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir()) | |
return output_path[0] | |
except Exception as e: | |
st.error(f"Failed to download patent PDF: {e}") | |
st.stop() | |
def clean_extracted_text(text): | |
""" | |
Cleans extracted text to remove metadata, headers, and irrelevant content. | |
""" | |
lines = text.split("\n") | |
cleaned_lines = [] | |
for line in lines: | |
line = line.strip() | |
# Filter out lines with metadata patterns | |
if ( | |
re.match(r"^(U\.S\.|United States|Sheet|Figure|References|Patent No|Date of Patent)", line) | |
or re.match(r"^\(?\d+\)?$", line) # Matches single numbers (page numbers) | |
or "Examiner" in line | |
or "Attorney" in line | |
or len(line) < 30 # Skip very short lines | |
): | |
continue | |
cleaned_lines.append(line) | |
return "\n".join(cleaned_lines) | |
def load_docs(document_path): | |
""" | |
Load and clean the PDF content, then split into chunks. | |
""" | |
try: | |
import fitz # PyMuPDF for text extraction | |
# Step 1: Extract plain text from PDF | |
doc = fitz.open(document_path) | |
extracted_text = [] | |
for page_num, page in enumerate(doc): | |
page_text = page.get_text("text") # Extract text | |
clean_page_text = clean_extracted_text(page_text) | |
if clean_page_text: # Keep only non-empty cleaned text | |
extracted_text.append(clean_page_text) | |
doc.close() | |
# Combine all pages into one text | |
full_text = "\n".join(extracted_text) | |
st.write(f"Total Cleaned Text Length: {len(full_text)} characters") | |
# Step 2: Chunk the cleaned text | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100, | |
separators=["\n\n", "\n", " ", ""] | |
) | |
split_docs = text_splitter.create_documents([full_text]) | |
st.write(f"Total Chunks After Splitting: {len(split_docs)}") | |
for i, doc in enumerate(split_docs[:3]): # Show first 3 chunks only | |
st.write(f"Chunk {i + 1}: {doc.page_content[:300]}...") | |
return split_docs | |
except Exception as e: | |
st.error(f"Failed to load and process PDF: {e}") | |
st.stop() | |
def initialize_vector_store(documents, persist_dir): | |
""" | |
Initialize the vector store with the provided documents. | |
""" | |
try: | |
# Initialize HuggingFace embeddings | |
embeddings = HuggingFaceEmbeddings() | |
# Create a Chroma vector store with the embeddings | |
vectordb = Chroma.from_documents( | |
documents=documents, | |
embedding=embeddings, # Pass embeddings directly | |
persist_directory=persist_dir | |
) | |
vectordb.persist() # Persist the vector store to disk | |
return vectordb | |
except Exception as e: | |
st.error(f"Failed to initialize the vector store: {e}") | |
st.stop() | |
def create_retriever(vectordb): | |
""" | |
Create a retriever from the vector store. | |
""" | |
return vectordb.as_retriever(search_kwargs={"k": 3}) | |
def create_retrieval_chain(vectordb, api_key): | |
""" | |
Create a conversational retrieval chain with memory. | |
""" | |
retriever = create_retriever(vectordb) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
return ConversationalRetrievalChain.from_llm( | |
llm=OpenAI(temperature=0, openai_api_key=api_key), | |
retriever=retriever, | |
memory=memory | |
) | |
def setup_retrieval_pipeline(file_path, persist_dir, api_key): | |
""" | |
Load documents, create a vector store, and initialize a retrieval chain. | |
""" | |
st.write(f"Processing file: {file_path}") | |
# Step 1: Process and chunk documents | |
documents = load_docs(file_path) | |
if not documents: | |
st.error("Failed to process documents. Please check the input file.") | |
return None | |
# Step 2: Initialize vector store | |
vectordb = initialize_vector_store(documents, persist_dir) | |
# Step 3: Create retrieval chain | |
retrieval_chain = create_retrieval_chain(vectordb, api_key) | |
return retrieval_chain | |
def extract_patent_number(url): | |
pattern = r"/patent/([A-Z]{2}\d+)" | |
match = re.search(pattern, url) | |
return match.group(1) if match else None | |
def preview_pdf(pdf_path, scale_factor=0.5): | |
""" | |
Generate and display a resized preview of the first page of the PDF. | |
Args: | |
pdf_path (str): Path to the PDF file. | |
scale_factor (float): Factor to reduce the image size (default is 0.5). | |
Returns: | |
str: Path to the resized image preview. | |
""" | |
try: | |
# Open the PDF and extract the first page | |
doc = fitz.open(pdf_path) | |
first_page = doc[0] | |
# Apply scaling using a transformation matrix | |
matrix = fitz.Matrix(scale_factor, scale_factor) # Scale down the image | |
pix = first_page.get_pixmap(matrix=matrix) # Generate scaled image | |
# Save the preview image | |
temp_image_path = os.path.join(tempfile.gettempdir(), "pdf_preview.png") | |
pix.save(temp_image_path) | |
doc.close() | |
return temp_image_path | |
except Exception as e: | |
st.error(f"Error generating PDF preview: {e}") | |
return None | |
if __name__ == "__main__": | |
st.set_page_config( | |
page_title="Patent Chat: Google Patents Chat Demo", | |
page_icon="", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
st.header(" Patent Chat: Google Patents Chat Demo") | |
# Input for Google Patent Link | |
patent_link = st.text_area( | |
"Enter Google Patent Link:", | |
value="https://patents.google.com/patent/US8676427B1/en", | |
height=90 | |
) | |
# Initialize session state | |
for key in ["LOADED_PATENT", "pdf_preview", "loaded_pdf_path", "chain", "messages", "loading_complete"]: | |
if key not in st.session_state: | |
st.session_state[key] = None | |
# Button to load and process patent | |
if st.button("Load and Process Patent"): | |
if not patent_link: | |
st.warning("Please enter a valid Google patent link.") | |
st.stop() | |
# Extract patent number | |
patent_number = extract_patent_number(patent_link) | |
if not patent_number: | |
st.error("Invalid patent link format.") | |
st.stop() | |
st.write(f"Patent number: **{patent_number}**") | |
# File handling | |
pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf") | |
if not os.path.isfile(pdf_path): | |
with st.spinner(" Downloading patent file..."): | |
try: | |
pdf_path = download_pdf(patent_number) | |
st.write(f"\u2705 File downloaded: {pdf_path}") | |
except Exception as e: | |
st.error(f"Failed to download patent: {e}") | |
st.stop() | |
else: | |
st.write("\u2705 File already downloaded.") | |
# Generate PDF preview only if not already displayed | |
if not st.session_state.get("pdf_preview_displayed", False): | |
with st.spinner("Generating PDF preview..."): | |
preview_image_path = preview_pdf(pdf_path, scale_factor=0.5) | |
if preview_image_path: | |
st.session_state.pdf_preview = preview_image_path | |
st.image(preview_image_path, caption="First Page Preview", use_container_width=False) | |
st.session_state["pdf_preview_displayed"] = True | |
else: | |
st.warning("Failed to generate PDF preview.") | |
st.session_state.pdf_preview = None | |
# Load the document into the system | |
st.session_state["loading_complete"] = False | |
with st.spinner("Loading document into the system..."): | |
try: | |
st.session_state.chain = setup_retrieval_pipeline( | |
pdf_path, PERSISTED_DIRECTORY, OPENAI_API_KEY | |
) | |
st.session_state.LOADED_PATENT = patent_number | |
st.session_state.loaded_pdf_path = pdf_path | |
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}] | |
st.session_state["loading_complete"] = True | |
except Exception as e: | |
st.error(f"Failed to load the document: {e}") | |
st.session_state["loading_complete"] = False | |
st.stop() | |
if st.session_state["loading_complete"]: | |
st.success("Document successfully loaded! You can now start asking questions.") | |
# Display previous chat messages | |
if st.session_state.messages: | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input for questions | |
if st.session_state.chain: | |
if user_input := st.chat_input("What is your question?"): | |
# User message | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
# Assistant response | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
with st.spinner("Generating response..."): | |
try: | |
# Generate response using the chain | |
assistant_response = st.session_state.chain({"question": user_input}) | |
full_response = assistant_response.get("answer", "I'm sorry, I couldn't process that question.") | |
except Exception as e: | |
full_response = f"An error occurred: {e}" | |
message_placeholder.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
else: | |
st.info("Press the 'Load and Process Patent' button to start processing.") | |