recast-llama3.2-f4t4_old / modeling_recast_llama.py
appledora's picture
Upload modeling_recast_llama.py with huggingface_hub
4098668 verified
# filename: recastmlp_llama_model.py
from .configuration_recast_llama import RECAST1B_llama
from transformers import PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List
from transformers import AutoConfig
from transformers.utils import logging
from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
logger = logging.get_logger(__name__)
class MLPTemplateBank(nn.Module):
def __init__(self, config, num_templates):
super().__init__()
self.num_templates = num_templates
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Store templates in a more efficient layout
self.up_templates = nn.Parameter(
torch.empty(num_templates, self.intermediate_size * self.hidden_size)
)
self.gate_templates = nn.Parameter(
torch.empty(num_templates, self.intermediate_size * self.hidden_size)
)
self.down_templates = nn.Parameter(
torch.empty(num_templates, self.hidden_size * self.intermediate_size)
)
nn.init.kaiming_normal_(self.up_templates)
nn.init.kaiming_normal_(self.gate_templates)
nn.init.kaiming_normal_(self.down_templates)
def forward(self, up_coeffs, gate_coeffs, down_coeffs):
# Simple matrix multiplication instead of broadcasting
up_weights = torch.mm(up_coeffs, self.up_templates)
gate_weights = torch.mm(gate_coeffs, self.gate_templates)
down_weights = torch.mm(down_coeffs, self.down_templates)
up_weights = up_weights.view(self.intermediate_size, self.hidden_size)
gate_weights = gate_weights.view(self.intermediate_size, self.hidden_size)
down_weights = down_weights.view(self.hidden_size, self.intermediate_size)
return gate_weights, up_weights, down_weights
class SharedLlamaMLP(nn.Module):
def __init__(self, config, bank):
super().__init__()
self.config = config
self.bank = bank
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
self.gate_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
self.down_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
nn.init.normal_(self.up_coefficients, mean=0.0, std=1.0)
nn.init.normal_(self.gate_coefficients, mean=0.0, std=1.0)
nn.init.normal_(self.down_coefficients, mean=0.0, std=1.0)
if config.mlp_bias:
self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
self.down_bias = nn.Parameter(torch.zeros(self.hidden_size))
else:
self.register_parameter("gate_bias", None)
self.register_parameter("up_bias", None)
self.register_parameter("down_bias", None)
self.act_fn = F.silu
def forward(self, x):
# Generate weights with minimal operations
gate_weights, up_weights, down_weights = self.bank(
self.up_coefficients, self.gate_coefficients, self.down_coefficients
)
# Standard MLP operations
gate_output = F.linear(x, gate_weights, self.gate_bias)
up_output = F.linear(x, up_weights, self.up_bias)
hidden_states = self.act_fn(gate_output) * up_output
output = F.linear(hidden_states, down_weights, self.down_bias)
return output
def fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(
source, target, ignore_index=ignore_index, reduction=reduction
)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRotaryEmbedding,
LlamaRMSNorm,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
class RECAST1B_llamaModel(PreTrainedModel):
config_class = RECAST1B_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
# Initialize rotary embeddings
rope_config = config.rope_scaling
if rope_config:
rope_type = rope_config.get("rope_type", "default")
scaling_factor = rope_config.get("factor", 1.0)
else:
rope_type = "default"
scaling_factor = None
original_config = AutoConfig.from_pretrained(
"meta-llama/Llama-3.2-1b", trust_remote_code=True
)
self.rotary_emb = LlamaRotaryEmbedding(
config=original_config,
)
# Create template banks first
self.banks = []
layers_per_group = config.num_hidden_layers // config.num_groups
for _ in range(config.num_groups):
bank = MLPTemplateBank(config, config.num_templates)
self.banks.append(bank)
# Create layers using LlamaDecoderLayer but replace MLPs
self.layers = nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
# Create standard LlamaDecoderLayer
decoder_layer = LlamaDecoderLayer(config, layer_idx)
# Replace its MLP with our SharedLlamaMLP
group_idx = layer_idx // layers_per_group
group_bank = self.banks[group_idx]
decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
self.layers.append(decoder_layer)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Set up cache position if not provided
if cache_position is None:
past_seen_tokens = (
0
if past_key_values is None
else (
past_key_values.get_seq_length()
if isinstance(past_key_values, Cache)
else past_key_values[0][0].size(-2) if past_key_values else 0
)
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Create position embeddings to be shared across the decoder layers
# Set up position IDs if not provided
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Get updated causal mask
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# Initialize outputs
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# Process through layers
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Final layer norm
hidden_states = self.norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
# Load from local checkpoint
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = cls(config)
checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
logger.info(
f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
)
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if len(missing_keys) > 0:
logger.warning(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
# Load from hub using parent's from_pretrained
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class RECAST1B_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
config_class = RECAST1B_llama
base_model_prefix = "llama"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.model = RECAST1B_llamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def loss_function(
self,
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(
shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in
`[0, ..., config.vocab_size]` or -100 (masked tokens).
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
# Calculate batch size for loss function
num_items_in_batch = (
input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
)
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
num_items_in_batch=num_items_in_batch,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if isinstance(
pretrained_model_name_or_path, str
) and pretrained_model_name_or_path.endswith(".pt"):
print("Loading from local checkpoint")
config = kwargs.get("config", None)
if config is None:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
model = cls(config)
checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
state_dict = checkpoint["model_state_dict"]
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if len(missing_keys) > 0:
logger.warning(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
logger.warning(f"Unexpected keys: {unexpected_keys}")
return model
else:
print("Loading from hub")
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)