gemma-2-2b layer 20 SAE width 65k SAE seems very off
I have been evaluating gemma-2-2b SAEs on a dataset of medical text. Looking at the 16k width on layer 20, the metrics seem to be about what I'd expect:
{
"l0_139": {
"l2_loss": 148.585,
"l1_loss": 2728.04,
"l0": 183.4201708984375,
"frac_variance_explained": -0.40267578125,
"cossim": 0.92232421875,
"l2_ratio": 0.93435546875,
"relative_reconstruction_bias": 1.97640625,
"loss_original": 1.8141272115707396,
"loss_reconstructed": 2.10102525472641,
"loss_zero": 12.452932243347169,
"frac_recovered": 0.9730078125,
"frac_alive": 0.9940185546875,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 139,
"layer": 20,
"width": "16k"
}
},
"l0_22": {
"l2_loss": 284.485,
"l1_loss": 2719.16,
"l0": 55.3458203125,
"frac_variance_explained": -31.495,
"cossim": 0.87486328125,
"l2_ratio": 0.91640625,
"relative_reconstruction_bias": 15.23140625,
"loss_original": 1.8141272115707396,
"loss_reconstructed": 2.4616863882541655,
"loss_zero": 12.452932243347169,
"frac_recovered": 0.9390829825401306,
"frac_alive": 0.82476806640625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 22,
"layer": 20,
"width": "16k"
}
},
"l0_294": {
"l2_loss": 130.0175,
"l1_loss": 3763.92,
"l0": 352.6443994140625,
"frac_variance_explained": -0.01845703125,
"cossim": 0.9406640625,
"l2_ratio": 0.94486328125,
"relative_reconstruction_bias": 1.71236328125,
"loss_original": 1.8141272115707396,
"loss_reconstructed": 2.0600193762779235,
"loss_zero": 12.452932243347169,
"frac_recovered": 0.9768525409698486,
"frac_alive": 0.99761962890625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 294,
"layer": 20,
"width": "16k"
}
},
"l0_38": {
"l2_loss": 251.58,
"l1_loss": 2645.76,
"l0": 73.9402734375,
"frac_variance_explained": -20.34841796875,
"cossim": 0.889765625,
"l2_ratio": 0.9233984375,
"relative_reconstruction_bias": 11.4728125,
"loss_original": 1.8141272115707396,
"loss_reconstructed": 2.3733639335632324,
"loss_zero": 12.452932243347169,
"frac_recovered": 0.947366454899311,
"frac_alive": 0.89910888671875,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 38,
"layer": 20,
"width": "16k"
}
},
"l0_71": {
"l2_loss": 189.87,
"l1_loss": 2500.32,
"l0": 109.7097705078125,
"frac_variance_explained": -4.80037109375,
"cossim": 0.90638671875,
"l2_ratio": 0.92884765625,
"relative_reconstruction_bias": 4.6397265625,
"loss_original": 1.8141272115707396,
"loss_reconstructed": 2.1981925880908966,
"loss_zero": 12.452932243347169,
"frac_recovered": 0.9638544994592667,
"frac_alive": 0.96929931640625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 71,
"layer": 20,
"width": "16k"
}
}
}
However, the 65k for layer 20 has really weird metrics, including a very poor loss recovered (i.e. Equation 10 from the gated SAEs paper: https://arxiv.org/pdf/2404.16014), despite having a low L2 loss. I thought it may be a quirk of the dataset, but have reproduced this somewhat on monology/pile-uncopyrighted
:
{
"l0_114": {
"l2_loss": 65.14174501419068,
"l1_loss": 326.4906903076172,
"l0": 19.7434326171875,
"frac_variance_explained": -1.1298050680756568,
"cossim": 0.44833588257431983,
"l2_ratio": 1.458713674545288,
"relative_reconstruction_bias": 3.926573168039322,
"loss_original": 2.151599160730839,
"loss_reconstructed": 12.79894030570984,
"loss_zero": 12.452933530807496,
"frac_recovered": -0.03705257594643627,
"frac_alive": 0.1755828857421875,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 114,
"layer": 20,
"width": "65k"
}
},
"l0_20": {
"l2_loss": 78.6240915298462,
"l1_loss": 274.0826930999756,
"l0": 6.4778857421875,
"frac_variance_explained": -8.00341603398323,
"cossim": 0.3754740992188454,
"l2_ratio": 1.6491711509227753,
"relative_reconstruction_bias": 8.075071120262146,
"loss_original": 2.151599160730839,
"loss_reconstructed": 18.244347710609436,
"loss_zero": 12.452933530807496,
"frac_recovered": -0.5657323953509331,
"frac_alive": 0.02691650390625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 20,
"layer": 20,
"width": "65k"
}
},
"l0_221": {
"l2_loss": 61.26867036819458,
"l1_loss": 394.06997283935544,
"l0": 30.2818212890625,
"frac_variance_explained": -0.004639597833156586,
"cossim": 0.47954541400074957,
"l2_ratio": 1.4224228554964065,
"relative_reconstruction_bias": 2.8707287490367888,
"loss_original": 2.151599160730839,
"loss_reconstructed": 10.630927562713623,
"loss_zero": 12.452933530807496,
"frac_recovered": 0.17530182713409886,
"frac_alive": 0.2276763916015625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 221,
"layer": 20,
"width": "65k"
}
},
"l0_34": {
"l2_loss": 77.58435577392578,
"l1_loss": 281.88170654296874,
"l0": 8.2050439453125,
"frac_variance_explained": -10.130443168580532,
"cossim": 0.41340469181537626,
"l2_ratio": 1.6560911977291106,
"relative_reconstruction_bias": 9.29422394156456,
"loss_original": 2.151599160730839,
"loss_reconstructed": 17.128004446029664,
"loss_zero": 12.452933530807496,
"frac_recovered": -0.4539519951120019,
"frac_alive": 0.065582275390625,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 34,
"layer": 20,
"width": "65k"
}
},
"l0_61": {
"l2_loss": 77.927738571167,
"l1_loss": 314.41664611816407,
"l0": 14.6854248046875,
"frac_variance_explained": -7.959465856552124,
"cossim": 0.41553613662719724,
"l2_ratio": 1.6819834589958191,
"relative_reconstruction_bias": 7.942268486022949,
"loss_original": 2.151599160730839,
"loss_reconstructed": 15.694214601516723,
"loss_zero": 12.452933530807496,
"frac_recovered": -0.31723272004863245,
"frac_alive": 0.1244659423828125,
"hyperparameters": {
"n_inputs": 200,
"context_length": 1024,
"l0": 61,
"layer": 20,
"width": "65k"
}
}
}
I will evaluate some other SAEs and other gemma models to see if this is just a specific problem with this SAE in this model in this layer. I did all evaluation with the dictionary_learning
repo (https://github.com/saprmarks/dictionary_learning). But would be good if someone sanity checks me / tells me if I'm missing something.
Discussion closed? Was this a bug?