TahaRasouli commited on
Commit
660dd5e
·
verified ·
1 Parent(s): a0be55a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -110
app.py CHANGED
@@ -1,14 +1,29 @@
1
  import streamlit as st
 
 
 
 
 
 
 
2
  import chromadb
3
  from chromadb.utils import embedding_functions
4
- from chromadb.config import Settings
5
- from groq import Groq
6
- import xml.etree.ElementTree as ET
7
- from datetime import datetime
8
- import os
9
 
10
- # Reuse the helper functions from the original script
 
 
 
 
 
 
 
 
11
  def extract_node_details(element):
 
 
 
12
  details = {
13
  "NodeId": element.attrib.get("NodeId", "N/A"),
14
  "Description": None,
@@ -31,8 +46,12 @@ def extract_node_details(element):
31
  return details
32
 
33
  def extract_value_content(value_element):
34
- if not list(value_element):
 
 
 
35
  return value_element.text or "No value provided."
 
36
  content = []
37
  for child in value_element:
38
  tag = child.tag.split('}')[-1]
@@ -40,10 +59,16 @@ def extract_value_content(value_element):
40
  content.append(f"<{tag}>{child_text}</{tag}>")
41
  return "".join(content)
42
 
43
- def parse_nodes_to_dict(uploaded_file):
44
- tree = ET.parse(uploaded_file)
 
 
 
 
45
  root = tree.getroot()
 
46
  namespace = root.tag.split('}')[0].strip('{')
 
47
  node_types = ["UAObject", "UAVariable", "UAObjectType"]
48
  nodes_dict = {}
49
  for node_type in node_types:
@@ -54,8 +79,28 @@ def parse_nodes_to_dict(uploaded_file):
54
  nodes_dict[node_id] = details
55
  return nodes_dict
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def convert_to_natural_language(details):
58
- client = Groq(api_key=st.secrets["GROQ_API_KEY"])
 
 
 
59
  messages = [
60
  {
61
  "role": "user",
@@ -68,112 +113,357 @@ def convert_to_natural_language(details):
68
  )
69
  return chat_completion.choices[0].message.content
70
 
71
- # Streamlit app
72
- def main():
73
- st.title("OPC UA Node Query System")
74
-
75
- # Create persistent storage directory
76
- os.makedirs("chroma_db", exist_ok=True)
77
-
78
- # Initialize ChromaDB with persistent storage
79
- chroma_client = chromadb.Client(Settings(
80
- chroma_db_impl="duckdb+parquet",
81
- persist_directory="chroma_db"
82
- ))
83
-
84
- # Initialize session state
85
- if 'collection_name' not in st.session_state:
86
- st.session_state.collection_name = None
87
- if 'initialized' not in st.session_state:
88
- st.session_state.initialized = False
89
-
90
- # File upload
91
- uploaded_file = st.file_uploader("Upload OPC UA XML file", type=['xml'])
92
-
93
- if uploaded_file and not st.session_state.initialized:
94
- with st.spinner("Processing XML file and initializing database..."):
95
  try:
96
- # Parse nodes
97
- nodes_dict = parse_nodes_to_dict(uploaded_file)
98
-
99
- # Convert to natural language
100
- node_NL = {}
101
- for node_id, details in nodes_dict.items():
102
- nl_description = convert_to_natural_language(details)
103
- node_NL[node_id] = nl_description
104
-
105
- # Create collection with unique name
106
- collection_name = f"node_embeddings_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
107
- collection = chroma_client.create_collection(
108
- name=collection_name,
109
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
110
- model_name="all-MiniLM-L6-v2"
111
- )
112
- )
113
-
114
- # Add nodes to ChromaDB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  collection.add(
116
- documents=[desc for desc in node_NL.values()],
117
- metadatas=[{"NodeId": node_id} for node_id in node_NL.keys()],
118
- ids=[node_id for node_id in node_NL.keys()]
119
  )
120
-
121
- # Persist the database
122
- st.session_state.collection_name = collection_name
123
- st.session_state.initialized = True
124
- st.success("Database initialized successfully!")
125
-
126
- except Exception as e:
127
- st.error(f"An error occurred: {str(e)}")
 
 
 
 
128
 
