File size: 1,834 Bytes
18fcaa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f09894d
18fcaa2
 
a22116e
f09894d
18fcaa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
---
license: mit
---

# Gemma 2b - IT -  Residual Stream SAEs

This SAE is a follow-up to my other [Gemma-2b SAEs](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs) trained on the based model. 

These SAEs were trained with [SAE Lens](https://github.com/jbloomAus/SAELens) and the library version is stored in the cfg.json.

All training hyperparameters are specified in cfg.json.

They are loadable using SAE via a few methods. The preferred method is to use the following:

```python
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE, ActivationsStore

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gemma-2b-it")
sae, cfg, sparsity  = SAE.from_pretrained(
  "gemma-2b-it-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  "blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired. 
)

# For loading activations or tokens from the training dataset.
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=4,
    device=device,
)

```

## SAEs


### Resid Post 12

Stats:
- 16384 Features (expansion factor 8) achieving a CE Loss score of
- CE Loss score of 98.13%. 
- Mean L0 58 (in practice L0 is log normal distributed and is heavily right tailed).
- Dead Features: Less than 500 dead features. 

Notes:
- This SAE was trained on [open-web-text tokenized](https://huggingface.co/datasets/chanind/openwebtext-gemma).
- The sparsity json didn't have enough samples in it so I wouldn't trust it.