K00B404 commited on
Commit
68a5c1a
·
verified ·
1 Parent(s): 795ae71

Update src/subpages/hidden_states.py

Browse files
Files changed (1) hide show
  1. src/subpages/hidden_states.py +39 -89
src/subpages/hidden_states.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements marked by a small black border.
3
- """
4
  import numpy as np
5
  import plotly.express as px
6
  import plotly.graph_objects as go
@@ -9,79 +6,24 @@ import streamlit as st
9
  from src.subpages.page import Context, Page
10
 
11
 
12
- @st.cache
13
- def reduce_dim_svd(X, n_iter: int, random_state=42):
14
- """Dimensionality reduction using truncated SVD (aka LSA).
 
15
 
16
- This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
 
 
17
 
18
- Args:
19
- X: Training data
20
- n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
21
- random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
22
 
23
- Returns:
24
- ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
25
- """
26
- from sklearn.decomposition import TruncatedSVD
27
 
28
- svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
29
- return svd.fit_transform(X)
30
-
31
-
32
- @st.cache
33
- def reduce_dim_pca(X, random_state=42):
34
- """Principal component analysis (PCA).
35
-
36
- Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
37
-
38
- Args:
39
- X: Training data
40
- random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
41
-
42
- Returns:
43
- ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
44
- """
45
- from sklearn.decomposition import PCA
46
-
47
- return PCA(n_components=2, random_state=random_state).fit_transform(X)
48
-
49
-
50
- @st.cache
51
- def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
52
- """Uniform Manifold Approximation and Projection
53
-
54
- Finds a low dimensional embedding of the data that approximates an underlying manifold.
55
-
56
- Args:
57
- X: Training data
58
- n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
59
- min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
60
- metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
61
-
62
- Returns:
63
- ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
64
- """
65
- from umap import UMAP
66
-
67
- return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
68
-
69
-
70
- class HiddenStatesPage(Page):
71
- name = "Hidden States"
72
- icon = "grid-3x3"
73
-
74
- def _get_widget_defaults(self):
75
- return {
76
- "n_tokens": 1_000,
77
- "svd_n_iter": 5,
78
- "svd_random_state": 42,
79
- "umap_n_neighbors": 15,
80
- "umap_metric": "euclidean",
81
- "umap_min_dist": 0.1,
82
- }
83
-
84
- def render(self, context: Context):
85
  st.title("Embeddings")
86
 
87
  with st.expander("💡", expanded=True):
@@ -90,7 +32,6 @@ class HiddenStatesPage(Page):
90
  )
91
 
92
  col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
93
- df = context.df_tokens_merged.copy()
94
  dim_algo = "SVD"
95
  n_tokens = 100
96
 
@@ -100,7 +41,7 @@ class HiddenStatesPage(Page):
100
  "#tokens",
101
  key="n_tokens",
102
  min_value=100,
103
- max_value=len(df["tokens"].unique()),
104
  step=100,
105
  )
106
 
@@ -131,30 +72,30 @@ class HiddenStatesPage(Page):
131
  pass
132
 
133
  with col2:
134
- sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
135
 
136
- X = np.array(df["hidden_states"].tolist())
137
  transformed_hidden_states = None
138
  if dim_algo == "SVD":
139
- transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
140
  elif dim_algo == "PCA":
141
- transformed_hidden_states = reduce_dim_pca(X)
142
  elif dim_algo == "UMAP":
143
- transformed_hidden_states = reduce_dim_umap(
144
  X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
145
  )
146
 
147
  assert isinstance(transformed_hidden_states, np.ndarray)
148
- df["x"] = transformed_hidden_states[:, 0]
149
- df["y"] = transformed_hidden_states[:, 1]
150
- df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
151
- df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
152
- df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
153
- df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
154
- df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
155
- df["disagreements"] = df["labels"] != df["preds"]
156
-
157
- subset = df[:n_tokens]
158
  disagreements_trace = go.Scatter(
159
  x=subset[subset["disagreements"]]["x"],
160
  y=subset[subset["disagreements"]]["y"],
@@ -192,3 +133,12 @@ class HiddenStatesPage(Page):
192
  )
193
  fig.add_trace(disagreements_trace)
194
  st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import plotly.express as px
3
  import plotly.graph_objects as go
 
6
  from src.subpages.page import Context, Page
7
 
8
 
9
+ class HiddenStatesVisualizer:
10
+ def __init__(self, context: Context):
11
+ self.context = context
12
+ self.df = context.df_tokens_merged.copy()
13
 
14
+ def _reduce_dim_svd(self, X, n_iter: int, random_state=42):
15
+ # Implement your SVD reduction here
16
+ pass
17
 
18
+ def _reduce_dim_pca(self, X, random_state=42):
19
+ # Implement your PCA reduction here
20
+ pass
 
21
 
22
+ def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
23
+ # Implement your UMAP reduction here
24
+ pass
 
25
 
26
+ def visualize_hidden_states(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  st.title("Embeddings")
28
 
29
  with st.expander("💡", expanded=True):
 
32
  )
33
 
34
  col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
 
35
  dim_algo = "SVD"
36
  n_tokens = 100
37
 
 
41
  "#tokens",
42
  key="n_tokens",
43
  min_value=100,
44
+ max_value=len(self.df["tokens"].unique()),
45
  step=100,
46
  )
47
 
 
72
  pass
73
 
74
  with col2:
75
+ sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
76
 
77
+ X = np.array(self.df["hidden_states"].tolist())
78
  transformed_hidden_states = None
79
  if dim_algo == "SVD":
80
+ transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
81
  elif dim_algo == "PCA":
82
+ transformed_hidden_states = self._reduce_dim_pca(X)
83
  elif dim_algo == "UMAP":
84
+ transformed_hidden_states = self._reduce_dim_umap(
85
  X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
86
  )
87
 
88
  assert isinstance(transformed_hidden_states, np.ndarray)
89
+ self.df["x"] = transformed_hidden_states[:, 0]
90
+ self.df["y"] = transformed_hidden_states[:, 1]
91
+ self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
92
+ self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
93
+ self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
94
+ self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
95
+ self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
96
+ self.df["disagreements"] = self.df["labels"] != self.df["preds"]
97
+
98
+ subset = self.df[:n_tokens]
99
  disagreements_trace = go.Scatter(
100
  x=subset[subset["disagreements"]]["x"],
101
  y=subset[subset["disagreements"]]["y"],
 
133
  )
134
  fig.add_trace(disagreements_trace)
135
  st.plotly_chart(fig)
136
+
137
+
138
+ class HiddenStatesPage(Page):
139
+ name = "Hidden States"
140
+ icon = "grid-3x3"
141
+
142
+ def render(self, context: Context):
143
+ visualizer = HiddenStatesVisualizer(context)
144
+ visualizer.visualize_hidden_states()