Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
·
29cc4c5
1
Parent(s):
9075b46
somewhat works with the new data
Browse files- db/schema.py +8 -4
- utils/loaders.py +2 -2
- views/nav_buttons.py +10 -3
- views/questions_screen.py +31 -13
db/schema.py
CHANGED
@@ -3,12 +3,16 @@ from typing import List, Dict, Optional
|
|
3 |
from datetime import datetime
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class Response(BaseModel):
|
7 |
config_id: str
|
8 |
-
|
9 |
-
|
10 |
-
query_p1: Dict[str, int]
|
11 |
-
comment: Optional[str]
|
12 |
timestamp: str
|
13 |
|
14 |
|
|
|
3 |
from datetime import datetime
|
4 |
|
5 |
|
6 |
+
class ModelRatings(BaseModel):
|
7 |
+
query_v_ratings: Dict[str, int]
|
8 |
+
query_p0_ratings: Dict[str, int]
|
9 |
+
query_p1_ratings: Dict[str, int]
|
10 |
+
|
11 |
+
|
12 |
class Response(BaseModel):
|
13 |
config_id: str
|
14 |
+
model_ratings: Dict[str, ModelRatings]
|
15 |
+
comment: Optional[str] = None
|
|
|
|
|
16 |
timestamp: str
|
17 |
|
18 |
|
utils/loaders.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import pandas as pd
|
2 |
import os
|
3 |
-
|
4 |
from datasets import load_dataset
|
5 |
from dotenv import load_dotenv
|
6 |
|
@@ -9,7 +9,7 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
9 |
REPO_NAME = os.getenv("DATA_REPO")
|
10 |
DATA_FILES = os.getenv("GEMINI_DATA_FILES")
|
11 |
|
12 |
-
|
13 |
def load_data():
|
14 |
try:
|
15 |
data = pd.read_csv("data/user-evaluation/merged.csv")[:5]
|
|
|
1 |
import pandas as pd
|
2 |
import os
|
3 |
+
import streamlit as st
|
4 |
from datasets import load_dataset
|
5 |
from dotenv import load_dotenv
|
6 |
|
|
|
9 |
REPO_NAME = os.getenv("DATA_REPO")
|
10 |
DATA_FILES = os.getenv("GEMINI_DATA_FILES")
|
11 |
|
12 |
+
@st.cache_data
|
13 |
def load_data():
|
14 |
try:
|
15 |
data = pd.read_csv("data/user-evaluation/merged.csv")[:5]
|
views/nav_buttons.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from db.schema import Feedback
|
2 |
from db.crud import save_feedback, read
|
3 |
import streamlit as st
|
4 |
from datetime import datetime
|
@@ -32,8 +32,9 @@ def any_value_zero(response):
|
|
32 |
)
|
33 |
|
34 |
|
35 |
-
def navigation_buttons(data,
|
36 |
"""Display navigation buttons."""
|
|
|
37 |
current_index = st.session_state.current_index
|
38 |
|
39 |
col1, col2, col3 = st.columns([1, 1, 2])
|
@@ -44,8 +45,14 @@ def navigation_buttons(data, ratings_v, ratings_p0, ratings_p1):
|
|
44 |
st.rerun()
|
45 |
|
46 |
with col2: # Next button
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if st.button("Next"):
|
48 |
-
if any(rating == 0 for rating in
|
49 |
st.warning("Please provide a rating before proceeding.")
|
50 |
else:
|
51 |
if current_index < len(data) - 1:
|
|
|
1 |
+
from db.schema import Feedback, Response
|
2 |
from db.crud import save_feedback, read
|
3 |
import streamlit as st
|
4 |
from datetime import datetime
|
|
|
32 |
)
|
33 |
|
34 |
|
35 |
+
def navigation_buttons(data, response: Response):
|
36 |
"""Display navigation buttons."""
|
37 |
+
#TODO change the function to accept the response object
|
38 |
current_index = st.session_state.current_index
|
39 |
|
40 |
col1, col2, col3 = st.columns([1, 1, 2])
|
|
|
45 |
st.rerun()
|
46 |
|
47 |
with col2: # Next button
|
48 |
+
all_ratings = []
|
49 |
+
for model_ratings in response.model_ratings.values():
|
50 |
+
all_ratings.extend(model_ratings.query_v_ratings.values())
|
51 |
+
all_ratings.extend(model_ratings.query_p0_ratings.values())
|
52 |
+
all_ratings.extend(model_ratings.query_p1_ratings.values())
|
53 |
+
|
54 |
if st.button("Next"):
|
55 |
+
if any(rating == 0 for rating in all_ratings):
|
56 |
st.warning("Please provide a rating before proceeding.")
|
57 |
else:
|
58 |
if current_index < len(data) - 1:
|
views/questions_screen.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from db.schema import Response
|
2 |
import streamlit as st
|
3 |
from datetime import datetime
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
from views.nav_buttons import navigation_buttons
|
|
|
7 |
st.set_page_config(layout="wide")
|
8 |
load_dotenv()
|
9 |
|
@@ -80,23 +81,40 @@ def questions_screen(data):
|
|
80 |
st.text_area("", config['context'], height=300, disabled=False)
|
81 |
|
82 |
# Render queries and collect ratings
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Additional comments
|
91 |
comment = st.text_area("Additional Comments (Optional):")
|
92 |
-
if "persona_alignment" in
|
93 |
-
|
|
|
94 |
# Collecting the response data
|
95 |
response = Response(
|
96 |
config_id=config["config_id"],
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
comment=comment,
|
101 |
timestamp=datetime.now().isoformat()
|
102 |
)
|
@@ -107,6 +125,6 @@ def questions_screen(data):
|
|
107 |
st.session_state.responses.append(response)
|
108 |
|
109 |
# Navigation buttons
|
110 |
-
navigation_buttons(data,
|
111 |
except IndexError:
|
112 |
print("Survey completed!")
|
|
|
1 |
+
from db.schema import Response, ModelRatings
|
2 |
import streamlit as st
|
3 |
from datetime import datetime
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
from views.nav_buttons import navigation_buttons
|
7 |
+
|
8 |
st.set_page_config(layout="wide")
|
9 |
load_dotenv()
|
10 |
|
|
|
81 |
st.text_area("", config['context'], height=300, disabled=False)
|
82 |
|
83 |
# Render queries and collect ratings
|
84 |
+
g_query_v_ratings = render_query_ratings("Query_v", config, "gemini_query_v", current_index)
|
85 |
+
g_query_p0_ratings = render_query_ratings("Query_p0",
|
86 |
+
config, "gemini_query_p0", current_index, has_persona_alignment=True)
|
87 |
+
g_query_p1_ratings = render_query_ratings("Query_p1",
|
88 |
+
config, "gemini_query_p1",
|
89 |
+
current_index, has_persona_alignment=True)
|
90 |
+
|
91 |
+
l_query_v_ratings = render_query_ratings("Query_v", config, "llama_query_v", current_index)
|
92 |
+
l_query_p0_ratings = render_query_ratings("Query_p0",
|
93 |
+
config, "llama_query_p0", current_index, has_persona_alignment=True)
|
94 |
+
l_query_p1_ratings = render_query_ratings("Query_p1",
|
95 |
+
config, "llama_query_p1",
|
96 |
+
current_index, has_persona_alignment=True)
|
97 |
|
98 |
# Additional comments
|
99 |
comment = st.text_area("Additional Comments (Optional):")
|
100 |
+
if "persona_alignment" in g_query_v_ratings or "persona_alignment" in l_query_v_ratings:
|
101 |
+
g_query_v_ratings.pop("persona_alignment")
|
102 |
+
l_query_v_ratings.pop("persona_alignment")
|
103 |
# Collecting the response data
|
104 |
response = Response(
|
105 |
config_id=config["config_id"],
|
106 |
+
model_ratings={
|
107 |
+
"gemini": ModelRatings(
|
108 |
+
query_v_ratings=g_query_v_ratings,
|
109 |
+
query_p0_ratings=g_query_p0_ratings,
|
110 |
+
query_p1_ratings=g_query_p1_ratings,
|
111 |
+
),
|
112 |
+
"llama": ModelRatings(
|
113 |
+
query_v_ratings=l_query_v_ratings,
|
114 |
+
query_p0_ratings=l_query_p0_ratings,
|
115 |
+
query_p1_ratings=l_query_p1_ratings,
|
116 |
+
)
|
117 |
+
},
|
118 |
comment=comment,
|
119 |
timestamp=datetime.now().isoformat()
|
120 |
)
|
|
|
125 |
st.session_state.responses.append(response)
|
126 |
|
127 |
# Navigation buttons
|
128 |
+
navigation_buttons(data, response)
|
129 |
except IndexError:
|
130 |
print("Survey completed!")
|