poemsforaphrodite commited on
Commit
b26ef8e
1 Parent(s): 343cdcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +614 -190
app.py CHANGED
@@ -14,6 +14,23 @@ from pinecone import Pinecone, ServerlessSpec
14
  import threading # {{ edit_25: Import threading for background processing }}
15
  import tiktoken
16
  from tiktoken.core import Encoding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Set page configuration to wide mode
19
  st.set_page_config(layout="wide")
@@ -28,8 +45,11 @@ db = mongo_client['llm_evaluation_system']
28
  users_collection = db['users']
29
  results_collection = db['evaluation_results']
30
 
31
- # Initialize OpenAI client
32
- openai_client = OpenAI() # {{ edit_12: Rename OpenAI client to 'openai_client' }}
 
 
 
33
 
34
  # Initialize Pinecone
35
  pinecone_client = Pinecone(api_key=os.getenv('PINECONE_API_KEY')) # {{ edit_13: Initialize Pinecone client using Pinecone class }}
@@ -97,11 +117,30 @@ def generate_response(prompt, context):
97
  except Exception as e:
98
  st.error(f"Error generating response: {str(e)}")
99
  return None
 
 
 
 
 
 
 
 
100
 
101
  # Function to clear the results database
102
- def clear_results_database():
103
  try:
104
- results_collection.delete_many({})
 
 
 
 
 
 
 
 
 
 
 
105
  return True
106
  except Exception as e:
107
  st.error(f"Error clearing results database: {str(e)}")
@@ -228,6 +267,62 @@ def save_results(username, model, prompt, context, response, evaluation): # {{
228
  }
229
  results_collection.insert_one(result)
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # Function for teacher model evaluation
232
  def teacher_evaluate(prompt, context, response):
233
  try:
@@ -236,8 +331,8 @@ def teacher_evaluate(prompt, context, response):
236
  Rate each factor on a scale of 0 to 1, where 1 is the best (or least problematic for negative factors like Hallucination and Bias).
237
  Please provide scores with two decimal places, and avoid extreme scores of exactly 0 or 1 unless absolutely necessary.
238
 
239
- Prompt: {prompt}
240
  Context: {context}
 
241
  Response: {response}
242
 
243
  Factors to evaluate:
