nazneen commited on
Commit
f760ec3
1 Parent(s): 295dbc0

added instr

Browse files
app.py CHANGED
@@ -1,5 +1,5 @@
1
- ### LIBRARIES ###
2
- # # Data
3
  import numpy as np
4
  import pandas as pd
5
  import torch
@@ -62,16 +62,8 @@ def down_samp(embedding):
62
 
63
 
64
  def data_comparison(df):
65
- # set up a dropdown select bindinf
66
- # input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
67
- #data_kmeans['distance_from_centroid'] = data_kmeans.apply(distance_from_centroid, axis=1)
68
-
69
  selection = alt.selection_multi(fields=['cluster','label'])
70
- color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.tolist())), alt.value("lightgray"))
71
- # color = alt.condition(selection,
72
- # alt.Color('cluster:Q', legend=None),
73
- # # scale = alt.Scale(domain = pop_domain,range=color_range)),
74
- # alt.value('lightgray'))
75
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
76
 
77
  # basic chart
@@ -97,7 +89,7 @@ def data_comparison(df):
97
  selection
98
  )
99
 
100
- layered = legend | scatter
101
 
102
  layered = layered.configure_axis(
103
  grid=False
@@ -112,14 +104,12 @@ def quant_panel(embedding_df):
112
  """ Quantitative Panel Layout"""
113
 
114
  all_metrics = {}
115
- # st.warning("**Data Comparison**")
116
-
117
- # with st.expander("how to read this chart:"):
118
- # st.markdown("* each **point** is a single sentence")
119
- # st.markdown("* the **position** of each dot is determined mathematically based upon an analysis of the words in a sentence. The **closer** two points on the visualization the **more similar** the sentences are. The **further apart ** two points on the visualization the **more different** the sentences are")
120
- # st.markdown(
121
- # " * the **shape** of each point reflects whether it a positive (diamond) or negative sentiment (circle)")
122
- # st.markdown("* the **color** of each point is the ")
123
  st.altair_chart(data_comparison(down_samp(embedding_df)))
124
 
125
  def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
@@ -246,7 +236,7 @@ if __name__ == "__main__":
246
  )
247
 
248
  loss_quantile = st.sidebar.slider(
249
- "Loss Quantile", min_value=0.0, max_value=1.0,step=0.1,value=0.95
250
  )
251
 
252
  run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
@@ -280,15 +270,16 @@ if __name__ == "__main__":
280
  table_html = dataframe.to_html(
281
  columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
282
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
283
- st.write(dataframe)
284
- # st_aggrid.AgGrid(dataframe)
285
- # table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
286
- # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
287
- # st.write(table_html)
288
 
289
  with rcol:
290
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
291
  commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
 
 
292
  st.write(commontokens)
293
 
294
  quant_panel(merged)
 
1
+ ## LIBRARIES ###
2
+ ## Data
3
  import numpy as np
4
  import pandas as pd
5
  import torch
 
62
 
63
 
64
  def data_comparison(df):
 
 
 
 
65
  selection = alt.selection_multi(fields=['cluster','label'])
66
+ color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray"))
 
 
 
 
67
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
68
 
69
  # basic chart
 
89
  selection
90
  )
91
 
92
+ layered = scatter | legend
93
 
94
  layered = layered.configure_axis(
95
  grid=False
 
104
  """ Quantitative Panel Layout"""
105
 
106
  all_metrics = {}
107
+ st.warning("**Error slice visualization**")
108
+
109
+ with st.expander("How to read this chart:"):
110
+ st.markdown("* Each **point** is an input example.")
111
+ st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.")
112
+ st.markdown("* The **shape** of each point reflects the label category -- positive (diamond) or negative sentiment (circle).")
 
 
113
  st.altair_chart(data_comparison(down_samp(embedding_df)))
114
 
115
  def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005):
 
236
  )
237
 
238
  loss_quantile = st.sidebar.slider(
239
+ "Loss Quantile", min_value=0.0, max_value=1.0,step=0.01,value=0.95
240
  )
241
 
242
  run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
 
270
  table_html = dataframe.to_html(
271
  columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
272
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
273
+ with st.expander("How to read the table:"):
274
+ st.markdown("* The table displays model error slices on the test set, sorted by loss.")
275
+ st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
276
+ st.write(dataframe,width=900, height=300)
 
277
 
278
  with rcol:
279
  st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
280
  commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
281
+ with st.expander("How to read the table:"):
282
+ st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
283
  st.write(commontokens)
284
 
285
  quant_panel(merged)
error_analysis/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (204 Bytes). View file
 
error_analysis/utils/__pycache__/style_hacks.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
error_analysis/utils/style_hacks.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- streamlit style hacks
3
  """
4
  import streamlit as st
5
 
@@ -10,12 +10,13 @@ def init_style():
10
  <style>
11
  /* Side Bar */
12
  [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
13
- width: 225px;
14
- margin-left: -500px;
15
  }
16
  [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
17
- width: 225px;
18
- margin-left: -500px;
 
 
19
  }
20
  .css-1outpf7 {
21
  background-color:rgb(254 244 219);
@@ -23,11 +24,7 @@ def init_style():
23
  padding:10px 10px 10px 10px;
24
  }
25
 
26
- /* Main Panel*/
27
- [data-testid="stVerticalBlock"]{
28
- margin-left: -200px;
29
- padding:10px 10px 10px -200px;
30
- }
31
  .css-18e3th9 {
32
  padding:10px 10px 10px -200px;
33
  }
 
1
  """
2
+ placeholder for all streamlit style hacks
3
  """
4
  import streamlit as st
5
 
 
10
  <style>
11
  /* Side Bar */
12
  [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
13
+ width: 300px;
 
14
  }
15
  [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
16
+ width: 300px;
17
+ }
18
+ [data-testid="stSidebar"]{
19
+ flex-basis: unset;
20
  }
21
  .css-1outpf7 {
22
  background-color:rgb(254 244 219);
 
24
  padding:10px 10px 10px 10px;
25
  }
26
 
27
+ /* Main Panel*/
 
 
 
 
28
  .css-18e3th9 {
29
  padding:10px 10px 10px -200px;
30
  }