nazneen commited on
Commit
7cba420
1 Parent(s): bbb450c
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -198,7 +198,7 @@ def topic_distribution(weights, smoothing=0.01):
198
  # return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
199
 
200
  def populate_session(dataset,model):
201
- data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
202
  if model == 'albert-base-v2-yelp-polarity':
203
  tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
204
  else:
@@ -208,7 +208,9 @@ def populate_session(dataset,model):
208
  if "selected_slice" not in st.session_state:
209
  st.session_state["selected_slice"] = None
210
 
211
-
 
 
212
 
213
  if __name__ == "__main__":
214
  ### STREAMLIT APP CONGFIG ###
@@ -235,7 +237,7 @@ if __name__ == "__main__":
235
  ### LOAD DATA AND SESSION VARIABLES ###
236
  ##uncomment the next next line to run dynamically and not from file
237
  #populate_session(dataset, model)
238
- data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
239
  loss_quantile = st.sidebar.slider(
240
  "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
241
  )
@@ -250,7 +252,7 @@ if __name__ == "__main__":
250
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
251
  #uncomment the next two lines to run dynamically and not from file
252
  #commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
253
- commontokens = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet')
254
  with st.expander("How to read the table:"):
255
  st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
256
  st.write(commontokens)
@@ -260,20 +262,22 @@ if __name__ == "__main__":
260
  num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
261
 
262
  if run_kmeans == 'True':
263
- merged = kmeans(data_df,num_clusters=num_clusters)
 
264
  with lcol:
265
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
266
- dataframe=pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
 
 
 
 
 
267
  #uncomment the next next line to run dynamically and not from file
268
  # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
269
  # by=['loss'], ascending=False)
270
  # table_html = dataframe.to_html(
271
  # columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
272
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
273
- with st.expander("How to read the table:"):
274
- st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
275
- st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
276
- st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
277
  st.write(dataframe,width=900, height=300)
278
 
279
  quant_panel(merged)
 
198
  # return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category)
199
 
200
  def populate_session(dataset,model):
201
+ data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
202
  if model == 'albert-base-v2-yelp-polarity':
203
  tokenizer = AutoTokenizer.from_pretrained('textattack/'+model)
204
  else:
 
208
  if "selected_slice" not in st.session_state:
209
  st.session_state["selected_slice"] = None
210
 
211
+ @st.cache(ttl=600)
212
+ def read_file_to_df(file):
213
+ return pd.read_parquet(file)
214
 
215
  if __name__ == "__main__":
216
  ### STREAMLIT APP CONGFIG ###
 
237
  ### LOAD DATA AND SESSION VARIABLES ###
238
  ##uncomment the next next line to run dynamically and not from file
239
  #populate_session(dataset, model)
240
+ data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet')
241
  loss_quantile = st.sidebar.slider(
242
  "Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
243
  )
 
252
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
253
  #uncomment the next two lines to run dynamically and not from file
254
  #commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
255
+ commontokens = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_commontokens.parquet')
256
  with st.expander("How to read the table:"):
257
  st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
258
  st.write(commontokens)
 
262
  num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
263
 
264
  if run_kmeans == 'True':
265
+ with st.spinner(text='running kmeans...'):
266
+ merged = kmeans(data_df,num_clusters=num_clusters)
267
  with lcol:
268
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
269
+ with st.expander("How to read the table:"):
270
+ st.markdown("* *Error slice* refers to the subset of evaluation dataset the model performs poorly on.")
271
+ st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
272
+ st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
273
+ with st.spinner(text='loading error slice...'):
274
+ dataframe=read_file_to_df('./assets/data/'+dataset+ '_'+ model+'_error-slices.parquet')
275
  #uncomment the next next line to run dynamically and not from file
276
  # dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
277
  # by=['loss'], ascending=False)
278
  # table_html = dataframe.to_html(
279
  # columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
280
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
 
 
 
 
281
  st.write(dataframe,width=900, height=300)
282
 
283
  quant_panel(merged)