specieslm-metazoa-upstream-k6 / modeling_rotarybert.py
Johannes
speciesLM k6, metazoa, upstream
690bf20
raw
history blame
6.46 kB
import math
from einops import rearrange, repeat
from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder, BertModel, BertForMaskedLM
from typing import List, Optional, Tuple, Union
from packaging import version
import torch
import torch.nn as nn
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
inv_freq = 1.0 / (
10000.0
** (torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) / rotary_dim)
)
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = torch.cos(freqs).to(dtype)
sin = torch.sin(freqs).to(dtype)
return cos, sin
# from transformers.models.roformer import RoFormerSinusoidalPositionalEmbedding
class RotaryBertSdpaSelfAttention(BertSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = False
# self.rotary_sinuses = RoFormerSinusoidalPositionalEmbedding(config.max_position_embeddings)
# Adapted from BertSelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask = None,
head_mask = None,
encoder_hidden_states = None,
encoder_attention_mask = None,
past_key_value = None,
output_attentions = False,
) -> Tuple[torch.Tensor]:
if output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
bsz, tgt_len, _ = hidden_states.size()
query_layer = self.query(hidden_states)
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
key_layer = self.key(current_states)
value_layer = self.value(current_states)
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3)
cos, sin = generate_cos_sin(query_layer.shape[1], query_layer.shape[-1], device = query_layer.device, dtype = torch.float32)
query_layer, key_layer = apply_rotary_emb_torch(query_layer, cos, sin), apply_rotary_emb_torch(key_layer, cos, sin)
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class RotaryBertAttention(BertAttention):
def __init__(self, config):
super().__init__(config)
self.self = RotaryBertSdpaSelfAttention(config)
class RotaryBertLayer(BertLayer):
def __init__(self, config):
super().__init__(config)
self.attention = RotaryBertAttention(config)
class RotaryBertEncoder(BertEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList([RotaryBertLayer(config) for _ in range(config.num_hidden_layers)])
class RotaryBertModel(BertModel):
def __init__(self, config):
super().__init__(config)
self.encoder = RotaryBertEncoder(config)
class RotaryBertForMaskedLM(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.bert = RotaryBertModel(config)