--- license: llama3.1 language: - en tags: - mechanistic interpretability - sparse autoencoder - llama - llama-3 --- ## Model Information The Goodfire SAE (Sparse Autoencoder) for [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) is an interpreter model designed to analyze and understand the model's internal representations. This SAE model is trained specifically on layer 19 of  Llama 3.1 8B and achieves an L0 count of 91, enabling the decomposition of complex neural activations into interpretable features. The model is optimized for interpretability tasks and model steering applications, allowing researchers and developers to gain insights into the model's internal processing and behavior patterns. As an open-source tool, it serves as a foundation for advancing interpretability research and enhancing control over large language model operations. __Model Creator__: [Goodfire](https://huggingface.co/Goodfire), built to work with [Meta's Llama models](https://huggingface.co/meta-llama) By using __Goodfire/Llama-3.1-8B-Instruct__model.layers.19__ you agree to the [LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct/blob/main/LICENSE) ## Intended Use By open-sourcing SAEs for leading open models, especially large-scale models like Llama 3.1 8B, we aim to accelerate progress in interpretability research. Our initial work with these SAEs has revealed promising applications in model steering, enhancing jailbreaking safeguards, and interpretable classification methods. We look forward to seeing how the research community builds upon these foundations and uncovers new applications. #### Feature labels To explore the feature labels check out the [Goodfire Ember SDK](https://www.goodfire.ai/blog/announcing-goodfire-ember/), the first hosted mechanistic interpretability API. The SDK provides an intuitive interface for interacting with these features, allowing you to investigate how Llama processes information and even steer its behavior. You can explore the SDK documentation at [docs.goodfire.ai](https://docs.goodfire.ai). ## How to use ```python import torch from typing import Optional, Callable import nnsight from nnsight.intervention import InterventionProxy # Autoencoder class SparseAutoEncoder(torch.nn.Module): def __init__( self, d_in: int, d_hidden: int, device: torch.device, dtype: torch.dtype = torch.bfloat16, ): super().__init__() self.d_in = d_in self.d_hidden = d_hidden self.device = device self.encoder_linear = torch.nn.Linear(d_in, d_hidden) self.decoder_linear = torch.nn.Linear(d_hidden, d_in) self.dtype = dtype self.to(self.device, self.dtype) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a batch of data using a linear, followed by a ReLU.""" return torch.nn.functional.relu(self.encoder_linear(x)) def decode(self, x: torch.Tensor) -> torch.Tensor: """Decode a batch of data using a linear.""" return self.decoder_linear(x) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """SAE forward pass. Returns the reconstruction and the encoded features.""" f = self.encode(x) return self.decode(f), f def load_sae( path: str, d_model: int, expansion_factor: int, device: torch.device = torch.device("cpu"), ): sae = SparseAutoEncoder( d_model, d_model * expansion_factor, device, ) sae_dict = torch.load( path, weights_only=True, map_location=device ) sae.load_state_dict(sae_dict) return sae # Lanngugae model InterventionInterface = Callable[[InterventionProxy], InterventionProxy] class ObservableLanguageModel: def __init__( self, model: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): self.dtype = dtype self.device = device self._original_model = model self._model = nnsight.LanguageModel( self._original_model, device_map=device, torch_dtype=getattr(torch, dtype) if isinstance(dtype, str) else dtype ) self.tokenizer = self._model.tokenizer self.d_model = self._attempt_to_infer_hidden_layer_dimensions() self.safe_mode = False # Nsight validation is disabled by default, slows down inference a lot. Turn on to debug. def _attempt_to_infer_hidden_layer_dimensions(self): config = self._model.config if hasattr(config, "hidden_size"): return int(config.hidden_size) raise Exception( "Could not infer hidden number of layer dimensions from model config" ) def _find_module(self, hook_point: str): submodules = hook_point.split(".") module = self._model while submodules: module = getattr(module, submodules.pop(0)) return module def forward( self, inputs: torch.Tensor, cache_activations_at: Optional[list[str]] = None, interventions: Optional[dict[str, InterventionInterface]] = None, use_cache: bool = True, past_key_values: Optional[tuple[torch.Tensor]] = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor], dict[str, torch.Tensor]]: cache: dict[str, torch.Tensor] = {} with self._model.trace( inputs, scan=self.safe_mode, validate=self.safe_mode, use_cache=use_cache, past_key_values=past_key_values, ): # If we input an intervention if interventions: for hook_site in interventions.keys(): if interventions[hook_site] is None: continue module = self._find_module(hook_site) if self.cleanup_intervention_layer: last_layer = self._find_module( self.cleanup_intervention_layer ) else: last_layer = None intervened_acts, direct_effect_tensor = interventions[ hook_site ](module.output[0]) # Add direct effect tensor as 0 if it is None if direct_effect_tensor is None: direct_effect_tensor = 0 # We only modify module.output[0] if use_cache: module.output = ( intervened_acts, module.output[1], ) if last_layer: last_layer.output = ( last_layer.output[0] - direct_effect_tensor, last_layer.output[1], ) else: module.output = (intervened_acts,) if last_layer: last_layer.output = ( last_layer.output[0] - direct_effect_tensor, ) if cache_activations_at is not None: for hook_point in cache_activations_at: module = self._find_module(hook_point) cache[hook_point] = module.output.save() if not past_key_values: logits = self._model.output[0][:, -1, :].save() else: logits = self._model.output[0].squeeze(1).save() kv_cache = self._model.output.past_key_values.save() return ( logits.value.detach(), kv_cache.value, {k: v[0].detach() for k, v in cache.items()}, ) # Reading out features from the model llama_3_1_8b = ObservableLanguageModel( "meta-llama/Llama-3.1-8B-Instruct", ) input_tokens = llama_3_1_8b.tokenizer.apply_chat_template( [ {"role": "user", "content": "Hello, how are you?"}, ], return_tensors="pt", ) logits, kv_cache, features = llama_3_1_8b.forward( input_tokens, cache_activations_at=["model.layers.19"], ) print(features["model.layers.19"].shape) # Intervention example sae = load_sae( path="./llama-3-8b-d-hidden.pth", d_model=4096, expansion_factor=16, ) PIRATE_FEATURE_INDEX = 0 VALUE_TO_MODIFY = 0.1 def example_intervention(activations: nnsight.InterventionProxy): features = sae.encode(activations).detach() reconstructed_acts = sae.decode(features).detach() error = activations - reconstructed_acts # Modify feature at index 0 across all token positions features[:, 0] += 0.1 # Very important to add the error term back in! return sae.decode(features) + error logits, kv_cache, features = llama_3_1_8b.forward( input_tokens, interventions={"model.layers.19": example_intervention}, ) print(llama_3_1_8b.tokenizer.decode(logits[-1].argmax(-1))) ``` ## Training We trained our SAE on activations harvested from Llama-3.1-8B-Instruct on the [LMSYS-Chat-1M dataset](https://arxiv.org/pdf/2309.11998). ## Responsibility & Safety Safety is at the core of everything we do at Goodfire. As a public benefit corporation, we’re dedicated to understanding AI models to enable safer, more reliable generative AI. You can read more about our comprehensive approach to safety and responsible development in our detailed [safety overview](https://www.goodfire.ai/blog/our-approach-to-safety/).