import math from einops import rearrange, repeat from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder, BertModel, BertForMaskedLM from typing import List, Optional, Tuple, Union from packaging import version import torch import torch.nn as nn from .configuration_rotarybert import RotaryBertConfig def rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, ) def generate_cos_sin(seqlen, rotary_dim, device, dtype): inv_freq = 1.0 / ( 10000.0 ** (torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) / rotary_dim) ) t = torch.arange(seqlen, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos = torch.cos(freqs).to(dtype) sin = torch.sin(freqs).to(dtype) return cos, sin # from transformers.models.roformer import RoFormerSinusoidalPositionalEmbedding class RotaryBertSdpaSelfAttention(BertSelfAttention): def __init__(self, config, position_embedding_type=None): super().__init__(config, position_embedding_type=position_embedding_type) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = False # self.rotary_sinuses = RoFormerSinusoidalPositionalEmbedding(config.max_position_embeddings) # Adapted from BertSelfAttention def forward( self, hidden_states: torch.Tensor, attention_mask = None, head_mask = None, encoder_hidden_states = None, encoder_attention_mask = None, past_key_value = None, output_attentions = False, ) -> Tuple[torch.Tensor]: if output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. logger.warning_once( "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " "the manual attention implementation, but specifying the manual implementation will be required from " "Transformers version v5.0.0 onwards. This warning can be removed using the argument " '`attn_implementation="eager"` when loading the model.' ) bsz, tgt_len, _ = hidden_states.size() query_layer = self.query(hidden_states) is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning key_layer = self.key(current_states) value_layer = self.value(current_states) query_layer = self.transpose_for_scores(query_layer) key_layer = self.transpose_for_scores(key_layer) value_layer = self.transpose_for_scores(value_layer) query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3) cos, sin = generate_cos_sin(query_layer.shape[1], query_layer.shape[-1], device = query_layer.device, dtype = torch.float32) query_layer, key_layer = apply_rotary_emb_torch(query_layer, cos, sin), apply_rotary_emb_torch(key_layer, cos, sin) query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3) # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. # Reference: https://github.com/pytorch/pytorch/issues/112577 if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() is_causal = ( True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=self.dropout_prob if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) outputs = (attn_output,) if self.is_decoder: outputs = outputs + (past_key_value,) return outputs class RotaryBertAttention(BertAttention): def __init__(self, config): super().__init__(config) self.self = RotaryBertSdpaSelfAttention(config) class RotaryBertLayer(BertLayer): def __init__(self, config): super().__init__(config) self.attention = RotaryBertAttention(config) class RotaryBertEncoder(BertEncoder): def __init__(self, config): super().__init__(config) self.layer = nn.ModuleList([RotaryBertLayer(config) for _ in range(config.num_hidden_layers)]) class RotaryBertModel(BertModel): def __init__(self, config): super().__init__(config) self.encoder = RotaryBertEncoder(config) class RotaryBertForMaskedLM(BertForMaskedLM): config_class = RotaryBertConfig def __init__(self, config): super().__init__(config) self.bert = RotaryBertModel(config)