namgoodfire's picture
Update README.md
fb23377 verified
|
raw
history blame
9.79 kB
metadata
license: llama3.1
language:
  - en
base_model:
  - meta-llama/Llama-3.1-8B-Instruct
tags:
  - mechanistic interpretability
  - sparse autoencoder
  - llama
  - llama-3

Model Information

The Goodfire SAE (Sparse Autoencoder) for meta-llama/Llama-3.1-8B-Instruct is an interpreter model designed to analyze and understand the model's internal representations. This SAE model is trained specifically on layer 19 of  Llama 3.1 8B and achieves an L0 count of 91, enabling the decomposition of complex neural activations into interpretable features. The model is optimized for interpretability tasks and model steering applications, allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control over large language model operations.

Model Creator: Goodfire, built to work with Meta's Llama models

By using Goodfire/Llama-3.1-8B-Instruct__model.layers.19 you agree to the LLAMA 3.1 COMMUNITY LICENSE AGREEMENT

Intended Use

By open-sourcing SAEs for leading open models, especially large-scale models like Llama 3.1 8B, we aim to accelerate progress in interpretability research.

Our initial work with these SAEs has revealed promising applications in model steering, enhancing jailbreaking safeguards, and interpretable classification methods. We look forward to seeing how the research community builds upon these foundations and uncovers new applications.

Feature labels

To explore the feature labels check out the Goodfire Ember SDK, the first hosted mechanistic interpretability API. The SDK provides an intuitive interface for interacting with these features, allowing you to investigate how Llama processes information and even steer its behavior. You can explore the SDK documentation at docs.goodfire.ai.

How to use

import torch
from typing import Optional, Callable

import nnsight
from nnsight.intervention import InterventionProxy


# Autoencoder


class SparseAutoEncoder(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        d_hidden: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.device = device
        self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
        self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
        self.dtype = dtype
        self.to(self.device, self.dtype)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a batch of data using a linear, followed by a ReLU."""
        return torch.nn.functional.relu(self.encoder_linear(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a batch of data using a linear."""
        return self.decoder_linear(x)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """SAE forward pass. Returns the reconstruction and the encoded features."""
        f = self.encode(x)
        return self.decode(f), f


def load_sae(
    path: str,
    d_model: int,
    expansion_factor: int,
    device: torch.device = torch.device("cpu"),
):
    sae = SparseAutoEncoder(
        d_model,
        d_model * expansion_factor,
        device,
    )
    sae_dict = torch.load(
        path, weights_only=True, map_location=device
    )
    sae.load_state_dict(sae_dict)

    return sae


# Lanngugae model


InterventionInterface = Callable[[InterventionProxy], InterventionProxy]


class ObservableLanguageModel:
    def __init__(
        self,
        model: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.dtype = dtype
        self.device = device
        self._original_model = model


        self._model = nnsight.LanguageModel(
            self._original_model,
            device_map=device,
            torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype
        )

        self.tokenizer = self._model.tokenizer

        self.d_model = self._attempt_to_infer_hidden_layer_dimensions()

        self.safe_mode = False  # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug.

    def _attempt_to_infer_hidden_layer_dimensions(self):
        config = self._model.config
        if hasattr(config, "hidden_size"):
            return int(config.hidden_size)

        raise Exception(
            "Could not infer hidden number of layer dimensions from model config"
        )

    def _find_module(self, hook_point: str):
        submodules = hook_point.split(".")
        module = self._model
        while submodules:
            module = getattr(module, submodules.pop(0))
        return module

    def forward(
        self,
        inputs: torch.Tensor,
        cache_activations_at: Optional[list[str]] = None,
        interventions: Optional[dict[str, InterventionInterface]] = None,
        use_cache: bool = True,
        past_key_values: Optional[tuple[torch.Tensor]] = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]:
        cache: dict[str, torch.Tensor] = {}
        with self._model.trace(
            inputs,
            scan=self.safe_mode,
            validate=self.safe_mode,
            use_cache=use_cache,
            past_key_values=past_key_values,
        ):
            # If we input an intervention
            if interventions:
                for hook_site in interventions.keys():
                    if interventions[hook_site] is None:
                        continue

                    module = self._find_module(hook_site)

                    if self.cleanup_intervention_layer:
                        last_layer = self._find_module(
                            self.cleanup_intervention_layer
                        )
                    else:
                        last_layer = None

                    intervened_acts, direct_effect_tensor = interventions[
                        hook_site
                    ](module.output[0])
                    # Add direct effect tensor as 0 if it is None
                    if direct_effect_tensor is None:
                        direct_effect_tensor = 0
                    # We only modify module.output[0]
                    if use_cache:
                        module.output = (
                            intervened_acts,
                            module.output[1],
                        )
                        if last_layer:
                            last_layer.output = (
                                last_layer.output[0] - direct_effect_tensor,
                                last_layer.output[1],
                            )
                    else:
                        module.output = (intervened_acts,)
                        if last_layer:
                            last_layer.output = (
                                last_layer.output[0] - direct_effect_tensor,
                            )

            if cache_activations_at is not None:
                for hook_point in cache_activations_at:
                    module = self._find_module(hook_point)
                    cache[hook_point] = module.output.save()

            if not past_key_values:
                logits = self._model.output[0][:, -1, :].save()
            else:
                logits = self._model.output[0].squeeze(1).save()
           
            kv_cache = self._model.output.past_key_values.save()

        return (
            logits.value.detach(),
            kv_cache.value,
            {k: v[0].detach() for k, v in cache.items()},
        )


# Reading out features from the model

llama_3_1_8b = ObservableLanguageModel(
    "meta-llama/Llama-3.1-8B-Instruct",
)

input_tokens = llama_3_1_8b.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    return_tensors="pt",
)
logits, kv_cache, features = llama_3_1_8b.forward(
    input_tokens,
    cache_activations_at=["model.layers.19"],
)

print(features["model.layers.19"].shape)


# Intervention example

sae = load_sae(
    path="./llama-3-8b-d-hidden.pth",
    d_model=4096,
    expansion_factor=16,
)

PIRATE_FEATURE_INDEX = 0
VALUE_TO_MODIFY = 0.1

def example_intervention(activations: nnsight.InterventionProxy):
    features = sae.encode(activations).detach()
    reconstructed_acts = sae.decode(features).detach()
    error = activations - reconstructed_acts

    # Modify feature at index 0 across all token positions
    features[:, 0] += 0.1

    # Very important to add the error term back in!
    return sae.decode(features) + error


logits, kv_cache, features = llama_3_1_8b.forward(
    input_tokens,
    interventions={"model.layers.19": example_intervention},
)

print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1)))

Training

We trained our SAE on activations harvested from Llama-3.1-8B-Instruct on the LMSYS-Chat-1M dataset.

Responsibility & Safety

Safety is at the core of everything we do at Goodfire. As a public benefit corporation, we’re dedicated to understanding AI models to enable safer, more reliable generative AI. You can read more about our comprehensive approach to safety and responsible development in our detailed safety overview.