charlieoneill commited on
Commit
10845f0
1 Parent(s): 63a794c

Update topk_sae.py

Browse files
Files changed (1) hide show
  1. topk_sae.py +1 -108
topk_sae.py CHANGED
@@ -1,11 +1,9 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import torch.optim as optim
5
  import numpy as np
6
  from torch.utils.data import DataLoader, TensorDataset
7
  from tqdm import tqdm
8
- import wandb
9
  import os
10
  import glob
11
 
@@ -153,109 +151,4 @@ def init_from_data_(ae, data_sample):
153
  # encoder as transpose of decoder
154
  ae.encoder.weight.data = ae.decoder.weight.t().clone()
155
 
156
- nn.init.zeros_(ae.latent_bias)
157
-
158
- def train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=None, save_dir="../models", model_name=""):
159
- os.makedirs(save_dir, exist_ok=True)
160
- step = 0
161
- num_batches = len(train_loader)
162
- for epoch in range(epochs):
163
- ae.train()
164
- total_loss = 0
165
- for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
166
- optimizer.zero_grad()
167
- x = batch[0].to(device)
168
- recons, info = ae(x)
169
- loss, recons_loss, auxk_loss = loss_fn(ae, x, recons, info, auxk_coef, multik_coef)
170
- loss.backward()
171
- step += 1
172
-
173
- # calculate proportion of dead latents (not fired in last num_batches = 1 epoch)
174
- dead_latents_prop = (ae.stats_last_nonzero > num_batches).float().mean().item()
175
-
176
- wandb.log({
177
- "total_loss": loss.item(),
178
- "reconstruction_loss": recons_loss.item(),
179
- "auxiliary_loss": auxk_loss.item(),
180
- "dead_latents_proportion": dead_latents_prop,
181
- "l0_norm": k,
182
- "step": step
183
- })
184
-
185
- unit_norm_decoder_grad_adjustment_(ae)
186
-
187
- if clip_grad is not None:
188
- torch.nn.utils.clip_grad_norm_(ae.parameters(), clip_grad)
189
-
190
- optimizer.step()
191
- unit_norm_decoder_(ae)
192
-
193
- total_loss += loss.item()
194
-
195
- avg_loss = total_loss / len(train_loader)
196
- print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
197
-
198
- # Delete previous model saves for this configuration
199
- for old_model in glob.glob(os.path.join(save_dir, f"{model_name}_epoch_*.pth")):
200
- os.remove(old_model)
201
-
202
- # Save new model
203
- save_path = os.path.join(save_dir, f"{model_name}_epoch_{epoch+1}.pth")
204
- torch.save(ae.state_dict(), save_path)
205
- print(f"Model saved to {save_path}")
206
-
207
- def main():
208
- d_model = 1536
209
- n_dirs = 3072 #9216
210
- k = 64 #64
211
- auxk = k*2 #256
212
- multik = 128
213
- batch_size = 1024
214
- lr = 1e-4
215
- auxk_coef = 1/32
216
- clip_grad = 1.0
217
- multik_coef = 0 # turn it off
218
-
219
- csLG = False
220
-
221
- # Create model name
222
- model_name = f"{k}_{n_dirs}_{auxk}_auxk" if not csLG else f"{k}_{n_dirs}_{auxk}_auxk_csLG"
223
- epochs = 50 if not csLG else 137
224
-
225
- wandb.init(project="saerch", name=model_name, config={
226
- "n_dirs": n_dirs,
227
- "d_model": d_model,
228
- "k": k,
229
- "auxk": auxk,
230
- "batch_size": batch_size,
231
- "lr": lr,
232
- "epochs": epochs,
233
- "auxk_coef": auxk_coef,
234
- "multik_coef": multik_coef,
235
- "clip_grad": clip_grad,
236
- "device": device.type
237
- })
238
-
239
- if not csLG:
240
- data = np.load("../data/vector_store_astroPH/abstract_embeddings.npy")
241
- print("Doing astro.ph...")
242
- else:
243
- data = np.load("../data/vector_store_csLG/abstract_embeddings.npy")
244
- print("Doing csLG...")
245
- data_tensor = torch.from_numpy(data).float()
246
- # Print shape
247
- print(f"Data shape: {data_tensor.shape}")
248
- dataset = TensorDataset(data_tensor)
249
- train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
250
-
251
- ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik).to(device)
252
- init_from_data_(ae, data_tensor[:10000].to(device))
253
-
254
- optimizer = optim.Adam(ae.parameters(), lr=lr)
255
-
256
- train(ae, train_loader, optimizer, epochs, k, auxk_coef, multik_coef, clip_grad=clip_grad, model_name=model_name)
257
-
258
- wandb.finish()
259
-
260
- if __name__ == "__main__":
261
- main()
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
  import numpy as np
5
  from torch.utils.data import DataLoader, TensorDataset
6
  from tqdm import tqdm
 
7
  import os
8
  import glob
9
 
 
151
  # encoder as transpose of decoder
152
  ae.encoder.weight.data = ae.decoder.weight.t().clone()
153
 
154
+ nn.init.zeros_(ae.latent_bias)