Spaces:
Sleeping
Sleeping
from typing import List, Dict, Union, Optional | |
from groq import Groq | |
import chromadb | |
import os | |
import datetime | |
import json | |
import xml.etree.ElementTree as ET | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
import PyPDF2 | |
from sentence_transformers import SentenceTransformer | |
class CustomEmbeddingFunction: | |
def __init__(self): | |
self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
def __call__(self, input: List[str]) -> List[List[float]]: | |
embeddings = self.model.encode(input) | |
return embeddings.tolist() | |
class EnhancedXMLProcessor: | |
def __init__(self): | |
self.processed_nodes = set() | |
self.reference_map = {} | |
self.node_info = {} | |
def build_reference_map(self, root) -> None: | |
"""Build a map of all node references for faster lookup""" | |
for element in root.findall('.//*'): | |
node_id = element.get('NodeId') | |
if node_id: | |
self.node_info[node_id] = { | |
'tag': element.tag, | |
'browse_name': element.get('BrowseName', ''), | |
'display_name': self._get_display_name(element), | |
'description': self._get_description(element), | |
'data_type': element.get('DataType', ''), | |
'references': [] | |
} | |
refs = element.find('References') | |
if refs is not None: | |
for ref in refs.findall('Reference'): | |
ref_type = ref.get('ReferenceType') | |
is_forward = ref.get('IsForward', 'true') == 'true' | |
target = ref.text | |
if ref_type in ['HasComponent', 'HasProperty', 'HasTypeDefinition']: | |
self.reference_map.setdefault(node_id, []).append({ | |
'type': ref_type, | |
'target': target, | |
'is_forward': is_forward | |
}) | |
self.node_info[node_id]['references'].append({ | |
'type': ref_type, | |
'target': target, | |
'is_forward': is_forward | |
}) | |
def _get_display_name(self, element) -> str: | |
"""Extract display name from element""" | |
display_name = element.find('DisplayName') | |
if display_name is not None: | |
return display_name.text | |
return '' | |
def _get_description(self, element) -> str: | |
"""Extract description from element""" | |
desc = element.find('Description') | |
if desc is not None: | |
return desc.text | |
return '' | |
def generate_natural_language(self, node_id: str, depth: int = 0, visited: set = None) -> List[str]: | |
"""Generate natural language description for a node and its children""" | |
if visited is None: | |
visited = set() | |
if node_id in visited: | |
return [] | |
visited.add(node_id) | |
descriptions = [] | |
node = self.node_info.get(node_id) | |
if not node: | |
return [] | |
base_desc = self._build_base_description(node, depth) | |
if base_desc: | |
descriptions.append(base_desc) | |
if node_id in self.reference_map: | |
child_descriptions = self._process_forward_references(node_id, depth + 1, visited) | |
descriptions.extend(child_descriptions) | |
return descriptions | |
def _build_base_description(self, node: Dict, depth: int) -> str: | |
"""Build the base description for a node""" | |
indentation = " " * depth | |
desc_parts = [] | |
if node['browse_name']: | |
browse_name = node['browse_name'].split(':')[-1] | |
desc_parts.append(f"a {browse_name}") | |
if node['display_name']: | |
desc_parts.append(f"(displayed as '{node['display_name']}')") | |
if node['data_type']: | |
desc_parts.append(f"of type {node['data_type']}") | |
if node['description']: | |
desc_parts.append(f"which {node['description']}") | |
if desc_parts: | |
return f"{indentation}Contains {' '.join(desc_parts)}" | |
return "" | |
def _process_forward_references(self, node_id: str, depth: int, visited: set) -> List[str]: | |
"""Process forward references to build hierarchical descriptions""" | |
descriptions = [] | |
for ref in self.reference_map.get(node_id, []): | |
if ref['is_forward'] and ref['type'] in ['HasComponent', 'HasProperty']: | |
target_descriptions = self.generate_natural_language(ref['target'], depth, visited) | |
descriptions.extend(target_descriptions) | |
return descriptions | |
def generate_complete_description(self, root) -> str: | |
"""Generate a complete natural language description of the XML structure""" | |
self.build_reference_map(root) | |
root_descriptions = [] | |
for node_id in self.node_info: | |
is_root = True | |
for ref_list in self.reference_map.values(): | |
for ref in ref_list: | |
if not ref['is_forward'] and ref['type'] == 'HasComponent' and ref['target'] == node_id: | |
is_root = False | |
break | |
if not is_root: | |
break | |
if is_root: | |
descriptions = self.generate_natural_language(node_id) | |
root_descriptions.extend(descriptions) | |
return "\n".join(root_descriptions) | |
class UnifiedDocumentProcessor: | |
def __init__(self, groq_api_key, collection_name="unified_content"): | |
"""Initialize the processor with necessary clients""" | |
self.groq_client = Groq(api_key=groq_api_key) | |
# XML-specific settings | |
self.max_elements_per_chunk = 50 | |
self.xml_processor = EnhancedXMLProcessor() | |
# PDF-specific settings | |
self.pdf_chunk_size = 500 | |
self.pdf_overlap = 50 | |
# Initialize NLTK | |
self._initialize_nltk() | |
# Initialize ChromaDB with a single collection for all document types | |
self.chroma_client = chromadb.Client() | |
existing_collections = self.chroma_client.list_collections() | |
collection_exists = any(col.name == collection_name for col in existing_collections) | |
if collection_exists: | |
print(f"Using existing collection: {collection_name}") | |
self.collection = self.chroma_client.get_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
else: | |
print(f"Creating new collection: {collection_name}") | |
self.collection = self.chroma_client.create_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
def _initialize_nltk(self): | |
"""Ensure NLTK's `punkt` tokenizer resource is available.""" | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
print("Downloading NLTK 'punkt' tokenizer...") | |
nltk.download('punkt') | |
def flatten_xml_to_text(self, element, depth=0) -> str: | |
"""Convert XML to natural language using the enhanced processor""" | |
try: | |
return self.xml_processor.generate_complete_description(element) | |
except Exception as e: | |
print(f"Error in enhanced XML processing: {str(e)}") | |
return self._original_flatten_xml_to_text(element, depth) | |
def _original_flatten_xml_to_text(self, element, depth=0) -> str: | |
"""Original fallback XML flattening implementation""" | |
text_parts = [] | |
element_info = f"Element: {element.tag}" | |
if element.attrib: | |
element_info += f", Attributes: {json.dumps(element.attrib)}" | |
if element.text and element.text.strip(): | |
element_info += f", Text: {element.text.strip()}" | |
text_parts.append(element_info) | |
for child in element: | |
child_text = self._original_flatten_xml_to_text(child, depth + 1) | |
text_parts.append(child_text) | |
return "\n".join(text_parts) | |
def extract_text_from_pdf(self, pdf_path: str) -> str: | |
"""Extract text from PDF file""" | |
try: | |
text = "" | |
with open(pdf_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() + " " | |
return text.strip() | |
except Exception as e: | |
raise Exception(f"Error extracting text from PDF: {str(e)}") | |
def chunk_text(self, text: str) -> List[str]: | |
"""Split text into chunks while preserving sentence boundaries""" | |
sentences = sent_tokenize(text) | |
chunks = [] | |
current_chunk = [] | |
current_size = 0 | |
for sentence in sentences: | |
words = sentence.split() | |
sentence_size = len(words) | |
if current_size + sentence_size > self.pdf_chunk_size: | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else [] | |
current_chunk = overlap_words + words | |
current_size = len(current_chunk) | |
else: | |
current_chunk = words | |
current_size = sentence_size | |
else: | |
current_chunk.extend(words) | |
current_size += sentence_size | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
def chunk_xml_text(self, text: str, max_chunk_size: int = 2000) -> List[str]: | |
"""Split flattened XML text into manageable chunks""" | |
lines = text.split('\n') | |
chunks = [] | |
current_chunk = [] | |
current_size = 0 | |
for line in lines: | |
line_size = len(line) | |
if current_size + line_size > max_chunk_size and current_chunk: | |
chunks.append('\n'.join(current_chunk)) | |
current_chunk = [] | |
current_size = 0 | |
current_chunk.append(line) | |
current_size += line_size | |
if current_chunk: | |
chunks.append('\n'.join(current_chunk)) | |
return chunks | |
def generate_natural_language(self, content: Union[List[Dict], str], content_type: str) -> str: | |
"""Generate natural language description with improved error handling and chunking""" | |
try: | |
if content_type == "xml": | |
prompt = f"Convert this XML structure description to a natural language summary that preserves the hierarchical relationships: {content}" | |
else: # pdf | |
prompt = f"Summarize this text while preserving key information: {content}" | |
max_prompt_length = 4000 | |
if len(prompt) > max_prompt_length: | |
prompt = prompt[:max_prompt_length] + "..." | |
response = self.groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama3-8b-8192", | |
max_tokens=1000 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
print(f"Error generating natural language: {str(e)}") | |
if len(content) > 2000: | |
half_length = len(content) // 2 | |
first_half = content[:half_length] | |
try: | |
return self.generate_natural_language(first_half, content_type) | |
except: | |
return None | |
return None | |
def store_in_vector_db(self, natural_language: str, metadata: Dict) -> str: | |
"""Store content in vector database""" | |
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
self.collection.add( | |
documents=[natural_language], | |
metadatas=[metadata], | |
ids=[doc_id] | |
) | |
return doc_id | |
def process_file(self, file_path: str) -> Dict: | |
"""Process any supported file type""" | |
try: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.xml': | |
return self.process_xml_file(file_path) | |
elif file_extension == '.pdf': | |
return self.process_pdf_file(file_path) | |
else: | |
return { | |
'success': False, | |
'error': f'Unsupported file type: {file_extension}' | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': f'Error processing file: {str(e)}' | |
} | |
def process_xml_file(self, xml_file_path: str) -> Dict: | |
"""Process XML file with improved chunking""" | |
try: | |
tree = ET.parse(xml_file_path) | |
root = tree.getroot() | |
flattened_text = self.flatten_xml_to_text(root) | |
chunks = self.chunk_xml_text(flattened_text) | |
print(f"Split XML into {len(chunks)} chunks") | |
results = [] | |
for i, chunk in enumerate(chunks): | |
print(f"Processing XML chunk {i+1}/{len(chunks)}") | |
try: | |
natural_language = self.generate_natural_language(chunk, "xml") | |
if natural_language: | |
metadata = { | |
'source_file': os.path.basename(xml_file_path), | |
'content_type': 'xml', | |
'chunk_id': i, | |
'total_chunks': len(chunks), | |
'timestamp': str(datetime.datetime.now()) | |
} | |
doc_id = self.store_in_vector_db(natural_language, metadata) | |
results.append({ | |
'chunk': i, | |
'success': True, | |
'doc_id': doc_id, | |
'natural_language': natural_language | |
}) | |
else: | |
results.append({ | |
'chunk': i, | |
'success': False, | |
'error': 'Failed to generate natural language' | |
}) | |
except Exception as e: | |
print(f"Error processing chunk {i}: {str(e)}") | |
results.append({ | |
'chunk': i, | |
'success': False, | |
'error': str(e) | |
}) | |
return { | |
'success': True, | |
'total_chunks': len(chunks), | |
'results': results | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def process_pdf_file(self, pdf_file_path: str) -> Dict: | |
"""Process PDF file""" | |
try: | |
full_text = self.extract_text_from_pdf(pdf_file_path) | |
chunks = self.chunk_text(full_text) | |
print(f"Split PDF into {len(chunks)} chunks") | |
results = [] | |
for i, chunk in enumerate(chunks): | |
print(f"Processing PDF chunk {i+1}/{len(chunks)}") | |
natural_language = self.generate_natural_language(chunk, "pdf") | |
if natural_language: | |
metadata = { | |
'source_file': os.path.basename(pdf_file_path), | |
'content_type': 'pdf', | |
'chunk_id': i, | |
'total_chunks': len(chunks), | |
'timestamp': str(datetime.datetime.now()), | |
'chunk_size': len(chunk.split()) | |
} | |
doc_id = self.store_in_vector_db(natural_language, metadata) | |
results.append({ | |
'chunk': i, | |
'success': True, | |
'doc_id': doc_id, | |
'natural_language': natural_language, | |
'original_text': chunk[:200] + "..." | |
}) | |
else: | |
results.append({ | |
'chunk': i, | |
'success': False, | |
'error': 'Failed to generate natural language summary' | |
}) | |
return { | |
'success': True, | |
'total_chunks': len(chunks), | |
'results': results | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def get_available_files(self) -> Dict[str, List[str]]: | |
"""Get list of all files in the database""" | |
try: | |
all_entries = self.collection.get( | |
include=['metadatas'] | |
) | |
files = { | |
'pdf': set(), | |
'xml': set() | |
} | |
for metadata in all_entries['metadatas']: | |
file_type = metadata['content_type'] | |
file_name = metadata['source_file'] | |
files[file_type].add(file_name) | |
return { | |
'pdf': sorted(list(files['pdf'])), | |
'xml': sorted(list(files['xml'])) | |
} | |
except Exception as e: | |
print(f"Error getting available files: {str(e)}") | |
return {'pdf': [], 'xml': []} | |
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str: | |
"""Ask a question using only the selected files""" | |
try: | |
filter_dict = { | |
'source_file': {'$in': selected_files} | |
} | |
results = self.collection.query( | |
query_texts=[question], | |
n_results=n_results, | |
where=filter_dict, | |
include=["documents", "metadatas"] | |
) | |
if not results['documents'][0]: | |
return "No relevant content found in the selected files." | |
context = "\n\n".join(results['documents'][0]) | |
prompt = f"""Based on the following content from the selected files, please answer this question: {question} | |
Content: | |
{context} | |
Please provide a direct answer based only on the information provided above.""" | |
response = self.groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama3-8b-8192", | |
temperature=0.2 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error processing your question: {str(e)}" |