license: llama3.1
language:
- en
base_model:
- meta-llama/Llama-3.1-8B-Instruct
tags:
- mechanistic interpretability
- sparse autoencoder
- llama
- llama-3
Model Information
The Goodfire SAE (Sparse Autoencoder) for meta-llama/Llama-3.1-8B-Instruct is an interpreter model designed to analyze and understand the model's internal representations. This SAE model is trained specifically on layer 19 of Llama 3.1 8B and achieves an L0 count of 91, enabling the decomposition of complex neural activations into interpretable features. The model is optimized for interpretability tasks and model steering applications, allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control over large language model operations.
Model Creator: Goodfire, built to work with Meta's Llama models
By using Goodfire/Llama-3.1-8B-Instruct__model.layers.19 you agree to the LLAMA 3.1 COMMUNITY LICENSE AGREEMENT
Intended Use
By open-sourcing SAEs for leading open models, especially large-scale models like Llama 3.1 8B, we aim to accelerate progress in interpretability research.
Our initial work with these SAEs has revealed promising applications in model steering, enhancing jailbreaking safeguards, and interpretable classification methods. We look forward to seeing how the research community builds upon these foundations and uncovers new applications.
Feature labels
To explore the feature labels check out the Goodfire Ember SDK, the first hosted mechanistic interpretability API. The SDK provides an intuitive interface for interacting with these features, allowing you to investigate how Llama processes information and even steer its behavior. You can explore the SDK documentation at docs.goodfire.ai.
How to use
import torch
from typing import Optional, Callable
import nnsight
from nnsight.intervention import InterventionProxy
# Autoencoder
class SparseAutoEncoder(torch.nn.Module):
def __init__(
self,
d_in: int,
d_hidden: int,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.d_in = d_in
self.d_hidden = d_hidden
self.device = device
self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
self.dtype = dtype
self.to(self.device, self.dtype)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a batch of data using a linear, followed by a ReLU."""
return torch.nn.functional.relu(self.encoder_linear(x))
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""Decode a batch of data using a linear."""
return self.decoder_linear(x)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""SAE forward pass. Returns the reconstruction and the encoded features."""
f = self.encode(x)
return self.decode(f), f
def load_sae(
path: str,
d_model: int,
expansion_factor: int,
device: torch.device = torch.device("cpu"),
):
sae = SparseAutoEncoder(
d_model,
d_model * expansion_factor,
device,
)
sae_dict = torch.load(
path, weights_only=True, map_location=device
)
sae.load_state_dict(sae_dict)
return sae
# Lanngugae model
InterventionInterface = Callable[[InterventionProxy], InterventionProxy]
class ObservableLanguageModel:
def __init__(
self,
model: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.dtype = dtype
self.device = device
self._original_model = model
self._model = nnsight.LanguageModel(
self._original_model,
device_map=device,
torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
)
self.tokenizer = self._model.tokenizer
self.d_model = self._attempt_to_infer_hidden_layer_dimensions()
self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.
def _attempt_to_infer_hidden_layer_dimensions(self):
config = self._model.config
if hasattr(config, "hidden_size"):
return int(config.hidden_size)
raise Exception(
"Could not infer hidden number of layer dimensions from model config"
)
def _find_module(self, hook_point: str):
submodules = hook_point.split(".")
module = self._model
while submodules:
module = getattr(module, submodules.pop(0))
return module
def forward(
self,
inputs: torch.Tensor,
cache_activations_at: Optional[list[str]] = None,
interventions: Optional[dict[str, InterventionInterface]] = None,
use_cache: bool = True,
past_key_values: Optional[tuple[torch.Tensor]] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
cache: dict[str, torch.Tensor] = {}
with self._model.trace(
inputs,
scan=self.safe_mode,
validate=self.safe_mode,
use_cache=use_cache,
past_key_values=past_key_values,
):
# If we input an intervention
if interventions:
for hook_site in interventions.keys():
if interventions[hook_site] is None:
continue
module = self._find_module(hook_site)
if self.cleanup_intervention_layer:
last_layer = self._find_module(
self.cleanup_intervention_layer
)
else:
last_layer = None
intervened_acts, direct_effect_tensor = interventions[
hook_site
](module.output[0])
# Add direct effect tensor as 0 if it is None
if direct_effect_tensor is None:
direct_effect_tensor = 0
# We only modify module.output[0]
if use_cache:
module.output = (
intervened_acts,
module.output[1],
)
if last_layer:
last_layer.output = (
last_layer.output[0] - direct_effect_tensor,
last_layer.output[1],
)
else:
module.output = (intervened_acts,)
if last_layer:
last_layer.output = (
last_layer.output[0] - direct_effect_tensor,
)
if cache_activations_at is not None:
for hook_point in cache_activations_at:
module = self._find_module(hook_point)
cache[hook_point] = module.output.save()
if not past_key_values:
logits = self._model.output[0][:, -1, :].save()
else:
logits = self._model.output[0].squeeze(1).save()
kv_cache = self._model.output.past_key_values.save()
return (
logits.value.detach(),
kv_cache.value,
{k: v[0].detach() for k, v in cache.items()},
)
# Reading out features from the model
llama_3_1_8b = ObservableLanguageModel(
"meta-llama/Llama-3.1-8B-Instruct",
)
input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
],
return_tensors="pt",
)
logits, kv_cache, features = llama_3_1_8b.forward(
input_tokens,
cache_activations_at=["model.layers.19"],
)
print(features["model.layers.19"].shape)
# Intervention example
sae = load_sae(
path="./llama-3-8b-d-hidden.pth",
d_model=4096,
expansion_factor=16,
)
PIRATE_FEATURE_INDEX = 0
VALUE_TO_MODIFY = 0.1
def example_intervention(activations: nnsight.InterventionProxy):
features = sae.encode(activations).detach()
reconstructed_acts = sae.decode(features).detach()
error = activations - reconstructed_acts
# Modify feature at index 0 across all token positions
features[:, 0] += 0.1
# Very important to add the error term back in!
return sae.decode(features) + error
logits, kv_cache, features = llama_3_1_8b.forward(
input_tokens,
interventions={"model.layers.19": example_intervention},
)
print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))
Training
We trained our SAE on activations harvested from Llama-3.1-8B-Instruct on the LMSYS-Chat-1M dataset.
Responsibility & Safety
Safety is at the core of everything we do at Goodfire. As a public benefit corporation, we’re dedicated to understanding AI models to enable safer, more reliable generative AI. You can read more about our comprehensive approach to safety and responsible development in our detailed safety overview.