# coding=utf-8 # Copyright 2023 Language Technology Group from University of Oslo and The HuggingFace Inc. team. # And Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Base implementation of the LTG-BERT/ELC-BERT Model is from Language Technology Group from University of Oslo and The HuggingFace Inc., Team # The StructFormer components is from The Google Research Authors - the authors were Yikang Shen and Yi Tay and Che Zheng and Dara Bahri and Donald Metzler and Aaron Courville # (and the code can be from here: https://github.com/google-research/google-research/tree/master/structformer), both were using Apache license, Version 2.0 """ PyTorch LTG-(ELC)-ParserBERT model.""" import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import checkpoint from .configuration_ltgbert import LtgBertConfig from transformers.modeling_utils import PreTrainedModel from transformers.activations import gelu_new from transformers.modeling_outputs import ( MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, BaseModelOutput, ) from transformers.pytorch_utils import softmax_backward_data from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, ) _CHECKPOINT_FOR_DOC = "ltg/bnc-bert-span" _CONFIG_FOR_DOC = "LtgBertConfig" LTG_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "bnc-bert-span", "bnc-bert-span-2x", "bnc-bert-span-0.5x", "bnc-bert-span-0.25x", "bnc-bert-span-order", "bnc-bert-span-document", "bnc-bert-span-word", "bnc-bert-span-subword", "norbert3-xs", "norbert3-small", "norbert3-base", "norbert3-large", "norbert3-oversampled-base", "norbert3-ncc-base", "norbert3-nak-base", "norbert3-nb-base", "norbert3-wiki-base", "norbert3-c4-base", ] class Conv1d(nn.Module): """1D convolution layer.""" def __init__(self, hidden_size, kernel_size, dilation=1): """Initialization. Args: hidden_size: dimension of input embeddings kernel_size: convolution kernel size dilation: the spacing between the kernel points """ super(Conv1d, self).__init__() if kernel_size % 2 == 0: padding = (kernel_size // 2) * dilation self.shift = True else: padding = ((kernel_size - 1) // 2) * dilation self.shift = False self.conv = nn.Conv1d( hidden_size, hidden_size, kernel_size, padding=padding, dilation=dilation ) def forward(self, x): """Compute convolution. Args: x: input embeddings Returns: conv_output: convolution results """ if self.shift: return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:] else: return self.conv(x.transpose(1, 2)).transpose(1, 2) def cumprod(x, reverse=False, exclusive=False): """cumulative product.""" if reverse: x = x.flip([-1]) if exclusive: x = F.pad(x[:, :, :-1], (1, 0), value=1) cx = x.cumprod(-1) if reverse: cx = cx.flip([-1]) return cx def cumsum(x, reverse=False, exclusive=False): """cumulative sum.""" bsz, _, length = x.size() device = x.device if reverse: if exclusive: w = torch.ones([bsz, length, length], device=device).tril(-1) else: w = torch.ones([bsz, length, length], device=device).tril(0) cx = torch.bmm(x, w) else: if exclusive: w = torch.ones([bsz, length, length], device=device).triu(1) else: w = torch.ones([bsz, length, length], device=device).triu(0) cx = torch.bmm(x, w) return cx def cummin(x, reverse=False, exclusive=False, max_value=1e4): """cumulative min.""" if reverse: if exclusive: x = F.pad(x[:, :, 1:], (0, 1), value=max_value) x = x.flip([-1]).cummin(-1)[0].flip([-1]) else: if exclusive: x = F.pad(x[:, :, :-1], (1, 0), value=max_value) x = x.cummin(-1)[0] return x class ParserNetwork(nn.Module): def __init__( self, config, pad=0, n_parser_layers=4, conv_size=9, relations=("head", "child"), weight_act="softmax", ): """ hidden_size: dimension of input embeddings nlayers: number of layers ntokens: number of output categories nhead: number of self-attention heads dropout: dropout rate pad: pad token index n_parser_layers: number of parsing layers conv_size: convolution kernel size for parser relations: relations that are used to compute self attention weight_act: relations distribution activation function """ super(ParserNetwork, self).__init__() self.hidden_size = config.hidden_size self.num_hidden_layers = config.num_hidden_layers self.num_attention_heads = config.num_attention_heads self.parser_layers = nn.ModuleList( [ nn.Sequential( Conv1d(self.hidden_size, conv_size), nn.LayerNorm(self.hidden_size, elementwise_affine=False), nn.Tanh(), ) for _ in range(n_parser_layers) ] ) self.distance_ff = nn.Sequential( Conv1d(self.hidden_size, 2), nn.LayerNorm(self.hidden_size, elementwise_affine=False), nn.Tanh(), nn.Linear(self.hidden_size, 1), ) self.height_ff = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.LayerNorm(self.hidden_size, elementwise_affine=False), nn.Tanh(), nn.Linear(self.hidden_size, 1), ) n_rel = len(relations) self._rel_weight = nn.Parameter( torch.zeros((self.num_hidden_layers, self.num_attention_heads, n_rel)) ) self._rel_weight.data.normal_(0, 0.1) self._scaler = nn.Parameter(torch.zeros(2)) self.n_parse_layers = n_parser_layers self.weight_act = weight_act self.relations = relations self.pad = pad @property def scaler(self): return self._scaler.exp() @property def rel_weight(self): if self.weight_act == "sigmoid": return torch.sigmoid(self._rel_weight) elif self.weight_act == "softmax": return torch.softmax(self._rel_weight, dim=-1) def parse(self, x, h): """ Parse input sentence. Args: x: input tokens (required). h: static embeddings Returns: distance: syntactic distance height: syntactic height """ mask = x != self.pad mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0) for i in range(self.n_parse_layers): h = h.masked_fill(~mask[:, :, None], 0) h = self.parser_layers[i](h) height = self.height_ff(h).squeeze(-1) height.masked_fill_(~mask, -1e4) distance = self.distance_ff(h).squeeze(-1) distance.masked_fill_(~mask_shifted, 1e4) # Calbrating the distance and height to the same level length = distance.size(1) height_max = height[:, None, :].expand(-1, length, -1) height_max = torch.cummax( height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e4, dim=-1 )[0].triu(0) margin_left = torch.relu( F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e4) - height_max ) margin_right = torch.relu(distance[:, None, :] - height_max) margin = torch.where( margin_left > margin_right, margin_right, margin_left ).triu(0) margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1) margin.masked_fill_(~margin_mask, 0) margin = margin.max() distance = distance - margin return distance, height def compute_block(self, distance, height): """Compute constituents from distance and height.""" beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0] gamma = torch.sigmoid(-beta_logits) ones = torch.ones_like(gamma) block_mask_left = cummin( gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1 ) block_mask_left = block_mask_left - F.pad( block_mask_left[:, :, :-1], (1, 0), value=0 ) block_mask_left.tril_(0) block_mask_right = cummin( gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1 ) block_mask_right = block_mask_right - F.pad( block_mask_right[:, :, 1:], (0, 1), value=0 ) block_mask_right.triu_(0) block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :] block = cumsum(block_mask_left).tril(0) + cumsum( block_mask_right, reverse=True ).triu(1) return block_p, block def compute_head(self, height): """Estimate head for each constituent.""" _, length = height.size() head_logits = height * self.scaler[1] index = torch.arange(length, device=height.device) mask = (index[:, None, None] <= index[None, None, :]) * ( index[None, None, :] <= index[None, :, None] ) head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1) head_logits.masked_fill_(~mask[None, :, :, :], -1e4) head_p = torch.softmax(head_logits, dim=-1) return head_p def generate_mask(self, x, distance, height): """Compute head and cibling distribution for each token.""" batch_size, length = x.size() eye = torch.eye(length, device=x.device, dtype=torch.bool) eye = eye[None, :, :].expand((batch_size, -1, -1)) block_p, block = self.compute_block(distance, height) head_p = self.compute_head(height) head = torch.einsum("blij,bijh->blh", block_p, head_p) head = head.masked_fill(eye, 0) child = head.transpose(1, 2) cibling = torch.bmm(head, child).masked_fill(eye, 0) rel_list = [] if "head" in self.relations: rel_list.append(head) if "child" in self.relations: rel_list.append(child) if "cibling" in self.relations: rel_list.append(cibling) rel = torch.stack(rel_list, dim=1) rel_weight = self.rel_weight dep = torch.einsum("lhr,brij->lbhij", rel_weight, rel) att_mask = dep.reshape( self.num_hidden_layers, batch_size, self.num_attention_heads, length, length ) return att_mask, cibling, head, block def forward(self, x, embeddings): """ Pass the x tokens through the parse network, get the syntactic height and distances and compute the distribution for each token """ x = torch.transpose(x, 0, 1) embeddings = torch.transpose(embeddings, 0, 1) distance, height = self.parse(x, embeddings) att_mask, cibling, head, block = self.generate_mask(x, distance, height) return att_mask, cibling, head, block class Encoder(nn.Module): def __init__(self, config, activation_checkpointing=False): super().__init__() self.layers = nn.ModuleList( [EncoderLayer(config, i) for i in range(config.num_hidden_layers)] ) for i, layer in enumerate(self.layers): layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i))) layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i))) self.activation_checkpointing = activation_checkpointing def forward(self, hidden_states, attention_mask, relative_embedding): hidden_states, attention_probs = [hidden_states], [] for i in range(len(self.layers)): if self.activation_checkpointing: hidden_state, attention_p = checkpoint.checkpoint( self.layers[i], hidden_states, attention_mask, relative_embedding ) else: hidden_state, attention_p = self.layers[i]( hidden_states, attention_mask[i], relative_embedding ) hidden_states.append(hidden_state) attention_probs.append(attention_p) return hidden_states, attention_probs class MaskClassifier(nn.Module): def __init__(self, config, subword_embedding): super().__init__() self.nonlinearity = nn.Sequential( nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=False ), nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=False ), nn.Dropout(config.hidden_dropout_prob), nn.Linear(subword_embedding.size(1), subword_embedding.size(0)), ) self.initialize(config.hidden_size, subword_embedding) def initialize(self, hidden_size, embedding): std = math.sqrt(2.0 / (5.0 * hidden_size)) nn.init.trunc_normal_( self.nonlinearity[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) self.nonlinearity[-1].weight = embedding self.nonlinearity[1].bias.data.zero_() self.nonlinearity[-1].bias.data.zero_() def forward(self, x, masked_lm_labels=None): if masked_lm_labels is not None: x = torch.index_select( x.flatten(0, 1), 0, torch.nonzero(masked_lm_labels.flatten() != -100).squeeze(), ) x = self.nonlinearity(x) return x class EncoderLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.attention = Attention(config) self.mlp = FeedForward(config) temp = torch.zeros(layer_num + 1) temp[-1] = 1 self.prev_layer_weights = nn.Parameter(temp) def forward(self, hidden_states, padding_mask, relative_embedding): prev_layer_weights = F.softmax(self.prev_layer_weights, dim=-1) x = prev_layer_weights[0] * hidden_states[0] for i, hidden_state in enumerate(hidden_states[1:]): x = x + prev_layer_weights[i + 1] * hidden_state attention_output, attention_probs = self.attention( x, padding_mask, relative_embedding ) x = attention_output x = x + self.mlp(x) return x, attention_probs class GeGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) x = x * gelu_new(gate) return x class FeedForward(nn.Module): def __init__(self, config): super().__init__() self.mlp = nn.Sequential( nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False ), nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False), GeGLU(), nn.LayerNorm( config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False, ), nn.Linear(config.intermediate_size, config.hidden_size, bias=False), nn.Dropout(config.hidden_dropout_prob), ) self.initialize(config.hidden_size) def initialize(self, hidden_size): std = math.sqrt(2.0 / (5.0 * hidden_size)) nn.init.trunc_normal_( self.mlp[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) nn.init.trunc_normal_( self.mlp[-2].weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) def forward(self, x): return self.mlp(x) class MaskedSoftmax(torch.autograd.Function): @staticmethod def forward(self, x, mask, dim): self.dim = dim x.masked_fill_(mask, float("-inf")) x = torch.softmax(x, self.dim) x.masked_fill_(mask, 0.0) self.save_for_backward(x) return x @staticmethod def backward(self, grad_output): (output,) = self.saved_tensors input_grad = softmax_backward_data(self, grad_output, output, self.dim, output) return input_grad, None, None class Attention(nn.Module): def __init__(self, config): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}" ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads self.in_proj_qk = nn.Linear( config.hidden_size, 2 * config.hidden_size, bias=True ) self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) self.pre_layer_norm = nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=False ) self.post_layer_norm = nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=True ) position_indices = torch.arange( config.max_position_embeddings, dtype=torch.long ).unsqueeze(1) - torch.arange( config.max_position_embeddings, dtype=torch.long ).unsqueeze( 0 ) position_indices = self.make_log_bucket_position( position_indices, config.position_bucket_size, config.max_position_embeddings, ) position_indices = config.position_bucket_size - 1 + position_indices self.register_buffer("position_indices", position_indices, persistent=True) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.scale = 1.0 / math.sqrt(3 * self.head_size) self.initialize() def make_log_bucket_position(self, relative_pos, bucket_size, max_position): sign = torch.sign(relative_pos) mid = bucket_size // 2 abs_pos = torch.where( (relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1), ) log_pos = ( torch.ceil( torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1) ).int() + mid ) bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long() return bucket_pos def initialize(self): std = math.sqrt(2.0 / (5.0 * self.hidden_size)) nn.init.trunc_normal_( self.in_proj_qk.weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) nn.init.trunc_normal_( self.in_proj_v.weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) nn.init.trunc_normal_( self.out_proj.weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) self.in_proj_qk.bias.data.zero_() self.in_proj_v.bias.data.zero_() self.out_proj.bias.data.zero_() def compute_attention_scores(self, hidden_states, relative_embedding): key_len, batch_size, _ = hidden_states.size() query_len = key_len if self.position_indices.size(0) < query_len: position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze( 1 ) - torch.arange(query_len, dtype=torch.long).unsqueeze(0) position_indices = self.make_log_bucket_position( position_indices, self.position_bucket_size, 512 ) position_indices = self.position_bucket_size - 1 + position_indices self.position_indices = position_indices.to(hidden_states.device) hidden_states = self.pre_layer_norm(hidden_states) query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D] value = self.in_proj_v(hidden_states) # shape: [T, B, D] query = query.reshape( query_len, batch_size * self.num_heads, self.head_size ).transpose(0, 1) key = key.reshape( key_len, batch_size * self.num_heads, self.head_size ).transpose(0, 1) value = value.view( key_len, batch_size * self.num_heads, self.head_size ).transpose(0, 1) attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale) query_pos, key_pos = self.in_proj_qk(self.dropout(relative_embedding)).chunk( 2, dim=-1 ) # shape: [2T-1, D] query_pos = query_pos.view( -1, self.num_heads, self.head_size ) # shape: [2T-1, H, D] key_pos = key_pos.view( -1, self.num_heads, self.head_size ) # shape: [2T-1, H, D] query = query.view(batch_size, self.num_heads, query_len, self.head_size) key = key.view(batch_size, self.num_heads, query_len, self.head_size) attention_c_p = torch.einsum( "bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale ) attention_p_c = torch.einsum( "bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1) ) position_indices = self.position_indices[:query_len, :key_len].expand( batch_size, self.num_heads, -1, -1 ) attention_c_p = attention_c_p.gather(3, position_indices) attention_p_c = attention_p_c.gather(2, position_indices) attention_scores = attention_scores.view( batch_size, self.num_heads, query_len, key_len ) attention_scores.add_(attention_c_p) attention_scores.add_(attention_p_c) return attention_scores, value def compute_output(self, attention_probs, value): attention_probs = self.dropout(attention_probs) context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D] context = context.transpose(0, 1).reshape( context.size(1), -1, self.hidden_size ) # shape: [Q, B, H*D] context = self.out_proj(context) context = self.post_layer_norm(context) context = self.dropout(context) return context def forward(self, hidden_states, attention_mask, relative_embedding): attention_scores, value = self.compute_attention_scores( hidden_states, relative_embedding ) attention_probs = torch.sigmoid(attention_scores) * attention_mask return self.compute_output(attention_probs, value), attention_probs.detach() class Embedding(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.word_layer_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.relative_embedding = nn.Parameter( torch.empty(2 * config.position_bucket_size - 1, config.hidden_size) ) self.relative_layer_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.initialize() def initialize(self): std = math.sqrt(2.0 / (5.0 * self.hidden_size)) nn.init.trunc_normal_( self.relative_embedding, mean=0.0, std=std, a=-2 * std, b=2 * std ) nn.init.trunc_normal_( self.word_embedding.weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) def forward(self, input_ids): word_embedding = self.dropout( self.word_layer_norm(self.word_embedding(input_ids)) ) relative_embeddings = self.relative_layer_norm(self.relative_embedding) return word_embedding, relative_embeddings # # HuggingFace wrappers # class LtgBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LtgBertConfig base_model_prefix = "bnc-bert" supports_gradient_checkpointing = True def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, Encoder): module.activation_checkpointing = value def _init_weights(self, _): pass # everything is already initialized LTG_BERT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LtgBertConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ LTG_BERT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare LTG-BERT transformer outputting raw hidden-states without any specific head on top.", LTG_BERT_START_DOCSTRING, ) class LtgBertModel(LtgBertPreTrainedModel): def __init__(self, config, add_mlm_layer=False): super().__init__(config) self.config = config self.embedding = Embedding(config) self.parser_network = ParserNetwork(config, pad=config.pad_token_id) self.transformer = Encoder(config, activation_checkpointing=False) self.classifier = ( MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None ) def get_input_embeddings(self): return self.embedding.word_embedding def set_input_embeddings(self, value): self.embedding.word_embedding = value def get_contextualized_embeddings( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if input_ids is not None: input_shape = input_ids.size() else: raise ValueError("You have to specify input_ids") batch_size, seq_length = input_shape device = input_ids.device static_embeddings, relative_embedding = self.embedding(input_ids.t()) att_mask, cibling, head, block = self.parser_network( input_ids.t(), static_embeddings ) contextualized_embeddings, attention_probs = self.transformer( static_embeddings, att_mask, relative_embedding ) contextualized_embeddings = [ e.transpose(0, 1) for e in contextualized_embeddings ] last_layer = contextualized_embeddings[-1] contextualized_embeddings = [contextualized_embeddings[0]] + [ contextualized_embeddings[i] - contextualized_embeddings[i - 1] for i in range(1, len(contextualized_embeddings)) ] return last_layer, contextualized_embeddings, attention_probs @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(input_ids, attention_mask) if not return_dict: return ( sequence_output, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return BaseModelOutput( last_hidden_state=sequence_output, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, ) @add_start_docstrings( """LTG-BERT model with a `language modeling` head on top.""", LTG_BERT_START_DOCSTRING, ) class LtgBertForMaskedLM(LtgBertModel): _keys_to_ignore_on_load_unexpected = ["head"] def __init__(self, config): super().__init__(config, add_mlm_layer=True) def get_output_embeddings(self): return self.classifier.nonlinearity[-1].weight def set_output_embeddings(self, new_embeddings): self.classifier.nonlinearity[-1].weight = new_embeddings @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(input_ids, attention_mask) subword_prediction = self.classifier(sequence_output) masked_lm_loss = None if labels is not None: masked_lm_loss = F.cross_entropy( subword_prediction.flatten(0, 1), labels.flatten() ) if not return_dict: output = ( subword_prediction, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else output ) return MaskedLMOutput( loss=masked_lm_loss, logits=subword_prediction, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, ) class Classifier(nn.Module): def __init__(self, config, num_labels: int): super().__init__() drop_out = getattr(config, "classifier_dropout", config.hidden_dropout_prob) self.nonlinearity = nn.Sequential( nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=False ), nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm( config.hidden_size, config.layer_norm_eps, elementwise_affine=False ), nn.Dropout(drop_out), nn.Linear(config.hidden_size, num_labels), ) self.initialize(config.hidden_size) def initialize(self, hidden_size): std = math.sqrt(2.0 / (5.0 * hidden_size)) nn.init.trunc_normal_( self.nonlinearity[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) nn.init.trunc_normal_( self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std ) self.nonlinearity[1].bias.data.zero_() self.nonlinearity[-1].bias.data.zero_() def forward(self, x): x = self.nonlinearity(x) return x @add_start_docstrings( """ LTG-BERT model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, LTG_BERT_START_DOCSTRING, ) class LtgBertForSequenceClassification(LtgBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config): super().__init__(config, add_mlm_layer=False) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output[:, 0, :]) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = ( logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, ) @add_start_docstrings( """ LTG-BERT model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, LTG_BERT_START_DOCSTRING, ) class LtgBertForTokenClassification(LtgBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config): super().__init__(config, add_mlm_layer=False) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = ( logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, ) @add_start_docstrings( """ LTG-BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, LTG_BERT_START_DOCSTRING, ) class LtgBertForQuestionAnswering(LtgBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config): super().__init__(config, add_mlm_layer=False) self.num_labels = config.num_labels self.head = Classifier(config, self.num_labels) @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(input_ids, attention_mask) logits = self.head(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = ( start_logits, end_logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, ) @add_start_docstrings( """ LTG-BERT model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, LTG_BERT_START_DOCSTRING, ) class LtgBertForMultipleChoice(LtgBertModel): _keys_to_ignore_on_load_unexpected = ["classifier"] _keys_to_ignore_on_load_missing = ["head"] def __init__(self, config): super().__init__(config, add_mlm_layer=False) self.num_labels = getattr(config, "num_labels", 2) self.head = Classifier(config, self.num_labels) @add_start_docstrings_to_model_forward( LTG_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) num_choices = input_ids.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_attention_mask = ( attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None ) ( sequence_output, contextualized_embeddings, attention_probs, ) = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask) logits = self.head(sequence_output) reshaped_logits = logits.view(-1, num_choices) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = ( reshaped_logits, *([contextualized_embeddings] if output_hidden_states else []), *([attention_probs] if output_attentions else []), ) return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=contextualized_embeddings if output_hidden_states else None, attentions=attention_probs if output_attentions else None, )