@@ -255,7 +350,7 @@ def teacher_evaluate(prompt, context, response):
255
  """
256
 
257
  evaluation_response = openai_client.chat.completions.create(
258
- model="gpt-4o-mini", # Corrected model name
259
  messages=[
260
  {"role": "system", "content": "You are an expert evaluator of language model responses."},
261
  {"role": "user", "content": evaluation_prompt}
@@ -328,14 +423,9 @@ else:
328
  st.sidebar.success(f"Welcome, {st.session_state.user}!")
329
  if st.sidebar.button("Logout"):
330
  st.session_state.user = None
331
- st.experimental_rerun()
 
332
 
333
- # Add Clear Results Database button
334
- if st.sidebar.button("Clear Results Database"):
335
- if clear_results_database(): # {{ edit_fix: Calling the newly defined clear_results_database function }}
336
- st.sidebar.success("Results database cleared successfully!")
337
- else:
338
- st.sidebar.error("Failed to clear results database.")
339
 
340
  # App content
341
  if st.session_state.user:
@@ -355,9 +445,23 @@ if st.session_state.user:
355
  if user_models:
356
  model_options = [model['model_name'] if model['model_name'] else model['model_id'] for model in user_models]
357
  selected_model = st.selectbox("Select Model to View Metrics", ["All Models"] + model_options)
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  else:
359
  st.error("You have no uploaded models.")
360
  selected_model = "All Models"
 
361
 
362
  try:
363
  query = {"username": st.session_state.user}
@@ -369,21 +473,81 @@ if st.session_state.user:
369
  if results:
370
  df = pd.DataFrame(results)
371
 
372
- # Count tokens for prompt, context, and response
373
- df['prompt_tokens'] = df['prompt'].apply(count_tokens)
374
- df['context_tokens'] = df['context'].apply(count_tokens)
375
- df['response_tokens'] = df['response'].apply(count_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  # Calculate total tokens for each row
378
  df['total_tokens'] = df['prompt_tokens'] + df['context_tokens'] + df['response_tokens']
379
 
 
380
  metrics = ["Accuracy", "Hallucination", "Groundedness", "Relevance", "Recall", "Precision", "Consistency", "Bias Detection"]
381
  for metric in metrics:
382
- df[metric] = df['evaluation'].apply(lambda x: x.get(metric, {}).get('score', 0) if x else 0) * 100
383
 
384
  df['timestamp'] = pd.to_datetime(df['timestamp'])
385
  df['query_number'] = range(1, len(df) + 1) # Add query numbers
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  @st.cache_data
388
  def create_metrics_graph(df, metrics):
389
  fig = px.line(
@@ -415,8 +579,7 @@ if st.session_state.user:
415
 
416
  # Latest Metrics
417
  st.subheader("Latest Metrics")
418
- latest_result = df.iloc[-1] # Get the last row (most recent query)
419
- latest_metrics = {metric: latest_result[metric] for metric in metrics}
420
 
421
  cols = st.columns(4)
422
  for i, (metric, value) in enumerate(latest_metrics.items()):
@@ -425,6 +588,9 @@ if st.session_state.user:
425
  st.metric(label=metric, value=f"{value:.2f}%", delta=None)
426
  st.progress(value / 100)
427
 
 
 
 
428
  # Detailed Data View
429
  st.subheader("Detailed Data View")
430
 
@@ -442,14 +608,15 @@ if st.session_state.user:
442
  # Prepare the data for display
443
  display_data = []
444
  for _, row in df.iterrows():
 
445
  display_row = {
446
- "Prompt": row['prompt'][:50] + "...", # Truncate long prompts
447
- "Context": row['context'][:50] + "...", # Truncate long contexts
448
- "Response": row['response'][:50] + "...", # Truncate long responses
449
  }
450
  # Add metrics to the display row
451
  for metric in metrics:
452
- display_row[metric] = row[metric] # Store as float, not string
453
 
454
  display_data.append(display_row)
455
 
@@ -490,20 +657,309 @@ if st.session_state.user:
490
  height=400 # Set a fixed height with scrolling
491
  )
492
 
493
- # Placeholders for future sections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  st.subheader("Worst Performing Slice Analysis")
495
- st.info("This section will show analysis of the worst-performing data slices.")
496
-
497
- st.subheader("UMAP Visualization")
498
- st.info("This section will contain UMAP visualizations for dimensionality reduction insights.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  else:
500
  st.info("No evaluation results available for the selected model.")
501
  except Exception as e:
502
- st.error(f"Error fetching data from database: {e}")
503
  st.error("Detailed error information:")
504
- st.error(str(e))
505
- import traceback
506
  st.error(traceback.format_exc())
 
507
 
508
  elif app_mode == "Model Upload":
509
  st.title("Upload Your Model")
@@ -562,7 +1018,6 @@ if st.session_state.user:
562
  elif app_mode == "Prompt Testing":
563
  st.title("Prompt Testing")
564
 
565
- # {{ edit_6: Use model_name instead of model_id }}
566
  model_selection_option = st.radio("Select Model Option:", ["Choose Existing Model", "Add New Model"])
567
 
568
  if model_selection_option == "Choose Existing Model":
@@ -572,136 +1027,94 @@ if st.session_state.user:
572
  if not user_models:
573
  st.error("You have no uploaded models. Please upload a model first.")
574
  else:
575
- # Display model_name instead of model_id
576
- model_name = st.selectbox("Select a Model for Testing", [model['model_name'] if model['model_name'] else model['model_id'] for model in user_models])
 
 
 
 
 
 
577
  else:
578
- # Option to enter model name or upload a link
579
- new_model_option = st.radio("Add Model By:", ["Enter Model Name", "Upload Model Link"])
580
-
581
- if new_model_option == "Enter Model Name":
582
- model_name_input = st.text_input("Enter New Model Name:")
583
- if st.button("Save Model Name"):
584
- if model_name_input:
585
- # {{ edit_3: Save the new model name to user's models }}
586
- model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
587
- users_collection.update_one(
588
- {"username": st.session_state.user},
589
- {"$push": {"models": {
590
- "model_id": model_id,
591
- "model_name": model_name_input,
592
- "file_path": None,
593
- "model_link": None,
594
- "uploaded_at": datetime.now()
595
- }}}
596
- )
597
- st.success(f"Model '{model_name_input}' saved successfully as {model_id}!")
598
- model_name = model_name_input # Use model_name instead of model_id
599
- else:
600
- st.error("Please enter a valid model name.")
601
- else:
602
- model_link = st.text_input("Enter Model Link:")
603
- if st.button("Save Model Link"):
604
- if model_link:
605
- # {{ edit_4: Save the model link to user's models }}
606
- model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
607
- users_collection.update_one(
608
- {"username": st.session_state.user},
609
- {"$push": {"models": {
610
- "model_id": model_id,
611
- "model_name": None,
612
- "file_path": None,
613
- "model_link": model_link,
614
- "uploaded_at": datetime.now()
615
- }}}
616
- )
617
- st.success(f"Model link saved successfully as {model_id}!")
618
- model_name = model_id # Use model_id if model_name is not available
619
- else:
620
- st.error("Please enter a valid model link.")
621
-
622
- # Two ways to provide prompts
623
- prompt_input_method = st.radio("Choose prompt input method:", ["Single JSON", "Batch Upload"])
624
 
625
- if prompt_input_method == "Single JSON":
626
- json_input = st.text_area("Enter your JSON input:")
627
- if json_input:
 
 
 
628
  try:
629
- data = json.loads(json_input)
630
- st.success("JSON parsed successfully!")
631
 
632
- # Display JSON in a table format
633
- st.subheader("Input Data")
634
- df = pd.json_normalize(data)
635
- st.table(df.style.set_properties(**{
636
- 'background-color': '#f0f8ff',
637
- 'color': '#333',
638
- 'border': '1px solid #ddd'
639
- }).set_table_styles([
640
- {'selector': 'th', 'props': [('background-color', '#4CAF50'), ('color', 'white')]},
641
- {'selector': 'td', 'props': [('text-align', 'left')]}
642
- ]))
643
- except json.JSONDecodeError:
644
- st.error("Invalid JSON. Please check your input.")
645
- else:
646
- uploaded_file = st.file_uploader("Upload a JSON file with prompts, contexts, and responses", type="json")
647
- if uploaded_file is not None:
648
- try:
649
- data = json.load(uploaded_file)
650
- st.success("JSON file loaded successfully!")
651
 
652
- # Display JSON in a table format
653
- st.subheader("Input Data")
654
- df = pd.json_normalize(data)
655
- st.table(df.style.set_properties(**{
656
- 'background-color': '#f0f8ff',
657
- 'color': '#333',
658
- 'border': '1px solid #ddd'
659
- }).set_table_styles([
660
- {'selector': 'th', 'props': [('background-color', '#4CAF50'), ('color', 'white')]},
661
- {'selector': 'td', 'props': [('text-align', 'left')]}
662
- ]))
663
  except json.JSONDecodeError:
664
- st.error("Invalid JSON file. Please check your file contents.")
665
-
666
- # Function to handle background evaluation
667
- def run_evaluations(data, selected_model, username): # {{ edit_30: Add 'username' parameter }}
668
- if isinstance(data, list):
669
- for item in data:
670
- if 'response' not in item:
671
- item['response'] = generate_response(item['prompt'], item['context'])
672
- evaluation = teacher_evaluate(item['prompt'], item['context'], item['response'])
673
- save_results(username, selected_model, item['prompt'], item['context'], item['response'], evaluation) # {{ edit_31: Pass 'username' to save_results }}
674
- # Optionally, update completed prompts in session_state or another mechanism
675
  else:
676
- if 'response' not in data:
677
- data['response'] = generate_response(data['prompt'], data['context'])
678
- evaluation = teacher_evaluate(data['prompt'], data['context'], data['response'])
679
- save_results(username, selected_model, data['prompt'], data['context'], data['response'], evaluation) # {{ edit_32: Pass 'username' to save_results }}
680
- # Optionally, update completed prompts in session_state or another mechanism
 
681
 
682
- # In the Prompt Testing section
 
 
 
 
 
 
 
 
 
 
683
  if st.button("Run Test"):
684
  if not model_name:
685
  st.error("Please select or add a valid Model.")
686
- elif not data:
687
- st.error("Please provide valid JSON input.")
 
 
688
  else:
689
- # {{ edit_28: Define 'selected_model' based on 'model_name' }}
690
- selected_model = next(
691
- (m for m in user_models if (m['model_name'] == model_name) or (m['model_id'] == model_name)),
692
- None
693
- )
694
- if selected_model:
695
- with st.spinner("Starting evaluations in the background..."):
696
- evaluation_thread = threading.Thread(
697
- target=run_evaluations,
698
- args=(data, selected_model, st.session_state.user) # {{ edit_33: Pass 'username' to the thread }}
699
- )
700
- evaluation_thread.start()
701
- st.success("Evaluations are running in the background. You can navigate away or close the site.")
702
- # {{ edit_34: Optionally, track running evaluations in session_state }}
703
- else:
704
- st.error("Selected model not found.")
 
 
 
 
 
 
 
 
705
 
706
  elif app_mode == "Manage Models":
707
  st.title("Manage Your Models")
@@ -712,46 +1125,58 @@ if st.session_state.user:
712
  st.stop()
713
  user_models = user.get("models", [])
714
 
715
- # {{ edit_1: Add option to add a new model }}
 
 
 
 
 
 
 
 
716
  st.subheader("Add a New Model")
717
- add_model_option = st.radio("Add Model By:", ["Enter Model Name", "Upload Model Link"])
718
 
719
- if add_model_option == "Enter Model Name":
720
  new_model_name = st.text_input("Enter New Model Name:")
721
- if st.button("Add Model Name"):
722
- if new_model_name:
723
- model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
724
- users_collection.update_one(
725
- {"username": st.session_state.user},
726
- {"$push": {"models": {
727
- "model_id": model_id,
728
- "model_name": new_model_name,
729
- "file_path": None,
730
- "model_link": None,
731
- "uploaded_at": datetime.now()
732
- }}}
733
- )
734
- st.success(f"Model '{new_model_name}' added successfully as {model_id}!")
735
- else:
736
- st.error("Please enter a valid model name.")
737
- else:
738
- new_model_link = st.text_input("Enter Model Link:")
739
- if st.button("Add Model Link"):
740
- if new_model_link:
741
  model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
 
 
 
 
 
 
 
 
 
742
  users_collection.update_one(
743
  {"username": st.session_state.user},
744
- {"$push": {"models": {
745
- "model_id": model_id,
746
- "model_name": None,
747
- "file_path": None,
748
- "model_link": new_model_link,
749
- "uploaded_at": datetime.now()
750
- }}}
751
  )
752
- st.success(f"Model link added successfully as {model_id}!")
753
  else:
754
- st.error("Please enter a valid model link.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
  st.markdown("---")
757
 
@@ -759,11 +1184,9 @@ if st.session_state.user:
759
  st.subheader("Your Models")
760
  for model in user_models:
761
  st.markdown(f"**Model ID:** {model['model_id']}")
762
- st.write(f"**Model Type:** {model.get('model_type', 'custom').capitalize()}") # {{ edit_14: Handle missing 'model_type' with default 'custom' }}
763
  if model.get("model_name"):
764
  st.write(f"**Model Name:** {model['model_name']}")
765
- if model.get("model_link"):
766
- st.write(f"**Model Link:** [Link]({model['model_link']})")
767
  if model.get("file_path"):
768
  st.write(f"**File Path:** {model['file_path']}")
769
  st.write(f"**Uploaded at:** {model['uploaded_at']}")
@@ -794,6 +1217,9 @@ if st.session_state.user:
794
  # Convert results to a pandas DataFrame
795
  df = pd.DataFrame(user_results)
796
 
 
 
 
797
  # Normalize the evaluation JSON into separate columns
798
  eval_df = df['evaluation'].apply(pd.Series)
799
  for metric in ["Accuracy", "Hallucination", "Groundedness", "Relevance", "Recall", "Precision", "Consistency", "Bias Detection"]:
@@ -842,7 +1268,7 @@ if st.session_state.user:
842
  'border': '1px solid #ddd'
843
  }).set_table_styles([
844
  {'selector': 'th', 'props': [('background-color', '#f5f5f5'), ('text-align', 'center')]},
845
- {'selector': 'td', 'props': [('text-align', 'center'), ('vertical-align', 'top')]}
846
  ]).format({
847
  "Accuracy (%)": "{:.2f}",
848
  "Hallucination (%)": "{:.2f}",
@@ -863,6 +1289,4 @@ if st.session_state.user:
863
 
864
  # Add a footer
865
  st.sidebar.markdown("---")
866
- st.sidebar.info("LLM Evaluation System - v0.2")
867
-
868
- # Function to handle model upload (placeholder)
 
14
  import threading # {{ edit_25: Import threading for background processing }}
15
  import tiktoken
16
  from tiktoken.core import Encoding
17
+ from runner import run_model
18
+ from bson.objectid import ObjectId
19
+ import traceback # Add this import at the top of your file
20
+ import umap
21
+ import plotly.graph_objs as go
22
+ from sklearn.preprocessing import StandardScaler
23
+ from sklearn.cluster import KMeans
24
+ import plotly.colors as plc
25
+
26
+ # Add this helper function at the beginning of your file
27
+ def extract_prompt_text(prompt):
28
+ if isinstance(prompt, dict):
29
+ return prompt.get('prompt', '')
30
+ elif isinstance(prompt, str):
31
+ return prompt
32
+ else:
33
+ return str(prompt)
34
 
35
  # Set page configuration to wide mode
36
  st.set_page_config(layout="wide")
 
45
  users_collection = db['users']
46
  results_collection = db['evaluation_results']
47
 
48
+ # Remove or comment out this line if it exists
49
+ # openai_client = OpenAI()
50
+
51
+ # Instead, use the openai_client from runner.py
52
+ from runner import openai_client
53
 
54
  # Initialize Pinecone
55
  pinecone_client = Pinecone(api_key=os.getenv('PINECONE_API_KEY')) # {{ edit_13: Initialize Pinecone client using Pinecone class }}
 
117
  except Exception as e:
118
  st.error(f"Error generating response: {str(e)}")
119
  return None
120
+
121
+ # Add this function to update the context for a model
122
+ def update_model_context(username, model_id, context):
123
+ users_collection.update_one(
124
+ {"username": username, "models.model_id": model_id},
125
+ {"$set": {"models.$.context": context}}
126
+ )
127
+
128
 
129
  # Function to clear the results database
130
+ def clear_results_database(username, model_identifier=None):
131
  try:
132
+ if model_identifier:
133
+ # Clear results for the specific model
134
+ results_collection.delete_many({
135
+ "username": username,
136
+ "$or": [
137
+ {"model_name": model_identifier},
138
+ {"model_id": model_identifier}
139
+ ]
140
+ })
141
+ else:
142
+ # Clear all results for the user
143
+ results_collection.delete_many({"username": username})
144
  return True
145
  except Exception as e:
146
  st.error(f"Error clearing results database: {str(e)}")
 
267
  }
268
  results_collection.insert_one(result)
269
 
270
+ # Modify the run_custom_evaluations function
271
+ def run_custom_evaluations(data, selected_model, username):
272
+ try:
273
+ model_name = selected_model['model_name']
274
+ model_id = selected_model['model_id']
275
+ model_type = selected_model.get('model_type', 'Unknown').lower()
276
+
277
+ if model_type == 'simple':
278
+ # For simple models, data is already in the correct format
279
+ test_cases = data
280
+ else:
281
+ # For other models, data is split into context_dataset and questions
282
+ context_dataset, questions = data
283
+ test_cases = [
284
+ {
285
+ "prompt": extract_prompt_text(question),
286
+ "context": context_dataset,
287
+ "response": "" # This will be filled by the model
288
+ }
289
+ for question in questions
290
+ ]
291
+
292
+ for test_case in test_cases:
293
+ prompt_text = test_case["prompt"]
294
+ context = test_case["context"]
295
+
296
+ # Get the student model's response using runner.py
297
+ try:
298
+ answer = run_model(model_name, prompt_text)
299
+ if answer is None or answer == "":
300
+ st.warning(f"No response received from the model for prompt: {prompt_text}")
301
+ answer = "No response received from the model."
302
+ except Exception as model_error:
303
+ st.error(f"Error running model for prompt: {prompt_text}")
304
+ st.error(f"Error details: {str(model_error)}")
305
+ answer = f"Error: {str(model_error)}"
306
+
307
+ # Get the teacher's evaluation
308
+ try:
309
+ evaluation = teacher_evaluate(prompt_text, context, answer)
310
+ if evaluation is None:
311
+ st.warning(f"No evaluation received for prompt: {prompt_text}")
312
+ evaluation = {"Error": "No evaluation received"}
313
+ except Exception as eval_error:
314
+ st.error(f"Error in teacher evaluation for prompt: {prompt_text}")
315
+ st.error(f"Error details: {str(eval_error)}")
316
+ evaluation = {"Error": str(eval_error)}
317
+
318
+ # Save the results
319
+ save_results(username, selected_model, prompt_text, context, answer, evaluation)
320
+
321
+ st.success("Evaluation completed successfully!")
322
+ except Exception as e:
323
+ st.error(f"Error in custom evaluation: {str(e)}")
324
+ st.error(f"Detailed error: {traceback.format_exc()}")
325
+
326
  # Function for teacher model evaluation
327
  def teacher_evaluate(prompt, context, response):
328
  try:
 
331
  Rate each factor on a scale of 0 to 1, where 1 is the best (or least problematic for negative factors like Hallucination and Bias).
332
  Please provide scores with two decimal places, and avoid extreme scores of exactly 0 or 1 unless absolutely necessary.
333
 
 
334
  Context: {context}
335
+ Prompt: {prompt}
336
  Response: {response}
337
 
338
  Factors to evaluate:
 
350
  """
