namgoodfire commited on
Commit
307eb7b
·
verified ·
1 Parent(s): 11a32ac

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -231
README.md CHANGED
@@ -45,237 +45,16 @@ and even steer its behavior. You can explore the SDK documentation at [docs.good
45
 
46
  ## How to use
47
 
48
- ```python
49
- import torch
50
- from typing import Optional, Callable
51
-
52
- import nnsight
53
- from nnsight.intervention import InterventionProxy
54
-
55
-
56
- # Autoencoder
57
-
58
-
59
- class SparseAutoEncoder(torch.nn.Module):
60
- def __init__(
61
- self,
62
- d_in: int,
63
- d_hidden: int,
64
- device: torch.device,
65
- dtype: torch.dtype = torch.bfloat16,
66
- ):
67
- super().__init__()
68
- self.d_in = d_in
69
- self.d_hidden = d_hidden
70
- self.device = device
71
- self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
72
- self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
73
- self.dtype = dtype
74
- self.to(self.device, self.dtype)
75
-
76
- def encode(self, x: torch.Tensor) -> torch.Tensor:
77
- """Encode a batch of data using a linear, followed by a ReLU."""
78
- return torch.nn.functional.relu(self.encoder_linear(x))
79
-
80
- def decode(self, x: torch.Tensor) -> torch.Tensor:
81
- """Decode a batch of data using a linear."""
82
- return self.decoder_linear(x)
83
-
84
- def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
85
- """SAE forward pass. Returns the reconstruction and the encoded features."""
86
- f = self.encode(x)
87
- return self.decode(f), f
88
-
89
-
90
- def load_sae(
91
- path: str,
92
- d_model: int,
93
- expansion_factor: int,
94
- device: torch.device = torch.device("cpu"),
95
- ):
96
- sae = SparseAutoEncoder(
97
- d_model,
98
- d_model * expansion_factor,
99
- device,
100
- )
101
- sae_dict = torch.load(
102
- path, weights_only=True, map_location=device
103
- )
104
- sae.load_state_dict(sae_dict)
105
-
106
- return sae
107
-
108
-
109
- # Lanngugae model
110
-
111
-
112
- InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
113
-
114
-
115
- class ObservableLanguageModel:
116
- def __init__(
117
- self,
118
- model: str,
119
- device: str = "cuda",
120
- dtype: torch.dtype = torch.bfloat16,
121
- ):
122
- self.dtype = dtype
123
- self.device = device
124
- self._original_model = model
125
-
126
-
127
- self._model = nnsight.LanguageModel(
128
- self._original_model,
129
- device_map=device,
130
- torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
131
- )
132
-
133
- self.tokenizer = self._model.tokenizer
134
-
135
- self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
136
-
137
- self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
138
-
139
- def _attempt_to_infer_hidden_layer_dimensions(self):
140
- config = self._model.config
141
- if hasattr(config, "hidden_size"):
142
- return int(config.hidden_size)
143
-
144
- raise Exception(
145
- "Could not infer hidden number of layer dimensions from model config"
146
- )
147
-
148
- def _find_module(self, hook_point: str):
149
- submodules = hook_point.split(".")
150
- module = self._model
151
- while submodules:
152
- module = getattr(module, submodules.pop(0))
153
- return module
154
-
155
- def forward(
156
- self,
157
- inputs: torch.Tensor,
158
- cache_activations_at: Optional[list[str]] = None,
159
- interventions: Optional[dict[str, InterventionInterface]] = None,
160
- use_cache: bool = True,
161
- past_key_values: Optional[tuple[torch.Tensor]] = None,
162
- ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
163
- cache: dict[str, torch.Tensor] = {}
164
- with self._model.trace(
165
- inputs,
166
- scan=self.safe_mode,
167
- validate=self.safe_mode,
168
- use_cache=use_cache,
169
- past_key_values=past_key_values,
170
- ):
171
- # If we input an intervention
172
- if interventions:
173
- for hook_site in interventions.keys():
174
- if interventions[hook_site] is None:
175
- continue
176
-
177
- module = self._find_module(hook_site)
178
-
179
- if self.cleanup_intervention_layer:
180
- last_layer = self._find_module(
181
- self.cleanup_intervention_layer
182
- )
183
- else:
184
- last_layer = None
185
-
186
- intervened_acts, direct_effect_tensor = interventions[
187
- hook_site
188
- ](module.output[0])
189
- # Add direct effect tensor as 0 if it is None
190
- if direct_effect_tensor is None:
191
- direct_effect_tensor = 0
192
- # We only modify module.output[0]
193
- if use_cache:
194
- module.output = (
195
- intervened_acts,
196
- module.output[1],
197
- )
198
- if last_layer:
199
- last_layer.output = (
200
- last_layer.output[0] - direct_effect_tensor,
201
- last_layer.output[1],
202
- )
203
- else:
204
- module.output = (intervened_acts,)
205
- if last_layer:
206
- last_layer.output = (
207
- last_layer.output[0] - direct_effect_tensor,
208
- )
209
-
210
- if cache_activations_at is not None:
211
- for hook_point in cache_activations_at:
212
- module = self._find_module(hook_point)
213
- cache[hook_point] = module.output.save()
214
-
215
- if not past_key_values:
216
- logits = self._model.output[0][:, -1, :].save()
217
- else:
218
- logits = self._model.output[0].squeeze(1).save()
219
-
220
- kv_cache = self._model.output.past_key_values.save()
221
-
222
- return (
223
- logits.value.detach(),
224
- kv_cache.value,
225
- {k: v[0].detach() for k, v in cache.items()},
226
- )
227
-
228
-
229
- # Reading out features from the model
230
-
231
- llama_3_1_8b = ObservableLanguageModel(
232
- "meta-llama/Llama-3.1-8B-Instruct",
233
- )
234
-
235
- input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
236
- [
237
- {"role": "user", "content": "Hello, how are you?"},
238
- ],
239
- return_tensors="pt",
240
- )
241
- logits, kv_cache, features = llama_3_1_8b.forward(
242
- input_tokens,
243
- cache_activations_at=["model.layers.19"],
244
- )
245
-
246
- print(features["model.layers.19"].shape)
247
-
248
-
249
- # Intervention example
250
-
251
- sae = load_sae(
252
- path="./llama-3-8b-d-hidden.pth",
253
- d_model=4096,
254
- expansion_factor=16,
255
- )
256
-
257
- PIRATE_FEATURE_INDEX = 0
258
- VALUE_TO_MODIFY = 0.1
259
-
260
- def example_intervention(activations: nnsight.InterventionProxy):
261
- features = sae.encode(activations).detach()
262
- reconstructed_acts = sae.decode(features).detach()
263
- error = activations - reconstructed_acts
264
-
265
- # Modify feature at index 0 across all token positions
266
- features[:, 0] += 0.1
267
-
268
- # Very important to add the error term back in!
269
- return sae.decode(features) + error
270
-
271
-
272
- logits, kv_cache, features = llama_3_1_8b.forward(
273
- input_tokens,
274
- interventions={"model.layers.19": example_intervention},
275
- )
276
-
277
- print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
278
- ```
279
 
280
  ## Training
281
 
 
45
 
46
  ## How to use
47
 
48
+ View the notebook guide below to get started with the SAE.
49
+
50
+ <a href="https://colab.research.google.com/drive/1IBMQtJqy8JiRk1Q48jDEgTISmtxhlCRL" target="_blank">
51
+ <img
52
+ src="https://colab.research.google.com/assets/colab-badge.svg"
53
+ alt="Open in Colab"
54
+ width="200px"
55
+ style={{ pointerEvents: "none" }}
56
+ />
57
+ </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  ## Training
60