import streamlit as st import os import tempfile from typing import Dict, List, Tuple import xml.etree.ElementTree as ET from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from groq import Groq import chromadb from chromadb.utils import embedding_functions import PyPDF2 import numpy as np # Initialize session state for storing processed files if 'processed_files' not in st.session_state: st.session_state.processed_files = {} if 'current_collection' not in st.session_state: st.session_state.current_collection = None if 'current_raw_nodes' not in st.session_state: st.session_state.current_raw_nodes = {} # Original XML processing functions remain unchanged def extract_node_details(element): """ Extracts details like description, value, NodeId, DisplayName, and references from an XML element. """ details = { "NodeId": element.attrib.get("NodeId", "N/A"), "Description": None, "DisplayName": None, "Value": None, "References": [] } for child in element: tag = child.tag.split('}')[-1] if tag == "Description": details["Description"] = child.text elif tag == "DisplayName": details["DisplayName"] = child.text elif tag == "Value": details["Value"] = extract_value_content(child) elif tag == "References": for reference in child: if reference.tag.split('}')[-1] == "Reference": details["References"].append(reference.attrib) return details def extract_value_content(value_element): """ Recursively extracts the content of a element, handling any embedded child elements. """ if not list(value_element): # No child elements, return text directly return value_element.text or "No value provided." # Process child elements content = [] for child in value_element: tag = child.tag.split('}')[-1] child_text = child.text.strip() if child.text else "" content.append(f"<{tag}>{child_text}") return "".join(content) def parse_nodes_to_dict(filename): """ Parses the XML file and saves node details into a dictionary. Each node's NodeId serves as the key, and the value is a dictionary of the node's details. """ tree = ET.parse(filename) root = tree.getroot() # Retrieve namespace from the root namespace = root.tag.split('}')[0].strip('{') # Node types to extract node_types = ["UAObject", "UAVariable", "UAObjectType"] nodes_dict = {} for node_type in node_types: for element in root.findall(f".//{{{namespace}}}{node_type}"): details = extract_node_details(element) node_id = details["NodeId"] if node_id != "N/A": nodes_dict[node_id] = details return nodes_dict def format_node_content(details): """ Formats raw node details into a single string for semantic comparison. """ content_parts = [] if details["Description"]: content_parts.append(f"Description: {details['Description']}") if details["DisplayName"]: content_parts.append(f"DisplayName: {details['DisplayName']}") if details["Value"]: content_parts.append(f"Value: {details['Value']}") return " | ".join(content_parts) def convert_to_natural_language(details): """ Converts node details to natural language using Groq LLM. """ client = Groq(api_key=os.getenv("GROQ_API_KEY")) messages = [ { "role": "user", "content": f"Convert the following node details to natural language: {details}" } ] chat_completion = client.chat.completions.create( messages=messages, model="llama3-8b-8192", ) return chat_completion.choices[0].message.content # New file type detection and processing functions without magic library def detect_file_type(file_path): """ Detects if the input file is PDF or XML using file extension and content analysis. """ try: # Check file extension file_extension = os.path.splitext(file_path)[1].lower() # Read the first few bytes of the file to check its content with open(file_path, 'rb') as f: header = f.read(8) # Read first 8 bytes # Check for PDF signature if file_extension == '.pdf' or header.startswith(b'%PDF'): # Verify it's actually a PDF by trying to open it try: with open(file_path, 'rb') as f: PyPDF2.PdfReader(f) return 'pdf' except: return 'unknown' # Check for XML elif file_extension == '.xml': # Try to parse as XML try: with open(file_path, 'r', encoding='utf-8') as f: content_start = f.read(1024) # Read first 1KB # Check for XML declaration or root element if content_start.strip().startswith((' 0: # Skip empty paragraphs chunk = { 'content': paragraph.strip(), 'metadata': { 'page_number': page_num + 1, 'paragraph_number': para_num + 1, 'source_type': 'pdf', 'file_name': os.path.basename(file_path) } } chunks.append(chunk) return chunks except Exception as e: print(f"Error processing PDF: {str(e)}") return [] def add_to_vector_db(collection, chunks, embedder): """ Adds processed chunks to the vector database with proper metadata. """ try: for i, chunk in enumerate(chunks): # Create unique ID for each chunk chunk_id = f"{chunk['metadata']['file_name']}_{chunk['metadata']['page_number']}_{chunk['metadata']['paragraph_number']}" collection.add( documents=[chunk['content']], metadatas=[chunk['metadata']], ids=[chunk_id] ) except Exception as e: print(f"Error adding to vector database: {str(e)}") def process_file(file_path): """ Main function to process either PDF or XML file and add to vector database. Also returns the raw node details for XML files. """ try: # Initialize ChromaDB and embedding function client = chromadb.Client() embedder = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="all-MiniLM-L6-v2" ) # Create or get collection collection = client.create_collection( name="document_embeddings", get_or_create=True ) # Store for raw node details raw_nodes = {} # Detect file type file_type = detect_file_type(file_path) if file_type == 'pdf': # Process PDF chunks = process_pdf(file_path) add_to_vector_db(collection, chunks, embedder) elif file_type == 'xml': # Parse XML and store raw nodes raw_nodes = parse_nodes_to_dict(file_path) # Convert to natural language for RAG for node_id, details in raw_nodes.items(): nl_description = convert_to_natural_language(details) # Add to vector DB collection.add( documents=[nl_description], metadatas=[{"NodeId": node_id, "source_type": "xml"}], ids=[node_id] ) else: raise ValueError("Unsupported file type") return collection, raw_nodes except Exception as e: print(f"Error processing file: {str(e)}") return None, {} def generate_rag_response(query_text, context): """ Generates a RAG response using the Groq LLM based on the query and retrieved context. Args: query_text (str): The user's query context (str): The retrieved context from the vector database Returns: str: The generated response from the LLM """ try: client = Groq(api_key=os.getenv("GROQ_API_KEY")) messages = [ { "role": "system", "content": "You are a helpful assistant that answers questions based on the provided context. " "If the context doesn't contain relevant information, acknowledge that." }, { "role": "user", "content": f"Answer the following query based on the provided context:\n\n" f"Query: {query_text}\n\n" f"Context: {context}" } ] chat_completion = client.chat.completions.create( messages=messages, model="llama3-8b-8192", ) return chat_completion.choices[0].message.content except Exception as e: print(f"Error generating RAG response: {str(e)}") return "Error generating response" def find_similar_nodes(query_text, raw_nodes, top_k=5): """ Finds the most semantically similar nodes to the query using raw node content. Args: query_text (str): The user's query raw_nodes (dict): Dictionary of node_id: node_details pairs top_k (int): Number of top results to return """ try: # Initialize the sentence transformer model model = SentenceTransformer('all-MiniLM-L6-v2') # Format node contents and create mapping node_contents = {} for node_id, details in raw_nodes.items(): formatted_content = format_node_content(details) if formatted_content: # Only include nodes with content node_contents[node_id] = formatted_content # Generate embeddings for the query query_embedding = model.encode([query_text])[0] # Create a list of (node_id, content) tuples nodes = list(node_contents.items()) contents = [content for _, content in nodes] # Generate embeddings for all node contents content_embeddings = model.encode(contents) # Calculate cosine similarities similarities = cosine_similarity([query_embedding], content_embeddings)[0] # Get indices of top-k similar nodes top_indices = np.argsort(similarities)[-top_k:][::-1] # Format results results = [] for idx in top_indices: node_id, content = nodes[idx] similarity_score = similarities[idx] results.append({ 'node_id': node_id, 'raw_content': content, 'original_details': raw_nodes[node_id], 'similarity_score': similarity_score }) return results except Exception as e: print(f"Error finding similar nodes: {str(e)}") return [] def query_documents(collection, raw_nodes, query_text, n_results=5): """ Query the vector database and perform semantic similarity search on raw nodes. """ try: # Get results from vector database results = collection.query( query_texts=[query_text], n_results=n_results ) # Combine the retrieved results into context for RAG retrieved_context = "\n".join(results["documents"][0]) # Generate RAG response rag_response = generate_rag_response(query_text, retrieved_context) # Find semantically similar nodes using raw node content similar_nodes = find_similar_nodes(query_text, raw_nodes) if raw_nodes else [] # Format vector DB results formatted_results = [] for i in range(len(results["documents"][0])): result = { "content": results["documents"][0][i], "metadata": results["metadatas"][0][i], "score": results["distances"][0][i] if "distances" in results else None, "rag_response": rag_response if i == 0 else None } formatted_results.append(result) return formatted_results, similar_nodes except Exception as e: print(f"Error querying documents: {str(e)}") return [], [] def main(): st.title("Document Query System") st.write("Upload PDF or XML files and query their contents") # File upload section uploaded_files = st.file_uploader( "Upload PDF or XML files", type=['pdf', 'xml'], accept_multiple_files=True ) # Process uploaded files if uploaded_files: for uploaded_file in uploaded_files: if uploaded_file.name not in st.session_state.processed_files: with st.spinner(f'Processing {uploaded_file.name}...'): collection, raw_nodes = process_file(uploaded_file) if collection: st.session_state.processed_files[uploaded_file.name] = { 'collection': collection, 'raw_nodes': raw_nodes } st.success(f"Successfully processed {uploaded_file.name}") else: st.error(f"Failed to process {uploaded_file.name}") # File selection and querying section if st.session_state.processed_files: selected_file = st.selectbox( "Select file to query", options=list(st.session_state.processed_files.keys()) ) if selected_file: st.session_state.current_collection = st.session_state.processed_files[selected_file]['collection'] st.session_state.current_raw_nodes = st.session_state.processed_files[selected_file]['raw_nodes'] query = st.text_input("Enter your query:") if st.button("Search"): if query: with st.spinner('Searching...'): results, similar_nodes = query_documents( st.session_state.current_collection, st.session_state.current_raw_nodes, query ) # Display RAG response if results and results[0]['rag_response']: st.subheader("Generated Answer") st.write(results[0]['rag_response']) # Display vector DB results st.subheader("Search Results") for i, result in enumerate(results, 1): with st.expander(f"Match {i}"): st.write(f"Content: {result['content']}") st.write(f"Source: {result['metadata']['source_type']}") if result['metadata']['source_type'] == 'pdf': st.write(f"Page: {result['metadata']['page_number']}") elif result['metadata']['source_type'] == 'xml': st.write(f"NodeId: {result['metadata']['NodeId']}") # Display semantic similarity results if similar_nodes: st.subheader("Similar Nodes") for i, node in enumerate(similar_nodes, 1): with st.expander(f"Similar Node {i}"): st.write(f"NodeId: {node['node_id']}") st.write(f"Description: {node['original_details'].get('Description', 'N/A')}") st.write(f"DisplayName: {node['original_details'].get('DisplayName', 'N/A')}") st.write(f"Value: {node['original_details'].get('Value', 'N/A')}") st.write(f"Similarity Score: {node['similarity_score']:.4f}") if __name__ == "__main__": main()