Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import tempfile | |
from typing import Dict, List, Tuple | |
import xml.etree.ElementTree as ET | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from groq import Groq | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import PyPDF2 | |
import numpy as np | |
# Initialize session state for storing processed files | |
if 'processed_files' not in st.session_state: | |
st.session_state.processed_files = {} | |
if 'current_collection' not in st.session_state: | |
st.session_state.current_collection = None | |
if 'current_raw_nodes' not in st.session_state: | |
st.session_state.current_raw_nodes = {} | |
# Original XML processing functions remain unchanged | |
def extract_node_details(element): | |
""" | |
Extracts details like description, value, NodeId, DisplayName, and references from an XML element. | |
""" | |
details = { | |
"NodeId": element.attrib.get("NodeId", "N/A"), | |
"Description": None, | |
"DisplayName": None, | |
"Value": None, | |
"References": [] | |
} | |
for child in element: | |
tag = child.tag.split('}')[-1] | |
if tag == "Description": | |
details["Description"] = child.text | |
elif tag == "DisplayName": | |
details["DisplayName"] = child.text | |
elif tag == "Value": | |
details["Value"] = extract_value_content(child) | |
elif tag == "References": | |
for reference in child: | |
if reference.tag.split('}')[-1] == "Reference": | |
details["References"].append(reference.attrib) | |
return details | |
def extract_value_content(value_element): | |
""" | |
Recursively extracts the content of a <Value> element, handling any embedded child elements. | |
""" | |
if not list(value_element): # No child elements, return text directly | |
return value_element.text or "No value provided." | |
# Process child elements | |
content = [] | |
for child in value_element: | |
tag = child.tag.split('}')[-1] | |
child_text = child.text.strip() if child.text else "" | |
content.append(f"<{tag}>{child_text}</{tag}>") | |
return "".join(content) | |
def parse_nodes_to_dict(filename): | |
""" | |
Parses the XML file and saves node details into a dictionary. | |
Each node's NodeId serves as the key, and the value is a dictionary of the node's details. | |
""" | |
tree = ET.parse(filename) | |
root = tree.getroot() | |
# Retrieve namespace from the root | |
namespace = root.tag.split('}')[0].strip('{') | |
# Node types to extract | |
node_types = ["UAObject", "UAVariable", "UAObjectType"] | |
nodes_dict = {} | |
for node_type in node_types: | |
for element in root.findall(f".//{{{namespace}}}{node_type}"): | |
details = extract_node_details(element) | |
node_id = details["NodeId"] | |
if node_id != "N/A": | |
nodes_dict[node_id] = details | |
return nodes_dict | |
def format_node_content(details): | |
""" | |
Formats raw node details into a single string for semantic comparison. | |
""" | |
content_parts = [] | |
if details["Description"]: | |
content_parts.append(f"Description: {details['Description']}") | |
if details["DisplayName"]: | |
content_parts.append(f"DisplayName: {details['DisplayName']}") | |
if details["Value"]: | |
content_parts.append(f"Value: {details['Value']}") | |
return " | ".join(content_parts) | |
def convert_to_natural_language(details): | |
""" | |
Converts node details to natural language using Groq LLM. | |
""" | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
messages = [ | |
{ | |
"role": "user", | |
"content": f"Convert the following node details to natural language: {details}" | |
} | |
] | |
chat_completion = client.chat.completions.create( | |
messages=messages, | |
model="llama3-8b-8192", | |
) | |
return chat_completion.choices[0].message.content | |
# New file type detection and processing functions without magic library | |
def detect_file_type(file_path): | |
""" | |
Detects if the input file is PDF or XML using file extension and content analysis. | |
""" | |
try: | |
# Check file extension | |
file_extension = os.path.splitext(file_path)[1].lower() | |
# Read the first few bytes of the file to check its content | |
with open(file_path, 'rb') as f: | |
header = f.read(8) # Read first 8 bytes | |
# Check for PDF signature | |
if file_extension == '.pdf' or header.startswith(b'%PDF'): | |
# Verify it's actually a PDF by trying to open it | |
try: | |
with open(file_path, 'rb') as f: | |
PyPDF2.PdfReader(f) | |
return 'pdf' | |
except: | |
return 'unknown' | |
# Check for XML | |
elif file_extension == '.xml': | |
# Try to parse as XML | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content_start = f.read(1024) # Read first 1KB | |
# Check for XML declaration or root element | |
if content_start.strip().startswith(('<?xml', '<')): | |
ET.parse(file_path) # Verify it's valid XML | |
return 'xml' | |
except: | |
return 'unknown' | |
return 'unknown' | |
except Exception as e: | |
print(f"Error detecting file type: {str(e)}") | |
return 'unknown' | |
def process_pdf(file_path): | |
""" | |
Extracts text content from PDF and splits it into meaningful chunks. | |
""" | |
try: | |
chunks = [] | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page_num in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_num] | |
text = page.extract_text() | |
# Split text into paragraphs | |
paragraphs = text.split('\n\n') | |
# Process each paragraph | |
for para_num, paragraph in enumerate(paragraphs): | |
if len(paragraph.strip()) > 0: # Skip empty paragraphs | |
chunk = { | |
'content': paragraph.strip(), | |
'metadata': { | |
'page_number': page_num + 1, | |
'paragraph_number': para_num + 1, | |
'source_type': 'pdf', | |
'file_name': os.path.basename(file_path) | |
} | |
} | |
chunks.append(chunk) | |
return chunks | |
except Exception as e: | |
print(f"Error processing PDF: {str(e)}") | |
return [] | |
def add_to_vector_db(collection, chunks, embedder): | |
""" | |
Adds processed chunks to the vector database with proper metadata. | |
""" | |
try: | |
for i, chunk in enumerate(chunks): | |
# Create unique ID for each chunk | |
chunk_id = f"{chunk['metadata']['file_name']}_{chunk['metadata']['page_number']}_{chunk['metadata']['paragraph_number']}" | |
collection.add( | |
documents=[chunk['content']], | |
metadatas=[chunk['metadata']], | |
ids=[chunk_id] | |
) | |
except Exception as e: | |
print(f"Error adding to vector database: {str(e)}") | |
def process_file(file_path): | |
""" | |
Main function to process either PDF or XML file and add to vector database. | |
Also returns the raw node details for XML files. | |
""" | |
try: | |
# Initialize ChromaDB and embedding function | |
client = chromadb.Client() | |
embedder = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
# Create or get collection | |
collection = client.create_collection( | |
name="document_embeddings", | |
get_or_create=True | |
) | |
# Store for raw node details | |
raw_nodes = {} | |
# Detect file type | |
file_type = detect_file_type(file_path) | |
if file_type == 'pdf': | |
# Process PDF | |
chunks = process_pdf(file_path) | |
add_to_vector_db(collection, chunks, embedder) | |
elif file_type == 'xml': | |
# Parse XML and store raw nodes | |
raw_nodes = parse_nodes_to_dict(file_path) | |
# Convert to natural language for RAG | |
for node_id, details in raw_nodes.items(): | |
nl_description = convert_to_natural_language(details) | |
# Add to vector DB | |
collection.add( | |
documents=[nl_description], | |
metadatas=[{"NodeId": node_id, "source_type": "xml"}], | |
ids=[node_id] | |
) | |
else: | |
raise ValueError("Unsupported file type") | |
return collection, raw_nodes | |
except Exception as e: | |
print(f"Error processing file: {str(e)}") | |
return None, {} | |
def generate_rag_response(query_text, context): | |
""" | |
Generates a RAG response using the Groq LLM based on the query and retrieved context. | |
Args: | |
query_text (str): The user's query | |
context (str): The retrieved context from the vector database | |
Returns: | |
str: The generated response from the LLM | |
""" | |
try: | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant that answers questions based on the provided context. " | |
"If the context doesn't contain relevant information, acknowledge that." | |
}, | |
{ | |
"role": "user", | |
"content": f"Answer the following query based on the provided context:\n\n" | |
f"Query: {query_text}\n\n" | |
f"Context: {context}" | |
} | |
] | |
chat_completion = client.chat.completions.create( | |
messages=messages, | |
model="llama3-8b-8192", | |
) | |
return chat_completion.choices[0].message.content | |
except Exception as e: | |
print(f"Error generating RAG response: {str(e)}") | |
return "Error generating response" | |
def find_similar_nodes(query_text, raw_nodes, top_k=5): | |
""" | |
Finds the most semantically similar nodes to the query using raw node content. | |
Args: | |
query_text (str): The user's query | |
raw_nodes (dict): Dictionary of node_id: node_details pairs | |
top_k (int): Number of top results to return | |
""" | |
try: | |
# Initialize the sentence transformer model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Format node contents and create mapping | |
node_contents = {} | |
for node_id, details in raw_nodes.items(): | |
formatted_content = format_node_content(details) | |
if formatted_content: # Only include nodes with content | |
node_contents[node_id] = formatted_content | |
# Generate embeddings for the query | |
query_embedding = model.encode([query_text])[0] | |
# Create a list of (node_id, content) tuples | |
nodes = list(node_contents.items()) | |
contents = [content for _, content in nodes] | |
# Generate embeddings for all node contents | |
content_embeddings = model.encode(contents) | |
# Calculate cosine similarities | |
similarities = cosine_similarity([query_embedding], content_embeddings)[0] | |
# Get indices of top-k similar nodes | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
# Format results | |
results = [] | |
for idx in top_indices: | |
node_id, content = nodes[idx] | |
similarity_score = similarities[idx] | |
results.append({ | |
'node_id': node_id, | |
'raw_content': content, | |
'original_details': raw_nodes[node_id], | |
'similarity_score': similarity_score | |
}) | |
return results | |
except Exception as e: | |
print(f"Error finding similar nodes: {str(e)}") | |
return [] | |
def query_documents(collection, raw_nodes, query_text, n_results=5): | |
""" | |
Query the vector database and perform semantic similarity search on raw nodes. | |
""" | |
try: | |
# Get results from vector database | |
results = collection.query( | |
query_texts=[query_text], | |
n_results=n_results | |
) | |
# Combine the retrieved results into context for RAG | |
retrieved_context = "\n".join(results["documents"][0]) | |
# Generate RAG response | |
rag_response = generate_rag_response(query_text, retrieved_context) | |
# Find semantically similar nodes using raw node content | |
similar_nodes = find_similar_nodes(query_text, raw_nodes) if raw_nodes else [] | |
# Format vector DB results | |
formatted_results = [] | |
for i in range(len(results["documents"][0])): | |
result = { | |
"content": results["documents"][0][i], | |
"metadata": results["metadatas"][0][i], | |
"score": results["distances"][0][i] if "distances" in results else None, | |
"rag_response": rag_response if i == 0 else None | |
} | |
formatted_results.append(result) | |
return formatted_results, similar_nodes | |
except Exception as e: | |
print(f"Error querying documents: {str(e)}") | |
return [], [] | |
def main(): | |
st.title("Document Query System") | |
st.write("Upload PDF or XML files and query their contents") | |
# File upload section | |
uploaded_files = st.file_uploader( | |
"Upload PDF or XML files", | |
type=['pdf', 'xml'], | |
accept_multiple_files=True | |
) | |
# Process uploaded files | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
if uploaded_file.name not in st.session_state.processed_files: | |
with st.spinner(f'Processing {uploaded_file.name}...'): | |
collection, raw_nodes = process_file(uploaded_file) | |
if collection: | |
st.session_state.processed_files[uploaded_file.name] = { | |
'collection': collection, | |
'raw_nodes': raw_nodes | |
} | |
st.success(f"Successfully processed {uploaded_file.name}") | |
else: | |
st.error(f"Failed to process {uploaded_file.name}") | |
# File selection and querying section | |
if st.session_state.processed_files: | |
selected_file = st.selectbox( | |
"Select file to query", | |
options=list(st.session_state.processed_files.keys()) | |
) | |
if selected_file: | |
st.session_state.current_collection = st.session_state.processed_files[selected_file]['collection'] | |
st.session_state.current_raw_nodes = st.session_state.processed_files[selected_file]['raw_nodes'] | |
query = st.text_input("Enter your query:") | |
if st.button("Search"): | |
if query: | |
with st.spinner('Searching...'): | |
results, similar_nodes = query_documents( | |
st.session_state.current_collection, | |
st.session_state.current_raw_nodes, | |
query | |
) | |
# Display RAG response | |
if results and results[0]['rag_response']: | |
st.subheader("Generated Answer") | |
st.write(results[0]['rag_response']) | |
# Display vector DB results | |
st.subheader("Search Results") | |
for i, result in enumerate(results, 1): | |
with st.expander(f"Match {i}"): | |
st.write(f"Content: {result['content']}") | |
st.write(f"Source: {result['metadata']['source_type']}") | |
if result['metadata']['source_type'] == 'pdf': | |
st.write(f"Page: {result['metadata']['page_number']}") | |
elif result['metadata']['source_type'] == 'xml': | |
st.write(f"NodeId: {result['metadata']['NodeId']}") | |
# Display semantic similarity results | |
if similar_nodes: | |
st.subheader("Similar Nodes") | |
for i, node in enumerate(similar_nodes, 1): | |
with st.expander(f"Similar Node {i}"): | |
st.write(f"NodeId: {node['node_id']}") | |
st.write(f"Description: {node['original_details'].get('Description', 'N/A')}") | |
st.write(f"DisplayName: {node['original_details'].get('DisplayName', 'N/A')}") | |
st.write(f"Value: {node['original_details'].get('Value', 'N/A')}") | |
st.write(f"Similarity Score: {node['similarity_score']:.4f}") | |
if __name__ == "__main__": | |
main() |