Trained Sparse Autoencoders on Pythia 2.8B
I trained SAEs on the MLP_out activations of the Pythia 2.8B dataset. I trained using github.com/magikarp01/facts-sae, a fork of github.com/saprmarks/dictionary_learning designed for efficient multi-GPU (not yet multinode) training. I have checkpoints saved every 10k steps, but I have not uploaded them all: message me if you want more intermediate checkpoints.
The goal was originally to analyze these SAEs specifically to determine how well they contribute to performance on a Sports Facts dataset. I'm currently working on some other projects so I haven't actually had time to do this, but hopefully in the future some results might come out of these SAEs.
SAE Setup
- Training Dataset: Uncopyrighted Pile, at monology/pile-uncopyrighted
- Model: 32-layer Pythia 2.8B
- Activation: MLP_out, so d_model of 2560
- Layers Trained: 0, 1, 2, 15
- Batch Size: 2048 for layer 15, 2560 for layers 0, 1, 2
- Training Tokens: 1e9 for layers 15, 0, 2, slightly less than 2e9 for layer 1.
- Training Steps: 4e5 for layers 0, 2, 5e5 for layer 15, 7.5e5 for layer 1
- Dictionary Size: 16x activation, so 40960
Training Hyperparamaters
- Learning Rate: 3e-4
- Sparsity Penalty: 1e-3
- Warmup Steps: 5000
- Resample Steps: 50000
- Optimizer: Constrained Adam
- Scheduler: LambdaLR, linear warmup lr between 0 and warmup_steps
SAE Metrics
Thanks
Thanks to Nat Friedman/NFDG for letting me use H100s from the Andromeda Cluster during downtime, and thanks to Sam Marks/NDIF for the original SAE training repo and for helping me distribute the SAEs. Work done as a late part of my MATS training phase with Neel Nanda.