namgoodfire
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,291 @@
|
|
1 |
-
---
|
2 |
-
license: llama3.1
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: llama3.1
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
base_model:
|
6 |
+
- meta-llama/Llama-3.1-8B-Instruct
|
7 |
+
tags:
|
8 |
+
- mechanistic interpretability
|
9 |
+
- sparse autoencoder
|
10 |
+
- llama
|
11 |
+
- llama-3
|
12 |
+
---
|
13 |
+
|
14 |
+
## Model Information
|
15 |
+
|
16 |
+
The Goodfire SAE (Sparse Autoencoder) for [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
|
17 |
+
is an interpreter model designed to analyze and understand
|
18 |
+
the model's internal representations. This SAE model is trained specifically on layer 50 of
|
19 |
+
Llama 3.3 70B and achieves an L0 count of 121, enabling the decomposition of complex neural activations
|
20 |
+
into interpretable features. The model is optimized for interpretability tasks and model steering applications,
|
21 |
+
allowing researchers and developers to gain insights into the model's internal processing and behavior patterns.
|
22 |
+
As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control
|
23 |
+
over large language model operations.
|
24 |
+
|
25 |
+
__Model Creator__: [Goodfire](https://huggingface.co/Goodfire), built to work with [Meta's Llama models](https://huggingface.co/meta-llama)
|
26 |
+
|
27 |
+
By using __Goodfire/Llama-3.1-8B-Instruct__model.layers.19__ you agree to the [LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct/blob/main/LICENSE)
|
28 |
+
|
29 |
+
|
30 |
+
## Intended Use
|
31 |
+
|
32 |
+
By open-sourcing SAEs for leading open models, especially large-scale
|
33 |
+
models like Llama 3.1 8B, we aim to accelerate progress in interpretability research.
|
34 |
+
|
35 |
+
Our initial work with these SAEs has revealed promising applications in model steering,
|
36 |
+
enhancing jailbreaking safeguards, and interpretable classification methods.
|
37 |
+
We look forward to seeing how the research community builds upon these
|
38 |
+
foundations and uncovers new applications.
|
39 |
+
|
40 |
+
#### Feature labels
|
41 |
+
|
42 |
+
To explore the feature labels check out the [Goodfire Ember SDK](https://www.goodfire.ai/blog/announcing-goodfire-ember/),
|
43 |
+
the first hosted mechanistic interpretability API.
|
44 |
+
The SDK provides an intuitive interface for interacting with these
|
45 |
+
features, allowing you to investigate how Llama processes information
|
46 |
+
and even steer its behavior. You can explore the SDK documentation at [docs.goodfire.ai](https://docs.goodfire.ai).
|
47 |
+
|
48 |
+
## How to use
|
49 |
+
|
50 |
+
```python
|
51 |
+
import torch
|
52 |
+
from typing import Optional, Callable
|
53 |
+
|
54 |
+
import nnsight
|
55 |
+
from nnsight.intervention import InterventionProxy
|
56 |
+
|
57 |
+
|
58 |
+
# Autoencoder
|
59 |
+
|
60 |
+
|
61 |
+
class SparseAutoEncoder(torch.nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
d_in: int,
|
65 |
+
d_hidden: int,
|
66 |
+
device: torch.device,
|
67 |
+
dtype: torch.dtype = torch.bfloat16,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self.d_in = d_in
|
71 |
+
self.d_hidden = d_hidden
|
72 |
+
self.device = device
|
73 |
+
self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
|
74 |
+
self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
|
75 |
+
self.dtype = dtype
|
76 |
+
self.to(self.device, self.dtype)
|
77 |
+
|
78 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
+
"""Encode a batch of data using a linear, followed by a ReLU."""
|
80 |
+
return torch.nn.functional.relu(self.encoder_linear(x))
|
81 |
+
|
82 |
+
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
83 |
+
"""Decode a batch of data using a linear."""
|
84 |
+
return self.decoder_linear(x)
|
85 |
+
|
86 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
87 |
+
"""SAE forward pass. Returns the reconstruction and the encoded features."""
|
88 |
+
f = self.encode(x)
|
89 |
+
return self.decode(f), f
|
90 |
+
|
91 |
+
|
92 |
+
def load_sae(
|
93 |
+
path: str,
|
94 |
+
d_model: int,
|
95 |
+
expansion_factor: int,
|
96 |
+
device: torch.device = torch.device("cpu"),
|
97 |
+
):
|
98 |
+
sae = SparseAutoEncoder(
|
99 |
+
d_model,
|
100 |
+
d_model * expansion_factor,
|
101 |
+
device,
|
102 |
+
)
|
103 |
+
sae_dict = torch.load(
|
104 |
+
path, weights_only=True, map_location=device
|
105 |
+
)
|
106 |
+
sae.load_state_dict(sae_dict)
|
107 |
+
|
108 |
+
return sae
|
109 |
+
|
110 |
+
|
111 |
+
# Lanngugae model
|
112 |
+
|
113 |
+
|
114 |
+
InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
|
115 |
+
|
116 |
+
|
117 |
+
class ObservableLanguageModel:
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
model: str,
|
121 |
+
device: str = "cuda",
|
122 |
+
dtype: torch.dtype = torch.bfloat16,
|
123 |
+
):
|
124 |
+
self.dtype = dtype
|
125 |
+
self.device = device
|
126 |
+
self._original_model = model
|
127 |
+
|
128 |
+
|
129 |
+
self._model = nnsight.LanguageModel(
|
130 |
+
self._original_model,
|
131 |
+
device_map=device,
|
132 |
+
torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
133 |
+
)
|
134 |
+
|
135 |
+
self.tokenizer = self._model.tokenizer
|
136 |
+
|
137 |
+
self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
|
138 |
+
|
139 |
+
self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
|
140 |
+
|
141 |
+
def _attempt_to_infer_hidden_layer_dimensions(self):
|
142 |
+
config = self._model.config
|
143 |
+
if hasattr(config, "hidden_size"):
|
144 |
+
return int(config.hidden_size)
|
145 |
+
|
146 |
+
raise Exception(
|
147 |
+
"Could not infer hidden number of layer dimensions from model config"
|
148 |
+
)
|
149 |
+
|
150 |
+
def _find_module(self, hook_point: str):
|
151 |
+
submodules = hook_point.split(".")
|
152 |
+
module = self._model
|
153 |
+
while submodules:
|
154 |
+
module = getattr(module, submodules.pop(0))
|
155 |
+
return module
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
inputs: torch.Tensor,
|
160 |
+
cache_activations_at: Optional[list[str]] = None,
|
161 |
+
interventions: Optional[dict[str, InterventionInterface]] = None,
|
162 |
+
use_cache: bool = True,
|
163 |
+
past_key_values: Optional[tuple[torch.Tensor]] = None,
|
164 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
|
165 |
+
cache: dict[str, torch.Tensor] = {}
|
166 |
+
with self._model.trace(
|
167 |
+
inputs,
|
168 |
+
scan=self.safe_mode,
|
169 |
+
validate=self.safe_mode,
|
170 |
+
use_cache=use_cache,
|
171 |
+
past_key_values=past_key_values,
|
172 |
+
):
|
173 |
+
# If we input an intervention
|
174 |
+
if interventions:
|
175 |
+
for hook_site in interventions.keys():
|
176 |
+
if interventions[hook_site] is None:
|
177 |
+
continue
|
178 |
+
|
179 |
+
module = self._find_module(hook_site)
|
180 |
+
|
181 |
+
if self.cleanup_intervention_layer:
|
182 |
+
last_layer = self._find_module(
|
183 |
+
self.cleanup_intervention_layer
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
last_layer = None
|
187 |
+
|
188 |
+
intervened_acts, direct_effect_tensor = interventions[
|
189 |
+
hook_site
|
190 |
+
](module.output[0])
|
191 |
+
# Add direct effect tensor as 0 if it is None
|
192 |
+
if direct_effect_tensor is None:
|
193 |
+
direct_effect_tensor = 0
|
194 |
+
# We only modify module.output[0]
|
195 |
+
if use_cache:
|
196 |
+
module.output = (
|
197 |
+
intervened_acts,
|
198 |
+
module.output[1],
|
199 |
+
)
|
200 |
+
if last_layer:
|
201 |
+
last_layer.output = (
|
202 |
+
last_layer.output[0] - direct_effect_tensor,
|
203 |
+
last_layer.output[1],
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
module.output = (intervened_acts,)
|
207 |
+
if last_layer:
|
208 |
+
last_layer.output = (
|
209 |
+
last_layer.output[0] - direct_effect_tensor,
|
210 |
+
)
|
211 |
+
|
212 |
+
if cache_activations_at is not None:
|
213 |
+
for hook_point in cache_activations_at:
|
214 |
+
module = self._find_module(hook_point)
|
215 |
+
cache[hook_point] = module.output.save()
|
216 |
+
|
217 |
+
if not past_key_values:
|
218 |
+
logits = self._model.output[0][:, -1, :].save()
|
219 |
+
else:
|
220 |
+
logits = self._model.output[0].squeeze(1).save()
|
221 |
+
|
222 |
+
kv_cache = self._model.output.past_key_values.save()
|
223 |
+
|
224 |
+
return (
|
225 |
+
logits.value.detach(),
|
226 |
+
kv_cache.value,
|
227 |
+
{k: v[0].detach() for k, v in cache.items()},
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
# Reading out features from the model
|
232 |
+
|
233 |
+
llama_3_1_8b = ObservableLanguageModel(
|
234 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
235 |
+
)
|
236 |
+
|
237 |
+
input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
|
238 |
+
[
|
239 |
+
{"role": "user", "content": "Hello, how are you?"},
|
240 |
+
],
|
241 |
+
return_tensors="pt",
|
242 |
+
)
|
243 |
+
logits, kv_cache, features = llama_3_1_8b.forward(
|
244 |
+
input_tokens,
|
245 |
+
cache_activations_at=["model.layers.19"],
|
246 |
+
)
|
247 |
+
|
248 |
+
print(features["model.layers.19"].shape)
|
249 |
+
|
250 |
+
|
251 |
+
# Intervention example
|
252 |
+
|
253 |
+
sae = load_sae(
|
254 |
+
path="./llama-3-8b-d-hidden.pth",
|
255 |
+
d_model=4096,
|
256 |
+
expansion_factor=16,
|
257 |
+
)
|
258 |
+
|
259 |
+
PIRATE_FEATURE_INDEX = 0
|
260 |
+
VALUE_TO_MODIFY = 0.1
|
261 |
+
|
262 |
+
def example_intervention(activations: nnsight.InterventionProxy):
|
263 |
+
features = sae.encode(activations).detach()
|
264 |
+
reconstructed_acts = sae.decode(features).detach()
|
265 |
+
error = activations - reconstructed_acts
|
266 |
+
|
267 |
+
# Modify feature at index 0 across all token positions
|
268 |
+
features[:, 0] += 0.1
|
269 |
+
|
270 |
+
# Very important to add the error term back in!
|
271 |
+
return sae.decode(features) + error
|
272 |
+
|
273 |
+
|
274 |
+
logits, kv_cache, features = llama_3_1_8b.forward(
|
275 |
+
input_tokens,
|
276 |
+
interventions={"model.layers.19": example_intervention},
|
277 |
+
)
|
278 |
+
|
279 |
+
print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
|
280 |
+
```
|
281 |
+
|
282 |
+
## Training
|
283 |
+
|
284 |
+
We trained our SAE on activations harvested from Llama-3.1-8B-Instruct on the [LMSYS-Chat-1M dataset](https://arxiv.org/pdf/2309.11998).
|
285 |
+
|
286 |
+
## Responsibility & Safety
|
287 |
+
|
288 |
+
Safety is at the core of everything we do at Goodfire. As a public benefit
|
289 |
+
corporation, we’re dedicated to understanding AI models to enable safer, more reliable
|
290 |
+
generative AI. You can read more about our comprehensive approach to
|
291 |
+
safety and responsible development in our detailed [safety overview](https://www.goodfire.ai/blog/our-approach-to-safety/).
|