from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel import torch import torch.nn.functional as F import numpy as np import os import torch.nn as nn from typing import List, Optional, Tuple, Union import math from transformers.models.llama.modeling_llama import LlamaDecoderLayer from torchmetrics.classification import MulticlassAccuracy from transformers.models.llama.modeling_llama import BaseModelOutputWithPast # sinusoidal positional encoding class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] * 1.0 emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class LlamaAdaptiveRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps # The gamma parameter self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x: torch.Tensor): # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim) # rsqrt: 1 / sqrt(x) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor): # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim) return self.weight * self._norm(x.float()).type_as(x) class MultiEmbedding(nn.Module): """Embedding for multiple quantization layers, summing up the embeddings of each layer.""" def __init__( self, num_embeddings=1028, embedding_dim=1024, num_quantization_layers=8, ): super().__init__() self.embeddings = nn.ModuleList( [ nn.Embedding(num_embeddings, embedding_dim) for _ in range(num_quantization_layers) ] ) # initialize embeddings for i in range(num_quantization_layers): self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02) self._is_hf_initialized = True # disable automatic init def forward(self, input_ids): """Input: [num_quant, B, T] -> Output: [B, T, H]""" num_quant, B, T = input_ids.shape summed_embeddings = torch.zeros( B, T, self.embeddings[0].embedding_dim, device=input_ids.device ) for i in range(num_quant): summed_embeddings += self.embeddings[i](input_ids[i]) return summed_embeddings class LlamaNARDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): """Override to adaptive layer norm""" super().__init__(config, layer_idx) # init attention, mlp, etc. self.input_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = LlamaAdaptiveRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class LlamaNAR(LlamaModel): def __init__( self, hidden_size=1024, num_heads=16, num_layers=16, config=LlamaConfig(0, 256, 1024, 1, 1), ): super().__init__(config) self.layers = nn.ModuleList( [ LlamaNARDecoderLayer( config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4), layer_idx=i, ) for i in range(num_layers) ] ) self.norm = LlamaAdaptiveRMSNorm(hidden_size) self.multi_embedding = MultiEmbedding( num_quantization_layers=8, embedding_dim=hidden_size ) self.post_init() def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # create noncausal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None def _expand_mask( mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None ): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, length: Optional[torch.LongTensor] = None, )-> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length, num_quant = input_ids.shape input_ids = input_ids.permute(2, 0, 1) # [num_quant, B, T] inputs_embeds = self.multi_embedding(input_ids) seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: raise NotImplementedError def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None return hidden_states class LlamaNAREmb(LlamaModel): """LlamaNAR model that works directly with embeddings input. This variant of LlamaNAR takes pre-computed embeddings as input instead of token IDs that need to be embedded. """ def __init__( self, hidden_size=1024, num_heads=16, num_layers=16, config=LlamaConfig(0, 256, 1024, 1, 1), ): super().__init__(config) self.layers = nn.ModuleList( [ LlamaNARDecoderLayer( config=LlamaConfig(hidden_size=hidden_size,num_attention_heads=num_heads,max_position_embeddings=4096,intermediate_size=hidden_size*4), layer_idx=i, ) for i in range(num_layers) ] ) self.norm = LlamaAdaptiveRMSNorm(hidden_size) self.post_init() def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # create noncausal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None def _expand_mask( mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None ): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ).to(inputs_embeds.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, )-> torch.Tensor: """ Returns: hidden_states: Tensor of shape (batch_size, sequence_length, hidden_size) """ if inputs_embeds is None: raise ValueError("inputs_embeds must be provided for LlamaNAREmb") if input_ids is not None: warnings.warn("input_ids is ignored in LlamaNAREmb, use inputs_embeds instead") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length, hidden_size = inputs_embeds.shape seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: raise NotImplementedError def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None return hidden_states if __name__ == '__main__': config = LlamaConfig(hidden_size=1024, num_attention_heads=8, num_hidden_layers=8) model = LlamaNAR(config=config) # 模拟输入数据 batch_size = 2 seq_length = 10 n_q = 8 input_ids = torch.randint(0, 1028, (batch_size, seq_length, n_q)) # 随机生成输入ID inputs_embeds = torch.randn(batch_size, seq_length, config.hidden_size) # 随机生成输入嵌入 attention_mask = torch.ones(batch_size, seq_length) # 所有位置可见 length = torch.tensor([4,10]) # 输入长度 # 前向传播 hidden_states, class_out = model( input_ids=input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True, length=length ) # 打印输出形状 print("Hidden States Shape:", hidden_states.shape) # 输出隐藏状态形状 print('Class output Shape:', class_out.shape)