Spaces:
Runtime error
Runtime error
add cache
Browse files
app.py
CHANGED
@@ -198,7 +198,7 @@ def topic_distribution(weights, smoothing=0.01):
|
|
198 |
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
|
199 |
|
200 |
def populate_session(dataset,model):
|
201 |
-
data_df =
|
202 |
if model == 'albert-base-v2-yelp-polarity':
|
203 |
tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
|
204 |
else:
|
@@ -208,7 +208,9 @@ def populate_session(dataset,model):
|
|
208 |
if "selected_slice" not in st.session_state:
|
209 |
st.session_state["selected_slice"] = None
|
210 |
|
211 |
-
|
|
|
|
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
### STREAMLIT APP CONGFIG ###
|
@@ -235,7 +237,7 @@ if __name__ == "__main__":
|
|
235 |
### LOAD DATA AND SESSION VARIABLES ###
|
236 |
##uncomment the next next line to run dynamically and not from file
|
237 |
#populate_session(dataset, model)
|
238 |
-
data_df =
|
239 |
loss_quantile = st.sidebar.slider(
|
240 |
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
241 |
)
|
@@ -250,7 +252,7 @@ if __name__ == "__main__":
|
|
250 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
251 |
#uncomment the next two lines to run dynamically and not from file
|
252 |
#commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
253 |
-
commontokens =
|
254 |
with st.expander("How to read the table:"):
|
255 |
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
256 |
st.write(commontokens)
|
@@ -260,20 +262,22 @@ if __name__ == "__main__":
|
|
260 |
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
261 |
|
262 |
if run_kmeans == 'True':
|
263 |
-
|
|
|
264 |
with lcol:
|
265 |
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
267 |
#uncomment the next next line to run dynamically and not from file
|
268 |
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
269 |
# by=['loss'], ascending=False)
|
270 |
# table_html = dataframe.to_html(
|
271 |
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
272 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
273 |
-
with st.expander("How to read the table:"):
|
274 |
-
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
275 |
-
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
276 |
-
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
277 |
st.write(dataframe,width=900, height=300)
|
278 |
|
279 |
quant_panel(merged)
|
|
|
198 |
# return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
|
199 |
|
200 |
def populate_session(dataset,model):
|
201 |
+
data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
202 |
if model == 'albert-base-v2-yelp-polarity':
|
203 |
tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
|
204 |
else:
|
|
|
208 |
if "selected_slice" not in st.session_state:
|
209 |
st.session_state["selected_slice"] = None
|
210 |
|
211 |
+
@st.cache(ttl=600)
|
212 |
+
def read_file_to_df(file):
|
213 |
+
return pd.read_parquet(file)
|
214 |
|
215 |
if __name__ == "__main__":
|
216 |
### STREAMLIT APP CONGFIG ###
|
|
|
237 |
### LOAD DATA AND SESSION VARIABLES ###
|
238 |
##uncomment the next next line to run dynamically and not from file
|
239 |
#populate_session(dataset, model)
|
240 |
+
data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
241 |
loss_quantile = st.sidebar.slider(
|
242 |
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
243 |
)
|
|
|
252 |
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
253 |
#uncomment the next two lines to run dynamically and not from file
|
254 |
#commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
255 |
+
commontokens = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet')
|
256 |
with st.expander("How to read the table:"):
|
257 |
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
258 |
st.write(commontokens)
|
|
|
262 |
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
263 |
|
264 |
if run_kmeans == 'True':
|
265 |
+
with st.spinner(text='running kmeans...'):
|
266 |
+
merged = kmeans(data_df,num_clusters=num_clusters)
|
267 |
with lcol:
|
268 |
st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
|
269 |
+
with st.expander("How to read the table:"):
|
270 |
+
st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
|
271 |
+
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
272 |
+
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
273 |
+
with st.spinner(text='loading error slice...'):
|
274 |
+
dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
|
275 |
#uncomment the next next line to run dynamically and not from file
|
276 |
# dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
|
277 |
# by=['loss'], ascending=False)
|
278 |
# table_html = dataframe.to_html(
|
279 |
# columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
|
280 |
# table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
|
|
|
|
|
|
|
|
|
281 |
st.write(dataframe,width=900, height=300)
|
282 |
|
283 |
quant_panel(merged)
|