Smart_AAS_v2.0 / unified_document_processor.py
TahaRasouli's picture
Update unified_document_processor.py
c94951d verified
raw
history blame
30.1 kB
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure both NLTK resources are available."""
try:
nltk.download('punkt')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
except Exception as e:
print(f"Warning: Error downloading NLTK resources: {str(e)}")
print("Falling back to basic sentence splitting...")
def _basic_sentence_split(self, text: str) -> List[str]:
"""Fallback method for sentence tokenization"""
sentences = []
current = ""
for char in text:
current += char
if char in ['.', '!', '?'] and len(current.strip()) > 0:
sentences.append(current.strip())
current = ""
if current.strip():
sentences.append(current.strip())
return sentences
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
try:
sentences = sent_tokenize(text)
except Exception as e:
print(f"Warning: Using fallback sentence splitting: {str(e)}")
sentences = self._basic_sentence_split(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with optimized batching and reduced database operations"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
# Process XML into chunks efficiently
chunks = []
paths = []
def process_element(element, current_path=""):
# Create element description
element_info = []
# Add basic information
element_info.append(f"Element: {element.tag}")
# Process namespace only if present
if '}' in element.tag:
namespace = element.tag.split('}')[0].strip('{')
element_info.append(f"Namespace: {namespace}")
# Process important attributes only
important_attrs = ['NodeId', 'BrowseName', 'DisplayName', 'Description', 'DataType']
attrs = {k: v for k, v in element.attrib.items() if k in important_attrs}
if attrs:
for key, value in attrs.items():
element_info.append(f"{key}: {value}")
# Process text content if meaningful
if element.text and element.text.strip():
element_info.append(f"Content: {element.text.strip()}")
# Create chunk text
chunk_text = " | ".join(element_info)
new_path = f"{current_path}/{element.tag}" if current_path else element.tag
chunks.append(chunk_text)
paths.append(new_path)
# Process children
for child in element:
process_element(child, new_path)
# Start processing from root
process_element(root)
print(f"Generated {len(chunks)} XML chunks")
# Batch process into database
batch_size = 100 # Increased batch size
results = []
for i in range(0, len(chunks), batch_size):
batch_end = min(i + batch_size, len(chunks))
batch_chunks = chunks[i:batch_end]
batch_paths = paths[i:batch_end]
# Prepare batch metadata
batch_metadata = [{
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': idx,
'total_chunks': len(chunks),
'xml_path': path,
'timestamp': str(datetime.datetime.now())
} for idx, path in enumerate(batch_paths, start=i)]
# Generate batch IDs
batch_ids = [
f"{os.path.basename(xml_file_path)}_xml_{idx}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
for idx in range(i, batch_end)
]
# Store batch in vector database
self.collection.add(
documents=batch_chunks,
metadatas=batch_metadata,
ids=batch_ids
)
# Track results
results.extend([{
'chunk': idx,
'success': True,
'doc_id': doc_id,
'text': text
} for idx, (doc_id, text) in enumerate(zip(batch_ids, batch_chunks), start=i)])
# Print progress
print(f"Processed chunks {i} to {batch_end} of {len(chunks)}")
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
print(f"Error processing XML: {str(e)}")
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file with direct embedding"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
try:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
# Store directly in vector database
doc_id = self.store_in_vector_db(chunk, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'text': chunk[:200] + "..." if len(chunk) > 200 else chunk
})
except Exception as e:
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def store_in_vector_db(self, text: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[text],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
# Format answer based on content type
formatted_answer = []
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
if meta['content_type'] == 'xml':
formatted_answer.append(f"Found in XML path: {meta['xml_path']}\n{doc}")
else:
formatted_answer.append(doc)
# Create response using the matched content
prompt = f"""Based on these relevant sections, please answer: {question}
Relevant Content:
{' '.join(formatted_answer)}
Please provide a clear, concise answer based on the above content."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"
def get_detailed_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict:
"""Get detailed context including path and metadata information"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas", "distances"]
)
if not results['documents'][0]:
return {
'success': False,
'error': "No relevant content found"
}
detailed_results = []
for doc, meta, distance in zip(results['documents'][0], results['metadatas'][0], results['distances'][0]):
result_info = {
'content': doc,
'metadata': meta,
'similarity_score': round((1 - distance) * 100, 2), # Convert to percentage
'source_info': {
'file': meta['source_file'],
'type': meta['content_type'],
'path': meta.get('xml_path', 'N/A'),
'context': json.loads(meta['context']) if meta.get('context') else {}
}
}
detailed_results.append(result_info)
return {
'success': True,
'results': detailed_results,
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_hierarchical_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict:
"""Get hierarchical context for XML files including parent-child relationships"""
try:
# Get initial results
initial_results = self.get_detailed_context(question, selected_files, n_results)
if not initial_results['success']:
return initial_results
hierarchical_results = []
for result in initial_results['results']:
if result['metadata']['content_type'] == 'xml':
# Get parent elements
parent_path = '/'.join(result['source_info']['path'].split('/')[:-1])
if parent_path:
parent_filter = {
'source_file': {'$eq': result['metadata']['source_file']},
'xml_path': {'$eq': parent_path}
}
parent_results = self.collection.query(
query_texts=[""], # Empty query to get exact match
where=parent_filter,
include=["documents", "metadatas"],
n_results=1
)
if parent_results['documents'][0]:
result['parent_info'] = {
'content': parent_results['documents'][0][0],
'metadata': parent_results['metadatas'][0][0]
}
# Get all potential children
all_filter = {
'source_file': {'$eq': result['metadata']['source_file']}
}
all_results = self.collection.query(
query_texts=[""],
where=all_filter,
include=["documents", "metadatas"],
n_results=100
)
# Manually filter children
children_info = []
current_path = result['source_info']['path']
if all_results['documents'][0]:
for doc, meta in zip(all_results['documents'][0], all_results['metadatas'][0]):
child_path = meta.get('xml_path', '')
if (child_path.startswith(current_path + '/') and
len(child_path.split('/')) == len(current_path.split('/')) + 1):
children_info.append({
'content': doc,
'metadata': meta
})
if children_info:
result['children_info'] = children_info[:5]
hierarchical_results.append(result)
return {
'success': True,
'results': hierarchical_results,
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_summary_and_details(self, question: str, selected_files: List[str]) -> Dict:
"""Get both a summary answer and detailed supporting information"""
try:
# Get hierarchical context first
detailed_results = self.get_hierarchical_context(question, selected_files)
if not detailed_results['success']:
return detailed_results
# Create summary prompt
relevant_content = []
for result in detailed_results['results']:
if result['metadata']['content_type'] == 'xml':
content_info = [
f"XML Path: {result['source_info']['path']}",
f"Content: {result['content']}"
]
if 'parent_info' in result:
content_info.append(f"Parent: {result['parent_info']['content']}")
if 'children_info' in result:
children_content = [child['content'] for child in result['children_info']]
content_info.append(f"Related Elements: {', '.join(children_content)}")
else:
content_info = [f"Content: {result['content']}"]
relevant_content.append('\n'.join(content_info))
summary_prompt = (
f"Based on the following content, please provide:\n"
"1. A concise answer to the question\n"
"2. Key supporting points\n"
"3. Related context if relevant\n\n"
f"Question: {question}\n\n"
f"Content:\n{chr(10).join(relevant_content)}"
)
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": summary_prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return {
'success': True,
'summary': response.choices[0].message.content,
'details': detailed_results['results'],
'query': question
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def calculate_detailed_score(self, distance: float, metadata: Dict, content: str, query: str) -> Dict:
"""
Calculate a detailed, multi-faceted relevance score
Components:
1. Vector Similarity (40%): Base similarity from embeddings
2. Content Match (20%): Direct term matching
3. Structural Relevance (20%): XML structure relevance (for XML files)
4. Context Completeness (10%): Completeness of metadata/context
5. Freshness (10%): How recent the content is
"""
try:
scores = {}
# 1. Vector Similarity Score (40%)
vector_similarity = 1 - distance # Convert distance to similarity
scores['vector_similarity'] = {
'score': vector_similarity,
'weight': 0.4,
'weighted_score': vector_similarity * 0.4
}
# 2. Content Match Score (20%)
content_match_score = self._calculate_content_match(content, query)
scores['content_match'] = {
'score': content_match_score,
'weight': 0.2,
'weighted_score': content_match_score * 0.2
}
# 3. Structural Relevance Score (20%)
if metadata['content_type'] == 'xml':
structural_score = self._calculate_structural_relevance(metadata)
else:
structural_score = 0.5 # Default for non-XML
scores['structural_relevance'] = {
'score': structural_score,
'weight': 0.2,
'weighted_score': structural_score * 0.2
}
# 4. Context Completeness Score (10%)
context_score = self._calculate_context_completeness(metadata)
scores['context_completeness'] = {
'score': context_score,
'weight': 0.1,
'weighted_score': context_score * 0.1
}
# 5. Freshness Score (10%)
freshness_score = self._calculate_freshness(metadata['timestamp'])
scores['freshness'] = {
'score': freshness_score,
'weight': 0.1,
'weighted_score': freshness_score * 0.1
}
# Calculate total score
total_score = sum(s['weighted_score'] for s in scores.values())
return {
'total_score': total_score,
'component_scores': scores,
'explanation': self._generate_score_explanation(scores)
}
except Exception as e:
print(f"Error in score calculation: {str(e)}")
return {
'total_score': 0.5,
'error': str(e)
}
def _calculate_content_match(self, content: str, query: str) -> float:
"""Calculate direct term matching score"""
try:
# Tokenize content and query
content_terms = set(content.lower().split())
query_terms = set(query.lower().split())
# Calculate overlap
matching_terms = content_terms.intersection(query_terms)
if not query_terms:
return 0.5
# Calculate scores for exact matches and partial matches
exact_match_score = len(matching_terms) / len(query_terms)
# Check for partial matches
partial_matches = 0
for q_term in query_terms:
for c_term in content_terms:
if q_term in c_term or c_term in q_term:
partial_matches += 0.5
partial_match_score = partial_matches / len(query_terms)
# Combine scores (70% exact matches, 30% partial matches)
return (exact_match_score * 0.7) + (partial_match_score * 0.3)
except Exception as e:
print(f"Error in content match calculation: {str(e)}")
return 0.5
def _calculate_structural_relevance(self, metadata: Dict) -> float:
"""Calculate structural relevance score for XML content"""
try:
score = 0.5 # Base score
if 'xml_path' in metadata:
path = metadata['xml_path']
# Score based on path depth (deeper paths might be more specific)
depth = len(path.split('/'))
depth_score = min(depth / 5, 1.0) # Normalize depth score
# Score based on element type
element_type = metadata.get('element_type', '')
type_scores = {
'UAObjectType': 0.9,
'UAVariableType': 0.9,
'UAObject': 0.8,
'UAVariable': 0.8,
'UAMethod': 0.7,
'UAView': 0.6,
'UAReferenceType': 0.7
}
type_score = type_scores.get(element_type, 0.5)
# Score based on context completeness
context = json.loads(metadata.get('context', '{}'))
context_score = len(context) / 10 if context else 0.5
# Combine scores
score = (depth_score * 0.3) + (type_score * 0.4) + (context_score * 0.3)
return score
except Exception as e:
print(f"Error in structural relevance calculation: {str(e)}")
return 0.5
def _calculate_context_completeness(self, metadata: Dict) -> float:
"""Calculate context completeness score"""
try:
expected_fields = {
'xml': ['xml_path', 'element_type', 'context', 'chunk_id', 'total_chunks'],
'pdf': ['chunk_id', 'total_chunks', 'chunk_size']
}
content_type = metadata.get('content_type', '')
if content_type not in expected_fields:
return 0.5
# Check for presence of expected fields
expected = expected_fields[content_type]
present_fields = sum(1 for field in expected if field in metadata)
# Calculate base completeness score
completeness = present_fields / len(expected)
# Add bonus for additional useful metadata
bonus = 0
if content_type == 'xml':
context = json.loads(metadata.get('context', '{}'))
if context:
bonus += 0.2
return min(completeness + bonus, 1.0)
except Exception as e:
print(f"Error in context completeness calculation: {str(e)}")
return 0.5
def _calculate_freshness(self, timestamp: str) -> float:
"""Calculate freshness score based on timestamp"""
try:
# Parse timestamp
doc_time = datetime.datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S.%f')
now = datetime.datetime.now()
# Calculate age in hours
age_hours = (now - doc_time).total_seconds() / 3600
# Score decreases with age (24 hours = 1 day)
if age_hours < 24:
return 1.0
elif age_hours < 168: # 1 week
return 0.8
elif age_hours < 720: # 1 month
return 0.6
else:
return 0.4
except Exception as e:
print(f"Error in freshness calculation: {str(e)}")
return 0.5
def _generate_score_explanation(self, scores: Dict) -> str:
"""Generate human-readable explanation of scores"""
try:
explanations = [
f"Total Score: {scores['total_score']:.2f}",
"\nComponent Scores:",
f"• Vector Similarity: {scores['vector_similarity']['score']:.2f} (40% weight)",
f"• Content Match: {scores['content_match']['score']:.2f} (20% weight)",
f"• Structural Relevance: {scores['structural_relevance']['score']:.2f} (20% weight)",
f"• Context Completeness: {scores['context_completeness']['score']:.2f} (10% weight)",
f"• Freshness: {scores['freshness']['score']:.2f} (10% weight)"
]
return "\n".join(explanations)
except Exception as e:
print(f"Error generating score explanation: {str(e)}")
return "Score explanation unavailable"