nazneen commited on
Commit
050aca6
1 Parent(s): 406f76b

interactive legend

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -80,8 +80,8 @@ def data_comparison(df):
80
  ).interactive()
81
 
82
  legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
83
- x=alt.X("label"),
84
- y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=""),
85
  shape=alt.Shape('label:N', scale=alt.Scale(
86
  range=['circle', 'diamond']), legend=None),
87
  color=color,
@@ -247,6 +247,22 @@ if __name__ == "__main__":
247
  data_df['slice'] = 'high-loss'
248
  data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  with rcol:
251
  with st.spinner(text='loading...'):
252
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
@@ -264,20 +280,6 @@ if __name__ == "__main__":
264
  if run_kmeans == 'True':
265
  with st.spinner(text='running kmeans...'):
266
  merged = kmeans(data_df,num_clusters=num_clusters)
267
- with lcol:
268
- st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
269
- with st.expander("How to read the table:"):
270
- st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
271
- st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
272
- st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
273
- with st.spinner(text='loading error slice...'):
274
- dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
275
- #uncomment the next next line to run dynamically and not from file
276
- # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
277
- # by=['loss'], ascending=False)
278
- # table_html = dataframe.to_html(
279
- # columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
280
- # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
281
- st.write(dataframe,width=900, height=300)
282
  with st.spinner(text='loading visualization...'):
283
  quant_panel(merged)
 
80
  ).interactive()
81
 
82
  legend = alt.Chart(df).mark_point(size=100, filled=True).encode(
83
+ x=alt.X("label:N"),
84
+ y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), sort='descending', title=''),
85
  shape=alt.Shape('label:N', scale=alt.Scale(
86
  range=['circle', 'diamond']), legend=None),
87
  color=color,
 
247
  data_df['slice'] = 'high-loss'
248
  data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
249
 
250
+ with lcol:
251
+ st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
252
+ with st.expander("How to read the table:"):
253
+ st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
254
+ st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
255
+ st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
256
+ with st.spinner(text='loading error slice...'):
257
+ dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
258
+ #uncomment the next next line to run dynamically and not from file
259
+ # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
260
+ # by=['loss'], ascending=False)
261
+ # table_html = dataframe.to_html(
262
+ # columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
263
+ # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
264
+ st.write(dataframe,width=900, height=300)
265
+
266
  with rcol:
267
  with st.spinner(text='loading...'):
268
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
 
280
  if run_kmeans == 'True':
281
  with st.spinner(text='running kmeans...'):
282
  merged = kmeans(data_df,num_clusters=num_clusters)
283
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  with st.spinner(text='loading visualization...'):
285
  quant_panel(merged)