# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. # # This code is based on transformers/src/transformers/models/llama/modeling_llama.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch InternLM2 model.""" import math import queue import threading import warnings from typing import List, Optional, Tuple, Union from functools import partial import torch import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from timm.models.layers import DropPath compute_ARank = False # [ARank] Set this to True to compute attention rank try: from transformers.generation.streamers import BaseStreamer except: # noqa # pylint: disable=bare-except BaseStreamer = None from .configuration_holistic_embedding import HolisticEmbeddingConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "HolisticEmbeddingConfig" flash_attn_func, flash_attn_varlen_func = None, None pad_input, index_first_axis, unpad_input = None, None, None def _import_flash_attn(): global flash_attn_func, flash_attn_varlen_func global pad_input, index_first_axis, unpad_input try: from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input except ImportError: raise ImportError("flash_attn is not installed.") _import_flash_attn() # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2 class InternLM2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ InternLM2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 class InternLM2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2 class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla. """ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.model.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class InternLM2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) return down_proj # Copied from transformers.model.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Modified from transformers.model.llama.modeling_llama.LlamaAttention class InternLM2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: HolisticEmbeddingConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.wqkv = nn.Linear( self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=config.attention_bias, ) self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = InternLM2RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "dynamic": self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, scaling_factor=scaling_factor, ) elif scaling_type == "linear": self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, scaling_factor=scaling_factor, ) else: raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " "Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if attention_mask is not None and len(attention_mask.shape) == 2: # Flash Attention Mode to Attention Mode new_attention_mask = torch.zeros(bsz, 1, q_len, q_len).to(hidden_states.device) upper_tri_indices = torch.triu_indices(row=q_len, col=q_len, offset=1) new_attention_mask[:, :, upper_tri_indices[0], upper_tri_indices[1]] = -65504. attention_mask = new_attention_mask qkv_states = self.wqkv(hidden_states) qkv_states = rearrange( qkv_states, "b q (h gs d) -> b q h gs d", gs=2 + self.num_key_value_groups, d=self.head_dim, ) query_states = qkv_states[..., : self.num_key_value_groups, :] query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) # min_dtype = torch.finfo(attn_weights.dtype).min # causal_mask = torch.full( # (q_len, kv_seq_len), fill_value=min_dtype, dtype=attn_weights.dtype, device=attn_weights.device # ) # if q_len != 1: # causal_mask = torch.triu(causal_mask, diagonal=1) # # causal_mask *= torch.arange(kv_seq_len, device=device) > cache_position.reshape(-1, 1) # causal_mask = causal_mask[None, None, :, :].expand(bsz, 1, -1, -1) # attention_mask = causal_mask attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2 class InternLM2FlashAttention2(InternLM2Attention): """ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # InternLM2FlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " "Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) qkv_states = rearrange( qkv_states, "b q (h gs d) -> b q h gs d", gs=2 + self.num_key_value_groups, d=self.head_dim, ) query_states = qkv_states[..., : self.num_key_value_groups, :] query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d") key_states = qkv_states[..., -2, :] value_states = qkv_states[..., -1, :] query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.wo(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence causal = self.is_causal and query_length != 1 if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q.to(torch.int64), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) INTERNLM2_ATTENTION_CLASSES = { "eager": InternLM2Attention, "flash_attention_2": InternLM2FlashAttention2, } # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer class InternLM2DecoderLayer(nn.Module): def __init__(self, config: HolisticEmbeddingConfig, drop_path_rate=0.0): super().__init__() self.hidden_size = config.hidden_size self.config = config self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config) if not compute_ARank else InternLM2Attention(config=config) self.feed_forward = InternLM2MLP(config) self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " "Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.attention_norm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs, ) hidden_states = residual + self.drop_path1(hidden_states) # Fully Connected residual = hidden_states hidden_states = self.ffn_norm(hidden_states) hidden_states = self.feed_forward(hidden_states) hidden_states = residual + self.drop_path2(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class VisionEmbeddings(nn.Module): def __init__(self, config: HolisticEmbeddingConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter( torch.randn(1, 1, self.embed_dim), ) self.patch_embedding = nn.Conv2d( in_channels=self.config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) self.post_init() def post_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) if isinstance(m, nn.Linear): torch.nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def _get_pos_embed(self, pos_embed, H, W): target_dtype = pos_embed.dtype pos_embed = pos_embed.float().reshape( 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) return pos_embed def forward(self, pixel_values: torch.FloatTensor, use_cls_token=False, ) -> torch.Tensor: target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] batch_size, _, height, width = patch_embeds.shape patch_embeds = patch_embeds.flatten(2).transpose(1, 2) if use_cls_token: class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) assert not self.config.use_2d_sincos_pos_embed, '2D SinCos pos embed is not supported with use_cls_token' position_embedding = torch.cat([ self.position_embedding[:, :1, :], self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) ], dim=1) embeddings = embeddings + position_embedding else: position_embedding = self._get_pos_embed(self.position_embedding[:, 1:, :], height, width).to(target_dtype) embeddings = patch_embeds + position_embedding return embeddings class HolisticEmbedding(PreTrainedModel): config_class = HolisticEmbeddingConfig _supports_flash_attn_2 = True def __init__(self, config: HolisticEmbeddingConfig): super().__init__(config) self.config = config self.hidden_size = self.config.hidden_size self.gradient_checkpointing = True self.vision_embeddings = VisionEmbeddings(config) self.llm_text_embeddings = nn.Embedding(self.config.llm_vocab_size, self.config.llm_hidden_size) self.special_token_maps = config.special_token_maps if len(self.special_token_maps) > 0: self.special_text_embeddings = nn.Embedding(len(config.special_token_maps), self.config.llm_hidden_size) assert self.config.use_ls is False, 'LS is not supported in InternLM2' if hasattr(config, 'drop_path_rate'): dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] else: dpr = [0.0] * config.num_hidden_layers self.encoder = nn.ModuleList([ InternLM2DecoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers) ]) if self.config.use_pixel_shuffle_proj: self.pixel_shuffle_proj = nn.Sequential( nn.Linear(int(config.hidden_size / (config.downsample_ratio * config.downsample_ratio)), config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size) ) self.num_img_tokens = (self.config.image_size // self.config.patch_size) ** 2 def set_gradient_checkpointing(self): self.gradient_checkpointing = True for layer in self.encoder: layer.gradient_checkpointing = True def resize_pos_embeddings(self, old_size, new_size, patch_size): pos_emb = self.vision_embeddings.position_embedding _, num_positions, embed_dim = pos_emb.shape cls_emb = pos_emb[:, :1, :] pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) pos_emb = torch.cat([cls_emb, pos_emb], dim=1) self.vision_embeddings.position_embedding = nn.Parameter(pos_emb) self.vision_embeddings.image_size = new_size logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) def replace_img_tokens(self, input_ids, hidden_states, vision_hidden_states): img_context_token_mask = (input_ids == self.config.img_context_token_id) hidden_states[img_context_token_mask] = hidden_states[img_context_token_mask] * 0.0 + vision_hidden_states.flatten(0, 1) return hidden_states def get_ignore_mask(self, input_ids): ignore_ids = torch.tensor( [self.special_token_maps[token] for token in [IMG_START_TOKEN, IMG_END_TOKEN]], device=input_ids.device) ignore_mask = torch.isin(input_ids, ignore_ids) return ignore_mask def get_text_mask(self, input_ids): txt_mask = (input_ids != self.config.img_context_token_id) return txt_mask def get_input_embeddings(self, input_ids): special_mask = input_ids > self.llm_text_embeddings.weight.shape[0] - 1 llm_embeddings = self.llm_text_embeddings(input_ids * (~special_mask).to(input_ids)) if len(self.special_token_maps) > 0: special_embeddings = self.special_text_embeddings((input_ids - self.llm_text_embeddings.weight.shape[0]) * special_mask.to(input_ids)) special_mask = special_mask.unsqueeze(-1) text_embeddings = llm_embeddings * (~special_mask).to(llm_embeddings) + \ special_embeddings * special_mask.to(llm_embeddings) else: text_embeddings = llm_embeddings return text_embeddings def get_txt_embeddings(self, input_ids): B, L = input_ids.shape txt_mask = (input_ids != self.config.img_context_token_id) txt_embeddings = self.llm_text_embeddings(input_ids[txt_mask]) txt_embeddings = txt_embeddings.reshape(-1, txt_embeddings.shape[-1]) return txt_embeddings def get_txt_feature(self, input_ids, feature): B, L, C = feature.shape txt_mask = (input_ids != self.config.img_context_token_id) txt_feature = feature[txt_mask].reshape(-1, C) return txt_feature def get_img_feature(self, input_ids, feature): B, L, C = feature.shape img_mask = (input_ids == self.config.img_context_token_id) img_feature = feature[img_mask].reshape(-1, C) return img_feature def pixel_shuffle(self, x, scale_factor=0.5): if getattr(self.config, 'pixel_shuffle_loc', 'pre') == 'post': x = x.view(x.shape[0]//self.num_img_tokens, self.num_img_tokens, -1) n, l, c = x.size() h = w = int(l ** 0.5) # N, W, H, C --> N, W, H * scale, C // scale x = x.reshape(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) x = x.permute(0, 2, 1, 3).reshape(n, int(l * scale_factor * scale_factor), int(c / (scale_factor * scale_factor))).contiguous() if getattr(self.config, 'pixel_shuffle_loc', 'pre') == 'post': x = x.view(int(x.shape[0]*self.num_img_tokens*(self.config.downsample_ratio**2)), -1) return x def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is not None: if len(pixel_values.shape) == 4: if self.gradient_checkpointing and self.training: vision_hidden_states = torch.utils.checkpoint.checkpoint(self.vision_embeddings, pixel_values) else: vision_hidden_states = self.vision_embeddings(pixel_values) if self.config.use_pixel_shuffle_proj and getattr(self.config, 'pixel_shuffle_loc', 'pre') == 'pre': vision_hidden_states = self.pixel_shuffle(vision_hidden_states, scale_factor=self.config.downsample_ratio) if self.gradient_checkpointing and self.training: vision_hidden_states = torch.utils.checkpoint.checkpoint(self.pixel_shuffle_proj, vision_hidden_states) else: vision_hidden_states = self.pixel_shuffle_proj(vision_hidden_states) hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.replace_img_tokens(input_ids, hidden_states, vision_hidden_states) else: raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') else: hidden_states = self.get_input_embeddings(input_ids) if position_ids is None: position_ids = torch.arange( hidden_states.shape[1], device=hidden_states.device ).unsqueeze(0) next_past_key_values = [] for layer_idx, layer_module in enumerate(self.encoder): if self.gradient_checkpointing and self.training: assert use_cache is None, 'Gradient checkpointing is not compatible with cache' outputs = torch.utils.checkpoint.checkpoint(layer_module, hidden_states, attention_mask, position_ids, None, False, False, ) hidden_states = outputs[0] else: outputs = layer_module( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, ) hidden_states = outputs[0] if use_cache: next_past_key_values.append(outputs[-1]) img_feature = self.get_img_feature(input_ids, hidden_states) if self.config.use_pixel_shuffle_proj and getattr(self.config, 'pixel_shuffle_loc', 'pre') == 'post': img_feature = self.pixel_shuffle(img_feature, scale_factor=self.config.downsample_ratio) img_feature = self.pixel_shuffle_proj(img_feature) return img_feature, hidden_states, next_past_key_values