ncut-pytorch / backbone_text.py
huzey's picture
added text models
7311cdd
raw
history blame
9.66 kB
# %%
#
from typing import List, Union
import torch
import os
from torch import nn
from typing import Optional, Tuple
from functools import partial
MODEL_DICT = {}
LAYER_DICT = {}
class Llama(nn.Module):
def __init__(self, model_id="meta-llama/Meta-Llama-3.1-8B"):
super().__init__()
import transformers
access_token = os.getenv("HF_ACCESS_TOKEN")
if access_token is None:
raise ValueError("HF_ACCESS_TOKEN environment variable must be set")
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
token=access_token,
device='cpu',
)
tokenizer = pipeline.tokenizer
model = pipeline.model
def new_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
self.attn_output = hidden_states.clone()
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
self.mlp_output = hidden_states.clone()
hidden_states = residual + hidden_states
self.block_output = hidden_states.clone()
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
# for layer in model.model.layers:
# setattr(layer.__class__, "forward", new_forward)
# setattr(layer.__class__, "__call__", new_forward)
setattr(model.model.layers[0].__class__, "forward", new_forward)
setattr(model.model.layers[0].__class__, "__call__", new_forward)
self.model = model
self.tokenizer = tokenizer
@torch.no_grad()
def forward(self, text: str):
encoded_input = self.tokenizer(text, return_tensors='pt')
device = next(self.model.parameters()).device
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
output = self.model(**encoded_input, output_hidden_states=True)
attn_outputs, mlp_outputs, block_outputs = [], [], []
for i, blk in enumerate(self.model.model.layers):
attn_outputs.append(blk.attn_output)
mlp_outputs.append(blk.mlp_output)
block_outputs.append(blk.block_output)
token_ids = encoded_input['input_ids']
token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
MODEL_DICT["meta-llama/Meta-Llama-3.1-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3.1-8B")
LAYER_DICT["meta-llama/Meta-Llama-3.1-8B"] = 32
MODEL_DICT["meta-llama/Meta-Llama-3-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3-8B")
LAYER_DICT["meta-llama/Meta-Llama-3-8B"] = 32
class GPT2(nn.Module):
def __init__(self):
super().__init__()
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
def new_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
self.attn_output = attn_output.clone()
hidden_states = attn_output + residual
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
self.mlp_output = feed_forward_hidden_states.clone()
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
self.block_output = hidden_states.clone()
return outputs # hidden_states, present, (attentions, cross_attentions)
setattr(model.h[0].__class__, "forward", new_forward)
self.model = model
self.tokenizer = tokenizer
@torch.no_grad()
def forward(self, text: str):
encoded_input = self.tokenizer(text, return_tensors='pt')
device = next(self.model.parameters()).device
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
output = self.model(**encoded_input, output_hidden_states=True)
attn_outputs, mlp_outputs, block_outputs = [], [], []
for i, blk in enumerate(self.model.h):
attn_outputs.append(blk.attn_output)
mlp_outputs.append(blk.mlp_output)
block_outputs.append(blk.block_output)
token_ids = encoded_input['input_ids']
token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
MODEL_DICT["gpt2"] = GPT2
LAYER_DICT["gpt2"] = 12
def download_all_models():
for model_name in MODEL_DICT:
print(f"Downloading {model_name}")
try:
model = MODEL_DICT[model_name]()
except Exception as e:
print(f"Error downloading {model_name}: {e}")
continue
if __name__ == '__main__':
model = MODEL_DICT["meta-llama/Meta-Llama-3-8B"]()
# model = MODEL_DICT["gpt2"]()
text = """
1. The majestic giraffe, with its towering height and distinctive long neck, roams the savannas of Africa. These gentle giants use their elongated tongues to pluck leaves from the tallest trees, making them well-adapted to their environment. Their unique coat patterns, much like human fingerprints, are unique to each individual.
"""
model = model.cuda()
output = model(text)
print(output["block"][1].shape)
print(output["token_texts"])