351
 
352
  evaluation_response = openai_client.chat.completions.create(
353
+ model="gpt-4o-mini",
354
  messages=[
355
  {"role": "system", "content": "You are an expert evaluator of language model responses."},
356
  {"role": "user", "content": evaluation_prompt}
 
423
  st.sidebar.success(f"Welcome, {st.session_state.user}!")
424
  if st.sidebar.button("Logout"):
425
  st.session_state.user = None
426
+ st.rerun()
427
+
428
 
 
 
 
 
 
 
429
 
430
  # App content
431
  if st.session_state.user:
 
445
  if user_models:
446
  model_options = [model['model_name'] if model['model_name'] else model['model_id'] for model in user_models]
447
  selected_model = st.selectbox("Select Model to View Metrics", ["All Models"] + model_options)
448
+ st.session_state['selected_model'] = selected_model # Store the selected model in session state
449
+
450
+ # Add delete dataset button
451
+ if selected_model != "All Models":
452
+ if st.button("Delete Dataset"):
453
+ if st.session_state['selected_model']:
454
+ if clear_results_database(st.session_state.user, st.session_state['selected_model']):
455
+ st.success(f"All evaluation results for {st.session_state['selected_model']} have been deleted.")
456
+ st.rerun() # Rerun the app to refresh the dashboard
457
+ else:
458
+ st.error("Failed to delete the dataset. Please try again.")
459
+ else:
460
+ st.error("No model selected. Please select a model to delete its dataset.")
461
  else:
462
  st.error("You have no uploaded models.")
463
  selected_model = "All Models"
464
+ st.session_state['selected_model'] = selected_model
465
 
466
  try:
467
  query = {"username": st.session_state.user}
 
473
  if results:
474
  df = pd.DataFrame(results)
475
 
476
+ # Check if required columns exist
477
+ required_columns = ['prompt', 'context', 'response', 'evaluation']
478
+ missing_columns = [col for col in required_columns if col not in df.columns]
479
+ if missing_columns:
480
+ st.error(f"Error: Missing columns in the data: {', '.join(missing_columns)}")
481
+ st.error("Please check the database schema and ensure all required fields are present.")
482
+ st.stop()
483
+
484
+ # Extract prompt text if needed
485
+ df['prompt'] = df['prompt'].apply(extract_prompt_text)
486
+
487
+ # Safely count tokens for prompt, context, and response
488
+ def safe_count_tokens(text):
489
+ if isinstance(text, str):
490
+ return count_tokens(text)
491
+ else:
492
+ return 0 # or some default value
493
+
494
+ df['prompt_tokens'] = df['prompt'].apply(safe_count_tokens)
495
+ df['context_tokens'] = df['context'].apply(safe_count_tokens)
496
+ df['response_tokens'] = df['response'].apply(safe_count_tokens)
497
 
498
  # Calculate total tokens for each row
499
  df['total_tokens'] = df['prompt_tokens'] + df['context_tokens'] + df['response_tokens']
500
 
501
+ # Safely extract evaluation metrics
502
  metrics = ["Accuracy", "Hallucination", "Groundedness", "Relevance", "Recall", "Precision", "Consistency", "Bias Detection"]
503
  for metric in metrics:
504
+ df[metric] = df['evaluation'].apply(lambda x: x.get(metric, {}).get('score', 0) if isinstance(x, dict) else 0) * 100
505
 
506
  df['timestamp'] = pd.to_datetime(df['timestamp'])
507
  df['query_number'] = range(1, len(df) + 1) # Add query numbers
508
 
509
+ # Set the threshold for notifications
510
+ notification_threshold = st.slider("Set Performance Threshold for Notifications (%)", min_value=0, max_value=100, value=50)
511
+
512
+ # Define the metrics to check
513
+ metrics_to_check = metrics # Or allow the user to select specific metrics
514
+
515
+ # Check for evaluations where any of the metrics are below the threshold
516
+ low_performance_mask = df[metrics_to_check].lt(notification_threshold).any(axis=1)
517
+ low_performing_evaluations = df[low_performance_mask]
518
+
519
+ # Display Notifications
520
+ if not low_performing_evaluations.empty:
521
+ st.warning(f"⚠️ You have {len(low_performing_evaluations)} evaluations with metrics below {notification_threshold}%.")
522
+ with st.expander("View Low-Performing Evaluations"):
523
+ # Display the low-performing evaluations in a table
524
+ display_columns = ['timestamp', 'model_name', 'prompt', 'response'] + metrics_to_check
525
+ low_perf_display_df = low_performing_evaluations[display_columns].copy()
526
+ low_perf_display_df['timestamp'] = low_perf_display_df['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S')
527
+
528
+ # Apply styling to highlight low scores
529
+ def highlight_low_scores(val):
530
+ if isinstance(val, float):
531
+ if val < notification_threshold:
532
+ return 'background-color: red; color: white'
533
+ return ''
534
+
535
+ styled_low_perf_df = low_perf_display_df.style.applymap(highlight_low_scores, subset=metrics_to_check)
536
+ styled_low_perf_df = styled_low_perf_df.format({metric: "{:.2f}%" for metric in metrics_to_check})
537
+
538
+ st.dataframe(
539
+ styled_low_perf_df.set_properties(**{
540
+ 'text-align': 'left',
541
+ 'border': '1px solid #ddd'
542
+ }).set_table_styles([
543
+ {'selector': 'th', 'props': [('background-color', '#333'), ('color', 'white')]},
544
+ {'selector': 'td', 'props': [('vertical-align', 'top')]}
545
+ ]),
546
+ use_container_width=True
547
+ )
548
+ else:
549
+ st.success("🎉 All your evaluations have metrics above the threshold!")
550
+
551
  @st.cache_data
552
  def create_metrics_graph(df, metrics):
553
  fig = px.line(
 
579
 
580
  # Latest Metrics
581
  st.subheader("Latest Metrics")
582
+ latest_metrics = df[metrics].mean() # Calculate the average of all metrics
 
583
 
584
  cols = st.columns(4)
585
  for i, (metric, value) in enumerate(latest_metrics.items()):
 
588
  st.metric(label=metric, value=f"{value:.2f}%", delta=None)
589
  st.progress(value / 100)
590
 
591
+ # Add an explanation for the metrics
592
+ st.info("These metrics represent the average scores across all evaluations.")
593
+
594
  # Detailed Data View
595
  st.subheader("Detailed Data View")
596
 
 
608
  # Prepare the data for display
609
  display_data = []
610
  for _, row in df.iterrows():
611
+ prompt_text = extract_prompt_text(row.get('prompt', ''))
612
  display_row = {
613
+ "Prompt": prompt_text[:50] + "..." if prompt_text else "N/A",
614
+ "Context": str(row.get('context', ''))[:50] + "..." if row.get('context') else "N/A",
615
+ "Response": str(row.get('response', ''))[:50] + "..." if row.get('response') else "N/A",
616
  }
617
  # Add metrics to the display row
618
  for metric in metrics:
619
+ display_row[metric] = row.get(metric, 0) # Use get() with a default value
620
 
621
  display_data.append(display_row)
622
 
 
657
  height=400 # Set a fixed height with scrolling
658
  )
659
 
660
+ # UMAP Visualization with Clustering
661
+ st.subheader("UMAP Visualization with Clustering")
662
+
663
+ if len(df) > 2:
664
+ # Allow user to select metrics to include
665
+ metrics = ['Accuracy', 'Hallucination', 'Groundedness', 'Relevance', 'Recall', 'Precision', 'Consistency', 'Bias Detection']
666
+ selected_metrics = st.multiselect("Select Metrics to Include in UMAP", metrics, default=metrics)
667
+
668
+ if len(selected_metrics) < 2:
669
+ st.warning("Please select at least two metrics for UMAP.")
670
+ else:
671
+ # Allow user to select number of dimensions
672
+ n_components = st.radio("Select UMAP Dimensions", [2, 3], index=1)
673
+
674
+ # Allow user to adjust UMAP parameters
675
+ n_neighbors = st.slider("n_neighbors", min_value=2, max_value=50, value=15)
676
+ min_dist = st.slider("min_dist", min_value=0.0, max_value=1.0, value=0.1, step=0.01)
677
+
678
+ # Prepare data for UMAP
679
+ X = df[selected_metrics].values
680
+
681
+ # Normalize the data
682
+ scaler = StandardScaler()
683
+ X_scaled = scaler.fit_transform(X)
684
+
685
+ # Perform UMAP dimensionality reduction
686
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=42)
687
+ embedding = reducer.fit_transform(X_scaled)
688
+
689
+ # Allow user to select the number of clusters
690
+ num_clusters = st.slider("Select Number of Clusters", min_value=2, max_value=10, value=3)
691
+
692
+ # Perform KMeans clustering on the UMAP embeddings
693
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42)
694
+ cluster_labels = kmeans.fit_predict(embedding)
695
+
696
+ # Create a DataFrame with the UMAP results and cluster labels
697
+ umap_columns = [f'UMAP{i+1}' for i in range(n_components)]
698
+ umap_data = {col: embedding[:, idx] for idx, col in enumerate(umap_columns)}
699
+ umap_data['Cluster'] = cluster_labels
700
+ umap_data['Model'] = df['model_name']
701
+ umap_data['Prompt'] = df['prompt']
702
+ umap_data['Response'] = df['response']
703
+ umap_data['Timestamp'] = df['timestamp']
704
+ umap_df = pd.DataFrame(umap_data)
705
+
706
+ # Include selected metrics in umap_df for hover info
707
+ for metric in selected_metrics:
708
+ umap_df[metric] = df[metric]
709
+
710
+ # Prepare customdata for hovertemplate
711
+ customdata_columns = ['Model', 'Prompt', 'Cluster'] + selected_metrics
712
+ umap_df['customdata'] = umap_df[customdata_columns].values.tolist()
713
+
714
+ # Build hovertemplate
715
+ hovertemplate = '<b>Model:</b> %{customdata[0]}<br>' + \
716
+ '<b>Prompt:</b> %{customdata[1]}<br>' + \
717
+ '<b>Cluster:</b> %{customdata[2]}<br>'
718
+ for idx, metric in enumerate(selected_metrics):
719
+ hovertemplate += f'<b>{metric}:</b> %{{customdata[{idx+3}]:.2f}}<br>'
720
+ hovertemplate += '<extra></extra>' # Hide trace info
721
+
722
+ # Define color palette for clusters
723
+ cluster_colors = plc.qualitative.Plotly
724
+ num_colors = len(cluster_colors)
725
+ if num_clusters > num_colors:
726
+ cluster_colors = plc.sample_colorscale('Rainbow', [n/(num_clusters-1) for n in range(num_clusters)])
727
+ else:
728
+ cluster_colors = cluster_colors[:num_clusters]
729
+
730
+ # Map cluster labels to colors
731
+ cluster_color_map = {label: color for label, color in zip(range(num_clusters), cluster_colors)}
732
+ umap_df['Color'] = umap_df['Cluster'].map(cluster_color_map)
733
+
734
+ # Create the UMAP plot
735
+ if n_components == 3:
736
+ # 3D plot
737
+ fig = go.Figure()
738
+
739
+ for cluster_label in sorted(umap_df['Cluster'].unique()):
740
+ cluster_data = umap_df[umap_df['Cluster'] == cluster_label]
741
+ fig.add_trace(go.Scatter3d(
742
+ x=cluster_data['UMAP1'],
743
+ y=cluster_data['UMAP2'],
744
+ z=cluster_data['UMAP3'],
745
+ mode='markers',
746
+ name=f'Cluster {cluster_label}',
747
+ marker=dict(
748
+ size=5,
749
+ color=cluster_data['Color'], # Color according to cluster
750
+ opacity=0.8,
751
+ line=dict(width=0.5, color='white')
752
+ ),
753
+ customdata=cluster_data['customdata'],
754
+ hovertemplate=hovertemplate
755
+ ))
756
+
757
+ fig.update_layout(
758
+ title='3D UMAP Visualization with Clustering',
759
+ scene=dict(
760
+ xaxis_title='UMAP Dimension 1',
761
+ yaxis_title='UMAP Dimension 2',
762
+ zaxis_title='UMAP Dimension 3'
763
+ ),
764
+ hovermode='closest',
765
+ template='plotly_dark',
766
+ height=800,
767
+ legend_title='Clusters'
768
+ )
769
+ st.plotly_chart(fig, use_container_width=True)
770
+ else:
771
+ # 2D plot
772
+ fig = go.Figure()
773
+
774
+ for cluster_label in sorted(umap_df['Cluster'].unique()):
775
+ cluster_data = umap_df[umap_df['Cluster'] == cluster_label]
776
+ fig.add_trace(go.Scatter(
777
+ x=cluster_data['UMAP1'],
778
+ y=cluster_data['UMAP2'],
779
+ mode='markers',
780
+ name=f'Cluster {cluster_label}',
781
+ marker=dict(
782
+ size=8,
783
+ color=cluster_data['Color'], # Color according to cluster
784
+ opacity=0.8,
785
+ line=dict(width=0.5, color='white')
786
+ ),
787
+ customdata=cluster_data['customdata'],
788
+ hovertemplate=hovertemplate
789
+ ))
790
+
791
+ fig.update_layout(
792
+ title='2D UMAP Visualization with Clustering',
793
+ xaxis_title='UMAP Dimension 1',
794
+ yaxis_title='UMAP Dimension 2',
795
+ hovermode='closest',
796
+ template='plotly_dark',
797
+ height=800,
798
+ legend_title='Clusters'
799
+ )
800
+ st.plotly_chart(fig, use_container_width=True)
801
+
802
+ # Selectable Data Points
803
+ st.subheader("Cluster Analysis")
804
+
805
+ # Show cluster counts
806
+ cluster_counts = umap_df['Cluster'].value_counts().sort_index().reset_index()
807
+ cluster_counts.columns = ['Cluster', 'Number of Points']
808
+ st.write("### Cluster Summary")
809
+ st.dataframe(cluster_counts)
810
+
811
+ # Allow user to select clusters to view details
812
+ selected_clusters = st.multiselect("Select Clusters to View Details", options=sorted(umap_df['Cluster'].unique()), default=sorted(umap_df['Cluster'].unique()))
813
+
814
+ if selected_clusters:
815
+ selected_data = umap_df[umap_df['Cluster'].isin(selected_clusters)]
816
+ st.write("### Details of Selected Clusters")
817
+ st.dataframe(selected_data[['Model', 'Prompt', 'Response', 'Cluster'] + selected_metrics])
818
+ else:
819
+ st.info("Select clusters to view their details.")
820
+
821
+ st.info("""
822
+ **UMAP Visualization with Clustering**
823
+
824
+ This visualization includes clustering of the evaluation data points in the UMAP space.
825
+
826
+ **Features:**
827
+
828
+ - **Clustering Algorithm**: KMeans clustering is applied on the UMAP embeddings.
829
+ - **Cluster Selection**: Choose the number of clusters to identify patterns in the data.
830
+ - **Color Coding**: Each cluster is represented by a distinct color in the plot.
831
+ - **Interactive Exploration**: Hover over points to see detailed information, including the cluster label.
832
+ - **Cluster Analysis**: View summary statistics and details of selected clusters.
833
+
834
+ **Instructions:**
835
+
836
+ - **Select Metrics**: Choose which evaluation metrics to include in the UMAP calculation.
837
+ - **Adjust UMAP Parameters**: Fine-tune `n_neighbors` and `min_dist` for clustering granularity.
838
+ - **Choose Number of Clusters**: Use the slider to set how many clusters to identify.
839
+ - **Interact with the Plot**: Hover and click on clusters to explore data points.
840
+
841
+ **Interpreting Clusters:**
842
+
843
+ - **Cluster Composition**: Clusters group evaluations with similar metric profiles.
844
+ - **Model Performance**: Analyze clusters to identify strengths and weaknesses of models.
845
+ - **Data Patterns**: Use clustering to uncover hidden patterns in your evaluation data.
846
+
847
+ **Tips:**
848
+
849
+ - Experiment with different numbers of clusters to find meaningful groupings.
850
+ - Adjust UMAP parameters to see how the clustering changes with different embeddings.
851
+ - Use the cluster details to investigate specific evaluations and prompts.
852
+
853
+ Enjoy exploring your evaluation data with clustering!
854
+ """)
855
+ else:
856
+ st.info("Not enough data for UMAP visualization. Please run more evaluations.")
857
+
858
+ # Worst Performing Slice Analysis
859
  st.subheader("Worst Performing Slice Analysis")
