Spaces:
Running
on
Zero
Running
on
Zero
Update SAE/sae.py
Browse files- SAE/sae.py +4 -1
SAE/sae.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
import os
|
9 |
import json
|
|
|
10 |
|
11 |
class SparseAutoencoder(nn.Module):
|
12 |
"""
|
@@ -105,6 +106,7 @@ class SparseAutoencoder(nn.Module):
|
|
105 |
def n_dirs(self):
|
106 |
return self.n_dirs_local
|
107 |
|
|
|
108 |
def encode(self, x):
|
109 |
x = x - self.pre_bias
|
110 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
@@ -120,6 +122,7 @@ class SparseAutoencoder(nn.Module):
|
|
120 |
|
121 |
return latents
|
122 |
|
|
|
123 |
def forward(self, x):
|
124 |
x = x - self.pre_bias
|
125 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
@@ -179,7 +182,7 @@ class SparseAutoencoder(nn.Module):
|
|
179 |
"auxk_vals": auxk_vals,
|
180 |
}
|
181 |
|
182 |
-
|
183 |
def decode_sparse(self, inds, vals):
|
184 |
rows, cols = inds.shape[0], self.n_dirs
|
185 |
|
|
|
7 |
import torch.nn as nn
|
8 |
import os
|
9 |
import json
|
10 |
+
import spaces
|
11 |
|
12 |
class SparseAutoencoder(nn.Module):
|
13 |
"""
|
|
|
106 |
def n_dirs(self):
|
107 |
return self.n_dirs_local
|
108 |
|
109 |
+
@spaces.GPU
|
110 |
def encode(self, x):
|
111 |
x = x - self.pre_bias
|
112 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
|
122 |
|
123 |
return latents
|
124 |
|
125 |
+
@spaces.GPU
|
126 |
def forward(self, x):
|
127 |
x = x - self.pre_bias
|
128 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
|
182 |
"auxk_vals": auxk_vals,
|
183 |
}
|
184 |
|
185 |
+
@spaces.GPU
|
186 |
def decode_sparse(self, inds, vals):
|
187 |
rows, cols = inds.shape[0], self.n_dirs
|
188 |
|