jbloom commited on
Commit
18fcaa2
1 Parent(s): 85ded2f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -3
README.md CHANGED
@@ -1,3 +1,55 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # Gemma 2b - IT - Residual Stream SAEs
6
+
7
+ This SAE is a follow-up to my other [Gemma-2b SAEs](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs) trained on the based model.
8
+
9
+ These SAEs were trained with [SAE Lens](https://github.com/jbloomAus/SAELens) and the library version is stored in the cfg.json.
10
+
11
+ All training hyperparameters are specified in cfg.json.
12
+
13
+ They are loadable using SAE via a few methods. The preferred method is to use the following:
14
+
15
+ ```python
16
+ import torch
17
+ from transformer_lens import HookedTransformer
18
+ from sae_lens import SparseAutoencoder, ActivationsStore
19
+
20
+ torch.set_grad_enabled(False)
21
+ model = HookedTransformer.from_pretrained("gemma-2b")
22
+ sae, cfg, sparsity = SparseAutoencoder.from_pretrained(
23
+ "gemma-2b-it-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
24
+ "blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired.
25
+ )
26
+
27
+ # For loading activations or tokens from the training dataset.
28
+ activation_store = ActivationsStore.from_sae(
29
+ model=model,
30
+ sae=sae,
31
+ streaming=True,
32
+ # fairly conservative parameters here so can use same for larger
33
+ # models without running out of memory.
34
+ store_batch_size_prompts=8,
35
+ train_batch_size_tokens=4096,
36
+ n_batches_in_buffer=4,
37
+ device=device,
38
+ )
39
+
40
+ ```
41
+
42
+ ## SAEs
43
+
44
+
45
+ ### Resid Post 12
46
+
47
+ Stats:
48
+ - 16384 Features (expansion factor 8) achieving a CE Loss score of
49
+ - CE Loss score of 98.13%.
50
+ - Mean L0 58 (in practice L0 is log normal distributed and is heavily right tailed).
51
+ - Dead Features: Less than 500 dead features.
52
+
53
+ Notes:
54
+ - This SAE was trained on [open-web-text tokenized](https://huggingface.co/datasets/chanind/openwebtext-gemma).
55
+ - The sparsity json didn't have enough samples in it so I wouldn't trust it.