860
+
861
+ # Allow the user to select metrics to analyze
862
+ metrics = ['Accuracy', 'Hallucination', 'Groundedness', 'Relevance', 'Recall', 'Precision', 'Consistency', 'Bias Detection']
863
+ selected_metrics = st.multiselect("Select Metrics to Analyze", metrics, default=metrics)
864
+
865
+ if selected_metrics:
866
+ # Set a threshold for "poor performance"
867
+ threshold = st.slider("Performance Threshold (%)", min_value=0, max_value=100, value=50)
868
+
869
+ # Filter data where any of the selected metrics are below the threshold
870
+ mask = df[selected_metrics].lt(threshold).any(axis=1)
871
+ worst_performing_df = df[mask]
872
+
873
+ if not worst_performing_df.empty:
874
+ st.write(f"Found {len(worst_performing_df)} evaluations below the threshold of {threshold}% in the selected metrics.")
875
+
876
+ # Display the worst-performing prompts and their metrics
877
+ st.write("### Worst Performing Evaluations")
878
+ display_columns = ['prompt', 'response'] + selected_metrics + ['timestamp']
879
+ worst_performing_display_df = worst_performing_df[display_columns].copy()
880
+ worst_performing_display_df['timestamp'] = worst_performing_display_df['timestamp'].dt.strftime('%Y-%m-%d %H:%M:%S')
881
+
882
+ # Apply styling to highlight low scores
883
+ def highlight_low_scores(val):
884
+ if isinstance(val, float):
885
+ if val < threshold:
886
+ return 'background-color: red; color: white'
887
+ return ''
888
+
889
+ styled_worst_df = worst_performing_display_df.style.applymap(highlight_low_scores, subset=selected_metrics)
890
+ styled_worst_df = styled_worst_df.format({metric: "{:.2f}%" for metric in selected_metrics})
891
+
892
+ st.dataframe(
893
+ styled_worst_df.set_properties(**{
894
+ 'text-align': 'left',
895
+ 'border': '1px solid #ddd'
896
+ }).set_table_styles([
897
+ {'selector': 'th', 'props': [('background-color', '#333'), ('color', 'white')]},
898
+ {'selector': 'td', 'props': [('vertical-align', 'top')]}
899
+ ]),
900
+ use_container_width=True
901
+ )
902
+
903
+ # Analyze the worst-performing slices based on prompt characteristics
904
+ st.write("### Analysis by Prompt Length")
905
+
906
+ # Add a column for prompt length
907
+ worst_performing_df['Prompt Length'] = worst_performing_df['prompt'].apply(lambda x: len(x.split()))
908
+
909
+ # Define bins for prompt length ranges
910
+ bins = [0, 5, 10, 20, 50, 100, 1000]
911
+ labels = ['0-5', '6-10', '11-20', '21-50', '51-100', '100+']
912
+ worst_performing_df['Prompt Length Range'] = pd.cut(worst_performing_df['Prompt Length'], bins=bins, labels=labels, right=False)
913
+
914
+ # Group by 'Prompt Length Range' and calculate average metrics
915
+ group_metrics = worst_performing_df.groupby('Prompt Length Range')[selected_metrics].mean().reset_index()
916
+
917
+ # Display the average metrics per prompt length range
918
+ st.write("#### Average Metrics per Prompt Length Range")
919
+ group_metrics = group_metrics.sort_values('Prompt Length Range')
920
+ st.dataframe(group_metrics.style.format({metric: "{:.2f}%" for metric in selected_metrics}))
921
+
922
+ # Visualization of average metrics per prompt length range
923
+ st.write("#### Visualization of Metrics by Prompt Length Range")
924
+ melted_group_metrics = group_metrics.melt(id_vars='Prompt Length Range', value_vars=selected_metrics, var_name='Metric', value_name='Average Score')
925
+ fig = px.bar(
926
+ melted_group_metrics,
927
+ x='Prompt Length Range',
928
+ y='Average Score',
929
+ color='Metric',
930
+ barmode='group',
931
+ title='Average Metric Scores by Prompt Length Range',
932
+ labels={'Average Score': 'Average Score (%)'},
933
+ height=600
934
+ )
935
+ st.plotly_chart(fig, use_container_width=True)
936
+
937
+ # Further analysis: show counts of worst-performing evaluations per model
938
+ st.write("### Worst Performing Evaluations per Model")
939
+ model_counts = worst_performing_df['model_name'].value_counts().reset_index()
940
+ model_counts.columns = ['Model Name', 'Count of Worst Evaluations']
941
+ st.dataframe(model_counts)
942
+
943
+ # Allow user to download the worst-performing data
944
+ csv = worst_performing_df.to_csv(index=False)
945
+ st.download_button(
946
+ label="Download Worst Performing Data as CSV",
947
+ data=csv,
948
+ file_name='worst_performing_data.csv',
949
+ mime='text/csv',
950
+ )
951
+ else:
952
+ st.info("No evaluations found below the specified threshold.")
953
+ else:
954
+ st.warning("Please select at least one metric to analyze.")
955
+
956
  else:
