File size: 6,455 Bytes
690bf20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
from einops import rearrange, repeat
from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder, BertModel, BertForMaskedLM
from typing import List, Optional, Tuple, Union
from packaging import version
import torch
import torch.nn as nn
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
inv_freq = 1.0 / (
10000.0
** (torch.arange(0, rotary_dim, 2, device=device, dtype=torch.float32) / rotary_dim)
)
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = torch.cos(freqs).to(dtype)
sin = torch.sin(freqs).to(dtype)
return cos, sin
# from transformers.models.roformer import RoFormerSinusoidalPositionalEmbedding
class RotaryBertSdpaSelfAttention(BertSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = False
# self.rotary_sinuses = RoFormerSinusoidalPositionalEmbedding(config.max_position_embeddings)
# Adapted from BertSelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask = None,
head_mask = None,
encoder_hidden_states = None,
encoder_attention_mask = None,
past_key_value = None,
output_attentions = False,
) -> Tuple[torch.Tensor]:
if output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
bsz, tgt_len, _ = hidden_states.size()
query_layer = self.query(hidden_states)
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
key_layer = self.key(current_states)
value_layer = self.value(current_states)
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3)
cos, sin = generate_cos_sin(query_layer.shape[1], query_layer.shape[-1], device = query_layer.device, dtype = torch.float32)
query_layer, key_layer = apply_rotary_emb_torch(query_layer, cos, sin), apply_rotary_emb_torch(key_layer, cos, sin)
query_layer, key_layer = query_layer.permute(0,2,1,3), key_layer.permute(0,2,1,3)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class RotaryBertAttention(BertAttention):
def __init__(self, config):
super().__init__(config)
self.self = RotaryBertSdpaSelfAttention(config)
class RotaryBertLayer(BertLayer):
def __init__(self, config):
super().__init__(config)
self.attention = RotaryBertAttention(config)
class RotaryBertEncoder(BertEncoder):
def __init__(self, config):
super().__init__(config)
self.layer = nn.ModuleList([RotaryBertLayer(config) for _ in range(config.num_hidden_layers)])
class RotaryBertModel(BertModel):
def __init__(self, config):
super().__init__(config)
self.encoder = RotaryBertEncoder(config)
class RotaryBertForMaskedLM(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.bert = RotaryBertModel(config)
|