Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from typing import List, Optional, Dict, Any | |
import logging | |
from enum import Enum | |
from dataclasses import dataclass | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.embeddings.base import Embeddings | |
import PyPDF2 | |
from huggingface_hub import InferenceClient | |
import torch | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
# Install required packages | |
embed_model = HuggingFaceBgeEmbeddings( | |
model_name="all-MiniLM-L6-v2",#"dunzhang/stella_en_1.5B_v5", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
model_name = "meta-llama/Llama-3.2-3B-Instruct"#"google/gemma-2-2b-it"#"prithivMLmods/Llama-3.2-3B-GGUF" | |
client = InferenceClient(model_name) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DocumentFormat(Enum): | |
PDF = ".pdf" | |
# Can be extended for other document types | |
class RAGConfig: | |
"""Configuration for RAG system parameters""" | |
chunk_size: int = 100 | |
chunk_overlap: int = 10 | |
retriever_k: int = 3 | |
persist_directory: str = "./chroma_db" | |
class AdvancedRAGSystem: | |
"""Advanced RAG System with improved error handling and type safety""" | |
def __init__( | |
self, | |
embed_model, | |
llm, | |
config = None | |
): | |
"""Initialize the RAG system with required models and optional configuration""" | |
self.embed_model = embed_model | |
self.llm = llm | |
self.config = config or RAGConfig() | |
self.vector_store: Optional[Chroma] = None | |
self.last_context: Optional[str] = None | |
self.context = None | |
self.source_documents = 0 | |
def _validate_file(self, file_path: Path) : | |
"""Validate if the file is of supported format and exists""" | |
return file_path.suffix.lower() == DocumentFormat.PDF.value and file_path.exists() | |
def _extract_text_from_pdf(self, pdf_path: Path) : | |
"""Extract text from a PDF file with proper error handling""" | |
try: | |
with open(pdf_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
return "\n".join( | |
page.extract_text() | |
for page in pdf_reader.pages | |
) | |
except Exception as e: | |
logger.error(f"Error processing PDF {pdf_path}: {str(e)}") | |
raise ValueError(f"Failed to process PDF {pdf_path}: {str(e)}") | |
def _create_document_chunks(self, texts: List[str]) : | |
"""Split documents into chunks using the configured parameters""" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.config.chunk_size, | |
chunk_overlap=self.config.chunk_overlap, | |
length_function=len, | |
add_start_index=True, | |
) | |
return text_splitter.create_documents(texts) | |
def process_pdfs(self, pdf_files: List[str]) : | |
"""Process and index PDF documents with improved error handling""" | |
try: | |
# Convert to Path objects and validate | |
pdf_paths = [Path(pdf.name) for pdf in pdf_files] | |
invalid_files = [f for f in pdf_paths if not self._validate_file(f)] | |
if invalid_files: | |
raise ValueError(f"Invalid or missing files: {invalid_files}") | |
# Extract text from valid PDFs | |
documents = [ | |
self._extract_text_from_pdf(pdf_path) | |
for pdf_path in pdf_paths | |
] | |
# Create document chunks | |
doc_chunks = self._create_document_chunks(documents) | |
# Initialize or update vector store | |
self.vector_store = Chroma.from_documents( | |
documents=doc_chunks, | |
embedding=self.embed_model, | |
persist_directory=self.config.persist_directory | |
) | |
logger.info(f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files") | |
return f"Successfully processed {len(doc_chunks)} chunks from {len(pdf_files)} PDF files" | |
except Exception as e: | |
error_msg = f"Error during PDF processing: {str(e)}" | |
logger.error(error_msg) | |
raise RuntimeError(error_msg) | |
def get_retriever(self) : | |
"""Get the document retriever with current configuration""" | |
if not self.vector_store: | |
raise RuntimeError("Vector store not initialized. Please process documents first.") | |
return self.vector_store.as_retriever(search_kwargs={"k": self.config.retriever_k}) | |
def _format_context(self, documents: List[Any]) : | |
"""Format retrieved documents into a single context string""" | |
return "\n\n".join(doc.page_content for doc in documents) | |
def query(self, question: str) : | |
"""Query the RAG system with improved error handling and response formatting""" | |
try: | |
if not self.vector_store: | |
raise RuntimeError("Please process PDF documents first before querying") | |
# Retrieve relevant documents | |
retriever = self.get_retriever() | |
retrieved_docs = retriever.get_relevant_documents(question) | |
context = self._format_context(retrieved_docs) | |
self.last_context = context | |
self.context = context | |
self.source_documents = len(retrieved_docs) | |
messages = [ | |
{ | |
"role":"system", | |
"content":f"""You are a helpful assistant. Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: | |
{context} | |
""" | |
}, | |
{ | |
"role": "user", | |
"content": question | |
} | |
] | |
return self.llm.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
max_tokens=500, | |
# stream=True | |
).choices[0].message.content | |
except Exception as e: | |
error_msg = f"Error during query processing: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def create_gradio_interface(rag_system: AdvancedRAGSystem) : | |
"""Create an improved Gradio interface for the RAG system""" | |
def process_files(files: List[Any], chunk_size: int, overlap: int) : | |
"""Process uploaded files with updated configuration""" | |
if not files: | |
return "Please upload PDF files" | |
# Update configuration with new parameters | |
rag_system.config.chunk_size = chunk_size | |
rag_system.config.chunk_overlap = overlap | |
try: | |
return rag_system.process_pdfs(files) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def query_streaming(question: str) : | |
try: | |
return rag_system.query(question) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def update_history(question: str): | |
try: | |
return f"Last context used ({rag_system.source_documents} documents):\n\n{rag_system.context}" | |
except Exception as e: | |
return f"Error retrieving context: {str(e)}" | |
with gr.Blocks(title="Advanced RAG System") as demo: | |
gr.Markdown("# Advanced RAG System with PDF Processing") | |
with gr.Tab("Upload & Process PDFs"): | |
with gr.Row(): | |
with gr.Column(): | |
file_input = gr.File( | |
file_count="multiple", | |
label="Upload PDF Documents", | |
file_types=[".pdf"] | |
) | |
chunk_size = gr.Slider( | |
minimum=100, | |
maximum=10000, | |
value=100, | |
step=100, | |
label="Chunk Size" | |
) | |
overlap = gr.Slider( | |
minimum=10, | |
maximum=5000, | |
value=10, | |
step=10, | |
label="Chunk Overlap" | |
) | |
process_button = gr.Button("Process PDFs", variant="primary") | |
process_output = gr.Textbox(label="Processing Status") | |
with gr.Tab("Query System"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Enter your question here...", | |
lines=3 | |
) | |
query_button = gr.Button("Get Answer", variant="primary") | |
answer_output = gr.Textbox( | |
label="Answer", | |
lines=10 | |
) | |
with gr.Column(scale=1): | |
history_output = gr.Textbox( | |
label="Retrieved Context", | |
lines=15 | |
) | |
# Set up event handlers | |
process_button.click( | |
fn=process_files, | |
inputs=[file_input, chunk_size, overlap], | |
outputs=[process_output] | |
) | |
query_button.click( | |
fn=query_streaming, | |
inputs=[question_input], | |
outputs=[answer_output], | |
# api_name="stream_response", | |
# queue=False, | |
# show_progress=False | |
).then( | |
fn=update_history, | |
inputs=[question_input], | |
outputs=[history_output] | |
) | |
return demo | |
rag_system = AdvancedRAGSystem(embed_model, client) | |
demo = create_gradio_interface(rag_system) | |
if __name__ == "__main__": | |
demo.launch() | |