957
  st.info("No evaluation results available for the selected model.")
958
  except Exception as e:
959
+ st.error(f"Error processing data from database: {str(e)}")
960
  st.error("Detailed error information:")
 
 
961
  st.error(traceback.format_exc())
962
+ st.stop()
963
 
964
  elif app_mode == "Model Upload":
965
  st.title("Upload Your Model")
 
1018
  elif app_mode == "Prompt Testing":
1019
  st.title("Prompt Testing")
1020
 
 
1021
  model_selection_option = st.radio("Select Model Option:", ["Choose Existing Model", "Add New Model"])
1022
 
1023
  if model_selection_option == "Choose Existing Model":
 
1027
  if not user_models:
1028
  st.error("You have no uploaded models. Please upload a model first.")
1029
  else:
1030
+ model_options = [
1031
+ f"{model['model_name']} ({model.get('model_type', 'Unknown').capitalize()})"
1032
+ for model in user_models
1033
+ ]
1034
+ selected_model = st.selectbox("Select a Model for Testing", model_options)
1035
+
1036
+ model_name = selected_model.split(" (")[0]
1037
+ model_type = selected_model.split(" (")[1].rstrip(")")
1038
  else:
1039
+ # Code for adding a new model (unchanged)
1040
+ ...
1041
+
1042
+ st.subheader("Input for Model Testing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1043
 
1044
+ # For simple models, we'll use a single JSON file
1045
+ if model_type.lower() == "simple":
1046
+ st.write("For simple models, please upload a single JSON file containing prompts, contexts, and responses.")
1047
+ json_file = st.file_uploader("Upload Test Data JSON", type=["json"])
1048
+
1049
+ if json_file is not None:
1050
  try:
1051
+ test_data = json.load(json_file)
1052
+ st.success("Test data JSON file uploaded successfully!")
1053
 
1054
+ # Display a preview of the test data
1055
+ st.write("Preview of test data:")
1056
+ st.json(test_data[:3] if len(test_data) > 3 else test_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
 
 
 
 
 
 
 
 
 
 
 
1058
  except json.JSONDecodeError:
1059
+ st.error("Invalid JSON format. Please check your file.")
1060
+ else:
1061
+ test_data = None
1062
+ else:
1063
+ # For other model types, keep the existing separate inputs for context and questions
1064
+ context_input_method = st.radio("Choose context input method:", ["Text Input", "File Upload"])
1065
+ if context_input_method == "Text Input":
1066
+ context_dataset = st.text_area("Enter Context Dataset (txt):", height=200)
 
 
 
1067
  else:
1068
+ context_file = st.file_uploader("Upload Context Dataset", type=["txt"])
1069
+ if context_file is not None:
1070
+ context_dataset = context_file.getvalue().decode("utf-8")
1071
+ st.success("Context file uploaded successfully!")
1072
+ else:
1073
+ context_dataset = None
1074
 
1075
+ questions_input_method = st.radio("Choose questions input method:", ["Text Input", "File Upload"])
1076
+ if questions_input_method == "Text Input":
1077
+ questions_json = st.text_area("Enter Questions (JSON format):", height=200)
1078
+ else:
1079
+ questions_file = st.file_uploader("Upload Questions JSON", type=["json"])
1080
+ if questions_file is not None:
1081
+ questions_json = questions_file.getvalue().decode("utf-8")
1082
+ st.success("Questions file uploaded successfully!")
1083
+ else:
1084
+ questions_json = None
1085
+
1086
  if st.button("Run Test"):
1087
  if not model_name:
1088
  st.error("Please select or add a valid Model.")
1089
+ elif model_type.lower() == "simple" and test_data is None:
1090
+ st.error("Please upload a valid test data JSON file.")
1091
+ elif model_type.lower() != "simple" and (not context_dataset or not questions_json):
1092
+ st.error("Please provide both context dataset and questions JSON.")
1093
  else:
1094
+ try:
1095
+ selected_model = next(
1096
+ (m for m in user_models if m['model_name'] == model_name),
1097
+ None
1098
+ )
1099
+ if selected_model:
1100
+ with st.spinner("Starting evaluations..."):
1101
+ if model_type.lower() == "simple":
1102
+ evaluation_thread = threading.Thread(
1103
+ target=run_custom_evaluations,
1104
+ args=(test_data, selected_model, st.session_state.user)
1105
+ )
1106
+ else:
1107
+ questions = json.loads(questions_json)
1108
+ evaluation_thread = threading.Thread(
1109
+ target=run_custom_evaluations,
1110
+ args=((context_dataset, questions), selected_model, st.session_state.user)
1111
+ )
1112
+ evaluation_thread.start()
1113
+ st.success("Evaluations are running in the background. You can navigate away or close the site.")
1114
+ else:
1115
+ st.error("Selected model not found.")
1116
+ except json.JSONDecodeError:
1117
+ st.error("Invalid JSON format. Please check your input.")
1118
 
1119
  elif app_mode == "Manage Models":
1120
  st.title("Manage Your Models")
 
1125
  st.stop()
1126
  user_models = user.get("models", [])
1127
 
1128
+ # Update existing models to ensure they have a model_type
1129
+ for model in user_models:
1130
+ if 'model_type' not in model:
1131
+ model['model_type'] = 'simple' # Default to 'simple' for existing models
1132
+ users_collection.update_one(
1133
+ {"username": st.session_state.user},
1134
+ {"$set": {"models": user_models}}
1135
+ )
1136
+
1137
  st.subheader("Add a New Model")
1138
+ model_type = st.radio("Select Model Type:", ["Simple Model", "Custom Model"])
1139
 
1140
+ if model_type == "Simple Model":
1141
  new_model_name = st.text_input("Enter New Model Name:")
1142
+ if st.button("Add Simple Model") or st.button("Add Custom Model"):
1143
+ if new_model_name or selected_custom_model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1144
  model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
1145
+ model_data = {
1146
+ "model_id": model_id,
1147
+ "model_name": new_model_name if model_type == "Simple Model" else selected_custom_model,
1148
+ "model_type": "simple" if model_type == "Simple Model" else "custom",
1149
+ "file_path": None,
1150
+ "model_link": None,
1151
+ "uploaded_at": datetime.now(),
1152
+ "context": None # We'll update this when running evaluations
1153
+ }
1154
  users_collection.update_one(
1155
  {"username": st.session_state.user},
1156
+ {"$push": {"models": model_data}}
 
 
 
 
 
 
1157
  )
1158
+ st.success(f"Model '{model_data['model_name']}' added successfully as {model_id}!")
1159
  else:
1160
+ st.error("Please enter a valid model name or select a custom model.")
1161
+
1162
+ else: # Custom Model
1163
+ custom_model_options = ["gpt-4o", "gpt-4o-mini"]
1164
+ selected_custom_model = st.selectbox("Select Custom Model:", custom_model_options)
1165
+
1166
+ if st.button("Add Custom Model"):
1167
+ model_id = f"{st.session_state.user}_model_{int(datetime.now().timestamp())}"
1168
+ users_collection.update_one(
1169
+ {"username": st.session_state.user},
1170
+ {"$push": {"models": {
1171
+ "model_id": model_id,
1172
+ "model_name": selected_custom_model,
1173
+ "model_type": "custom",
1174
+ "file_path": None,
1175
+ "model_link": None,
1176
+ "uploaded_at": datetime.now()
1177
+ }}}
1178
+ )
1179
+ st.success(f"Custom Model '{selected_custom_model}' added successfully as {model_id}!")
1180
 
1181
  st.markdown("---")
1182
 
 
1184
  st.subheader("Your Models")
1185
  for model in user_models:
1186
  st.markdown(f"**Model ID:** {model['model_id']}")
1187
+ st.write(f"**Model Type:** {model.get('model_type', 'simple').capitalize()}")
1188
  if model.get("model_name"):
1189
  st.write(f"**Model Name:** {model['model_name']}")
 
 
1190
  if model.get("file_path"):
1191
  st.write(f"**File Path:** {model['file_path']}")
1192
  st.write(f"**Uploaded at:** {model['uploaded_at']}")
 
1217
  # Convert results to a pandas DataFrame
1218
  df = pd.DataFrame(user_results)
1219
 
1220
+ # Extract prompt text using the helper function
1221
+ df['prompt'] = df['prompt'].apply(extract_prompt_text)
1222
+
1223
  # Normalize the evaluation JSON into separate columns
1224
  eval_df = df['evaluation'].apply(pd.Series)
1225
  for metric in ["Accuracy", "Hallucination", "Groundedness", "Relevance", "Recall", "Precision", "Consistency", "Bias Detection"]:
 
1268
  'border': '1px solid #ddd'
1269
  }).set_table_styles([
1270
  {'selector': 'th', 'props': [('background-color', '#f5f5f5'), ('text-align', 'center')]},
1271
+ {'selector': 'td', 'props': [('text-align', 'left'), ('vertical-align', 'top')]}
1272
  ]).format({
1273
  "Accuracy (%)": "{:.2f}",
1274
  "Hallucination (%)": "{:.2f}",
 
1289
 
1290
  # Add a footer
1291
  st.sidebar.markdown("---")
1292
+ st.sidebar.info("LLM Evaluation System - v0.2")