|
import math |
|
from einops import rearrange, repeat |
|
from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder, BertModel, BertForMaskedLM |
|
from typing import List, Optional, Tuple, Union |
|
from packaging import version |
|
import torch |
|
import torch.nn as nn |
|
|
|
def rotate_half(x, interleaved=False): |
|
if not interleaved: |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
else: |
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) |
|
|
|
|
|
def apply_rotary_emb_torch(x, cos, sin, interleaved=False): |
|
""" |
|
x: (batch_size, seqlen, nheads, headdim) |
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) |
|
""" |
|
ro_dim = cos.shape[-1] * 2 |
|
assert ro_dim <= x.shape[-1] |
|
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
|
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") |
|
return torch.cat( |
|
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], |
|
dim=-1, |
|
) |
|
|
|
|
|
def generate_cos_sin(seqlen, rotary_dim, device, dtype): |
|
inv_freq = 1.0 / ( |
|
10000.0 |
|
** (torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) / rotary_dim) |
|
) |
|
t = torch.arange(seqlen, device=device, dtype=torch.float32) |
|
freqs = torch.outer(t, inv_freq) |
|
cos = torch.cos(freqs).to(dtype) |
|
sin = torch.sin(freqs).to(dtype) |
|
return cos, sin |
|
|
|
|
|
|
|
|
|
class RotaryBertSdpaSelfAttention(BertSelfAttention): |
|
def __init__(self, config, position_embedding_type=None): |
|
super().__init__(config, position_embedding_type=position_embedding_type) |
|
self.dropout_prob = config.attention_probs_dropout_prob |
|
self.require_contiguous_qkv = False |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask = None, |
|
head_mask = None, |
|
encoder_hidden_states = None, |
|
encoder_attention_mask = None, |
|
past_key_value = None, |
|
output_attentions = False, |
|
) -> Tuple[torch.Tensor]: |
|
if output_attentions or head_mask is not None: |
|
|
|
logger.warning_once( |
|
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " |
|
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " |
|
"the manual attention implementation, but specifying the manual implementation will be required from " |
|
"Transformers version v5.0.0 onwards. This warning can be removed using the argument " |
|
'`attn_implementation="eager"` when loading the model.' |
|
) |
|
|
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
query_layer = self.query(hidden_states) |
|
|
|
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states |
|
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask |
|
|
|
|
|
|
|
key_layer = self.key(current_states) |
|
value_layer = self.value(current_states) |
|
|
|
|
|
query_layer = self.transpose_for_scores(query_layer) |
|
key_layer = self.transpose_for_scores(key_layer) |
|
value_layer = self.transpose_for_scores(value_layer) |
|
|
|
|
|
|
|
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3) |
|
|
|
|
|
cos, sin = generate_cos_sin(query_layer.shape[1], query_layer.shape[-1], device = query_layer.device, dtype = torch.float32) |
|
|
|
|
|
query_layer, key_layer = apply_rotary_emb_torch(query_layer, cos, sin), apply_rotary_emb_torch(key_layer, cos, sin) |
|
|
|
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3) |
|
|
|
|
|
|
|
|
|
|
|
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: |
|
query_layer = query_layer.contiguous() |
|
key_layer = key_layer.contiguous() |
|
value_layer = value_layer.contiguous() |
|
|
|
|
|
is_causal = ( |
|
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False |
|
) |
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attn_mask=attention_mask, |
|
dropout_p=self.dropout_prob if self.training else 0.0, |
|
is_causal=is_causal, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2) |
|
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) |
|
|
|
outputs = (attn_output,) |
|
if self.is_decoder: |
|
outputs = outputs + (past_key_value,) |
|
return outputs |
|
|
|
|
|
class RotaryBertAttention(BertAttention): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.self = RotaryBertSdpaSelfAttention(config) |
|
|
|
class RotaryBertLayer(BertLayer): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.attention = RotaryBertAttention(config) |
|
|
|
class RotaryBertEncoder(BertEncoder): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.layer = nn.ModuleList([RotaryBertLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
class RotaryBertModel(BertModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.encoder = RotaryBertEncoder(config) |
|
|
|
class RotaryBertForMaskedLM(BertForMaskedLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.bert = RotaryBertModel(config) |
|
|