nazneen commited on
Commit
08090c3
1 Parent(s): 609167a

k means clustering

Browse files
Files changed (1) hide show
  1. app.py +78 -48
app.py CHANGED
@@ -1,6 +1,5 @@
1
  ### LIBRARIES ###
2
  # # Data
3
- from matplotlib.pyplot import legend
4
  import numpy as np
5
  import pandas as pd
6
  import torch
@@ -10,11 +9,15 @@ from math import floor
10
  from datasets import load_dataset
11
  from collections import defaultdict
12
  from transformers import AutoTokenizer
 
13
 
14
  # Analysis
15
  # from gensim.models.doc2vec import Doc2Vec
16
  # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
17
- # import nltk
 
 
 
18
  # nltk.download('punkt') #make sure that punkt is downloaded
19
 
20
  # App & Visualization
@@ -23,11 +26,11 @@ import altair as alt
23
  import plotly.graph_objects as go
24
  from streamlit_vega_lite import altair_component
25
 
 
 
26
  # utils
27
  from random import sample
28
- from error_analysis import utils as ut
29
- import os
30
-
31
 
32
 
33
  def down_samp(embedding):
@@ -61,12 +64,14 @@ def down_samp(embedding):
61
  def data_comparison(df):
62
  # set up a dropdown select bindinf
63
  # input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
64
- selection = alt.selection_multi(fields=['slice','label'])
65
- color = alt.condition(alt.datum.slice == 'high-loss', alt.value("orange"), alt.value("steelblue"))
 
 
66
  # color = alt.condition(selection,
67
- # alt.Color('slice:Q', legend=None),
68
- # # scale = alt.Scale(domain = pop_domain,range=color_range)),
69
- # alt.value('lightgray'))
70
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
71
 
72
  # basic chart
@@ -75,7 +80,7 @@ def data_comparison(df):
75
  y=alt.Y('y', axis=None),
76
  color=color,
77
  shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
78
- tooltip=['slice','content','label','pred'],
79
  opacity=opacity
80
  ).properties(
81
  width=1500,
@@ -83,28 +88,21 @@ def data_comparison(df):
83
  ).interactive()
84
 
85
  legend = alt.Chart(df).mark_point().encode(
86
- y=alt.Y('slice:N', axis=alt.Axis(orient='right'), title="",),
87
  x=alt.X("label"),
88
  shape=alt.Shape('label', scale=alt.Scale(
89
- range=['circle', 'diamond']), legend=None),
90
- color=color
91
  ).add_selection(
92
  selection
93
  )
94
-
95
- layered = legend | scatter
96
 
97
  layered = layered.configure_axis(
98
  grid=False
99
  ).configure_view(
100
  strokeOpacity=0
101
- ).configure_legend(
102
- strokeColor='gray',
103
- fillColor='#EEEEEE',
104
- padding=10,
105
- cornerRadius=10,
106
- orient='top-right'
107
-
108
  )
109
 
110
  return layered
@@ -166,7 +164,36 @@ def get_data(spotlight, emb):
166
  return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'],
167
  dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
 
 
 
 
170
  def topic_distribution(weights, smoothing=0.01):
171
  topic_frequencies = defaultdict(float)
172
  topic_frequencies_spotlight = defaultdict(float)
@@ -196,15 +223,10 @@ def topic_distribution(weights, smoothing=0.01):
196
 
197
  if __name__ == "__main__":
198
  ### STREAMLIT APP CONGFIG ###
199
- os.system("pip --ignore-installed streamlit ")
200
  st.set_page_config(layout="wide", page_title="Error Slice Analysis")
201
 
202
- ut.init_style()
203
-
204
- lcol, rcol = st.columns([2, 3])
205
  # ******* loading the mode and the data
206
- with st.sidebar:
207
- st.title('Error Analysis')
208
  dataset = st.sidebar.selectbox(
209
  "Dataset",
210
  ["amazon_polarity", "squad", "movielens", "waterbirds"],
@@ -221,15 +243,19 @@ if __name__ == "__main__":
221
  index=0
222
  )
223
 
224
- loss_quantile = st.sidebar.selectbox(
225
- "Loss Quantile",
226
- [0.98, 0.95, 0.9, 0.8, 0.75],
227
- index = 1
228
  )
 
 
 
 
 
