Ashmi Banerjee commited on
Commit
29cc4c5
·
1 Parent(s): 9075b46

somewhat works with the new data

Browse files
db/schema.py CHANGED
@@ -3,12 +3,16 @@ from typing import List, Dict, Optional
3
  from datetime import datetime
4
 
5
 
 
 
 
 
 
 
6
  class Response(BaseModel):
7
  config_id: str
8
- query_v: Dict[str, int]
9
- query_p0: Dict[str, int]
10
- query_p1: Dict[str, int]
11
- comment: Optional[str]
12
  timestamp: str
13
 
14
 
 
3
  from datetime import datetime
4
 
5
 
6
+ class ModelRatings(BaseModel):
7
+ query_v_ratings: Dict[str, int]
8
+ query_p0_ratings: Dict[str, int]
9
+ query_p1_ratings: Dict[str, int]
10
+
11
+
12
  class Response(BaseModel):
13
  config_id: str
14
+ model_ratings: Dict[str, ModelRatings]
15
+ comment: Optional[str] = None
 
 
16
  timestamp: str
17
 
18
 
utils/loaders.py CHANGED
@@ -1,6 +1,6 @@
1
  import pandas as pd
2
  import os
3
-
4
  from datasets import load_dataset
5
  from dotenv import load_dotenv
6
 
@@ -9,7 +9,7 @@ HF_TOKEN = os.getenv("HF_TOKEN")
9
  REPO_NAME = os.getenv("DATA_REPO")
10
  DATA_FILES = os.getenv("GEMINI_DATA_FILES")
11
 
12
-
13
  def load_data():
14
  try:
15
  data = pd.read_csv("data/user-evaluation/merged.csv")[:5]
 
1
  import pandas as pd
2
  import os
3
+ import streamlit as st
4
  from datasets import load_dataset
5
  from dotenv import load_dotenv
6
 
 
9
  REPO_NAME = os.getenv("DATA_REPO")
10
  DATA_FILES = os.getenv("GEMINI_DATA_FILES")
11
 
12
+ @st.cache_data
13
  def load_data():
14
  try:
15
  data = pd.read_csv("data/user-evaluation/merged.csv")[:5]
views/nav_buttons.py CHANGED
@@ -1,4 +1,4 @@
1
- from db.schema import Feedback
2
  from db.crud import save_feedback, read
3
  import streamlit as st
4
  from datetime import datetime
@@ -32,8 +32,9 @@ def any_value_zero(response):
32
  )
33
 
34
 
35
- def navigation_buttons(data, ratings_v, ratings_p0, ratings_p1):
36
  """Display navigation buttons."""
 
37
  current_index = st.session_state.current_index
38
 
39
  col1, col2, col3 = st.columns([1, 1, 2])
@@ -44,8 +45,14 @@ def navigation_buttons(data, ratings_v, ratings_p0, ratings_p1):
44
  st.rerun()
45
 
46
  with col2: # Next button
 
 
 
 
 
 
47
  if st.button("Next"):
48
- if any(rating == 0 for rating in [ratings_v, ratings_p0, ratings_p1]):
49
  st.warning("Please provide a rating before proceeding.")
50
  else:
51
  if current_index < len(data) - 1:
 
1
+ from db.schema import Feedback, Response
2
  from db.crud import save_feedback, read
3
  import streamlit as st
4
  from datetime import datetime
 
32
  )
33
 
34
 
35
+ def navigation_buttons(data, response: Response):
36
  """Display navigation buttons."""
37
+ #TODO change the function to accept the response object
38
  current_index = st.session_state.current_index
39
 
40
  col1, col2, col3 = st.columns([1, 1, 2])
 
45
  st.rerun()
46
 
47
  with col2: # Next button
48
+ all_ratings = []
49
+ for model_ratings in response.model_ratings.values():
50
+ all_ratings.extend(model_ratings.query_v_ratings.values())
51
+ all_ratings.extend(model_ratings.query_p0_ratings.values())
52
+ all_ratings.extend(model_ratings.query_p1_ratings.values())
53
+
54
  if st.button("Next"):
55
+ if any(rating == 0 for rating in all_ratings):
56
  st.warning("Please provide a rating before proceeding.")
57
  else:
58
  if current_index < len(data) - 1:
views/questions_screen.py CHANGED
@@ -1,9 +1,10 @@
1
- from db.schema import Response
2
  import streamlit as st
