Smart_AAS_v2.0 / app.py
TahaRasouli's picture
Update app.py
660dd5e verified
raw
history blame
17.3 kB
import streamlit as st
import os
import tempfile
from typing import Dict, List, Tuple
import xml.etree.ElementTree as ET
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import chromadb
from chromadb.utils import embedding_functions
import PyPDF2
import numpy as np
# Initialize session state for storing processed files
if 'processed_files' not in st.session_state:
st.session_state.processed_files = {}
if 'current_collection' not in st.session_state:
st.session_state.current_collection = None
if 'current_raw_nodes' not in st.session_state:
st.session_state.current_raw_nodes = {}
# Original XML processing functions remain unchanged
def extract_node_details(element):
"""
Extracts details like description, value, NodeId, DisplayName, and references from an XML element.
"""
details = {
"NodeId": element.attrib.get("NodeId", "N/A"),
"Description": None,
"DisplayName": None,
"Value": None,
"References": []
}
for child in element:
tag = child.tag.split('}')[-1]
if tag == "Description":
details["Description"] = child.text
elif tag == "DisplayName":
details["DisplayName"] = child.text
elif tag == "Value":
details["Value"] = extract_value_content(child)
elif tag == "References":
for reference in child:
if reference.tag.split('}')[-1] == "Reference":
details["References"].append(reference.attrib)
return details
def extract_value_content(value_element):
"""
Recursively extracts the content of a <Value> element, handling any embedded child elements.
"""
if not list(value_element): # No child elements, return text directly
return value_element.text or "No value provided."
# Process child elements
content = []
for child in value_element:
tag = child.tag.split('}')[-1]
child_text = child.text.strip() if child.text else ""
content.append(f"<{tag}>{child_text}</{tag}>")
return "".join(content)
def parse_nodes_to_dict(filename):
"""
Parses the XML file and saves node details into a dictionary.
Each node's NodeId serves as the key, and the value is a dictionary of the node's details.
"""
tree = ET.parse(filename)
root = tree.getroot()
# Retrieve namespace from the root
namespace = root.tag.split('}')[0].strip('{')
# Node types to extract
node_types = ["UAObject", "UAVariable", "UAObjectType"]
nodes_dict = {}
for node_type in node_types:
for element in root.findall(f".//{{{namespace}}}{node_type}"):
details = extract_node_details(element)
node_id = details["NodeId"]
if node_id != "N/A":
nodes_dict[node_id] = details
return nodes_dict
def format_node_content(details):
"""
Formats raw node details into a single string for semantic comparison.
"""
content_parts = []
if details["Description"]:
content_parts.append(f"Description: {details['Description']}")
if details["DisplayName"]:
content_parts.append(f"DisplayName: {details['DisplayName']}")
if details["Value"]:
content_parts.append(f"Value: {details['Value']}")
return " | ".join(content_parts)
def convert_to_natural_language(details):
"""
Converts node details to natural language using Groq LLM.
"""
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
messages = [
{
"role": "user",
"content": f"Convert the following node details to natural language: {details}"
}
]
chat_completion = client.chat.completions.create(
messages=messages,
model="llama3-8b-8192",
)
return chat_completion.choices[0].message.content
# New file type detection and processing functions without magic library
def detect_file_type(file_path):
"""
Detects if the input file is PDF or XML using file extension and content analysis.
"""
try:
# Check file extension
file_extension = os.path.splitext(file_path)[1].lower()
# Read the first few bytes of the file to check its content
with open(file_path, 'rb') as f:
header = f.read(8) # Read first 8 bytes
# Check for PDF signature
if file_extension == '.pdf' or header.startswith(b'%PDF'):
# Verify it's actually a PDF by trying to open it
try:
with open(file_path, 'rb') as f:
PyPDF2.PdfReader(f)
return 'pdf'
except:
return 'unknown'
# Check for XML
elif file_extension == '.xml':
# Try to parse as XML
try:
with open(file_path, 'r', encoding='utf-8') as f:
content_start = f.read(1024) # Read first 1KB
# Check for XML declaration or root element
if content_start.strip().startswith(('<?xml', '<')):
ET.parse(file_path) # Verify it's valid XML
return 'xml'
except:
return 'unknown'
return 'unknown'
except Exception as e:
print(f"Error detecting file type: {str(e)}")
return 'unknown'
def process_pdf(file_path):
"""
Extracts text content from PDF and splits it into meaningful chunks.
"""
try:
chunks = []
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text = page.extract_text()
# Split text into paragraphs
paragraphs = text.split('\n\n')
# Process each paragraph
for para_num, paragraph in enumerate(paragraphs):
if len(paragraph.strip()) > 0: # Skip empty paragraphs
chunk = {
'content': paragraph.strip(),
'metadata': {
'page_number': page_num + 1,
'paragraph_number': para_num + 1,
'source_type': 'pdf',
'file_name': os.path.basename(file_path)
}
}
chunks.append(chunk)
return chunks
except Exception as e:
print(f"Error processing PDF: {str(e)}")
return []
def add_to_vector_db(collection, chunks, embedder):
"""
Adds processed chunks to the vector database with proper metadata.
"""
try:
for i, chunk in enumerate(chunks):
# Create unique ID for each chunk
chunk_id = f"{chunk['metadata']['file_name']}_{chunk['metadata']['page_number']}_{chunk['metadata']['paragraph_number']}"
collection.add(
documents=[chunk['content']],
metadatas=[chunk['metadata']],
ids=[chunk_id]
)
except Exception as e:
print(f"Error adding to vector database: {str(e)}")
def process_file(file_path):
"""
Main function to process either PDF or XML file and add to vector database.
Also returns the raw node details for XML files.
"""
try:
# Initialize ChromaDB and embedding function
client = chromadb.Client()
embedder = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
# Create or get collection
collection = client.create_collection(
name="document_embeddings",
get_or_create=True
)
# Store for raw node details
raw_nodes = {}
# Detect file type
file_type = detect_file_type(file_path)
if file_type == 'pdf':
# Process PDF
chunks = process_pdf(file_path)
add_to_vector_db(collection, chunks, embedder)
elif file_type == 'xml':
# Parse XML and store raw nodes
raw_nodes = parse_nodes_to_dict(file_path)
# Convert to natural language for RAG
for node_id, details in raw_nodes.items():
nl_description = convert_to_natural_language(details)
# Add to vector DB
collection.add(
documents=[nl_description],
metadatas=[{"NodeId": node_id, "source_type": "xml"}],
ids=[node_id]
)
else:
raise ValueError("Unsupported file type")
return collection, raw_nodes
except Exception as e:
print(f"Error processing file: {str(e)}")
return None, {}
def generate_rag_response(query_text, context):
"""
Generates a RAG response using the Groq LLM based on the query and retrieved context.
Args:
query_text (str): The user's query
context (str): The retrieved context from the vector database
Returns:
str: The generated response from the LLM
"""
try:
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
messages = [
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on the provided context. "
"If the context doesn't contain relevant information, acknowledge that."
},
{
"role": "user",
"content": f"Answer the following query based on the provided context:\n\n"
f"Query: {query_text}\n\n"
f"Context: {context}"
}
]
chat_completion = client.chat.completions.create(
messages=messages,
model="llama3-8b-8192",
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"Error generating RAG response: {str(e)}")
return "Error generating response"
def find_similar_nodes(query_text, raw_nodes, top_k=5):
"""
Finds the most semantically similar nodes to the query using raw node content.
Args:
query_text (str): The user's query
raw_nodes (dict): Dictionary of node_id: node_details pairs
top_k (int): Number of top results to return
"""
try:
# Initialize the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Format node contents and create mapping
node_contents = {}
for node_id, details in raw_nodes.items():
formatted_content = format_node_content(details)
if formatted_content: # Only include nodes with content
node_contents[node_id] = formatted_content
# Generate embeddings for the query
query_embedding = model.encode([query_text])[0]
# Create a list of (node_id, content) tuples
nodes = list(node_contents.items())
contents = [content for _, content in nodes]
# Generate embeddings for all node contents
content_embeddings = model.encode(contents)
# Calculate cosine similarities
similarities = cosine_similarity([query_embedding], content_embeddings)[0]
# Get indices of top-k similar nodes
top_indices = np.argsort(similarities)[-top_k:][::-1]
# Format results
results = []
for idx in top_indices:
node_id, content = nodes[idx]
similarity_score = similarities[idx]
results.append({
'node_id': node_id,
'raw_content': content,
'original_details': raw_nodes[node_id],
'similarity_score': similarity_score
})
return results
except Exception as e:
print(f"Error finding similar nodes: {str(e)}")
return []
def query_documents(collection, raw_nodes, query_text, n_results=5):
"""
Query the vector database and perform semantic similarity search on raw nodes.
"""
try:
# Get results from vector database
results = collection.query(
query_texts=[query_text],
n_results=n_results
)
# Combine the retrieved results into context for RAG
retrieved_context = "\n".join(results["documents"][0])
# Generate RAG response
rag_response = generate_rag_response(query_text, retrieved_context)
# Find semantically similar nodes using raw node content
similar_nodes = find_similar_nodes(query_text, raw_nodes) if raw_nodes else []
# Format vector DB results
formatted_results = []
for i in range(len(results["documents"][0])):
result = {
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"score": results["distances"][0][i] if "distances" in results else None,
"rag_response": rag_response if i == 0 else None
}
formatted_results.append(result)
return formatted_results, similar_nodes
except Exception as e:
print(f"Error querying documents: {str(e)}")
return [], []
def main():
st.title("Document Query System")
st.write("Upload PDF or XML files and query their contents")
# File upload section
uploaded_files = st.file_uploader(
"Upload PDF or XML files",
type=['pdf', 'xml'],
accept_multiple_files=True
)
# Process uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
if uploaded_file.name not in st.session_state.processed_files:
with st.spinner(f'Processing {uploaded_file.name}...'):
collection, raw_nodes = process_file(uploaded_file)
if collection:
st.session_state.processed_files[uploaded_file.name] = {
'collection': collection,
'raw_nodes': raw_nodes
}
st.success(f"Successfully processed {uploaded_file.name}")
else:
st.error(f"Failed to process {uploaded_file.name}")
# File selection and querying section
if st.session_state.processed_files:
selected_file = st.selectbox(
"Select file to query",
options=list(st.session_state.processed_files.keys())
)
if selected_file:
st.session_state.current_collection = st.session_state.processed_files[selected_file]['collection']
st.session_state.current_raw_nodes = st.session_state.processed_files[selected_file]['raw_nodes']
query = st.text_input("Enter your query:")
if st.button("Search"):
if query:
with st.spinner('Searching...'):
results, similar_nodes = query_documents(
st.session_state.current_collection,
st.session_state.current_raw_nodes,
query
)
# Display RAG response
if results and results[0]['rag_response']:
st.subheader("Generated Answer")
st.write(results[0]['rag_response'])
# Display vector DB results
st.subheader("Search Results")
for i, result in enumerate(results, 1):
with st.expander(f"Match {i}"):
st.write(f"Content: {result['content']}")
st.write(f"Source: {result['metadata']['source_type']}")
if result['metadata']['source_type'] == 'pdf':
st.write(f"Page: {result['metadata']['page_number']}")
elif result['metadata']['source_type'] == 'xml':
st.write(f"NodeId: {result['metadata']['NodeId']}")
# Display semantic similarity results
if similar_nodes:
st.subheader("Similar Nodes")
for i, node in enumerate(similar_nodes, 1):
with st.expander(f"Similar Node {i}"):
st.write(f"NodeId: {node['node_id']}")
st.write(f"Description: {node['original_details'].get('Description', 'N/A')}")
st.write(f"DisplayName: {node['original_details'].get('DisplayName', 'N/A')}")
st.write(f"Value: {node['original_details'].get('Value', 'N/A')}")
st.write(f"Similarity Score: {node['similarity_score']:.4f}")
if __name__ == "__main__":
main()