229
  ### LOAD DATA AND SESSION VARIABLES ###
230
- data_df = pd.read_parquet('./assets/data/amazon_polarity.test.parquet')
231
- data_df.reset_index(drop=True, inplace=True)
232
- embedding_umap = data_df[['x','y']]
 
233
  if "user_data" not in st.session_state:
234
  st.session_state["user_data"] = data_df
235
  if "selected_slice" not in st.session_state:
@@ -237,26 +263,30 @@ if __name__ == "__main__":
237
  if "embedding" not in st.session_state:
238
  st.session_state["embedding"] = embedding_umap
239
 
 
 
 
 
 
 
 
 
240
  with lcol:
241
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
242
- dataframe = data_df[['content', 'label', 'pred', 'loss']].sort_values(
243
  by=['loss'], ascending=False)
244
  table_html = dataframe.to_html(
245
- columns=['content', 'label', 'pred', 'loss'], max_rows=100)
246
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
247
  st.write(dataframe)
248
- st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
249
- commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
250
- st.write(commontokens)
251
  # st_aggrid.AgGrid(dataframe)
252
  # table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
253
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
254
  # st.write(table_html)
255
 
256
- with rcol:
257
- data_df['loss'] = data_df['loss'].astype(float)
258
- losses = data_df['loss']
259
- high_loss = losses.quantile(loss_quantile)
260
- data_df['slice'] = 'high-loss'
261
- data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
262
- quant_panel(data_df)
 
1
  ### LIBRARIES ###
2
  # # Data
 
3
  import numpy as np
4
  import pandas as pd
5
  import torch
 
9
  from datasets import load_dataset
10
  from collections import defaultdict
11
  from transformers import AutoTokenizer
12
+ pd.options.display.float_format = '${:,.2f}'.format
13
 
14
  # Analysis
15
  # from gensim.models.doc2vec import Doc2Vec
16
  # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
17
+ import nltk
18
+ from nltk.cluster import KMeansClusterer
19
+ import scipy.spatial.distance as sdist
20
+ from scipy.spatial import distance_matrix
21
  # nltk.download('punkt') #make sure that punkt is downloaded
22
 
23
  # App & Visualization
 
26
  import plotly.graph_objects as go
27
  from streamlit_vega_lite import altair_component
28
 
29
+
30
+
31
  # utils
32
  from random import sample
33
+ # from PIL import Image
 
 
34
 
35
 
36
  def down_samp(embedding):
 
64
  def data_comparison(df):
65
  # set up a dropdown select bindinf
66
  # input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
67
+ #data_kmeans['distance_from_centroid'] = data_kmeans.apply(distance_from_centroid, axis=1)
68
+
69
+ selection = alt.selection_multi(fields=['cluster','label'])
70
+ color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.tolist())), alt.value("lightgray"))
71
  # color = alt.condition(selection,
72
+ # alt.Color('cluster:Q', legend=None),
73
+ # # scale = alt.Scale(domain = pop_domain,range=color_range)),
74
+ # alt.value('lightgray'))
75
  opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25))
76
 
77
  # basic chart
 
80
  y=alt.Y('y', axis=None),
81
  color=color,
82
  shape=alt.Shape('label', scale=alt.Scale(range=['circle', 'diamond'])),
83
+ tooltip=['cluster','slice','content','label','pred'],
84
  opacity=opacity
85
  ).properties(
86
  width=1500,
 
88
  ).interactive()
89
 
90
  legend = alt.Chart(df).mark_point().encode(
91
+ y=alt.Y('cluster:O', axis=alt.Axis(orient='right'), title=""),
92
  x=alt.X("label"),
93
  shape=alt.Shape('label', scale=alt.Scale(
94
+ range=['circle', 'diamond']), legend=None),
95
+ color=color,
96
  ).add_selection(
97
  selection
98
  )
99
+
100
+ layered = scatter |legend
101
 
102
  layered = layered.configure_axis(
103
  grid=False
104
  ).configure_view(
105
  strokeOpacity=0
 
 
 
 
 
 
 
106
  )
107
 
108
  return layered
 
