# %% # from typing import List, Union import torch import os from torch import nn from typing import Optional, Tuple from functools import partial MODEL_DICT = {} LAYER_DICT = {} class Llama(nn.Module): def __init__(self, model_id="meta-llama/Meta-Llama-3.1-8B"): super().__init__() import transformers access_token = os.getenv("HF_ACCESS_TOKEN") if access_token is None: raise ValueError("HF_ACCESS_TOKEN environment variable must be set") pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, token=access_token, device='cpu', ) tokenizer = pipeline.tokenizer model = pipeline.model def new_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) self.attn_output = hidden_states.clone() hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) self.mlp_output = hidden_states.clone() hidden_states = residual + hidden_states self.block_output = hidden_states.clone() outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs # for layer in model.model.layers: # setattr(layer.__class__, "forward", new_forward) # setattr(layer.__class__, "__call__", new_forward) setattr(model.model.layers[0].__class__, "forward", new_forward) setattr(model.model.layers[0].__class__, "__call__", new_forward) self.model = model self.tokenizer = tokenizer @torch.no_grad() def forward(self, text: str): encoded_input = self.tokenizer(text, return_tensors='pt') device = next(self.model.parameters()).device encoded_input = {k: v.to(device) for k, v in encoded_input.items()} output = self.model(**encoded_input, output_hidden_states=True) attn_outputs, mlp_outputs, block_outputs = [], [], [] for i, blk in enumerate(self.model.model.layers): attn_outputs.append(blk.attn_output) mlp_outputs.append(blk.mlp_output) block_outputs.append(blk.block_output) token_ids = encoded_input['input_ids'] token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]] return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts} MODEL_DICT["meta-llama/Meta-Llama-3.1-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3.1-8B") LAYER_DICT["meta-llama/Meta-Llama-3.1-8B"] = 32 MODEL_DICT["meta-llama/Meta-Llama-3-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3-8B") LAYER_DICT["meta-llama/Meta-Llama-3-8B"] = 32 class GPT2(nn.Module): def __init__(self): super().__init__() from transformers import GPT2Tokenizer, GPT2Model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') def new_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection self.attn_output = attn_output.clone() hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " "cross-attention layers by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection self.mlp_output = feed_forward_hidden_states.clone() hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] self.block_output = hidden_states.clone() return outputs # hidden_states, present, (attentions, cross_attentions) setattr(model.h[0].__class__, "forward", new_forward) self.model = model self.tokenizer = tokenizer @torch.no_grad() def forward(self, text: str): encoded_input = self.tokenizer(text, return_tensors='pt') device = next(self.model.parameters()).device encoded_input = {k: v.to(device) for k, v in encoded_input.items()} output = self.model(**encoded_input, output_hidden_states=True) attn_outputs, mlp_outputs, block_outputs = [], [], [] for i, blk in enumerate(self.model.h): attn_outputs.append(blk.attn_output) mlp_outputs.append(blk.mlp_output) block_outputs.append(blk.block_output) token_ids = encoded_input['input_ids'] token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]] return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts} MODEL_DICT["gpt2"] = GPT2 LAYER_DICT["gpt2"] = 12 def download_all_models(): for model_name in MODEL_DICT: print(f"Downloading {model_name}") try: model = MODEL_DICT[model_name]() except Exception as e: print(f"Error downloading {model_name}: {e}") continue if __name__ == '__main__': model = MODEL_DICT["meta-llama/Meta-Llama-3-8B"]() # model = MODEL_DICT["gpt2"]() text = """ 1. The majestic giraffe, with its towering height and distinctive long neck, roams the savannas of Africa. These gentle giants use their elongated tongues to pluck leaves from the tallest trees, making them well-adapted to their environment. Their unique coat patterns, much like human fingerprints, are unique to each individual. """ model = model.cuda() output = model(text) print(output["block"][1].shape) print(output["token_texts"])