# # Zhoubo # """ FLASH: https://arxiv.org/abs/2202.10447 """ import copy import torch import os from collections import Sequence import json import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN from .modeling import * from .ops import XSoftmax, sequence_masking from .bert import * from .config import ModelConfig from .cache_utils import load_model_state import einops class ScaleNorm(nn.Module): def __init__(self, eps=1e-5): super().__init__() self.eps = eps self.scala = nn.Parameter(torch.ones(1)) def forward(self, x): mean_square = (x ** 2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(mean_square + self.eps) * self.scala return x class OffsetScale(nn.Module): def __init__(self, dim, heads = 1): super().__init__() self.gamma = nn.Parameter(torch.ones(heads, dim)) self.beta = nn.Parameter(torch.zeros(heads, dim)) # nn.init.normal_(self.gamma, std = 0.02) # nn.init.xavier_uniform_(self.gamma) def forward(self, x): out = (x * self.gamma) + self.beta return out class ScaledSinuEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.scale = nn.Parameter(torch.ones(1,)) inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x): n, device = x.shape[1], x.device t = torch.arange(n, device = device).type_as(self.inv_freq) sinu = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) return emb * self.scale def RoPE(x, dim): """ :param x: input tensor :param dim: oprate dimension :return: tensor """ shape = x.shape if isinstance(dim, int): dim = [dim] spatial_shape = [shape[i] for i in dim] total_len = 1 for i in spatial_shape: total_len *= i position = torch.reshape(torch.arange(total_len, dtype=torch.float, device=x.device), spatial_shape) for i in range(dim[-1] + 1, len(shape) - 1, 1): position = torch.unsqueeze(position, dim=-1) half_size = shape[-1] // 2 freq_seq = -torch.arange(half_size, dtype=torch.float, device=x.device) / float(half_size) inv_freq = 10000 ** -freq_seq sinusoid = torch.einsum("...,d->...d", position, inv_freq) sin = torch.sin(sinusoid).repeat_interleave(2, -1) cos = torch.cos(sinusoid).repeat_interleave(2, -1) tensor_cross = torch.stack([-x[..., 1:: 2], x[..., :: 2]], -1).reshape(x.shape) # x1, x2 = torch.chunk(x, 2, dim=-1) # return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) return x * cos + tensor_cross * sin def rel_pos_bias(seq_len, s): a = torch.rand([1, s], dtype=torch.float) b = torch.rand([1, s], dtype=torch.float) w = torch.rand([2 * seq_len - 1], dtype=torch.float) if seq_len <= 512: t = F.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len) t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) r = (2 * seq_len - 1) // 2 t = t[..., r:-r] else: a = RoPE(a.repeat(seq_len, 1), dim=[0]) b = RoPE(b.repeat(seq_len, 1), dim=[0]) t = torch.einsum("mk,nk->mn", a, b) return t def squared_relu(x, attention_mask, dim=-1): rmask = ~(attention_mask.bool()) x = x.masked_fill(rmask, 0) return torch.square(F.relu(x)) def attention_normalize(a, axis=-1, mask=None, fn='softmax'): if fn == 'softmax': return XSoftmax.apply(a, mask, axis) else: mask_ = a > -float('inf') / 10 # mask_ = mask_.byte() mask_ = torch.sum(mask_, axis=axis, keepdim=True) l = torch.maximum(mask_, torch.ones_like(mask_)) if fn == 'squared_relu': rmask = ~(mask.bool()) a = a.masked_fill(rmask, 0) return torch.square(F.relu(a)) / l elif fn == 'softmax_plus': return XSoftmax.apply(a * torch.log(l) / np.log(512), mask, axis) return a class GAULinear(nn.Linear): def init_weight(self): nn.init.xavier_uniform_(self.weight) class GatedAttentionUnit(nn.Module): """ GAU Block: Gate Attention Unit """ def __init__( self, max_seq_length, hidden_size, attention_key_size=128, activation='swish', use_bias=True, attention_norm_type='squared_relu', attention_scale=True, dropout=0.1, pre_norm=False, norm_type="layer_norm", eps=1e-5, shift_token=False, use_rel_bias=False, add_residual=True, **kwargs,): super(GatedAttentionUnit, self).__init__(**kwargs) self.max_seq_length = max_seq_length self.units = hidden_size self.intermediate_size = self.units * 2 self.key_size = attention_key_size self.activation = activation self.use_bias = use_bias self.attention_norm_type = attention_norm_type self.attention_scale = attention_scale self.dropout = StableDropout(dropout) self.i_dense = nn.Sequential( nn.Linear(self.units, 2 * self.intermediate_size + self.key_size, bias=self.use_bias), nn.SiLU() ) self.o_dense = nn.Sequential( nn.Linear(self.intermediate_size, self.units, bias=self.use_bias), self.dropout) self.q_scaleoffset = OffsetScale(self.key_size) self.k_scaleoffset = OffsetScale(self.key_size) self.pre_norm = pre_norm self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type.lower() == "layer_norm" else ScaleNorm(eps=eps)) self.add_residual = add_residual def forward(self, x, attention_mask=None, **kwargs): shortcut = x if self.pre_norm: x = self.norm(x) x = self.i_dense(x) u, v, qk = torch.split(x, [self.intermediate_size, self.intermediate_size, self.key_size], dim=-1) q, k = self.q_scaleoffset(qk), self.k_scaleoffset(qk) qk = RoPE(torch.stack([q, k], 2), dim=1) q, k = qk[:, :, 0], qk[:, :, 1] a = torch.einsum('bmd,bnd->bmn', q, k) if self.attention_scale: a = a / self.key_size**0.5 a = sequence_masking(a, attention_mask, '-inf', -1) A = attention_normalize(a, -1, fn=self.attention_norm_type) if self.dropout: A = self.dropout(A) out = self.o_dense(u * torch.einsum('bmn,bnd->bmd', A, v)) if self.add_residual: out = out + shortcut if not self.pre_norm: out = self.norm(out) return out # # 加入RoPE # if p_bias == 'rotary': # qk = K.stack([q, k], 2) # qk = apply_rotary_position_embeddings(inputs[n], qk)[0] # q, k = qk[:, :, 0], qk[:, :, 1] # # Attention # a = tf.einsum('bmd,bnd->bmn', q, k) # if self.attention_scale: # a = a / self.key_size**0.5 # if a_bias is not None: # a = a + a_bias # a = sequence_masking(a, mask, '-inf', -1) # A = attention_normalize(a, -1, self.normalization) # if self.attention_dropout: # A = Dropout(self.attention_dropout)(A) # # 计算输出 # o = self.o_dense(u * tf.einsum('bmn,bnd->bmd', A, v)) # return o class GAU(nn.Module): def __init__(self, max_seq_length, hidden_size, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5, hidden_act="silu", shift_token=False, use_rel_bias=False, attention_norm_type='softmax', pre_norm=False, dropout=0, add_residual = True): super(GAU, self).__init__() self.max_seq_length = max_seq_length self.shift_token = shift_token hidden_dim = int(expansion_factor * hidden_size) self.norm = (nn.LayerNorm(hidden_size, eps=eps) if norm_type == "layer_norm" else ScaleNorm(eps=eps)) self.use_rel_bias = use_rel_bias self.attention_norm_type = attention_norm_type # if attention_norm_type == 'relu': # self.attention_norm_func = squared_relu # else: # self.attention_norm_func = XSoftmax.apply # self.norm = norm_klass(hidden_size) self.dropout = nn.Dropout(dropout) self.to_hidden = nn.Sequential( nn.Linear(hidden_size, hidden_dim * 2), nn.SiLU() ) self.to_qk = nn.Sequential( nn.Linear(hidden_size, s), nn.SiLU() ) self.offsetscale = OffsetScale(s, heads = 2) self.to_out = nn.Sequential( nn.Linear(hidden_dim, hidden_size), nn.Dropout(dropout) ) self.add_residual = add_residual self.act_fn = ACT2FN[hidden_act] self.pre_norm = pre_norm def forward( self, x, relative_pos = None, attention_mask = None ): seq_len, device = x.shape[-2], x.device if self.pre_norm: normed_x = self.norm(x) else: normed_x = x v, gate = self.to_hidden(normed_x).chunk(2, dim = -1) qk = self.to_qk(normed_x) base = self.offsetscale(qk) base = RoPE(base, 1) q, k = base.unbind(dim = -2) sim = torch.einsum('b i d, b j d -> b i j', q, k) if relative_pos is not None: sim = sim + relative_pos if attention_mask is not None: if attention_mask.dim() < 3: attention_mask = einops.rearrange(attention_mask, 'b j -> b 1 j') # attn = attn.masked_fill(~attention_mask.bool(), 0.) attn = attention_normalize(sim, mask=attention_mask, fn=self.attention_norm_type) # attn = F.relu(sim) ** 2 / seq_len# / q.size(-1) # logger.info(attn.max()) attn = self.dropout(attn) # if self.causal: # causal_mask = torch.ones((seq_len, seq_len), dtype = torch.bool, device = device).triu(1) # attn = attn.masked_fill(causal_mask, 0.) out = torch.einsum('b i j, b j d -> b i d', attn, v) out = out * gate out = self.to_out(out) if self.add_residual: out = out + x if not self.pre_norm: out = self.norm(out) return out class GAULayer(nn.Module): def __init__(self, config, shift_token=False, use_ffn=False): super(GAULayer, self).__init__() self.attention = GatedAttentionUnit(config.max_position_embeddings, config.hidden_size, shift_token=shift_token, use_rel_bias=config.use_rel_bias, norm_type=config.norm_type, attention_norm_type=config.attention_norm_type, pre_norm=config.pre_norm, dropout=config.hidden_dropout_prob) if use_ffn: self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.use_ffn = use_ffn def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): attention_output = self.attention(hidden_states, attention_mask=attention_mask, relative_pos=relative_pos) if self.use_ffn: intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output else: return attention_output class FlashBlock(nn.Module): """ FLASH Block: Fast Linear Attention with a Single Head """ def __init__(self, model_size, sequence_length, chunk_size=256, expansion_factor=2, s=128, norm_type="layer_norm", eps=1e-5, hidden_act="silu"): super(FlashBlock, self).__init__() self.s = s self.eps = eps self.norm_type = norm_type self.model_size = model_size self.chunk_size = chunk_size self.hidden_act = hidden_act self.sequence_length = sequence_length self.expansion_factor = expansion_factor self.e = int(self.model_size * self.expansion_factor) self.dense1 = nn.Linear(self.model_size, 2 * self.e + self.s, bias=True) self.gamma = nn.Parameter(torch.rand((4, self.s))) self.beta = nn.Parameter(torch.rand((4, self.s))) self.dense2 = nn.Linear(self.e, self.model_size) self.LayerNorm = ( nn.LayerNorm(model_size, eps=self.eps) if norm_type == "layer_norm" else ScaleNorm(eps=self.eps)) nn.init.xavier_normal_(self.dense1.weight) self.act_fn = ACT2FN(self.hidden_act) def global_linear_attention(self, query, key, value, causal): if causal: kv = torch.einsum("bgcs, bgce->bgse", key, value) kv = torch.cumsum(kv, dim=1) lin_v = torch.einsum("bgcs, bgse->bgce", query, kv) return lin_v else: kv = torch.einsum("bgcs, bgce->bse", key, value) lin_v = torch.einsum("bgcs, bse->bgce", query, kv) return lin_v def segment_ids_to_mask(self, segment_ids, causal=False): """Generate the segment mask from the segment ids. The segment mask is used to remove the attention between tokens in different documents. """ min_ids, max_ids = torch.min(segment_ids, dim=-1).values, torch.max(segment_ids, dim=-1).values # 1.0 indicates in the same group and 0.0 otherwise mask = torch.logical_and(torch.less_equal(min_ids[:, :, None], max_ids[:, None, :]), torch.greater_equal(max_ids[:, :, None], min_ids[:, None, :])) mask = torch.tensor(mask, torch.float32) if causal: g = segment_ids.size()[1] causal_mask = 1.0 - torch.triu(torch.ones([g, g], dtype=torch.float32)) # 保留主对角线以及主对角线以上的元素 mask *= causal_mask mask = torch.div(mask, torch.sum(mask, dim=-1, keepdim=True)) return mask def forward(self, x, causal=False, attention_mask=None, sequence_mask=None, **kwargs): """ inputs: [batch_size, num_chunk, chunk_length, model_size] """ _, g, n, d = x.size() shortcut, x = x, self.LayerNorm(x) # 通过线性变换得到Z,见论文公式(4) uv = self.dense1(x) # 将uv按最后一维切分,得到Ug:[C*e],Vg:[C*e], Zg:[C*s], 论文中的3.2部分 # u:[batch_size, num_chunk, chunk_length, self.e] # v:[batch_size, num_chunk, chunk_length, self.e] # z:[batch_size, num_chunk, chunk_length, self.s] u, v, z = torch.split(self.act_fn(uv), [self.e, self.e, self.s], dim=-1) # 生成quad_q, quad_k, lin_q, lin_k # 首先进行简单的offset和scale,融入RoPE位置向量 z = torch.einsum("...r, hr->...hr", z, self.gamma) + self.beta z = RoPE(z, dim=[1, 2]) quad_q, quad_k, lin_q, lin_k = torch.unbind(z, dim=-2) # 按-2维进行分解得到quad_q, quad_k, lin_q和lin_k # 计算global的lin_v lin_v = self.global_linear_attention(lin_q, lin_k, v, causal) if causal: # 线性注意力部分 lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7) mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal) cum_lin_kv = torch.einsum('bhke, bgh->bgke', lin_kv, mask) linear = torch.einsum("bgnk, bgke->bgne", lin_kv, cum_lin_kv) # 二次注意力 quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分 bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n] kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分 causal_mask = torch.triu(torch.ones([n, n], dtype=x.dtype)) quadratic = torch.einsum("bgnm, bgme->bgne", kernel * causal_mask, v) else: lin_kv = torch.einsum("bgnk, bgne->bgke", lin_k, lin_v) / torch.tensor(n, x.dtype) # 见公式(7) mask = self.segment_ids_to_mask(segment_ids=segment_ids, causal=causal) lin_kv = torch.einsum("bhke, bgh->bgke", lin_kv, mask) linear = torch.einsum("bgnk, bgke->bgne", lin_q, lin_kv) # 二次注意力 quad_qk = torch.einsum("bgnk, bgmk->bgnm", quad_q, quad_k) # 论文Local attention per chunk部分 bias = rel_pos_bias(self.sequence_length, self.s)[:, :n, :n] kernel = torch.square(F.relu(quad_qk / n + bias)) # 论文中的relu**2部分 quadratic = torch.einsum("bgnm, bgme->bgne", kernel, v) x = u * (quadratic + linear) x = self.dense2(x) x = x + shortcut return x class RelativePositionBias(nn.Module): def __init__( self, scale, causal = False, num_buckets = 32, max_distance = 128 ): super().__init__() self.scale = scale self.causal = causal self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, 1) @staticmethod def _relative_position_bucket( relative_position, causal = True, num_buckets = 32, max_distance = 128 ): ret = 0 n = -relative_position if not causal: num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, x): i, j, device = *x.shape[-2:], x.device q_pos = torch.arange(i, dtype = torch.long, device = device) k_pos = torch.arange(j, dtype = torch.long, device = device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j 1 -> i j') return bias * self.scale class FlashEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config, with_position=False): super(FlashEmbeddings, self).__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size) self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size) self.with_position = with_position if with_position: self.position_embeddings = ScaledSinuEmbedding(config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) self.dropout = StableDropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None, position_ids=None, token_mask=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) if self.with_position: position_embeddings = self.position_embeddings(words_embeddings) else: position_embeddings = 0 token_type_embeddings = self.token_type_embeddings(token_type_ids) # if self.num_pos_emb > 1: # num_batch = position_embeddings.size(0) # num_pos = position_embeddings.size(1) # position_embeddings = position_embeddings.view( # num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] embeddings = words_embeddings + position_embeddings + token_type_embeddings # if self.fp32_embedding: # embeddings = embeddings.half() embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, token_mask) embeddings = self.dropout(embeddings) return { 'embeddings': embeddings, 'position_embeddings': position_embeddings} class GAUEncoder(nn.Module): def __init__(self, config, shift_token=False): super().__init__() layer = GAULayer(config, shift_token=shift_token) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: extended_attention_mask = attention_mask.unsqueeze(1) attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1) attention_mask = attention_mask #.byte() return attention_mask def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None): all_encoder_layers = [] att_matrices = [] if isinstance(hidden_states, Sequence): next_kv = hidden_states[0] else: next_kv = hidden_states # rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): output_states = layer_module(next_kv, attention_mask, query_states = query_states, relative_pos=relative_pos) if return_att: output_states, att_m = output_states # if i == 0 and self.with_conv: # prenorm = output_states #output['prenorm_states'] # output_states = self.conv(hidden_states, prenorm, input_mask) if query_states is not None: query_states = output_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None else: next_kv = output_states if output_all_encoded_layers: all_encoder_layers.append(output_states) if return_att: att_matrices.append(att_m) if not output_all_encoded_layers: all_encoder_layers.append(output_states) if return_att: att_matrices.append(att_m) return { 'hidden_states': all_encoder_layers, 'attention_matrices': att_matrices } class FlashEncoder(nn.Module): def __init__(self, config): super().__init__(config) layer = GateAttentionUnit(config.max_position_embeddings, config.hidden_size) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, token_mask=None, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None): # history embedding and encoded layer must be simultanously given assert (prev_embedding is None) == (prev_encoded_layers is None) all_encoder_layers = [] if (prev_embedding is not None) and (prev_encoded_layers is not None): history_states = prev_embedding for i, layer_module in enumerate(self.layer): hidden_states = layer_module( hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if prev_encoded_layers is not None: history_states = prev_encoded_layers[i] else: for layer_module in self.layer: hidden_states = layer_module( hidden_states, attention_mask=attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) return all_encoder_layers # class FlashQuadModel(BertModel): # def __init__(self, config, pooler=False, shift_token=False, causal=False) -> None: # super().__init__(config) # self.embeddings = FlashEmbeddings(config) # self.encoder = GAUEncoder(config, causal=causal, shift_token=shift_token) # if not pooler: # self.pooler = None # self.apply(self.init_bert_weights) class FlashQuadModel(torch.nn.Module): """ Parameters: config: A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, pre_trained: The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, i.e. [**base, large, base_mnli, large_mnli**] """ def __init__(self, config=None, pre_trained=None, pooler=False, shift_token=False, causal=False, **kwargs): super().__init__() state = None if pre_trained is not None: state, model_config = load_model_state(pre_trained) if config is not None and model_config is not None: for k in config.__dict__: if k not in ['hidden_size', 'intermediate_size', 'num_attention_heads', 'num_hidden_layers', 'vocab_size', 'max_position_embeddings']: model_config.__dict__[k] = config.__dict__[k] config = copy.copy(model_config) self.embeddings = FlashEmbeddings(config, with_position=True) self.encoder = GAUEncoder(config, shift_token=shift_token) if not pooler: self.pooler = None self.config = config self.pre_trained = pre_trained self.apply_state(state) def get_attention_mask(self, input_ids=None, token_type_ids=None, attention_mask=None, input_mask=None): if attention_mask is None: if input_mask is not None: return input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1)) else: return torch.ones_like(input_ids, dtype=torch.uint8).unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), input_mask.size(1)) else: if attention_mask.dim() == 2: if input_mask is not None: attention_mask = attention_mask * input_mask return attention_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1)) if attention_mask.dim() == 4: attention_mask = attention_mask.squeeze(2) if attention_mask.dim() == 3: if input_mask is not None: return attention_mask * input_mask.unsqueeze(-1).expand(input_mask.size(0), input_mask.size(1), attention_mask.size(-1)) else: return attention_mask def forward(self, input_ids, input_mask, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids=None, return_att=False): """ Args: input_ids: a torch.LongTensor of shape [batch_size, sequence_length] \ with the word token indices in the vocabulary attention_mask: an optional parameter for input mask or attention mask. - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \ input sequence length in the current batch. It's the mask that we typically use for attention when \ a batch has varying length sentences. - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \ In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence. token_type_ids: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \ a `sentence B` token (see BERT paper for more details). output_all_encoded_layers: whether to output results of all encoder layers, default, True Returns: - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \ the last layer of stacked transformer layers - Attention matrix of self-attention layers if `return_att=True` Example:: # Batch of wordPiece token ids. # Each sample was padded with zero to the maxium length of the batch input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) # Mask of valid input ids attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) # DeBERTa model initialized with pretrained base model bert = DeBERTa(pre_trained='base') encoder_layers = bert(input_ids, attention_mask=attention_mask) """ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) # input_mask = torch.ones_like(input_ids) if input_mask is None: idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1]) input_mask = idxs > 0 if not torch.any(input_mask): input_mask = torch.ones_like(input_ids) input_mask = input_mask # .byte() attention_mask = self.get_attention_mask(input_ids, token_type_ids, attention_mask, input_mask) attention_mask = attention_mask #.byte() embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, input_mask) encoder_output = self.encoder(embedding_output['embeddings'], attention_mask, output_all_encoded_layers=output_all_encoded_layers, return_att = return_att) encoder_output.update(embedding_output) return encoder_output def apply_state(self, state = None): """ Load state from previous loaded model state dictionary. Args: state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \ If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \ the `DeBERTa` model """ if self.pre_trained is None and state is None: return if state is None: state, config = load_model_state(self.pre_trained) self.config = config prefix = '' for k in state: if 'embeddings.' in k: if not k.startswith('embeddings.'): prefix = k[:k.index('embeddings.')] break missing_keys = [] unexpected_keys = [] error_msgs = [] self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs) class FlashModel(BertModel): def __init__(self, config) -> None: super().__init__(config) self.encoder = FlashEncoder(config) self.apply(self.init_bert_weights) if __name__ == '__main__': model = FlashModel(768, 64)