TahaRasouli commited on
Commit
a0be55a
·
verified ·
1 Parent(s): dd82072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -60
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import streamlit as st
2
  import chromadb
3
  from chromadb.utils import embedding_functions
 
4
  from groq import Groq
5
  import xml.etree.ElementTree as ET
6
  from datetime import datetime
 
7
 
8
  # Reuse the helper functions from the original script
9
  def extract_node_details(element):
@@ -70,9 +72,18 @@ def convert_to_natural_language(details):
70
  def main():
71
  st.title("OPC UA Node Query System")
72
 
 
 
 
 
 
 
 
 
 
73
  # Initialize session state
74
- if 'collection' not in st.session_state:
75
- st.session_state.collection = None
76
  if 'initialized' not in st.session_state:
77
  st.session_state.initialized = False
78
 
@@ -81,72 +92,88 @@ def main():
81
 
82
  if uploaded_file and not st.session_state.initialized:
83
  with st.spinner("Processing XML file and initializing database..."):
84
- # Parse nodes
85
- nodes_dict = parse_nodes_to_dict(uploaded_file)
86
-
87
- # Convert to natural language
88
- node_NL = {}
89
- for node_id, details in nodes_dict.items():
90
- nl_description = convert_to_natural_language(details)
91
- node_NL[node_id] = nl_description
92
-
93
- # Initialize ChromaDB
94
- client = chromadb.Client()
95
-
96
- # Create collection
97
- collection = client.create_collection(
98
- name=f"node_embeddings_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
99
- embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
100
- model_name="all-MiniLM-L6-v2"
101
  )
102
- )
103
-
104
- # Add nodes to ChromaDB
105
- collection.add(
106
- documents=[desc for desc in node_NL.values()],
107
- metadatas=[{"NodeId": node_id} for node_id in node_NL.keys()],
108
- ids=[node_id for node_id in node_NL.keys()]
109
- )
110
-
111
- st.session_state.collection = collection
112
- st.session_state.initialized = True
113
- st.success("Database initialized successfully!")
 
 
 
114
 
115
  # Query section
116
- if st.session_state.initialized:
117
  st.header("Query Nodes")
 
 
 
 
 
 
 
 
 
118
  user_query = st.text_input("Enter your query:")
119
 
120
  if user_query:
121
  with st.spinner("Searching and generating response..."):
122
- # Retrieve matches
123
- results = st.session_state.collection.query(
124
- query_texts=[user_query],
125
- n_results=5
126
- )
127
-
128
- # Display results
129
- st.subheader("Top Matches")
130
- for i, (doc, metadata) in enumerate(zip(results["documents"][0], results["metadatas"][0]), 1):
131
- with st.expander(f"Match {i}: NodeId = {metadata['NodeId']}"):
132
- st.write(doc)
133
-
134
- # Generate LLM response
135
- retrieved_context = "\n".join(results["documents"][0])
136
- client = Groq(api_key=st.secrets["GROQ_API_KEY"])
137
- messages = [
138
- {
139
- "role": "user",
140
- "content": f"Answer the following query based on the provided context:\n\nQuery: {user_query}\n\nContext: {retrieved_context}"
141
- }
142
- ]
143
- chat_completion = client.chat.completions.create(
144
- messages=messages,
145
- model="llama3-8b-8192",
146
- )
147
-
148
- st.subheader("Generated Answer")
149
- st.write(chat_completion.choices[0].message.content)
 
 
 
 
150
 
151
  if __name__ == "__main__":
152
  main()
 
1
  import streamlit as st
2
  import chromadb
3
  from chromadb.utils import embedding_functions
4
+ from chromadb.config import Settings
5
  from groq import Groq
6
  import xml.etree.ElementTree as ET
7
  from datetime import datetime
8
+ import os
9
 
10
  # Reuse the helper functions from the original script
11
  def extract_node_details(element):
 
72
  def main():
73
  st.title("OPC UA Node Query System")
74
 
75
+ # Create persistent storage directory
76
+ os.makedirs("chroma_db", exist_ok=True)
77
+
78
+ # Initialize ChromaDB with persistent storage
79
+ chroma_client = chromadb.Client(Settings(
80
+ chroma_db_impl="duckdb+parquet",
81
+ persist_directory="chroma_db"
82
+ ))
83
+
84
  # Initialize session state
85
+ if 'collection_name' not in st.session_state:
86
+ st.session_state.collection_name = None
87
  if 'initialized' not in st.session_state:
88
  st.session_state.initialized = False
89
 
 
92
 
93
  if uploaded_file and not st.session_state.initialized:
94
  with st.spinner("Processing XML file and initializing database..."):
95
+ try:
96
+ # Parse nodes
97
+ nodes_dict = parse_nodes_to_dict(uploaded_file)
98
+
99
+ # Convert to natural language
100
+ node_NL = {}
101
+ for node_id, details in nodes_dict.items():
102
+ nl_description = convert_to_natural_language(details)
103
+ node_NL[node_id] = nl_description
104
+
105
+ # Create collection with unique name
106
+ collection_name = f"node_embeddings_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
107
+ collection = chroma_client.create_collection(
108
+ name=collection_name,
109
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
110
+ model_name="all-MiniLM-L6-v2"
111
+ )
112
  )
113
+
114
+ # Add nodes to ChromaDB
115
+ collection.add(
116
+ documents=[desc for desc in node_NL.values()],
117
+ metadatas=[{"NodeId": node_id} for node_id in node_NL.keys()],
118
+ ids=[node_id for node_id in node_NL.keys()]
119
+ )
120
+
121
+ # Persist the database
122
+ st.session_state.collection_name = collection_name
123
+ st.session_state.initialized = True
124
+ st.success("Database initialized successfully!")
125
+
126
+ except Exception as e:
127
+ st.error(f"An error occurred: {str(e)}")
128
 
129
  # Query section
130
+ if st.session_state.initialized and st.session_state.collection_name:
131
  st.header("Query Nodes")
132
+
133
+ # Get the existing collection
134
+ collection = chroma_client.get_collection(
135
+ name=st.session_state.collection_name,
136
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
137
+ model_name="all-MiniLM-L6-v2"
138
+ )
139
+ )
140
+
141
  user_query = st.text_input("Enter your query:")
142
 
143
  if user_query:
144
  with st.spinner("Searching and generating response..."):
145
+ try:
146
+ # Retrieve matches
147
+ results = collection.query(
148
+ query_texts=[user_query],
149
+ n_results=5
150
+ )
151
+
152
+ # Display results
153
+ st.subheader("Top Matches")
154
+ for i, (doc, metadata) in enumerate(zip(results["documents"][0], results["metadatas"][0]), 1):
155
+ with st.expander(f"Match {i}: NodeId = {metadata['NodeId']}"):
156
+ st.write(doc)
157
+
158
+ # Generate LLM response
159
+ retrieved_context = "\n".join(results["documents"][0])
160
+ client = Groq(api_key=st.secrets["GROQ_API_KEY"])
161
+ messages = [
162
+ {
163
+ "role": "user",
164
+ "content": f"Answer the following query based on the provided context:\n\nQuery: {user_query}\n\nContext: {retrieved_context}"
165
+ }
166
+ ]
167
+ chat_completion = client.chat.completions.create(
168
+ messages=messages,
169
+ model="llama3-8b-8192",
170
+ )
171
+
172
+ st.subheader("Generated Answer")
173
+ st.write(chat_completion.choices[0].message.content)
174
+
175
+ except Exception as e:
176
+ st.error(f"An error occurred during query: {str(e)}")
177
 
178
  if __name__ == "__main__":
179
  main()