surokpro2 commited on
Commit
c049ca4
1 Parent(s): b2d604f

Update SAE/sae.py

Browse files
Files changed (1) hide show
  1. SAE/sae.py +4 -1
SAE/sae.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import torch.nn as nn
8
  import os
9
  import json
 
10
 
11
  class SparseAutoencoder(nn.Module):
12
  """
@@ -105,6 +106,7 @@ class SparseAutoencoder(nn.Module):
105
  def n_dirs(self):
106
  return self.n_dirs_local
107
 
 
108
  def encode(self, x):
109
  x = x - self.pre_bias
110
  latents_pre_act = self.encoder(x) + self.latent_bias
@@ -120,6 +122,7 @@ class SparseAutoencoder(nn.Module):
120
 
121
  return latents
122
 
 
123
  def forward(self, x):
124
  x = x - self.pre_bias
125
  latents_pre_act = self.encoder(x) + self.latent_bias
@@ -179,7 +182,7 @@ class SparseAutoencoder(nn.Module):
179
  "auxk_vals": auxk_vals,
180
  }
181
 
182
-
183
  def decode_sparse(self, inds, vals):
184
  rows, cols = inds.shape[0], self.n_dirs
185
 
 
7
  import torch.nn as nn
8
  import os
9
  import json
10
+ import spaces
11
 
12
  class SparseAutoencoder(nn.Module):
13
  """
 
106
  def n_dirs(self):
107
  return self.n_dirs_local
108
 
109
+ @spaces.GPU
110
  def encode(self, x):
111
  x = x - self.pre_bias
112
  latents_pre_act = self.encoder(x) + self.latent_bias
 
122
 
123
  return latents
124
 
125
+ @spaces.GPU
126
  def forward(self, x):
127
  x = x - self.pre_bias
128
  latents_pre_act = self.encoder(x) + self.latent_bias
 
182
  "auxk_vals": auxk_vals,
183
  }
184
 
185
+ @spaces.GPU
186
  def decode_sparse(self, inds, vals):
187
  rows, cols = inds.shape[0], self.n_dirs
188