# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) Microsoft, Inc. 2020 # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This piece of code is modified based on https://github.com/huggingface/transformers import torch from torch import nn from collections import Sequence from packaging import version from .ops import * from .disentangled_attention import * from .da_utils import * __all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead'] class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config def forward(self, hidden_states, input_states, mask=None): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states += input_states hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) return hidden_states class BertAttention(nn.Module): def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) self.output = BertSelfOutput(config) self.config = config def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): output = self.self(hidden_states, attention_mask, return_att, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) self_output, att_matrix, att_logits_=output['hidden_states'], output['attention_probs'], output['attention_logits'] if query_states is None: query_states = hidden_states attention_output = self.output(self_output, query_states, attention_mask) if return_att: return (attention_output, att_matrix) else: return attention_output class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config def forward(self, hidden_states, input_states, mask=None): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states += input_states hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) return hidden_states class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \ query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) if return_att: attention_output, att_matrix = attention_output intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output, attention_mask) if return_att: return (layer_output, att_matrix) else: return layer_output class ConvLayer(nn.Module): def __init__(self, config): super().__init__() kernel_size = getattr(config, 'conv_kernel_size', 3) groups = getattr(config, 'conv_groups', 1) self.conv_act = getattr(config, 'conv_act', 'tanh') self.conv = torch.nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size, padding = (kernel_size-1)//2, groups = groups) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config def forward(self, hidden_states, residual_states, input_mask): out = self.conv(hidden_states.permute(0,2,1).contiguous()).permute(0,2,1).contiguous() if version.Version(torch.__version__) >= version.Version('1.2.0a'): rmask = (1-input_mask).bool() else: rmask = (1-input_mask).byte() out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) out = ACT2FN[self.conv_act](self.dropout(out)) output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) return output_states class BertEncoder(nn.Module): """ Modified BertEncoder with relative position bias support """ def __init__(self, config): super().__init__() #layer = BertLayer(config) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.relative_attention = getattr(config, 'relative_attention', False) if self.relative_attention: self.max_relative_positions = getattr(config, 'max_relative_positions', -1) if self.max_relative_positions <1: self.max_relative_positions = config.max_position_embeddings self.position_buckets = getattr(config, 'position_buckets', -1) pos_ebd_size = self.max_relative_positions*2 if self.position_buckets>0: pos_ebd_size = self.position_buckets*2 self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) self.norm_rel_ebd = [x.strip() for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|')] if 'layer_norm' in self.norm_rel_ebd: self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine = True) kernel_size = getattr(config, 'conv_kernel_size', 0) self.with_conv = False if kernel_size > 0: self.with_conv = True self.conv = ConvLayer(config) def get_rel_embedding(self): rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd): rel_embeddings = self.LayerNorm(rel_embeddings) return rel_embeddings def get_attention_mask(self, attention_mask): if attention_mask.dim()<=2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1) attention_mask = attention_mask.byte() elif attention_mask.dim()==3: attention_mask = attention_mask.unsqueeze(1) return attention_mask def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size = self.position_buckets, max_position=self.max_relative_positions) return relative_pos def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None): if attention_mask.dim()<=2: input_mask = attention_mask else: input_mask = (attention_mask.sum(-2)>0).byte() attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) all_encoder_layers = [] att_matrices = [] if isinstance(hidden_states, Sequence): next_kv = hidden_states[0] else: next_kv = hidden_states rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): output_states = layer_module(next_kv, attention_mask, return_att, query_states = query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) if return_att: output_states, att_m = output_states if i == 0 and self.with_conv: prenorm = output_states #output['prenorm_states'] output_states = self.conv(hidden_states, prenorm, input_mask) if query_states is not None: query_states = output_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None else: next_kv = output_states if output_all_encoded_layers: all_encoder_layers.append(output_states) if return_att: att_matrices.append(att_m) if not output_all_encoded_layers: all_encoder_layers.append(output_states) if return_att: att_matrices.append(att_m) return { 'hidden_states': all_encoder_layers, 'attention_matrices': att_matrices } class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config): super(BertEmbeddings, self).__init__() padding_idx = getattr(config, 'padding_idx', 0) self.embedding_size = getattr(config, 'embedding_size', config.hidden_size) self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx) self.position_biased_input = getattr(config, 'position_biased_input', True) self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) if config.type_vocab_size>0: self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) if self.embedding_size != config.hidden_size: self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.output_to_half = False self.config = config def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None): seq_length = input_ids.size(1) if position_ids is None: position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids.long()) embeddings = words_embeddings if self.config.type_vocab_size>0: token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings += token_type_embeddings if self.position_biased_input: embeddings += position_embeddings if self.embedding_size != self.config.hidden_size: embeddings = self.embed_proj(embeddings) embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask) embeddings = self.dropout(embeddings) return { 'embeddings': embeddings, 'position_embeddings': position_embeddings} class BertLMPredictionHead(nn.Module): def __init__(self, config, vocab_size): super().__init__() self.embedding_size = getattr(config, 'embedding_size', config.hidden_size) self.dense = nn.Linear(config.hidden_size, self.embedding_size) self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True) self.bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, hidden_states, embeding_weight): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) # b x s x d hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) # b x s x v logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias return logits class AR_MASK(object): def get_attention_mask(self, input_ids=None, token_type_ids=None ): seq_len = input_ids.size(1) # idxs = torch.arange(0, seq_len) # mask = idxs[None, :] <= idxs[:, None] mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.uint8)).to(input_ids.device) mask = mask.unsqueeze(0).expand(input_ids.size(0), seq_len, seq_len) return mask # torch.diagonal(torch.ones([input_ids.size(1), input_ids.size(1)])).byte().to(input_ids.device) class Prefix_MASK(object): def get_attention_mask(self, input_ids=None, token_type_ids=None): idxs = torch.cumsum(token_type_ids, axis=1) mask = idxs[:, None, :] <= idxs[:, :, None] return mask.byte().to(input_ids.device)