Spaces:
Sleeping
Sleeping
from typing import List, Dict, Union | |
from groq import Groq | |
import chromadb | |
import os | |
import datetime | |
import json | |
import xml.etree.ElementTree as ET | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
import PyPDF2 | |
from sentence_transformers import SentenceTransformer | |
class CustomEmbeddingFunction: | |
def __init__(self): | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
def __call__(self, input: List[str]) -> List[List[float]]: | |
embeddings = self.model.encode(input) | |
return embeddings.tolist() | |
class UnifiedDocumentProcessor: | |
def __init__(self, groq_api_key, collection_name="unified_content"): | |
"""Initialize the processor with necessary clients""" | |
self.groq_client = Groq(api_key=groq_api_key) | |
# XML-specific settings | |
self.max_elements_per_chunk = 50 | |
# PDF-specific settings | |
self.pdf_chunk_size = 500 | |
self.pdf_overlap = 50 | |
# Initialize NLTK | |
self._initialize_nltk() | |
# Initialize ChromaDB with a single collection for all document types | |
self.chroma_client = chromadb.Client() | |
existing_collections = self.chroma_client.list_collections() | |
collection_exists = any(col.name == collection_name for col in existing_collections) | |
if collection_exists: | |
print(f"Using existing collection: {collection_name}") | |
self.collection = self.chroma_client.get_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
else: | |
print(f"Creating new collection: {collection_name}") | |
self.collection = self.chroma_client.create_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
def _initialize_nltk(self): | |
"""Ensure both NLTK resources are available.""" | |
try: | |
nltk.download('punkt') | |
try: | |
nltk.data.find('tokenizers/punkt_tab') | |
except LookupError: | |
nltk.download('punkt_tab') | |
except Exception as e: | |
print(f"Warning: Error downloading NLTK resources: {str(e)}") | |
print("Falling back to basic sentence splitting...") | |
def _basic_sentence_split(self, text: str) -> List[str]: | |
"""Fallback method for sentence tokenization""" | |
sentences = [] | |
current = "" | |
for char in text: | |
current += char | |
if char in ['.', '!', '?'] and len(current.strip()) > 0: | |
sentences.append(current.strip()) | |
current = "" | |
if current.strip(): | |
sentences.append(current.strip()) | |
return sentences | |
def extract_text_from_pdf(self, pdf_path: str) -> str: | |
"""Extract text from PDF file""" | |
try: | |
text = "" | |
with open(pdf_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() + " " | |
return text.strip() | |
except Exception as e: | |
raise Exception(f"Error extracting text from PDF: {str(e)}") | |
def chunk_text(self, text: str) -> List[str]: | |
"""Split text into chunks while preserving sentence boundaries""" | |
try: | |
sentences = sent_tokenize(text) | |
except Exception as e: | |
print(f"Warning: Using fallback sentence splitting: {str(e)}") | |
sentences = self._basic_sentence_split(text) | |
chunks = [] | |
current_chunk = [] | |
current_size = 0 | |
for sentence in sentences: | |
words = sentence.split() | |
sentence_size = len(words) | |
if current_size + sentence_size > self.pdf_chunk_size: | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else [] | |
current_chunk = overlap_words + words | |
current_size = len(current_chunk) | |
else: | |
current_chunk = words | |
current_size = sentence_size | |
else: | |
current_chunk.extend(words) | |
current_size += sentence_size | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
def process_xml_file(self, xml_file_path: str) -> Dict: | |
"""Process XML file with optimized batching and reduced database operations""" | |
try: | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
# Process XML into chunks efficiently | |
chunks = [] | |
paths = [] | |
def process_element(element, current_path=""): | |
# Create element description | |
element_info = [] | |
# Add basic information | |
element_info.append(f"Element: {element.tag}") | |
# Process namespace only if present | |
if '}' in element.tag: | |
namespace = element.tag.split('}')[0].strip('{') | |
element_info.append(f"Namespace: {namespace}") | |
# Process important attributes only | |
important_attrs = ['NodeId', 'BrowseName', 'DisplayName', 'Description', 'DataType'] | |
attrs = {k: v for k, v in element.attrib.items() if k in important_attrs} | |
if attrs: | |
for key, value in attrs.items(): | |
element_info.append(f"{key}: {value}") | |
# Process text content if meaningful | |
if element.text and element.text.strip(): | |
element_info.append(f"Content: {element.text.strip()}") | |
# Create chunk text | |
chunk_text = " | ".join(element_info) | |
new_path = f"{current_path}/{element.tag}" if current_path else element.tag | |
chunks.append(chunk_text) | |
paths.append(new_path) | |
# Process children | |
for child in element: | |
process_element(child, new_path) | |
# Start processing from root | |
process_element(root) | |
print(f"Generated {len(chunks)} XML chunks") | |
# Batch process into database | |
batch_size = 100 # Increased batch size | |
results = [] | |
for i in range(0, len(chunks), batch_size): | |
batch_end = min(i + batch_size, len(chunks)) | |
batch_chunks = chunks[i:batch_end] | |
batch_paths = paths[i:batch_end] | |
# Prepare batch metadata | |
batch_metadata = [{ | |
'source_file': os.path.basename(xml_file_path), | |
'content_type': 'xml', | |
'chunk_id': idx, | |
'total_chunks': len(chunks), | |
'xml_path': path, | |
'timestamp': str(datetime.datetime.now()) | |
} for idx, path in enumerate(batch_paths, start=i)] | |
# Generate batch IDs | |
batch_ids = [ | |
f"{os.path.basename(xml_file_path)}_xml_{idx}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
for idx in range(i, batch_end) | |
] | |
# Store batch in vector database | |
self.collection.add( | |
documents=batch_chunks, | |
metadatas=batch_metadata, | |
ids=batch_ids | |
) | |
# Track results | |
results.extend([{ | |
'chunk': idx, | |
'success': True, | |
'doc_id': doc_id, | |
'text': text | |
} for idx, (doc_id, text) in enumerate(zip(batch_ids, batch_chunks), start=i)]) | |
# Print progress | |
print(f"Processed chunks {i} to {batch_end} of {len(chunks)}") | |
return { | |
'success': True, | |
'total_chunks': len(chunks), | |
'results': results | |
} | |
except Exception as e: | |
print(f"Error processing XML: {str(e)}") | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def process_pdf_file(self, pdf_file_path: str) -> Dict: | |
"""Process PDF file with direct embedding""" | |
try: | |
full_text = self.extract_text_from_pdf(pdf_file_path) | |
chunks = self.chunk_text(full_text) | |
print(f"Split PDF into {len(chunks)} chunks") | |
results = [] | |
for i, chunk in enumerate(chunks): | |
try: | |
metadata = { | |
'source_file': os.path.basename(pdf_file_path), | |
'content_type': 'pdf', | |
'chunk_id': i, | |
'total_chunks': len(chunks), | |
'timestamp': str(datetime.datetime.now()), | |
'chunk_size': len(chunk.split()) | |
} | |
# Store directly in vector database | |
doc_id = self.store_in_vector_db(chunk, metadata) | |
results.append({ | |
'chunk': i, | |
'success': True, | |
'doc_id': doc_id, | |
'text': chunk[:200] + "..." if len(chunk) > 200 else chunk | |
}) | |
except Exception as e: | |
results.append({ | |
'chunk': i, | |
'success': False, | |
'error': str(e) | |
}) | |
return { | |
'success': True, | |
'total_chunks': len(chunks), | |
'results': results | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def store_in_vector_db(self, text: str, metadata: Dict) -> str: | |
"""Store content in vector database""" | |
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
self.collection.add( | |
documents=[text], | |
metadatas=[metadata], | |
ids=[doc_id] | |
) | |
return doc_id | |
def get_available_files(self) -> Dict[str, List[str]]: | |
"""Get list of all files in the database""" | |
try: | |
all_entries = self.collection.get( | |
include=['metadatas'] | |
) | |
files = { | |
'pdf': set(), | |
'xml': set() | |
} | |
for metadata in all_entries['metadatas']: | |
file_type = metadata['content_type'] | |
file_name = metadata['source_file'] | |
files[file_type].add(file_name) | |
return { | |
'pdf': sorted(list(files['pdf'])), | |
'xml': sorted(list(files['xml'])) | |
} | |
except Exception as e: | |
print(f"Error getting available files: {str(e)}") | |
return {'pdf': [], 'xml': []} | |
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str: | |
"""Ask a question using only the selected files""" | |
try: | |
filter_dict = { | |
'source_file': {'$in': selected_files} | |
} | |
results = self.collection.query( | |
query_texts=[question], | |
n_results=n_results, | |
where=filter_dict, | |
include=["documents", "metadatas"] | |
) | |
if not results['documents'][0]: | |
return "No relevant content found in the selected files." | |
# Format answer based on content type | |
formatted_answer = [] | |
for doc, meta in zip(results['documents'][0], results['metadatas'][0]): | |
if meta['content_type'] == 'xml': | |
formatted_answer.append(f"Found in XML path: {meta['xml_path']}\n{doc}") | |
else: | |
formatted_answer.append(doc) | |
# Create response using the matched content | |
prompt = f"""Based on these relevant sections, please answer: {question} | |
Relevant Content: | |
{' '.join(formatted_answer)} | |
Please provide a clear, concise answer based on the above content.""" | |
response = self.groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama3-8b-8192", | |
temperature=0.2 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error processing your question: {str(e)}" | |
def get_detailed_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict: | |
"""Get detailed context including path and metadata information""" | |
try: | |
filter_dict = { | |
'source_file': {'$in': selected_files} | |
} | |
results = self.collection.query( | |
query_texts=[question], | |
n_results=n_results, | |
where=filter_dict, | |
include=["documents", "metadatas", "distances"] | |
) | |
if not results['documents'][0]: | |
return { | |
'success': False, | |
'error': "No relevant content found" | |
} | |
detailed_results = [] | |
for doc, meta, distance in zip(results['documents'][0], results['metadatas'][0], results['distances'][0]): | |
result_info = { | |
'content': doc, | |
'metadata': meta, | |
'similarity_score': round((1 - distance) * 100, 2), # Convert to percentage | |
'source_info': { | |
'file': meta['source_file'], | |
'type': meta['content_type'], | |
'path': meta.get('xml_path', 'N/A'), | |
'context': json.loads(meta['context']) if meta.get('context') else {} | |
} | |
} | |
detailed_results.append(result_info) | |
return { | |
'success': True, | |
'results': detailed_results, | |
'query': question | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def get_hierarchical_context(self, question: str, selected_files: List[str], n_results: int = 5) -> Dict: | |
"""Get hierarchical context for XML files including parent-child relationships""" | |
try: | |
# Get initial results | |
initial_results = self.get_detailed_context(question, selected_files, n_results) | |
if not initial_results['success']: | |
return initial_results | |
hierarchical_results = [] | |
for result in initial_results['results']: | |
if result['metadata']['content_type'] == 'xml': | |
# Get parent elements | |
parent_path = '/'.join(result['source_info']['path'].split('/')[:-1]) | |
if parent_path: | |
parent_filter = { | |
'source_file': {'$eq': result['metadata']['source_file']}, | |
'xml_path': {'$eq': parent_path} | |
} | |
parent_results = self.collection.query( | |
query_texts=[""], # Empty query to get exact match | |
where=parent_filter, | |
include=["documents", "metadatas"], | |
n_results=1 | |
) | |
if parent_results['documents'][0]: | |
result['parent_info'] = { | |
'content': parent_results['documents'][0][0], | |
'metadata': parent_results['metadatas'][0][0] | |
} | |
# Get all potential children | |
all_filter = { | |
'source_file': {'$eq': result['metadata']['source_file']} | |
} | |
all_results = self.collection.query( | |
query_texts=[""], | |
where=all_filter, | |
include=["documents", "metadatas"], | |
n_results=100 | |
) | |
# Manually filter children | |
children_info = [] | |
current_path = result['source_info']['path'] | |
if all_results['documents'][0]: | |
for doc, meta in zip(all_results['documents'][0], all_results['metadatas'][0]): | |
child_path = meta.get('xml_path', '') | |
if (child_path.startswith(current_path + '/') and | |
len(child_path.split('/')) == len(current_path.split('/')) + 1): | |
children_info.append({ | |
'content': doc, | |
'metadata': meta | |
}) | |
if children_info: | |
result['children_info'] = children_info[:5] | |
hierarchical_results.append(result) | |
return { | |
'success': True, | |
'results': hierarchical_results, | |
'query': question | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def get_summary_and_details(self, question: str, selected_files: List[str]) -> Dict: | |
"""Get both a summary answer and detailed supporting information""" | |
try: | |
# Get hierarchical context first | |
detailed_results = self.get_hierarchical_context(question, selected_files) | |
if not detailed_results['success']: | |
return detailed_results | |
# Create summary prompt | |
relevant_content = [] | |
for result in detailed_results['results']: | |
if result['metadata']['content_type'] == 'xml': | |
content_info = [ | |
f"XML Path: {result['source_info']['path']}", | |
f"Content: {result['content']}" | |
] | |
if 'parent_info' in result: | |
content_info.append(f"Parent: {result['parent_info']['content']}") | |
if 'children_info' in result: | |
children_content = [child['content'] for child in result['children_info']] | |
content_info.append(f"Related Elements: {', '.join(children_content)}") | |
else: | |
content_info = [f"Content: {result['content']}"] | |
relevant_content.append('\n'.join(content_info)) | |
summary_prompt = ( | |
f"Based on the following content, please provide:\n" | |
"1. A concise answer to the question\n" | |
"2. Key supporting points\n" | |
"3. Related context if relevant\n\n" | |
f"Question: {question}\n\n" | |
f"Content:\n{chr(10).join(relevant_content)}" | |
) | |
response = self.groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": summary_prompt}], | |
model="llama3-8b-8192", | |
temperature=0.2 | |
) | |
return { | |
'success': True, | |
'summary': response.choices[0].message.content, | |
'details': detailed_results['results'], | |
'query': question | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def process_file(self, file_path: str) -> Dict: | |
"""Process any supported file type""" | |
try: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.xml': | |
return self.process_xml_file(file_path) | |
elif file_extension == '.pdf': | |
return self.process_pdf_file(file_path) | |
else: | |
return { | |
'success': False, | |
'error': f'Unsupported file type: {file_extension}' | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': f'Error processing file: {str(e)}' | |
} | |
def calculate_detailed_score(self, distance: float, metadata: Dict, content: str, query: str) -> Dict: | |
""" | |
Calculate a detailed, multi-faceted relevance score | |
Components: | |
1. Vector Similarity (40%): Base similarity from embeddings | |
2. Content Match (20%): Direct term matching | |
3. Structural Relevance (20%): XML structure relevance (for XML files) | |
4. Context Completeness (10%): Completeness of metadata/context | |
5. Freshness (10%): How recent the content is | |
""" | |
try: | |
scores = {} | |
# 1. Vector Similarity Score (40%) | |
vector_similarity = 1 - distance # Convert distance to similarity | |
scores['vector_similarity'] = { | |
'score': vector_similarity, | |
'weight': 0.4, | |
'weighted_score': vector_similarity * 0.4 | |
} | |
# 2. Content Match Score (20%) | |
content_match_score = self._calculate_content_match(content, query) | |
scores['content_match'] = { | |
'score': content_match_score, | |
'weight': 0.2, | |
'weighted_score': content_match_score * 0.2 | |
} | |
# 3. Structural Relevance Score (20%) | |
if metadata['content_type'] == 'xml': | |
structural_score = self._calculate_structural_relevance(metadata) | |
else: | |
structural_score = 0.5 # Default for non-XML | |
scores['structural_relevance'] = { | |
'score': structural_score, | |
'weight': 0.2, | |
'weighted_score': structural_score * 0.2 | |
} | |
# 4. Context Completeness Score (10%) | |
context_score = self._calculate_context_completeness(metadata) | |
scores['context_completeness'] = { | |
'score': context_score, | |
'weight': 0.1, | |
'weighted_score': context_score * 0.1 | |
} | |
# 5. Freshness Score (10%) | |
freshness_score = self._calculate_freshness(metadata['timestamp']) | |
scores['freshness'] = { | |
'score': freshness_score, | |
'weight': 0.1, | |
'weighted_score': freshness_score * 0.1 | |
} | |
# Calculate total score | |
total_score = sum(s['weighted_score'] for s in scores.values()) | |
return { | |
'total_score': total_score, | |
'component_scores': scores, | |
'explanation': self._generate_score_explanation(scores) | |
} | |
except Exception as e: | |
print(f"Error in score calculation: {str(e)}") | |
return { | |
'total_score': 0.5, | |
'error': str(e) | |
} | |
def _calculate_content_match(self, content: str, query: str) -> float: | |
"""Calculate direct term matching score""" | |
try: | |
# Tokenize content and query | |
content_terms = set(content.lower().split()) | |
query_terms = set(query.lower().split()) | |
# Calculate overlap | |
matching_terms = content_terms.intersection(query_terms) | |
if not query_terms: | |
return 0.5 | |
# Calculate scores for exact matches and partial matches | |
exact_match_score = len(matching_terms) / len(query_terms) | |
# Check for partial matches | |
partial_matches = 0 | |
for q_term in query_terms: | |
for c_term in content_terms: | |
if q_term in c_term or c_term in q_term: | |
partial_matches += 0.5 | |
partial_match_score = partial_matches / len(query_terms) | |
# Combine scores (70% exact matches, 30% partial matches) | |
return (exact_match_score * 0.7) + (partial_match_score * 0.3) | |
except Exception as e: | |
print(f"Error in content match calculation: {str(e)}") | |
return 0.5 | |
def _calculate_structural_relevance(self, metadata: Dict) -> float: | |
"""Calculate structural relevance score for XML content""" | |
try: | |
score = 0.5 # Base score | |
if 'xml_path' in metadata: | |
path = metadata['xml_path'] | |
# Score based on path depth (deeper paths might be more specific) | |
depth = len(path.split('/')) | |
depth_score = min(depth / 5, 1.0) # Normalize depth score | |
# Score based on element type | |
element_type = metadata.get('element_type', '') | |
type_scores = { | |
'UAObjectType': 0.9, | |
'UAVariableType': 0.9, | |
'UAObject': 0.8, | |
'UAVariable': 0.8, | |
'UAMethod': 0.7, | |
'UAView': 0.6, | |
'UAReferenceType': 0.7 | |
} | |
type_score = type_scores.get(element_type, 0.5) | |
# Score based on context completeness | |
context = json.loads(metadata.get('context', '{}')) | |
context_score = len(context) / 10 if context else 0.5 | |
# Combine scores | |
score = (depth_score * 0.3) + (type_score * 0.4) + (context_score * 0.3) | |
return score | |
except Exception as e: | |
print(f"Error in structural relevance calculation: {str(e)}") | |
return 0.5 | |
def _calculate_context_completeness(self, metadata: Dict) -> float: | |
"""Calculate context completeness score""" | |
try: | |
expected_fields = { | |
'xml': ['xml_path', 'element_type', 'context', 'chunk_id', 'total_chunks'], | |
'pdf': ['chunk_id', 'total_chunks', 'chunk_size'] | |
} | |
content_type = metadata.get('content_type', '') | |
if content_type not in expected_fields: | |
return 0.5 | |
# Check for presence of expected fields | |
expected = expected_fields[content_type] | |
present_fields = sum(1 for field in expected if field in metadata) | |
# Calculate base completeness score | |
completeness = present_fields / len(expected) | |
# Add bonus for additional useful metadata | |
bonus = 0 | |
if content_type == 'xml': | |
context = json.loads(metadata.get('context', '{}')) | |
if context: | |
bonus += 0.2 | |
return min(completeness + bonus, 1.0) | |
except Exception as e: | |
print(f"Error in context completeness calculation: {str(e)}") | |
return 0.5 | |
def _calculate_freshness(self, timestamp: str) -> float: | |
"""Calculate freshness score based on timestamp""" | |
try: | |
# Parse timestamp | |
doc_time = datetime.datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S.%f') | |
now = datetime.datetime.now() | |
# Calculate age in hours | |
age_hours = (now - doc_time).total_seconds() / 3600 | |
# Score decreases with age (24 hours = 1 day) | |
if age_hours < 24: | |
return 1.0 | |
elif age_hours < 168: # 1 week | |
return 0.8 | |
elif age_hours < 720: # 1 month | |
return 0.6 | |
else: | |
return 0.4 | |
except Exception as e: | |
print(f"Error in freshness calculation: {str(e)}") | |
return 0.5 | |
def _generate_score_explanation(self, scores: Dict) -> str: | |
"""Generate human-readable explanation of scores""" | |
try: | |
explanations = [ | |
f"Total Score: {scores['total_score']:.2f}", | |
"\nComponent Scores:", | |
f"• Vector Similarity: {scores['vector_similarity']['score']:.2f} (40% weight)", | |
f"• Content Match: {scores['content_match']['score']:.2f} (20% weight)", | |
f"• Structural Relevance: {scores['structural_relevance']['score']:.2f} (20% weight)", | |
f"• Context Completeness: {scores['context_completeness']['score']:.2f} (10% weight)", | |
f"• Freshness: {scores['freshness']['score']:.2f} (10% weight)" | |
] | |
return "\n".join(explanations) | |
except Exception as e: | |
print(f"Error generating score explanation: {str(e)}") | |
return "Score explanation unavailable" |