charlieoneill commited on
Commit
63a794c
1 Parent(s): b481357

Create topk_sae.py

Browse files
Files changed (1) hide show
  1. topk_sae.py +261 -0
topk_sae.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import numpy as np
6
+ from torch.utils.data import DataLoader, TensorDataset
7
+ from tqdm import tqdm
8
+ import wandb
9
+ import os
10
+ import glob
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ class FastAutoencoder(nn.Module):
15
+ def __init__(self, n_dirs: int, d_model: int, k: int, auxk: int, multik: int, dead_steps_threshold: int = 266):
16
+ super().__init__()
17
+ self.n_dirs = n_dirs
18
+ self.d_model = d_model
19
+ self.k = k
20
+ self.auxk = auxk
21
+ self.multik = multik
22
+ self.dead_steps_threshold = dead_steps_threshold
23
+
24
+ self.encoder = nn.Linear(d_model, n_dirs, bias=False)
25
+ self.decoder = nn.Linear(n_dirs, d_model, bias=False)
26
+
27
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
28
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs))
29
+
30
+ self.stats_last_nonzero = torch.zeros(n_dirs, dtype=torch.long, device=device)
31
+
32
+ def forward(self, x):
33
+ x = x - self.pre_bias
34
+ latents_pre_act = self.encoder(x) + self.latent_bias
35
+
36
+ # Main top-k selection
37
+ topk_values, topk_indices = torch.topk(latents_pre_act, k=self.k, dim=-1)
38
+ topk_values = F.relu(topk_values)
39
+ multik_values, multik_indices = torch.topk(latents_pre_act, k=4*self.k, dim=-1)
40
+ multik_values = F.relu(multik_values)
41
+
42
+ latents = torch.zeros_like(latents_pre_act)
43
+ latents.scatter_(-1, topk_indices, topk_values)
44
+ multik_latents = torch.zeros_like(latents_pre_act)
45
+ multik_latents.scatter_(-1, multik_indices, multik_values)
46
+
47
+ # Update stats_last_nonzero
48
+ self.stats_last_nonzero += 1
49
+ self.stats_last_nonzero.scatter_(0, topk_indices.unique(), 0)
50
+
51
+ recons = self.decoder(latents) + self.pre_bias
52
+ multik_recons = self.decoder(multik_latents) + self.pre_bias
53
+
54
+ # AuxK
55
+ if self.auxk is not None:
56
+ # Create dead latents mask
57
+ dead_mask = (self.stats_last_nonzero > self.dead_steps_threshold).float()
58
+
59
+ # Apply mask to latents_pre_act
60
+ dead_latents_pre_act = latents_pre_act * dead_mask
61
+
62
+ # Select top-k_aux from dead latents
63
+ auxk_values, auxk_indices = torch.topk(dead_latents_pre_act, k=self.auxk, dim=-1)
64
+ auxk_values = F.relu(auxk_values)
65
+
66
+ else:
67
+ auxk_values, auxk_indices = None, None
68
+
69
+ return recons, {
70
+ "topk_indices": topk_indices,
71
+ "topk_values": topk_values,
72
+ "multik_indices": multik_indices,
73
+ "multik_values": multik_values,
74
+ "multik_recons": multik_recons,
75
+ "auxk_indices": auxk_indices,
76
+ "auxk_values": auxk_values,
77
+ "latents_pre_act": latents_pre_act,
78
+ "latents_post_act": latents,
79
+ }
80
+
81
+ def decode_sparse(self, indices, values):
82
+ latents = torch.zeros(self.n_dirs, device=indices.device)
83
+ latents.scatter_(-1, indices, values)
84
+ return self.decoder(latents) + self.pre_bias
85
+
86
+ # def decode_sparse(self, indices, values):
87
+ # latents = torch.zeros(1, self.n_dirs, device=indices.device, dtype=torch.float32)
88
+ # latents.scatter_(-1, indices.unsqueeze(0), values.unsqueeze(0))
89
+ # return self.decoder(latents.squeeze(0)) + self.pre_bias
90
+
91
+ def print_tensor_info(self, tensor, name):
92
+ print(f"{name} - Shape: {tensor.shape}, Dtype: {tensor.dtype}, Device: {tensor.device}")
93
+
94
+ def decode_clamp(self, latents, clamp):
95
+ topk_values, topk_indices = torch.topk(latents, k = 64, dim=-1)
96
+ topk_values = F.relu(topk_values)
97
+ latents = torch.zeros_like(latents)
98
+ latents.scatter_(-1, topk_indices, topk_values)
99
+ # multiply latents by clamp, which is 1D but has has the same size as each latent vector
100
+ latents = latents * clamp
101
+
102
+ return self.decoder(latents) + self.pre_bias
103
+
104
+ def decode_at_k(self, latents, k):
105
+ topk_values, topk_indices = torch.topk(latents, k=k, dim=-1)
106
+ topk_values = F.relu(topk_values)
107
+ latents = torch.zeros_like(latents)
108
+ latents.scatter_(-1, topk_indices, topk_values)
109
+
110
+ return self.decoder(latents) + self.pre_bias
111
+
112
+ def unit_norm_decoder_(autoencoder: FastAutoencoder) -> None:
113
+ with torch.no_grad():
114
+ autoencoder.decoder.weight.div_(autoencoder.decoder.weight.norm(dim=0, keepdim=True))
115
+
116
+ def unit_norm_decoder_grad_adjustment_(autoencoder: FastAutoencoder) -> None:
117
+ if autoencoder.decoder.weight.grad is not None:
118
+ with torch.no_grad():
119
+ proj = torch.sum(autoencoder.decoder.weight * autoencoder.decoder.weight.grad, dim=0, keepdim=True)
120
+ autoencoder.decoder.weight.grad.sub_(proj * autoencoder.decoder.weight)
121
+
122
+ def mse(output, target):
123
+ return F.mse_loss(output, target)
124
+
125
+ def normalized_mse(recon, xs):
126
+ return mse(recon, xs) / mse(xs.mean(dim=0, keepdim=True).expand_as(xs), xs)
127
+
128
+ def loss_fn(ae, x, recons, info, auxk_coef, multik_coef):
129
+ recons_loss = normalized_mse(recons, x)
130
+ recons_loss += multik_coef * normalized_mse(info["multik_recons"], x)
131
+
132
+ if ae.auxk is not None:
133
+ e = x - recons.detach() # reconstruction error
134
+ auxk_latents = torch.zeros_like(info["latents_pre_act"])
135
+ auxk_latents.scatter_(-1, info["auxk_indices"], info["auxk_values"])
136
+ e_hat = ae.decoder(auxk_latents) # reconstruction of error using dead latents
137
+ auxk_loss = normalized_mse(e_hat, e)
138
+ total_loss = recons_loss + auxk_coef * auxk_loss
139
+ else:
140
+ auxk_loss = torch.tensor(0.0, device=device)
141
+ total_loss = recons_loss
142
+
143
+ return total_loss, recons_loss, auxk_loss
144
+
145
+ def init_from_data_(ae, data_sample):
146
+ # set pre_bias to median of data
147
+ ae.pre_bias.data = torch.median(data_sample, dim=0).values
148
+ nn.init.xavier_uniform_(ae.decoder.weight)
149
+
150
+ # decoder is unit norm
151
+ unit_norm_decoder_(ae)
152
+
153
+ # encoder as transpose of decoder
154
+ ae.encoder.weight.data = ae.decoder.weight.t().clone()
155
+
156
+ nn.init.zeros_(ae.latent_bias)
157
+
158
+ def train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=None, save_dir="../models", model_name=""):
159
+ os.makedirs(save_dir, exist_ok=True)
160
+ step = 0
161
+ num_batches = len(train_loader)
162
+ for epoch in range(epochs):
163
+ ae.train()
164
+ total_loss = 0
165
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
166
+ optimizer.zero_grad()
167
+ x = batch[0].to(device)
168
+ recons, info = ae(x)
169
+ loss, recons_loss, auxk_loss = loss_fn(ae, x, recons, info, auxk_coef, multik_coef)
170
+ loss.backward()
171
+ step += 1
172
+
173
+ # calculate proportion of dead latents (not fired in last num_batches = 1 epoch)
174
+ dead_latents_prop = (ae.stats_last_nonzero > num_batches).float().mean().item()
175
+
176
+ wandb.log({
177
+ "total_loss": loss.item(),
178
+ "reconstruction_loss": recons_loss.item(),
179
+ "auxiliary_loss": auxk_loss.item(),
180
+ "dead_latents_proportion": dead_latents_prop,
181
+ "l0_norm": k,
182
+ "step": step
183
+ })
184
+
185
+ unit_norm_decoder_grad_adjustment_(ae)
186
+
187
+ if clip_grad is not None:
188
+ torch.nn.utils.clip_grad_norm_(ae.parameters(), clip_grad)
189
+
190
+ optimizer.step()
191
+ unit_norm_decoder_(ae)
192
+
193
+ total_loss += loss.item()
194
+
195
+ avg_loss = total_loss / len(train_loader)
196
+ print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
197
+
198
+ # Delete previous model saves for this configuration
199
+ for old_model in glob.glob(os.path.join(save_dir, f"{model_name}_epoch_*.pth")):
200
+ os.remove(old_model)
201
+
202
+ # Save new model
203
+ save_path = os.path.join(save_dir, f"{model_name}_epoch_{epoch+1}.pth")
204
+ torch.save(ae.state_dict(), save_path)
205
+ print(f"Model saved to {save_path}")
206
+
207
+ def main():
208
+ d_model = 1536
209
+ n_dirs = 3072 #9216
210
+ k = 64 #64
211
+ auxk = k*2 #256
212
+ multik = 128
213
+ batch_size = 1024
214
+ lr = 1e-4
215
+ auxk_coef = 1/32
216
+ clip_grad = 1.0
217
+ multik_coef = 0 # turn it off
218
+
219
+ csLG = False
220
+
221
+ # Create model name
222
+ model_name = f"{k}_{n_dirs}_{auxk}_auxk" if not csLG else f"{k}_{n_dirs}_{auxk}_auxk_csLG"
223
+ epochs = 50 if not csLG else 137
224
+
225
+ wandb.init(project="saerch", name=model_name, config={
226
+ "n_dirs": n_dirs,
227
+ "d_model": d_model,
228
+ "k": k,
229
+ "auxk": auxk,
230
+ "batch_size": batch_size,
231
+ "lr": lr,
232
+ "epochs": epochs,
233
+ "auxk_coef": auxk_coef,
234
+ "multik_coef": multik_coef,
235
+ "clip_grad": clip_grad,
236
+ "device": device.type
237
+ })
238
+
239
+ if not csLG:
240
+ data = np.load("../data/vector_store_astroPH/abstract_embeddings.npy")
241
+ print("Doing astro.ph...")
242
+ else:
243
+ data = np.load("../data/vector_store_csLG/abstract_embeddings.npy")
244
+ print("Doing csLG...")
245
+ data_tensor = torch.from_numpy(data).float()
246
+ # Print shape
247
+ print(f"Data shape: {data_tensor.shape}")
248
+ dataset = TensorDataset(data_tensor)
249
+ train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
250
+
251
+ ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik).to(device)
252
+ init_from_data_(ae, data_tensor[:10000].to(device))
253
+
254
+ optimizer = optim.Adam(ae.parameters(), lr=lr)
255
+
256
+ train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=clip_grad, model_name=model_name)
257
+
258
+ wandb.finish()
259
+
260
+ if __name__ == "__main__":
261
+ main()