164
  return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'],
165
  dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1)
166
 
167
+ @st.cache(ttl=600)
168
+ def clustering(data,num_clusters):
169
+
170
+ X = np.array(data['embedding'].tolist())
171
+
172
+ kclusterer = KMeansClusterer(
173
+ num_clusters, distance=nltk.cluster.util.cosine_distance,
174
+ repeats=25,avoid_empty_clusters=True)
175
+
176
+ assigned_clusters = kclusterer.cluster(X, assign_clusters=True)
177
+
178
+ data['cluster'] = pd.Series(assigned_clusters, index=data.index).astype('int')
179
+ data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x])
180
+
181
+
182
+ return data, assigned_clusters
183
+
184
+ def kmeans(df, num_clusters=3):
185
+ data_hl = df.loc[df['slice'] == 'high-loss']
186
+ data_kmeans,clusters = clustering(data_hl,num_clusters)
187
+ merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y'))
188
+ merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True)
189
+ merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int')
190
+ return merged
191
 
192
+ @st.cache(ttl=600)
193
+ def distance_from_centroid(row):
194
+ return sdist.norm(row['embedding'] - row['centroid'].tolist())
195
+
196
+ @st.cache(ttl=600)
197
  def topic_distribution(weights, smoothing=0.01):
198
  topic_frequencies = defaultdict(float)
199
  topic_frequencies_spotlight = defaultdict(float)
 
223
 
224
  if __name__ == "__main__":
225
  ### STREAMLIT APP CONGFIG ###
 
226
  st.set_page_config(layout="wide", page_title="Error Slice Analysis")
227
 
228
+ lcol, rcol = st.columns([2, 2])
 
 
229
  # ******* loading the mode and the data
 
 
230
  dataset = st.sidebar.selectbox(
231
  "Dataset",
232
  ["amazon_polarity", "squad", "movielens", "waterbirds"],
 
243
  index=0
244
  )
245
 
246
+ loss_quantile = st.sidebar.slider(
247
+ "Loss Quantile", min_value=0.0, max_value=1.0,step=0.1,value=0.95
 
 
248
  )
249
+
250
+ run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
251
+
252
+ num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
253
+
254
  ### LOAD DATA AND SESSION VARIABLES ###
255
+ data = pd.read_parquet('./assets/data/amazon_polarity.test.parquet')
256
+ embedding_umap = data[['x','y']]
257
+ emb_df = pd.read_parquet('./assets/data/amazon_test_emb.parquet')
258
+ data_df = pd.DataFrame([data['content'], data['label'], data['pred'], data['loss'], emb_df['embedding'], data['x'], data['y']]).transpose()
259
  if "user_data" not in st.session_state:
260
  st.session_state["user_data"] = data_df
261
  if "selected_slice" not in st.session_state:
 
263
  if "embedding" not in st.session_state:
264
  st.session_state["embedding"] = embedding_umap
265
 
266
+ data_df['loss'] = data_df['loss'].astype(float)
267
+ losses = data_df['loss']
268
+ high_loss = losses.quantile(loss_quantile)
269
+ data_df['slice'] = 'high-loss'
270
+ data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
271
+
272
+ if run_kmeans == 'True':
273
+ merged = kmeans(data_df,num_clusters=num_clusters)
274
  with lcol:
275
  st.markdown('<h3>Error Slices</h3>',unsafe_allow_html=True)
276
+ dataframe = merged[['content', 'label', 'pred', 'loss', 'cluster']].sort_values(
277
  by=['loss'], ascending=False)
278
  table_html = dataframe.to_html(
279
+ columns=['content', 'label', 'pred', 'loss', 'cluster'], max_rows=50)
280
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
281
  st.write(dataframe)
 
 
 
282
  # st_aggrid.AgGrid(dataframe)
283
  # table_html = dataframe.to_html(columns=['content', 'label', 'pred', 'loss'], max_rows=100)
284
  # table_html = table_html.replace("<th>", '<th align="left">') # left-align the headers
285
  # st.write(table_html)
286
 
287
+ with rcol:
288
+ st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
289
+ commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
290
+ st.write(commontokens)
291
+
292
+ quant_panel(merged)