poemsforaphrodite commited on
Commit
28546e7
·
verified ·
1 Parent(s): 803c397

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +6 -0
  2. trainer.py +628 -0
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ streamlit
3
+ openai
4
+ pymongo
5
+ pinecone-client
6
+ uuid
trainer.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import streamlit as st
4
+ from streamlit.runtime.scriptrunner import RerunException, StopException
5
+ from openai import OpenAI
6
+ from pymongo import MongoClient
7
+ from pinecone import Pinecone
8
+ import uuid
9
+ from datetime import datetime
10
+ import time
11
+ from streamlit.runtime.caching import cache_data
12
+ from streamlit_autorefresh import st_autorefresh
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ # Configuration
18
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
+ MONGODB_URI = os.getenv("MONGODB_URI")
20
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
21
+ PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT")
22
+ PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME")
23
+ GLOBAL_MEMORY_ID = "global_common_memory_id" # Added GLOBAL_MEMORY_ID
24
+
25
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
26
+ mongo_client = MongoClient(MONGODB_URI)
27
+ db = mongo_client["Wall_Street"]
28
+ conversation_history = db["conversation_history"]
29
+ global_common_memory = db["global_common_memory"] # New global common memory collection
30
+
31
+ # Initialize GLOBAL_MEMORY_ID if it doesn't exist
32
+ if not global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID}):
33
+ global_common_memory.insert_one({
34
+ "memory_id": GLOBAL_MEMORY_ID,
35
+ "memory": []
36
+ })
37
+
38
+ # Initialize Pinecone
39
+ pc = Pinecone(api_key=PINECONE_API_KEY)
40
+ pinecone_index = pc.Index(PINECONE_INDEX_NAME)
41
+
42
+ # Set up Streamlit page configuration
43
+ st.set_page_config(page_title="GPT-Driven Chat System - Tester", page_icon="🔬", layout="wide")
44
+
45
+ # Custom CSS to improve the UI
46
+ st.markdown("""
47
+ <style>
48
+ /* Your custom CSS styles */
49
+ </style>
50
+ """, unsafe_allow_html=True)
51
+
52
+ # Initialize Streamlit session state
53
+ if 'chat_history' not in st.session_state:
54
+ st.session_state['chat_history'] = []
55
+ if 'user_type' not in st.session_state:
56
+ st.session_state['user_type'] = None
57
+ if 'session_id' not in st.session_state:
58
+ st.session_state['session_id'] = str(uuid.uuid4())
59
+
60
+ # --- Common Memory Functions ---
61
+
62
+ @cache_data(ttl=300) # Cache for 5 minutes
63
+ def get_global_common_memory():
64
+ """Retrieve the global common memory."""
65
+ memory_doc = global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID})
66
+ return memory_doc.get('memory', []) if memory_doc else []
67
+
68
+ def append_to_global_common_memory(content):
69
+ """Append content to the global common memory."""
70
+ try:
71
+ # First, ensure the document exists with an initialized memory array
72
+ global_common_memory.update_one(
73
+ {"memory_id": GLOBAL_MEMORY_ID},
74
+ {"$setOnInsert": {"memory": []}},
75
+ upsert=True
76
+ )
77
+
78
+ # Then, add the new content to the memory array
79
+ result = global_common_memory.update_one(
80
+ {"memory_id": GLOBAL_MEMORY_ID},
81
+ {"$push": {"memory": content}}
82
+ )
83
+
84
+ # Invalidate the cache after updating
85
+ get_global_common_memory.clear()
86
+
87
+ st.success("Memory appended successfully!")
88
+
89
+ # Instead of using st.rerun(), we'll set a flag in session state
90
+ st.session_state['memory_updated'] = True
91
+
92
+ except Exception as e:
93
+ st.error(f"Failed to append to global common memory: {str(e)}")
94
+
95
+ def clear_global_common_memory():
96
+ """Clear all items from the global common memory."""
97
+ try:
98
+ global_common_memory.update_one(
99
+ {"memory_id": GLOBAL_MEMORY_ID},
100
+ {"$set": {"memory": []}},
101
+ upsert=True
102
+ )
103
+ # Invalidate the cache after clearing
104
+ get_global_common_memory.clear()
105
+ st.success("Global common memory cleared successfully!")
106
+ except Exception as e:
107
+ st.error(f"Failed to clear global common memory: {str(e)}")
108
+
109
+ # --- Relevant Context Retrieval ---
110
+
111
+ @cache_data(ttl=60) # Cache for 1 minute
112
+ def get_relevant_context(query, top_k=3):
113
+ """
114
+ Retrieve relevant context from Pinecone based on the user query.
115
+ """
116
+ try:
117
+ query_embedding = openai_client.embeddings.create(
118
+ model="text-embedding-3-large", # Updated to use the larger model
119
+ input=query
120
+ ).data[0].embedding
121
+
122
+ results = pinecone_index.query(vector=query_embedding, top_k=top_k, include_metadata=True)
123
+ contexts = [item['metadata']['text'] for item in results['matches']]
124
+ return " ".join(contexts)
125
+ except Exception as e:
126
+ print(f"Error retrieving context: {str(e)}")
127
+ return ""
128
+
129
+ # --- GPT Response Function ---
130
+
131
+ def get_gpt_response(prompt, context="", system_message=None):
132
+ try:
133
+ common_memory = get_global_common_memory()
134
+ system_msg = (
135
+ "You are a helpful assistant. Use the following context and global common memory "
136
+ "to inform your responses, but don't mention them explicitly unless directly relevant to the user's question."
137
+ )
138
+
139
+ if system_message:
140
+ system_msg += f"\n\nTrainer Instructions:\n{system_message}"
141
+
142
+ if common_memory:
143
+ memory_str = "\n".join(common_memory)
144
+ system_msg += f"\n\nGlobal Common Memory:\n{memory_str}"
145
+
146
+ messages = [
147
+ {"role": "system", "content": system_msg},
148
+ {"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"}
149
+ ]
150
+
151
+ completion = openai_client.chat.completions.create(
152
+ model="gpt-4o-mini",
153
+ messages=messages
154
+ )
155
+ response = completion.choices[0].message.content.strip()
156
+
157
+ return response
158
+ except Exception as e:
159
+ st.error(f"Error generating response: {str(e)}")
160
+ return None
161
+
162
+ # --- Send User Message ---
163
+ def send_message(message):
164
+ """
165
+ Sends a user message. If admin takeover is active, messages are sent to admin instead of GPT.
166
+ """
167
+ context = get_relevant_context(message)
168
+
169
+ user_message = {
170
+ "role": "user",
171
+ "content": message,
172
+ "timestamp": datetime.utcnow(),
173
+ "status": "approved" # User messages are always approved
174
+ }
175
+
176
+ # Upsert the user message immediately
177
+ result = conversation_history.update_one(
178
+ {"session_id": st.session_state['session_id']},
179
+ {
180
+ "$push": {"messages": user_message},
181
+ "$set": {"last_updated": datetime.utcnow()},
182
+ "$setOnInsert": {"created_at": datetime.utcnow()}
183
+ },
184
+ upsert=True
185
+ )
186
+
187
+ # Update the session state with the user message
188
+ st.session_state['chat_history'].append(user_message)
189
+
190
+ if not st.session_state.get('admin_takeover_active'):
191
+ # Generate GPT response if takeover is not active
192
+ gpt_response = get_gpt_response(message, context)
193
+
194
+ assistant_message = {
195
+ "role": "assistant",
196
+ "content": gpt_response,
197
+ "timestamp": datetime.utcnow(),
198
+ "status": "pending" # Set status to pending for admin approval
199
+ }
200
+
201
+ # Upsert the assistant message
202
+ result = conversation_history.update_one(
203
+ {"session_id": st.session_state['session_id']},
204
+ {
205
+ "$push": {"messages": assistant_message},
206
+ "$set": {"last_updated": datetime.utcnow()}
207
+ }
208
+ )
209
+
210
+ # Update the session state with the assistant message
211
+ st.session_state['chat_history'].append(assistant_message)
212
+
213
+ # --- Send Admin Message ---
214
+
215
+ def send_admin_message(message):
216
+ """
217
+ Sends an admin message directly to the user during a takeover.
218
+ """
219
+ admin_message = {
220
+ "role": "admin",
221
+ "content": message,
222
+ "timestamp": datetime.utcnow(),
223
+ "status": "approved"
224
+ }
225
+
226
+ # Upsert the admin message
227
+ result = conversation_history.update_one(
228
+ {"session_id": st.session_state['session_id']},
229
+ {
230
+ "$push": {"messages": admin_message},
231
+ "$set": {"last_updated": datetime.utcnow()}
232
+ }
233
+ )
234
+
235
+ # Update the session state with the admin message
236
+ st.session_state['chat_history'].append(admin_message)
237
+
238
+ # --- Takeover Functions ---
239
+
240
+ def activate_takeover(session_id):
241
+ """
242
+ Activates takeover mode for the given session.
243
+ """
244
+ try:
245
+ db.takeover_status.update_one(
246
+ {"session_id": session_id},
247
+ {"$set": {"active": True, "activated_at": datetime.utcnow()}},
248
+ upsert=True
249
+ )
250
+ st.success(f"Takeover activated for session {session_id[:8]}...")
251
+ except Exception as e:
252
+ st.error(f"Failed to activate takeover: {str(e)}")
253
+
254
+ def deactivate_takeover(session_id):
255
+ """
256
+ Deactivates takeover mode for the given session.
257
+ """
258
+ try:
259
+ db.takeover_status.update_one(
260
+ {"session_id": session_id},
261
+ {"$set": {"active": False}},
262
+ )
263
+ st.success(f"Takeover deactivated for session {session_id[:8]}...")
264
+ except Exception as e:
265
+ st.error(f"Failed to deactivate takeover: {str(e)}")
266
+
267
+ def handle_admin_takeover(session_id):
268
+ st.subheader("Admin Takeover")
269
+
270
+ takeover_active = db.takeover_status.find_one({"session_id": session_id})
271
+ is_active = takeover_active.get("active", False) if takeover_active else False
272
+
273
+ if is_active:
274
+ st.info("Takeover is currently active for this session.")
275
+ if st.button("Deactivate Takeover"):
276
+ deactivate_takeover(session_id)
277
+ st.success("Takeover deactivated.")
278
+ st.rerun()
279
+ else:
280
+ st.warning("Takeover is not active for this session.")
281
+ if st.button("Activate Takeover"):
282
+ activate_takeover(session_id)
283
+ st.success("Takeover activated.")
284
+ st.rerun()
285
+
286
+ if is_active:
287
+ admin_message = st.text_area("Send Message to User", key="admin_message")
288
+ if st.button("Send Admin Message"):
289
+ admin_message = st.session_state.get("admin_message", "")
290
+ if admin_message.strip():
291
+ send_admin_message(admin_message.strip())
292
+ st.success("Admin message sent successfully!")
293
+ st.session_state["admin_message"] = ""
294
+ else:
295
+ st.warning("Please enter a message to send.")
296
+
297
+ # --- View Full Chat (User Perspective) ---
298
+
299
+ def view_full_chat(session_id):
300
+ st.title(f"Full Chat View - Session: {session_id[:8]}...")
301
+
302
+ chat = db.chat_history.find_one({"session_id": session_id})
303
+
304
+ if not chat:
305
+ st.error("Chat not found.")
306
+ return
307
+
308
+ col1, col2 = st.columns([2, 1])
309
+ with col1:
310
+ st.subheader(f"Session ID: {session_id}")
311
+ with col2:
312
+ st.write(f"Last Updated: {chat.get('last_updated', 'N/A')}")
313
+
314
+ st.markdown("---")
315
+
316
+ for message in chat.get('messages', []):
317
+ role = message['role'].capitalize()
318
+ content = message['content']
319
+ timestamp = message.get('timestamp', 'N/A')
320
+
321
+ if role == 'User':
322
+ with st.chat_message("user"):
323
+ st.markdown(f"**User** - {timestamp}")
324
+ st.markdown(content)
325
+ elif role == 'Assistant':
326
+ with st.chat_message("assistant"):
327
+ st.markdown(f"**Assistant** - {timestamp}")
328
+ st.markdown(content)
329
+ elif role == 'Admin':
330
+ with st.chat_message("human"):
331
+ st.markdown(f"**Admin** - {timestamp}")
332
+ st.markdown(content)
333
+
334
+ st.markdown("---")
335
+
336
+ # Add text box to append to global memory
337
+ st.subheader("Add to Global Memory")
338
+ new_memory = st.text_area("Enter new memory item", key=f"new_memory_input_{session_id}")
339
+ if st.button("Add Memory", key=f"add_memory_button_{session_id}"):
340
+ if new_memory.strip():
341
+ append_to_global_common_memory(new_memory.strip())
342
+ st.success("New memory item added to global memory!")
343
+ # Instead of rerunning, we'll update the session state
344
+ st.session_state[f'memory_added_{session_id}'] = True
345
+ st.rerun()
346
+ else:
347
+ st.warning("Please enter a valid memory item.")
348
+
349
+ # Display success message if memory was added
350
+ if st.session_state.get(f'memory_added_{session_id}'):
351
+ st.success("Memory item added successfully!")
352
+ # Clear the flag
353
+ del st.session_state[f'memory_added_{session_id}']
354
+
355
+ st.markdown("---")
356
+ col1, col2, col3 = st.columns([1, 1, 1])
357
+ with col2:
358
+ if st.button("Back to Chat History", use_container_width=True):
359
+ st.session_state.pop('full_chat_view', None)
360
+ st.rerun()
361
+
362
+ # --- Clear Global Chat Memory---
363
+
364
+ def clear_global_common_memory():
365
+ """Clear all items from the global common memory."""
366
+ try:
367
+ global_common_memory.update_one(
368
+ {"memory_id": GLOBAL_MEMORY_ID},
369
+ {"$set": {"memory": []}},
370
+ upsert=True
371
+ )
372
+ # Invalidate the cache after clearing
373
+ get_global_common_memory.clear()
374
+ st.success("Global common memory cleared successfully!")
375
+ except Exception as e:
376
+ st.error(f"Failed to clear global common memory: {str(e)}")
377
+
378
+ def display_chat_history():
379
+ st.subheader("All Chat History")
380
+
381
+ all_chats = list(db.chat_history.find().sort("last_updated", -1))
382
+
383
+ if not all_chats:
384
+ st.info("No chat history found.")
385
+ return
386
+
387
+ for idx, chat in enumerate(all_chats):
388
+ session_id = chat['session_id']
389
+ last_updated = chat.get('last_updated', 'N/A')
390
+
391
+ with st.expander(f"Session: {session_id[:8]}... - Last Updated: {last_updated}"):
392
+ if chat.get('messages'):
393
+ last_message = chat['messages'][-1]
394
+ st.markdown(f"Last message ({last_message['role'].capitalize()}):")
395
+ st.markdown(f"> {last_message['content'][:100]}...")
396
+
397
+ if st.button(f"Show Full Chat", key=f"show_full_chat_{idx}"):
398
+ st.session_state['full_chat_view'] = session_id
399
+ st.rerun()
400
+
401
+ def trainer_intervention_tab():
402
+ st.subheader("Trainer Intervention")
403
+
404
+ # Handle admin intervention
405
+ handle_admin_intervention()
406
+
407
+ def handle_admin_intervention():
408
+ st.subheader("Review Pending Responses")
409
+ pending_responses = conversation_history.find(
410
+ {"messages.role": "assistant", "messages.status": "pending"}
411
+ )
412
+
413
+ for conversation in pending_responses:
414
+ st.write(f"Session ID: {conversation['session_id'][:8]}...")
415
+
416
+ for i, message in enumerate(conversation['messages']):
417
+ if message['role'] == 'assistant' and message.get('status') == 'pending':
418
+ user_message = conversation['messages'][i-1]['content'] if i > 0 else "N/A"
419
+ st.write(f"**User:** {user_message}")
420
+ st.write(f"**GPT:** {message['content']}")
421
+
422
+ col1, col2, col3, col4 = st.columns(4)
423
+ with col1:
424
+ if st.button("Approve", key=f"approve_{conversation['session_id']}_{i}"):
425
+ if approve_response(conversation['session_id'], i):
426
+ st.success("Response approved")
427
+ time.sleep(0.5)
428
+ st.rerun()
429
+ with col2:
430
+ if st.button("Modify", key=f"modify_{conversation['session_id']}_{i}"):
431
+ st.session_state['modifying'] = (conversation['session_id'], i)
432
+ st.rerun()
433
+ with col3:
434
+ if st.button("Regenerate", key=f"regenerate_{conversation['session_id']}_{i}"):
435
+ st.session_state['regenerating'] = (conversation['session_id'], i)
436
+ st.rerun()
437
+ with col4:
438
+ takeover_doc = db.takeover_status.find_one({"session_id": conversation['session_id']})
439
+ takeover_active = takeover_doc.get("active", False) if takeover_doc else False
440
+ if takeover_active:
441
+ if st.button("Deactivate Takeover", key=f"deactivate_takeover_{conversation['session_id']}_{i}"):
442
+ deactivate_takeover(conversation['session_id'])
443
+ st.success("Takeover deactivated.")
444
+ st.rerun()
445
+ else:
446
+ if st.button("Activate Takeover", key=f"activate_takeover_{conversation['session_id']}_{i}"):
447
+ activate_takeover(conversation['session_id'])
448
+ st.success("Takeover activated.")
449
+ st.rerun()
450
+
451
+ st.divider()
452
+
453
+ if 'regenerating' in st.session_state:
454
+ session_id, message_index = st.session_state['regenerating']
455
+ with st.form(key="regenerate_form"):
456
+ operator_input = st.text_input("Enter additional instructions for regeneration:")
457
+ submit_button = st.form_submit_button("Submit")
458
+
459
+ if submit_button:
460
+ del st.session_state['regenerating']
461
+ regenerate_response(session_id, message_index, operator_input)
462
+ st.success("Response regenerated with operator input.")
463
+ st.rerun()
464
+
465
+ if 'modifying' in st.session_state:
466
+ session_id, message_index = st.session_state['modifying']
467
+ conversation = conversation_history.find_one({"session_id": session_id})
468
+ message = conversation['messages'][message_index]
469
+
470
+ modified_content = st.text_area("Modify the response:", value=message['content'])
471
+ if st.button("Save Modification"):
472
+ save_modified_response(session_id, message_index, modified_content)
473
+ st.success("Response modified and approved")
474
+ del st.session_state['modifying']
475
+ st.rerun()
476
+
477
+ def approve_response(session_id, message_index):
478
+ try:
479
+ result = conversation_history.update_one(
480
+ {"session_id": session_id},
481
+ {"$set": {f"messages.{message_index}.status": "approved"}}
482
+ )
483
+ return result.modified_count > 0
484
+ except Exception as e:
485
+ st.error(f"Failed to approve response: {str(e)}")
486
+ return False
487
+
488
+ def save_modified_response(session_id, message_index, modified_content):
489
+ try:
490
+ conversation_history.update_one(
491
+ {"session_id": session_id},
492
+ {
493
+ "$set": {
494
+ f"messages.{message_index}.content": modified_content,
495
+ f"messages.{message_index}.status": "approved"
496
+ }
497
+ }
498
+ )
499
+ except Exception as e:
500
+ st.error(f"Failed to save modified response: {str(e)}")
501
+
502
+ def regenerate_response(session_id, message_index, operator_input):
503
+ try:
504
+ conversation = conversation_history.find_one({"session_id": session_id})
505
+ user_message = conversation['messages'][message_index - 1]['content'] if message_index > 0 else ""
506
+ new_response = get_gpt_response(user_message, system_message=operator_input)
507
+
508
+ conversation_history.update_one(
509
+ {"session_id": session_id},
510
+ {
511
+ "$set": {
512
+ f"messages.{message_index}.content": new_response,
513
+ f"messages.{message_index}.status": "pending"
514
+ }
515
+ }
516
+ )
517
+ except Exception as e:
518
+ st.error(f"Failed to regenerate response: {str(e)}")
519
+
520
+ def trainer_page():
521
+ st.title("Trainer Dashboard")
522
+
523
+ # Add auto-refresh every 10 seconds (10000 milliseconds)
524
+ st_autorefresh(interval=10000, limit=None, key="trainer_autorefresh")
525
+
526
+ tab1, tab2, tab3 = st.tabs(["Current Status", "Chat History", "Intervention"])
527
+
528
+ with tab1:
529
+ # Display current global memory
530
+ st.subheader("Current Global Memory")
531
+ global_memory = get_global_common_memory()
532
+ if global_memory:
533
+ for idx, item in enumerate(global_memory, 1):
534
+ st.text(f"{idx}. {item}")
535
+ else:
536
+ st.info("No global memory items found.")
537
+
538
+ # Add button to clear global memory
539
+ if st.button("Clear Global Memory", key="clear_global_memory"):
540
+ clear_global_common_memory()
541
+ st.success("Global memory cleared successfully!")
542
+ time.sleep(1)
543
+ st.rerun()
544
+
545
+ # Display current chats
546
+ st.subheader("Active Chats")
547
+ chats = list(conversation_history.find().sort("last_updated", -1).limit(5))
548
+ for idx, chat in enumerate(chats):
549
+ with st.expander(f"Session: {chat['session_id'][:8]}... - Last Updated: {chat.get('last_updated', 'N/A')}"):
550
+ for message in chat.get('messages', [])[-5:]:
551
+ role = message['role'].capitalize()
552
+ content = message['content']
553
+ st.markdown(f"**{role}:** {content}")
554
+
555
+ col1, col2, col3, col4 = st.columns(4)
556
+ with col1:
557
+ if st.button(f"View Full Chat", key=f"view_chat_{idx}"):
558
+ st.session_state['selected_chat'] = chat['session_id']
559
+ st.rerun()
560
+ with col2:
561
+ takeover_doc = db.takeover_status.find_one({"session_id": chat['session_id']})
562
+ takeover_active = takeover_doc.get("active", False) if takeover_doc else False
563
+ if takeover_active:
564
+ if st.button(f"Deactivate Takeover", key=f"deactivate_takeover_{idx}"):
565
+ deactivate_takeover(chat['session_id'])
566
+ st.success("Takeover deactivated.")
567
+ st.rerun()
568
+ else:
569
+ if st.button(f"Activate Takeover", key=f"activate_takeover_{idx}"):
570
+ activate_takeover(chat['session_id'])
571
+ st.success("Takeover activated.")
572
+ st.rerun()
573
+ with col3:
574
+ if st.button(f"Delete Chat", key=f"delete_chat_{idx}"):
575
+ delete_chat(chat['session_id'])
576
+ st.success(f"Chat {chat['session_id'][:8]}... deleted.")
577
+ st.rerun()
578
+ with col4:
579
+ if takeover_active:
580
+ st.text_input("Send message", key=f"takeover_message_{idx}")
581
+ if st.button("Send", key=f"send_takeover_{idx}"):
582
+ message = st.session_state[f"takeover_message_{idx}"]
583
+ if message.strip():
584
+ send_admin_message(chat['session_id'], message.strip())
585
+ st.success("Message sent.")
586
+ st.rerun()
587
+ else:
588
+ st.warning("Please enter a message to send.")
589
+
590
+ # Manual refresh button
591
+ if st.button("Refresh", key="refresh_button"):
592
+ st.rerun()
593
+
594
+ with tab2:
595
+ display_chat_history()
596
+
597
+ with tab3:
598
+ trainer_intervention_tab()
599
+
600
+ def delete_chat(session_id):
601
+ try:
602
+ result = conversation_history.delete_one({"session_id": session_id})
603
+ if result.deleted_count == 0:
604
+ st.error("Failed to delete chat. Please try again.")
605
+ except Exception as e:
606
+ st.error(f"Error deleting chat: {str(e)}")
607
+
608
+ # --- Main Function ---
609
+
610
+ def main():
611
+ try:
612
+ if 'memory_updated' in st.session_state:
613
+ del st.session_state['memory_updated']
614
+ st.rerun()
615
+
616
+ if 'full_chat_view' in st.session_state:
617
+ view_full_chat(st.session_state['full_chat_view'])
618
+ elif 'selected_chat' in st.session_state:
619
+ view_full_chat(st.session_state['selected_chat'])
620
+ else:
621
+ trainer_page()
622
+ except (RerunException, StopException):
623
+ raise
624
+ except Exception as e:
625
+ st.error(f"An unexpected error occurred: {str(e)}")
626
+
627
+ if __name__ == "__main__":
628
+ main()