namgoodfire commited on
Commit
092488e
·
verified ·
1 Parent(s): 385fa24

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +291 -3
README.md CHANGED
@@ -1,3 +1,291 @@
1
- ---
2
- license: llama3.1
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.1
3
+ language:
4
+ - en
5
+ base_model:
6
+ - meta-llama/Llama-3.1-8B-Instruct
7
+ tags:
8
+ - mechanistic interpretability
9
+ - sparse autoencoder
10
+ - llama
11
+ - llama-3
12
+ ---
13
+
14
+ ## Model Information
15
+
16
+ The Goodfire SAE (Sparse Autoencoder) for [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
17
+ is an interpreter model designed to analyze and understand
18
+ the model's internal representations. This SAE model is trained specifically on layer 50 of 
19
+ Llama 3.3 70B and achieves an L0 count of 121, enabling the decomposition of complex neural activations
20
+ into interpretable features. The model is optimized for interpretability tasks and model steering applications,
21
+ allowing researchers and developers to gain insights into the model's internal processing and behavior patterns.
22
+ As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control
23
+ over large language model operations.
24
+
25
+ __Model Creator__: [Goodfire](https://huggingface.co/Goodfire), built to work with [Meta's Llama models](https://huggingface.co/meta-llama)
26
+
27
+ By using __Goodfire/Llama-3.1-8B-Instruct__model.layers.19__ you agree to the [LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct/blob/main/LICENSE)
28
+
29
+
30
+ ## Intended Use
31
+
32
+ By open-sourcing SAEs for leading open models, especially large-scale
33
+ models like Llama 3.1 8B, we aim to accelerate progress in interpretability research.
34
+
35
+ Our initial work with these SAEs has revealed promising applications in model steering,
36
+ enhancing jailbreaking safeguards, and interpretable classification methods.
37
+ We look forward to seeing how the research community builds upon these
38
+ foundations and uncovers new applications.
39
+
40
+ #### Feature labels
41
+
42
+ To explore the feature labels check out the [Goodfire Ember SDK](https://www.goodfire.ai/blog/announcing-goodfire-ember/),
43
+ the first hosted mechanistic interpretability API.
44
+ The SDK provides an intuitive interface for interacting with these
45
+ features, allowing you to investigate how Llama processes information
46
+ and even steer its behavior. You can explore the SDK documentation at [docs.goodfire.ai](https://docs.goodfire.ai).
47
+
48
+ ## How to use
49
+
50
+ ```python
51
+ import torch
52
+ from typing import Optional, Callable
53
+
54
+ import nnsight
55
+ from nnsight.intervention import InterventionProxy
56
+
57
+
58
+ # Autoencoder
59
+
60
+
61
+ class SparseAutoEncoder(torch.nn.Module):
62
+ def __init__(
63
+ self,
64
+ d_in: int,
65
+ d_hidden: int,
66
+ device: torch.device,
67
+ dtype: torch.dtype = torch.bfloat16,
68
+ ):
69
+ super().__init__()
70
+ self.d_in = d_in
71
+ self.d_hidden = d_hidden
72
+ self.device = device
73
+ self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
74
+ self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
75
+ self.dtype = dtype
76
+ self.to(self.device, self.dtype)
77
+
78
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
79
+ """Encode a batch of data using a linear, followed by a ReLU."""
80
+ return torch.nn.functional.relu(self.encoder_linear(x))
81
+
82
+ def decode(self, x: torch.Tensor) -> torch.Tensor:
83
+ """Decode a batch of data using a linear."""
84
+ return self.decoder_linear(x)
85
+
86
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
87
+ """SAE forward pass. Returns the reconstruction and the encoded features."""
88
+ f = self.encode(x)
89
+ return self.decode(f), f
90
+
91
+
92
+ def load_sae(
93
+ path: str,
94
+ d_model: int,
95
+ expansion_factor: int,
96
+ device: torch.device = torch.device("cpu"),
97
+ ):
98
+ sae = SparseAutoEncoder(
99
+ d_model,
100
+ d_model * expansion_factor,
101
+ device,
102
+ )
103
+ sae_dict = torch.load(
104
+ path, weights_only=True, map_location=device
105
+ )
106
+ sae.load_state_dict(sae_dict)
107
+
108
+ return sae
109
+
110
+
111
+ # Lanngugae model
112
+
113
+
114
+ InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
115
+
116
+
117
+ class ObservableLanguageModel:
118
+ def __init__(
119
+ self,
120
+ model: str,
121
+ device: str = "cuda",
122
+ dtype: torch.dtype = torch.bfloat16,
123
+ ):
124
+ self.dtype = dtype
125
+ self.device = device
126
+ self._original_model = model
127
+
128
+
129
+ self._model = nnsight.LanguageModel(
130
+ self._original_model,
131
+ device_map=device,
132
+ torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
133
+ )
134
+
135
+ self.tokenizer = self._model.tokenizer
136
+
137
+ self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
138
+
139
+ self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
140
+
141
+ def _attempt_to_infer_hidden_layer_dimensions(self):
142
+ config = self._model.config
143
+ if hasattr(config, "hidden_size"):
144
+ return int(config.hidden_size)
145
+
146
+ raise Exception(
147
+ "Could not infer hidden number of layer dimensions from model config"
148
+ )
149
+
150
+ def _find_module(self, hook_point: str):
151
+ submodules = hook_point.split(".")
152
+ module = self._model
153
+ while submodules:
154
+ module = getattr(module, submodules.pop(0))
155
+ return module
156
+
157
+ def forward(
158
+ self,
159
+ inputs: torch.Tensor,
160
+ cache_activations_at: Optional[list[str]] = None,
161
+ interventions: Optional[dict[str, InterventionInterface]] = None,
162
+ use_cache: bool = True,
163
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
164
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
165
+ cache: dict[str, torch.Tensor] = {}
166
+ with self._model.trace(
167
+ inputs,
168
+ scan=self.safe_mode,
169
+ validate=self.safe_mode,
170
+ use_cache=use_cache,
171
+ past_key_values=past_key_values,
172
+ ):
173
+ # If we input an intervention
174
+ if interventions:
175
+ for hook_site in interventions.keys():
176
+ if interventions[hook_site] is None:
177
+ continue
178
+
179
+ module = self._find_module(hook_site)
180
+
181
+ if self.cleanup_intervention_layer:
182
+ last_layer = self._find_module(
183
+ self.cleanup_intervention_layer
184
+ )
185
+ else:
186
+ last_layer = None
187
+
188
+ intervened_acts, direct_effect_tensor = interventions[
189
+ hook_site
190
+ ](module.output[0])
191
+ # Add direct effect tensor as 0 if it is None
192
+ if direct_effect_tensor is None:
193
+ direct_effect_tensor = 0
194
+ # We only modify module.output[0]
195
+ if use_cache:
196
+ module.output = (
197
+ intervened_acts,
198
+ module.output[1],
199
+ )
200
+ if last_layer:
201
+ last_layer.output = (
202
+ last_layer.output[0] - direct_effect_tensor,
203
+ last_layer.output[1],
204
+ )
205
+ else:
206
+ module.output = (intervened_acts,)
207
+ if last_layer:
208
+ last_layer.output = (
209
+ last_layer.output[0] - direct_effect_tensor,
210
+ )
211
+
212
+ if cache_activations_at is not None:
213
+ for hook_point in cache_activations_at:
214
+ module = self._find_module(hook_point)
215
+ cache[hook_point] = module.output.save()
216
+
217
+ if not past_key_values:
218
+ logits = self._model.output[0][:, -1, :].save()
219
+ else:
220
+ logits = self._model.output[0].squeeze(1).save()
221
+
222
+ kv_cache = self._model.output.past_key_values.save()
223
+
224
+ return (
225
+ logits.value.detach(),
226
+ kv_cache.value,
227
+ {k: v[0].detach() for k, v in cache.items()},
228
+ )
229
+
230
+
231
+ # Reading out features from the model
232
+
233
+ llama_3_1_8b = ObservableLanguageModel(
234
+ "meta-llama/Llama-3.1-8B-Instruct",
235
+ )
236
+
237
+ input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
238
+ [
239
+ {"role": "user", "content": "Hello, how are you?"},
240
+ ],
241
+ return_tensors="pt",
242
+ )
243
+ logits, kv_cache, features = llama_3_1_8b.forward(
244
+ input_tokens,
245
+ cache_activations_at=["model.layers.19"],
246
+ )
247
+
248
+ print(features["model.layers.19"].shape)
249
+
250
+
251
+ # Intervention example
252
+
253
+ sae = load_sae(
254
+ path="./llama-3-8b-d-hidden.pth",
255
+ d_model=4096,
256
+ expansion_factor=16,
257
+ )
258
+
259
+ PIRATE_FEATURE_INDEX = 0
260
+ VALUE_TO_MODIFY = 0.1
261
+
262
+ def example_intervention(activations: nnsight.InterventionProxy):
263
+ features = sae.encode(activations).detach()
264
+ reconstructed_acts = sae.decode(features).detach()
265
+ error = activations - reconstructed_acts
266
+
267
+ # Modify feature at index 0 across all token positions
268
+ features[:, 0] += 0.1
269
+
270
+ # Very important to add the error term back in!
271
+ return sae.decode(features) + error
272
+
273
+
274
+ logits, kv_cache, features = llama_3_1_8b.forward(
275
+ input_tokens,
276
+ interventions={"model.layers.19": example_intervention},
277
+ )
278
+
279
+ print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
280
+ ```
281
+
282
+ ## Training
283
+
284
+ We trained our SAE on activations harvested from Llama-3.1-8B-Instruct on the [LMSYS-Chat-1M dataset](https://arxiv.org/pdf/2309.11998).
285
+
286
+ ## Responsibility & Safety
287
+
288
+ Safety is at the core of everything we do at Goodfire. As a public benefit
289
+ corporation, we’re dedicated to understanding AI models to enable safer, more reliable
290
+ generative AI. You can read more about our comprehensive approach to
291
+ safety and responsible development in our detailed [safety overview](https://www.goodfire.ai/blog/our-approach-to-safety/).