129
- # Query section
130
- if st.session_state.initialized and st.session_state.collection_name:
131
- st.header("Query Nodes")
132
-
133
- # Get the existing collection
134
- collection = chroma_client.get_collection(
135
- name=st.session_state.collection_name,
136
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
137
- model_name="all-MiniLM-L6-v2"
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  )
140
-
141
- user_query = st.text_input("Enter your query:")
142
-
143
- if user_query:
144
- with st.spinner("Searching and generating response..."):
145
- try:
146
- # Retrieve matches
147
- results = collection.query(
148
- query_texts=[user_query],
149
- n_results=5
150
- )
151
-
152
- # Display results
153
- st.subheader("Top Matches")
154
- for i, (doc, metadata) in enumerate(zip(results["documents"][0], results["metadatas"][0]), 1):
155
- with st.expander(f"Match {i}: NodeId = {metadata['NodeId']}"):
156
- st.write(doc)
157
-
158
- # Generate LLM response
159
- retrieved_context = "\n".join(results["documents"][0])
160
- client = Groq(api_key=st.secrets["GROQ_API_KEY"])
161
- messages = [
162
- {
163
- "role": "user",
164
- "content": f"Answer the following query based on the provided context:\n\nQuery: {user_query}\n\nContext: {retrieved_context}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  }
166
- ]
167
- chat_completion = client.chat.completions.create(
168
- messages=messages,
169
- model="llama3-8b-8192",
170
- )
171
-
172
- st.subheader("Generated Answer")
173
- st.write(chat_completion.choices[0].message.content)
174
-
175
- except Exception as e:
176
- st.error(f"An error occurred during query: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  if __name__ == "__main__":
179
  main()
 
1
  import streamlit as st
2
+ import os
3
+ import tempfile
4
+ from typing import Dict, List, Tuple
5
+ import xml.etree.ElementTree as ET
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from groq import Groq
9
  import chromadb
10
  from chromadb.utils import embedding_functions
11
+ import PyPDF2
12
+ import numpy as np
 
 
 
13
 
14
+ # Initialize session state for storing processed files
15
+ if 'processed_files' not in st.session_state:
16
+ st.session_state.processed_files = {}
17
+ if 'current_collection' not in st.session_state:
18
+ st.session_state.current_collection = None
19
+ if 'current_raw_nodes' not in st.session_state:
20
+ st.session_state.current_raw_nodes = {}
21
+
22
+ # Original XML processing functions remain unchanged
23
  def extract_node_details(element):
24
+ """
25
+ Extracts details like description, value, NodeId, DisplayName, and references from an XML element.
26
+ """
27
  details = {
28
  "NodeId": element.attrib.get("NodeId", "N/A"),
29
  "Description": None,
 
46
  return details
47
 
48
  def extract_value_content(value_element):
49
+ """
50
+ Recursively extracts the content of a <Value> element, handling any embedded child elements.
51
+ """
52
+ if not list(value_element): # No child elements, return text directly
53
  return value_element.text or "No value provided."
54
+ # Process child elements
55
  content = []
56
  for child in value_element:
57
  tag = child.tag.split('}')[-1]
 
59
  content.append(f"<{tag}>{child_text}</{tag}>")
60
  return "".join(content)
61
 
62
+ def parse_nodes_to_dict(filename):
63
+ """
64
+ Parses the XML file and saves node details into a dictionary.
65
+ Each node's NodeId serves as the key, and the value is a dictionary of the node's details.
66
+ """
67
+ tree = ET.parse(filename)
68
  root = tree.getroot()
69
+ # Retrieve namespace from the root
70
  namespace = root.tag.split('}')[0].strip('{')
71
+ # Node types to extract
72
  node_types = ["UAObject", "UAVariable", "UAObjectType"]
73
  nodes_dict = {}
74
  for node_type in node_types:
 
79
  nodes_dict[node_id] = details
80
  return nodes_dict
81
 
82
+
83
+ def format_node_content(details):
84
+ """
85
+ Formats raw node details into a single string for semantic comparison.
86
+ """
87
+ content_parts = []
88
+
89
+ if details["Description"]:
90
+ content_parts.append(f"Description: {details['Description']}")
91
+ if details["DisplayName"]:
92
+ content_parts.append(f"DisplayName: {details['DisplayName']}")
93
+ if details["Value"]:
94
+ content_parts.append(f"Value: {details['Value']}")
95
+
96
+ return " | ".join(content_parts)
97
+
98
+
99
  def convert_to_natural_language(details):
100
+ """
101
+ Converts node details to natural language using Groq LLM.
102
+ """
103
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
104
  messages = [
105
  {
106
  "role": "user",
 
113
  )
114
  return chat_completion.choices[0].message.content
115
 
116
+ # New file type detection and processing functions without magic library
117
+ def detect_file_type(file_path):
118
+ """
119
+ Detects if the input file is PDF or XML using file extension and content analysis.
120
+ """
121
+ try:
122
+ # Check file extension
123
+ file_extension = os.path.splitext(file_path)[1].lower()
124
+
125
+ # Read the first few bytes of the file to check its content
126
+ with open(file_path, 'rb') as f:
127
+ header = f.read(8) # Read first 8 bytes
128
+
129
+ # Check for PDF signature
130
+ if file_extension == '.pdf' or header.startswith(b'%PDF'):
131
+ # Verify it's actually a PDF by trying to open it
 
 
 
 
 
 
 
 
132
  try:
133
+ with open(file_path, 'rb') as f:
134
+ PyPDF2.PdfReader(f)
135
+ return 'pdf'
136
+ except:
137
+ return 'unknown'
138
+
139
+ # Check for XML
140
+ elif file_extension == '.xml':
141
+ # Try to parse as XML
142
+ try:
143
+ with open(file_path, 'r', encoding='utf-8') as f:
144
+ content_start = f.read(1024) # Read first 1KB
145
+ # Check for XML declaration or root element
146
+ if content_start.strip().startswith(('<?xml', '<')):
147
+ ET.parse(file_path) # Verify it's valid XML
148
+ return 'xml'
149
+ except:
150
+ return 'unknown'
151
+
152
+ return 'unknown'
153
+
154
+ except Exception as e:
155
+ print(f"Error detecting file type: {str(e)}")
156
+ return 'unknown'
157
+
158
+ def process_pdf(file_path):
159
+ """
160
+ Extracts text content from PDF and splits it into meaningful chunks.
161
+ """
162
+ try:
163
+ chunks = []
164
+ with open(file_path, 'rb') as file:
165
+ pdf_reader = PyPDF2.PdfReader(file)
166
+
167
+ for page_num in range(len(pdf_reader.pages)):
168
+ page = pdf_reader.pages[page_num]
169
+ text = page.extract_text()
170
+
171
+ # Split text into paragraphs
172
+ paragraphs = text.split('\n\n')
173
+
174
+ # Process each paragraph
175
+ for para_num, paragraph in enumerate(paragraphs):
176
+ if len(paragraph.strip()) > 0: # Skip empty paragraphs
177
+ chunk = {
178
+ 'content': paragraph.strip(),
179
+ 'metadata': {
180
+ 'page_number': page_num + 1,
181
+ 'paragraph_number': para_num + 1,
182
+ 'source_type': 'pdf',
183
+ 'file_name': os.path.basename(file_path)
184
+ }
185
+ }
186
+ chunks.append(chunk)
187
+
188
+ return chunks
189
+
190
+ except Exception as e:
191
+ print(f"Error processing PDF: {str(e)}")
192
+ return []
193
+
194
+ def add_to_vector_db(collection, chunks, embedder):
195
+ """
196
+ Adds processed chunks to the vector database with proper metadata.
197
+ """
198
+ try:
199
+ for i, chunk in enumerate(chunks):
200
+ # Create unique ID for each chunk
201
+ chunk_id = f"{chunk['metadata']['file_name']}_{chunk['metadata']['page_number']}_{chunk['metadata']['paragraph_number']}"
202
+
203
+ collection.add(
204
+ documents=[chunk['content']],
205
+ metadatas=[chunk['metadata']],
206
+ ids=[chunk_id]
207
+ )
208
+
209
+ except Exception as e:
210
+ print(f"Error adding to vector database: {str(e)}")
211
+
212
+ def process_file(file_path):
213
+ """
214
+ Main function to process either PDF or XML file and add to vector database.
215
+ Also returns the raw node details for XML files.
216
+ """
217
+ try:
218
+ # Initialize ChromaDB and embedding function
219
+ client = chromadb.Client()
220
+ embedder = embedding_functions.SentenceTransformerEmbeddingFunction(
221
+ model_name="all-MiniLM-L6-v2"
222
+ )
223
+
224
+ # Create or get collection
225
+ collection = client.create_collection(
226
+ name="document_embeddings",
227
+ get_or_create=True
228
+ )
229
+
230
+ # Store for raw node details
231
+ raw_nodes = {}
232
+
233
+ # Detect file type
234
+ file_type = detect_file_type(file_path)
235
+
236
+ if file_type == 'pdf':
237
+ # Process PDF
238
+ chunks = process_pdf(file_path)
239
+ add_to_vector_db(collection, chunks, embedder)
240
+
241
+ elif file_type == 'xml':
242
+ # Parse XML and store raw nodes
243
+ raw_nodes = parse_nodes_to_dict(file_path)
244
+
245
+ # Convert to natural language for RAG
246
+ for node_id, details in raw_nodes.items():
247
+ nl_description = convert_to_natural_language(details)
248
+
249
+ # Add to vector DB
250
  collection.add(
251
+ documents=[nl_description],
252
+ metadatas=[{"NodeId": node_id, "source_type": "xml"}],
253
+ ids=[node_id]
254
  )
255
+ else:
256
+ raise ValueError("Unsupported file type")
257
+
258
+ return collection, raw_nodes
259
+
260
+ except Exception as e:
261
+ print(f"Error processing file: {str(e)}")
262
+ return None, {}
263
+
264
+ def generate_rag_response(query_text, context):
265
+ """
266
+ Generates a RAG response using the Groq LLM based on the query and retrieved context.
267
 
268
+ Args:
269
+ query_text (str): The user's query
270
+ context (str): The retrieved context from the vector database
271
+
272
+ Returns:
273
+ str: The generated response from the LLM
274
+ """
275
+ try:
276
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
277
+ messages = [
278
+ {
279
+ "role": "system",
280
+ "content": "You are a helpful assistant that answers questions based on the provided context. "
281
+ "If the context doesn't contain relevant information, acknowledge that."
282
+ },
283
+ {
284
+ "role": "user",
285
+ "content": f"Answer the following query based on the provided context:\n\n"
286
+ f"Query: {query_text}\n\n"
287
+ f"Context: {context}"
288
+ }
289
+ ]
290
+
291
+ chat_completion = client.chat.completions.create(
292
+ messages=messages,
293
+ model="llama3-8b-8192",
294
  )
295
+
296
+ return chat_completion.choices[0].message.content
297
+
298
+ except Exception as e:
299
+ print(f"Error generating RAG response: {str(e)}")
300
+ return "Error generating response"
301
+
302
+
303
+ def find_similar_nodes(query_text, raw_nodes, top_k=5):
304
+ """
305
+ Finds the most semantically similar nodes to the query using raw node content.
306
+
307
+ Args:
308
+ query_text (str): The user's query
309
+ raw_nodes (dict): Dictionary of node_id: node_details pairs
310
+ top_k (int): Number of top results to return
311
+ """
312
+ try:
313
+ # Initialize the sentence transformer model
314
+ model = SentenceTransformer('all-MiniLM-L6-v2')
315
+
316
+ # Format node contents and create mapping
317
+ node_contents = {}
318
+ for node_id, details in raw_nodes.items():
319
+ formatted_content = format_node_content(details)
320
+ if formatted_content: # Only include nodes with content
321
+ node_contents[node_id] = formatted_content
322
+
323
+ # Generate embeddings for the query
324
+ query_embedding = model.encode([query_text])[0]
325
+
326
+ # Create a list of (node_id, content) tuples
327
+ nodes = list(node_contents.items())
328
+ contents = [content for _, content in nodes]
329
+
330
+ # Generate embeddings for all node contents
331
+ content_embeddings = model.encode(contents)
332
+
333
+ # Calculate cosine similarities
334
+ similarities = cosine_similarity([query_embedding], content_embeddings)[0]
335
+
336
+ # Get indices of top-k similar nodes
337
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
338
+
339
+ # Format results
340
+ results = []
341
+ for idx in top_indices:
342
+ node_id, content = nodes[idx]
343
+ similarity_score = similarities[idx]
344
+ results.append({
345
+ 'node_id': node_id,
346
+ 'raw_content': content,
347
+ 'original_details': raw_nodes[node_id],
348
+ 'similarity_score': similarity_score
349
+ })
350
+
351
+ return results
352
+
353
+ except Exception as e:
354
+ print(f"Error finding similar nodes: {str(e)}")
355
+ return []
356
+
357
+ def query_documents(collection, raw_nodes, query_text, n_results=5):
358
+ """
359
+ Query the vector database and perform semantic similarity search on raw nodes.
360
+ """
361
+ try:
362
+ # Get results from vector database
363
+ results = collection.query(
364
+ query_texts=[query_text],
365
+ n_results=n_results
366
+ )
367
+
368
+ # Combine the retrieved results into context for RAG
369
+ retrieved_context = "\n".join(results["documents"][0])
370
+
371
+ # Generate RAG response
372
+ rag_response = generate_rag_response(query_text, retrieved_context)
373
+
374
+ # Find semantically similar nodes using raw node content
375
+ similar_nodes = find_similar_nodes(query_text, raw_nodes) if raw_nodes else []
376
+
377
+ # Format vector DB results
378
+ formatted_results = []
379
+ for i in range(len(results["documents"][0])):
380
+ result = {
381
+ "content": results["documents"][0][i],
382
+ "metadata": results["metadatas"][0][i],
383
+ "score": results["distances"][0][i] if "distances" in results else None,
384
+ "rag_response": rag_response if i == 0 else None
385
+ }
386
+ formatted_results.append(result)
387
+
388
+ return formatted_results, similar_nodes
389
+
390
+ except Exception as e:
391
+ print(f"Error querying documents: {str(e)}")
392
+ return [], []
393
+
394
+ def main():
395
+ st.title("Document Query System")
396
+ st.write("Upload PDF or XML files and query their contents")
397
+
398
+ # File upload section
399
+ uploaded_files = st.file_uploader(
400
+ "Upload PDF or XML files",
401
+ type=['pdf', 'xml'],
402
+ accept_multiple_files=True
403
+ )
404
+
405
+ # Process uploaded files
406
+ if uploaded_files:
407
+ for uploaded_file in uploaded_files:
408
+ if uploaded_file.name not in st.session_state.processed_files:
409
+ with st.spinner(f'Processing {uploaded_file.name}...'):
410
+ collection, raw_nodes = process_file(uploaded_file)
411
+ if collection:
412
+ st.session_state.processed_files[uploaded_file.name] = {
413
+ 'collection': collection,
414
+ 'raw_nodes': raw_nodes
415
  }
416
+ st.success(f"Successfully processed {uploaded_file.name}")
417
+ else:
418
+ st.error(f"Failed to process {uploaded_file.name}")
419
+
420
+ # File selection and querying section
421
+ if st.session_state.processed_files:
422
+ selected_file = st.selectbox(
423
+ "Select file to query",
424
+ options=list(st.session_state.processed_files.keys())
425
+ )
426
+
427
+ if selected_file:
428
+ st.session_state.current_collection = st.session_state.processed_files[selected_file]['collection']
429
+ st.session_state.current_raw_nodes = st.session_state.processed_files[selected_file]['raw_nodes']
430
+
431
+ query = st.text_input("Enter your query:")
432
+ if st.button("Search"):
433
+ if query:
434
+ with st.spinner('Searching...'):
435
+ results, similar_nodes = query_documents(
436
+ st.session_state.current_collection,
437
+ st.session_state.current_raw_nodes,
438
+ query
439
+ )
440
+
441
+ # Display RAG response
442
+ if results and results[0]['rag_response']:
443
+ st.subheader("Generated Answer")
444
+ st.write(results[0]['rag_response'])
445
+
446
+ # Display vector DB results
447
+ st.subheader("Search Results")
448
+ for i, result in enumerate(results, 1):
449
+ with st.expander(f"Match {i}"):
450
+ st.write(f"Content: {result['content']}")
451
+ st.write(f"Source: {result['metadata']['source_type']}")
452
+ if result['metadata']['source_type'] == 'pdf':
453
+ st.write(f"Page: {result['metadata']['page_number']}")
454
+ elif result['metadata']['source_type'] == 'xml':
455
+ st.write(f"NodeId: {result['metadata']['NodeId']}")
456
+
457
+ # Display semantic similarity results
458
+ if similar_nodes:
459
+ st.subheader("Similar Nodes")
460
+ for i, node in enumerate(similar_nodes, 1):
461
+ with st.expander(f"Similar Node {i}"):
462
+ st.write(f"NodeId: {node['node_id']}")
463
+ st.write(f"Description: {node['original_details'].get('Description', 'N/A')}")
464
+ st.write(f"DisplayName: {node['original_details'].get('DisplayName', 'N/A')}")
465
+ st.write(f"Value: {node['original_details'].get('Value', 'N/A')}")
466
+ st.write(f"Similarity Score: {node['similarity_score']:.4f}")
467
 
468
  if __name__ == "__main__":
469
  main()