3
  from datetime import datetime
4
  import os
5
  from dotenv import load_dotenv
6
  from views.nav_buttons import navigation_buttons
 
7
  st.set_page_config(layout="wide")
8
  load_dotenv()
9
 
@@ -80,23 +81,40 @@ def questions_screen(data):
80
  st.text_area("", config['context'], height=300, disabled=False)
81
 
82
  # Render queries and collect ratings
83
- query_v_ratings = render_query_ratings("Query_v", config, "gemini_query_v", current_index)
84
- query_p0_ratings = render_query_ratings("Query_p0",
85
- config, "gemini_query_p0", current_index, has_persona_alignment=True)
86
- query_p1_ratings = render_query_ratings("Query_p1",
87
- config, "gemini_query_p1",
88
- current_index, has_persona_alignment=True)
 
 
 
 
 
 
 
89
 
90
  # Additional comments
91
  comment = st.text_area("Additional Comments (Optional):")
92
- if "persona_alignment" in query_v_ratings:
93
- query_v_ratings.pop("persona_alignment")
 
94
  # Collecting the response data
95
  response = Response(
96
  config_id=config["config_id"],
97
- query_v=query_v_ratings,
98
- query_p0=query_p0_ratings,
99
- query_p1=query_p1_ratings,
 
 
 
 
 
 
 
 
 
100
  comment=comment,
101
  timestamp=datetime.now().isoformat()
102
  )
@@ -107,6 +125,6 @@ def questions_screen(data):
107
  st.session_state.responses.append(response)
108
 
109
  # Navigation buttons
110
- navigation_buttons(data, query_v_ratings["clarity"], query_p0_ratings["clarity"], query_p1_ratings["clarity"])
111
  except IndexError:
112
  print("Survey completed!")
 
1
+ from db.schema import Response, ModelRatings
2
  import streamlit as st
3
  from datetime import datetime
4
  import os
5
  from dotenv import load_dotenv
6
  from views.nav_buttons import navigation_buttons
7
+
8
  st.set_page_config(layout="wide")
9
  load_dotenv()
10
 
 
81
  st.text_area("", config['context'], height=300, disabled=False)
82
 
83
  # Render queries and collect ratings
84
+ g_query_v_ratings = render_query_ratings("Query_v", config, "gemini_query_v", current_index)
85
+ g_query_p0_ratings = render_query_ratings("Query_p0",
86
+ config, "gemini_query_p0", current_index, has_persona_alignment=True)
87
+ g_query_p1_ratings = render_query_ratings("Query_p1",
88
+ config, "gemini_query_p1",
89
+ current_index, has_persona_alignment=True)
90
+
91
+ l_query_v_ratings = render_query_ratings("Query_v", config, "llama_query_v", current_index)
92
+ l_query_p0_ratings = render_query_ratings("Query_p0",
93
+ config, "llama_query_p0", current_index, has_persona_alignment=True)
94
+ l_query_p1_ratings = render_query_ratings("Query_p1",
95
+ config, "llama_query_p1",
96
+ current_index, has_persona_alignment=True)
97
 
98
  # Additional comments
99
  comment = st.text_area("Additional Comments (Optional):")
100
+ if "persona_alignment" in g_query_v_ratings or "persona_alignment" in l_query_v_ratings:
101
+ g_query_v_ratings.pop("persona_alignment")
102
+ l_query_v_ratings.pop("persona_alignment")
103
  # Collecting the response data
104
  response = Response(
105
  config_id=config["config_id"],
106
+ model_ratings={
107
+ "gemini": ModelRatings(
108
+ query_v_ratings=g_query_v_ratings,
109
+ query_p0_ratings=g_query_p0_ratings,
110
+ query_p1_ratings=g_query_p1_ratings,
111
+ ),
112
+ "llama": ModelRatings(
113
+ query_v_ratings=l_query_v_ratings,
114
+ query_p0_ratings=l_query_p0_ratings,
115
+ query_p1_ratings=l_query_p1_ratings,
116
+ )
117
+ },
118
  comment=comment,
119
  timestamp=datetime.now().isoformat()
120
  )
 
125
  st.session_state.responses.append(response)
126
 
127
  # Navigation buttons
128
+ navigation_buttons(data, response)
129
  except IndexError:
130
  print("Survey completed!")