Smart_AAS_v2.0 / unified_document_processor.py
TahaRasouli's picture
Update unified_document_processor.py
b658c92 verified
raw
history blame
20.1 kB
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure both NLTK resources are available."""
try:
nltk.download('punkt')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
except Exception as e:
print(f"Warning: Error downloading NLTK resources: {str(e)}")
print("Falling back to basic sentence splitting...")
def _basic_sentence_split(self, text: str) -> List[str]:
"""Fallback method for sentence tokenization"""
sentences = []
current = ""
for char in text:
current += char
if char in ['.', '!', '?'] and len(current.strip()) > 0:
sentences.append(current.strip())
current = ""
if current.strip():
sentences.append(current.strip())
return sentences
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
try:
sentences = sent_tokenize(text)
except Exception as e:
print(f"Warning: Using fallback sentence splitting: {str(e)}")
sentences = self._basic_sentence_split(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def store_in_vector_db(self, text: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[text],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with direct embedding"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
# Process XML into semantic chunks with context
chunks = []
current_path = []
def process_element(element, context=None):
if context is None:
context = {}
# Create element description
current_path.append(element.tag)
element_info = []
# Add tag information
element_info.append(f"Element: {element.tag}")
element_info.append(f"Path: {'/' + '/'.join(current_path)}")
# Process namespace if present
if '}' in element.tag:
namespace = element.tag.split('}')[0].strip('{')
element_info.append(f"Namespace: {namespace}")
# Process attributes with improved structure
if element.attrib:
for key, value in element.attrib.items():
element_info.append(f"Attribute - {key}: {value}")
# Process text content
if element.text and element.text.strip():
element_info.append(f"Content: {element.text.strip()}")
# Create chunk text
chunk_text = " | ".join(element_info)
# Store chunk with metadata
chunks.append({
'text': chunk_text,
'path': '/' + '/'.join(current_path),
'context': context.copy(),
'element_type': element.tag
})
# Process children
child_context = context.copy()
if element.attrib:
child_context[element.tag] = element.attrib
for child in element:
process_element(child, child_context)
current_path.pop()
# Start processing from root
process_element(root)
print(f"Generated {len(chunks)} XML chunks")
results = []
for i, chunk in enumerate(chunks):
try:
metadata = {
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': i,
'total_chunks': len(chunks),
'xml_path': chunk['path'],
'element_type': chunk['element_type'],
'context': json.dumps(chunk['context']),
'timestamp': str(datetime.datetime.now())
}
# Store directly in vector database
doc_id = self.store_in_vector_db(chunk['text'], metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'text': chunk['text']
})
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file with direct embedding"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
try:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
# Store directly in vector database
doc_id = self.store_in_vector_db(chunk, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'text': chunk[:200] + "..." if len(chunk) > 200 else chunk
})
except Exception as e:
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
# Format answer based on content type
formatted_answer = []
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
if meta['content_type'] == 'xml':
formatted_answer.append(f"Found in XML path: {meta['xml_path']}\n{doc}")
else:
formatted_answer.append(doc)
# Create response using the matched content
prompt = f"""Based on these relevant sections, please answer: {question}
Relevant Content:
{' '.join(formatted_answer)}
Please provide a clear, concise answer based on the above content."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"
def get_detailed_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict:
"""Get detailed context including path and metadata information"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas", "distances"]
)
if not results['documents'][0]:
return {
'success': False,
'error': "No relevant content found"
}
detailed_results = []
for doc, meta, distance in zip(results['documents'][0], results['metadatas'][0], results['distances'][0]):
result_info = {
'content': doc,
'metadata': meta,
'relevance_score': 1 - distance, # Convert distance to similarity score
'source_info': {
'file': meta['source_file'],
'type': meta['content_type'],
'path': meta.get('xml_path', 'N/A'), # Only for XML files
'context': json.loads(meta['context']) if meta.get('context') else {}
}
}
detailed_results.append(result_info)
return {
'success': True,
'results': detailed_results,
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_hierarchical_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict:
"""Get hierarchical context for XML files including parent-child relationships"""
try:
# Get initial results
initial_results = self.get_detailed_context(question, selected_files, n_results)
if not initial_results['success']:
return initial_results
hierarchical_results = []
for result in initial_results['results']:
if result['metadata']['content_type'] == 'xml':
# Get parent elements
parent_path = '/'.join(result['source_info']['path'].split('/')[:-1])
if parent_path:
parent_filter = {
'source_file': result['metadata']['source_file'],
'xml_path': parent_path
}
parent_results = self.collection.query(
query_texts=[""], # Empty query to get exact match
where=parent_filter,
include=["documents", "metadatas"],
n_results=1
)
if parent_results['documents'][0]:
result['parent_info'] = {
'content': parent_results['documents'][0][0],
'metadata': parent_results['metadatas'][0][0]
}
# Get immediate children
child_path_prefix = result['source_info']['path'] + '/'
child_filter = {
'source_file': result['metadata']['source_file'],
'xml_path': {'$contains': child_path_prefix}
}
child_results = self.collection.query(
query_texts=[""], # Empty query to get exact matches
where=child_filter,
include=["documents", "metadatas"],
n_results=5
)
if child_results['documents'][0]:
result['children_info'] = [{
'content': doc,
'metadata': meta
} for doc, meta in zip(child_results['documents'][0], child_results['metadatas'][0])]
hierarchical_results.append(result)
return {
'success': True,
'results': hierarchical_results,
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_summary_and_details(self, question: str, selected_files: List[str]) -> Dict:
"""Get both a summary answer and detailed supporting information"""
try:
# Get hierarchical context first
detailed_results = self.get_hierarchical_context(question, selected_files)
if not detailed_results['success']:
return detailed_results
# Create summary prompt
relevant_content = []
for result in detailed_results['results']:
if result['metadata']['content_type'] == 'xml':
content_info = [
f"XML Path: {result['source_info']['path']}",
f"Content: {result['content']}"
]
if 'parent_info' in result:
content_info.append(f"Parent: {result['parent_info']['content']}")
if 'children_info' in result:
children_content = [child['content'] for child in result['children_info']]
content_info.append(f"Related Elements: {', '.join(children_content)}")
else:
content_info = [f"Content: {result['content']}"]
relevant_content.append('\n'.join(content_info))
summary_prompt = f"""Based on the following content, please provide:
1. A concise answer to the question
2. Key supporting points
3. Related context if relevant
Question: {question}
Content:
{'\n\n'.join(relevant_content)}
"""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": summary_prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return {
'success': True,
'summary': response.choices[0].message.content,
'details': detailed_results['results'],
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}