namgoodfire
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -45,237 +45,16 @@ and even steer its behavior. You can explore the SDK documentation at [docs.good
|
|
45 |
|
46 |
## How to use
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
class SparseAutoEncoder(torch.nn.Module):
|
60 |
-
def __init__(
|
61 |
-
self,
|
62 |
-
d_in: int,
|
63 |
-
d_hidden: int,
|
64 |
-
device: torch.device,
|
65 |
-
dtype: torch.dtype = torch.bfloat16,
|
66 |
-
):
|
67 |
-
super().__init__()
|
68 |
-
self.d_in = d_in
|
69 |
-
self.d_hidden = d_hidden
|
70 |
-
self.device = device
|
71 |
-
self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
|
72 |
-
self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
|
73 |
-
self.dtype = dtype
|
74 |
-
self.to(self.device, self.dtype)
|
75 |
-
|
76 |
-
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
-
"""Encode a batch of data using a linear, followed by a ReLU."""
|
78 |
-
return torch.nn.functional.relu(self.encoder_linear(x))
|
79 |
-
|
80 |
-
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
81 |
-
"""Decode a batch of data using a linear."""
|
82 |
-
return self.decoder_linear(x)
|
83 |
-
|
84 |
-
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
85 |
-
"""SAE forward pass. Returns the reconstruction and the encoded features."""
|
86 |
-
f = self.encode(x)
|
87 |
-
return self.decode(f), f
|
88 |
-
|
89 |
-
|
90 |
-
def load_sae(
|
91 |
-
path: str,
|
92 |
-
d_model: int,
|
93 |
-
expansion_factor: int,
|
94 |
-
device: torch.device = torch.device("cpu"),
|
95 |
-
):
|
96 |
-
sae = SparseAutoEncoder(
|
97 |
-
d_model,
|
98 |
-
d_model * expansion_factor,
|
99 |
-
device,
|
100 |
-
)
|
101 |
-
sae_dict = torch.load(
|
102 |
-
path, weights_only=True, map_location=device
|
103 |
-
)
|
104 |
-
sae.load_state_dict(sae_dict)
|
105 |
-
|
106 |
-
return sae
|
107 |
-
|
108 |
-
|
109 |
-
# Lanngugae model
|
110 |
-
|
111 |
-
|
112 |
-
InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
|
113 |
-
|
114 |
-
|
115 |
-
class ObservableLanguageModel:
|
116 |
-
def __init__(
|
117 |
-
self,
|
118 |
-
model: str,
|
119 |
-
device: str = "cuda",
|
120 |
-
dtype: torch.dtype = torch.bfloat16,
|
121 |
-
):
|
122 |
-
self.dtype = dtype
|
123 |
-
self.device = device
|
124 |
-
self._original_model = model
|
125 |
-
|
126 |
-
|
127 |
-
self._model = nnsight.LanguageModel(
|
128 |
-
self._original_model,
|
129 |
-
device_map=device,
|
130 |
-
torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
131 |
-
)
|
132 |
-
|
133 |
-
self.tokenizer = self._model.tokenizer
|
134 |
-
|
135 |
-
self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
|
136 |
-
|
137 |
-
self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
|
138 |
-
|
139 |
-
def _attempt_to_infer_hidden_layer_dimensions(self):
|
140 |
-
config = self._model.config
|
141 |
-
if hasattr(config, "hidden_size"):
|
142 |
-
return int(config.hidden_size)
|
143 |
-
|
144 |
-
raise Exception(
|
145 |
-
"Could not infer hidden number of layer dimensions from model config"
|
146 |
-
)
|
147 |
-
|
148 |
-
def _find_module(self, hook_point: str):
|
149 |
-
submodules = hook_point.split(".")
|
150 |
-
module = self._model
|
151 |
-
while submodules:
|
152 |
-
module = getattr(module, submodules.pop(0))
|
153 |
-
return module
|
154 |
-
|
155 |
-
def forward(
|
156 |
-
self,
|
157 |
-
inputs: torch.Tensor,
|
158 |
-
cache_activations_at: Optional[list[str]] = None,
|
159 |
-
interventions: Optional[dict[str, InterventionInterface]] = None,
|
160 |
-
use_cache: bool = True,
|
161 |
-
past_key_values: Optional[tuple[torch.Tensor]] = None,
|
162 |
-
) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
|
163 |
-
cache: dict[str, torch.Tensor] = {}
|
164 |
-
with self._model.trace(
|
165 |
-
inputs,
|
166 |
-
scan=self.safe_mode,
|
167 |
-
validate=self.safe_mode,
|
168 |
-
use_cache=use_cache,
|
169 |
-
past_key_values=past_key_values,
|
170 |
-
):
|
171 |
-
# If we input an intervention
|
172 |
-
if interventions:
|
173 |
-
for hook_site in interventions.keys():
|
174 |
-
if interventions[hook_site] is None:
|
175 |
-
continue
|
176 |
-
|
177 |
-
module = self._find_module(hook_site)
|
178 |
-
|
179 |
-
if self.cleanup_intervention_layer:
|
180 |
-
last_layer = self._find_module(
|
181 |
-
self.cleanup_intervention_layer
|
182 |
-
)
|
183 |
-
else:
|
184 |
-
last_layer = None
|
185 |
-
|
186 |
-
intervened_acts, direct_effect_tensor = interventions[
|
187 |
-
hook_site
|
188 |
-
](module.output[0])
|
189 |
-
# Add direct effect tensor as 0 if it is None
|
190 |
-
if direct_effect_tensor is None:
|
191 |
-
direct_effect_tensor = 0
|
192 |
-
# We only modify module.output[0]
|
193 |
-
if use_cache:
|
194 |
-
module.output = (
|
195 |
-
intervened_acts,
|
196 |
-
module.output[1],
|
197 |
-
)
|
198 |
-
if last_layer:
|
199 |
-
last_layer.output = (
|
200 |
-
last_layer.output[0] - direct_effect_tensor,
|
201 |
-
last_layer.output[1],
|
202 |
-
)
|
203 |
-
else:
|
204 |
-
module.output = (intervened_acts,)
|
205 |
-
if last_layer:
|
206 |
-
last_layer.output = (
|
207 |
-
last_layer.output[0] - direct_effect_tensor,
|
208 |
-
)
|
209 |
-
|
210 |
-
if cache_activations_at is not None:
|
211 |
-
for hook_point in cache_activations_at:
|
212 |
-
module = self._find_module(hook_point)
|
213 |
-
cache[hook_point] = module.output.save()
|
214 |
-
|
215 |
-
if not past_key_values:
|
216 |
-
logits = self._model.output[0][:, -1, :].save()
|
217 |
-
else:
|
218 |
-
logits = self._model.output[0].squeeze(1).save()
|
219 |
-
|
220 |
-
kv_cache = self._model.output.past_key_values.save()
|
221 |
-
|
222 |
-
return (
|
223 |
-
logits.value.detach(),
|
224 |
-
kv_cache.value,
|
225 |
-
{k: v[0].detach() for k, v in cache.items()},
|
226 |
-
)
|
227 |
-
|
228 |
-
|
229 |
-
# Reading out features from the model
|
230 |
-
|
231 |
-
llama_3_1_8b = ObservableLanguageModel(
|
232 |
-
"meta-llama/Llama-3.1-8B-Instruct",
|
233 |
-
)
|
234 |
-
|
235 |
-
input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
|
236 |
-
[
|
237 |
-
{"role": "user", "content": "Hello, how are you?"},
|
238 |
-
],
|
239 |
-
return_tensors="pt",
|
240 |
-
)
|
241 |
-
logits, kv_cache, features = llama_3_1_8b.forward(
|
242 |
-
input_tokens,
|
243 |
-
cache_activations_at=["model.layers.19"],
|
244 |
-
)
|
245 |
-
|
246 |
-
print(features["model.layers.19"].shape)
|
247 |
-
|
248 |
-
|
249 |
-
# Intervention example
|
250 |
-
|
251 |
-
sae = load_sae(
|
252 |
-
path="./llama-3-8b-d-hidden.pth",
|
253 |
-
d_model=4096,
|
254 |
-
expansion_factor=16,
|
255 |
-
)
|
256 |
-
|
257 |
-
PIRATE_FEATURE_INDEX = 0
|
258 |
-
VALUE_TO_MODIFY = 0.1
|
259 |
-
|
260 |
-
def example_intervention(activations: nnsight.InterventionProxy):
|
261 |
-
features = sae.encode(activations).detach()
|
262 |
-
reconstructed_acts = sae.decode(features).detach()
|
263 |
-
error = activations - reconstructed_acts
|
264 |
-
|
265 |
-
# Modify feature at index 0 across all token positions
|
266 |
-
features[:, 0] += 0.1
|
267 |
-
|
268 |
-
# Very important to add the error term back in!
|
269 |
-
return sae.decode(features) + error
|
270 |
-
|
271 |
-
|
272 |
-
logits, kv_cache, features = llama_3_1_8b.forward(
|
273 |
-
input_tokens,
|
274 |
-
interventions={"model.layers.19": example_intervention},
|
275 |
-
)
|
276 |
-
|
277 |
-
print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
|
278 |
-
```
|
279 |
|
280 |
## Training
|
281 |
|
|
|
45 |
|
46 |
## How to use
|
47 |
|
48 |
+
View the notebook guide below to get started with the SAE.
|
49 |
+
|
50 |
+
<a href="https://colab.research.google.com/drive/1IBMQtJqy8JiRk1Q48jDEgTISmtxhlCRL" target="_blank">
|
51 |
+
<img
|
52 |
+
src="https://colab.research.google.com/assets/colab-badge.svg"
|
53 |
+
alt="Open in Colab"
|
54 |
+
width="200px"
|
55 |
+
style={{ pointerEvents: "none" }}
|
56 |
+
/>
|
57 |
+
</a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
## Training
|
60 |
|