DeBERTa-base / modeling /modeling.py
bozhou's picture
Upload 22 files
23fe031
raw
history blame
124 kB
# coding=utf-8
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import imp
import os
import copy
import json
import math
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from .file_utils import cached_path
from ..models.loss import LabelSmoothingLoss
from ..models.ops import * # XSoftmax, XDropout, ACT2FN, StableDropout
from loguru import logger
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
# def gelu(x):
# """Implementation of the gelu activation function.
# For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
# 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
# """
# return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# def swish(x):
# return x * torch.sigmoid(x)
# ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
relax_projection=0,
new_pos_ids=False,
initializer_range=0.02,
task_idx=None,
fp32_embedding=False,
ffn_type=0,
label_smoothing=None,
num_qkv=0,
seg_emb=False):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.relax_projection = relax_projection
self.new_pos_ids = new_pos_ids
self.initializer_range = initializer_range
self.task_idx = task_idx
self.fp32_embedding = fp32_embedding
self.ffn_type = ffn_type
self.label_smoothing = label_smoothing
self.num_qkv = num_qkv
self.seg_emb = seg_emb
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
# try:
# from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
# except ImportError:
# print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:, None, :].expand(-1, bsz, -1)
else:
return pos_emb[:, None, :]
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
if hasattr(config, 'fp32_embedding'):
self.fp32_embedding = config.fp32_embedding
else:
self.fp32_embedding = False
if hasattr(config, 'new_pos_ids') and config.new_pos_ids:
self.num_pos_emb = 4
else:
self.num_pos_emb = 1
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size*self.num_pos_emb)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = StableDropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
if self.num_pos_emb > 1:
num_batch = position_embeddings.size(0)
num_pos = position_embeddings.size(1)
position_embeddings = position_embeddings.view(
num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
embeddings = words_embeddings + position_embeddings + token_type_embeddings
if self.fp32_embedding:
embeddings = embeddings.half()
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
if hasattr(config, 'num_qkv') and (config.num_qkv > 1):
self.num_qkv = config.num_qkv
else:
self.num_qkv = 1
self.query = nn.Linear(
config.hidden_size, self.all_head_size*self.num_qkv)
self.key = nn.Linear(config.hidden_size,
self.all_head_size*self.num_qkv)
self.value = nn.Linear(
config.hidden_size, self.all_head_size*self.num_qkv)
self.dropout = StableDropout(config.attention_probs_dropout_prob)
self.uni_debug_flag = True if os.getenv(
'UNI_DEBUG_FLAG', '') else False
if self.uni_debug_flag:
self.register_buffer('debug_attention_probs',
torch.zeros((512, 512)))
if hasattr(config, 'seg_emb') and config.seg_emb:
self.b_q_s = nn.Parameter(torch.zeros(
1, self.num_attention_heads, 1, self.attention_head_size))
self.seg_emb = nn.Embedding(
config.type_vocab_size, self.all_head_size)
else:
self.b_q_s = None
self.seg_emb = None
def transpose_for_scores(self, x, mask_qkv=None):
if self.num_qkv > 1:
sz = x.size()[:-1] + (self.num_qkv,
self.num_attention_heads, self.all_head_size)
# (batch, pos, num_qkv, head, head_hid)
x = x.view(*sz)
if mask_qkv is None:
x = x[:, :, 0, :, :]
elif isinstance(mask_qkv, int):
x = x[:, :, mask_qkv, :, :]
else:
# mask_qkv: (batch, pos)
if mask_qkv.size(1) > sz[1]:
mask_qkv = mask_qkv[:, :sz[1]]
# -> x: (batch, pos, head, head_hid)
x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand(
sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2)
else:
sz = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
# (batch, pos, head, head_hid)
x = x.view(*sz)
# (batch, head, pos, head_hid)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, history_states=None, mask_qkv=None, seg_ids=None):
if history_states is None:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
else:
x_states = torch.cat((history_states, hidden_states), dim=1)
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(x_states)
mixed_value_layer = self.value(x_states)
query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv)
key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv)
value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch, head, pos, pos)
attention_scores = torch.matmul(
query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
if self.seg_emb is not None:
seg_rep = self.seg_emb(seg_ids)
# (batch, pos, head, head_hid)
seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size(
1), self.num_attention_heads, self.attention_head_size)
qs = torch.einsum('bnih,bjnh->bnij',
query_layer+self.b_q_s, seg_rep)
attention_scores = attention_scores + qs
# attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
if self.uni_debug_flag:
_pos = attention_probs.size(-1)
self.debug_attention_probs[:_pos, :_pos].copy_(
attention_probs[0].mean(0).view(_pos, _pos))
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = StableDropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, history_states=None, mask_qkv=None, seg_ids=None):
self_output = self.self(
input_tensor, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = StableDropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TransformerFFN(nn.Module):
def __init__(self, config):
super(TransformerFFN, self).__init__()
self.ffn_type = config.ffn_type
assert self.ffn_type in (1, 2)
if self.ffn_type in (1, 2):
self.wx0 = nn.Linear(config.hidden_size, config.hidden_size)
if self.ffn_type in (2,):
self.wx1 = nn.Linear(config.hidden_size, config.hidden_size)
if self.ffn_type in (1, 2):
self.output = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
self.dropout = StableDropout(config.hidden_dropout_prob)
def forward(self, x):
if self.ffn_type in (1, 2):
x0 = self.wx0(x)
if self.ffn_type == 1:
x1 = x
elif self.ffn_type == 2:
x1 = self.wx1(x)
out = self.output(x0 * x1)
out = self.dropout(out)
out = self.LayerNorm(out + x)
return out
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.ffn_type = config.ffn_type
if self.ffn_type:
self.ffn = TransformerFFN(config)
else:
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, history_states=None, mask_qkv=None, seg_ids=None):
attention_output = self.attention(
hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
if self.ffn_type:
layer_output = self.ffn(attention_output)
else:
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer)
for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, seg_ids=None):
# history embedding and encoded layer must be simultanously given
assert (prev_embedding is None) == (prev_encoded_layers is None)
all_encoder_layers = []
if (prev_embedding is not None) and (prev_encoded_layers is not None):
history_states = prev_embedding
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states, attention_mask, history_states=history_states, mask_qkv=mask_qkv, seg_ids=seg_ids)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if prev_encoded_layers is not None:
history_states = prev_encoded_layers[i]
else:
for layer_module in self.layer:
hidden_states = layer_module(
hidden_states, attention_mask=attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
hid_size = config.hidden_size
if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
hid_size *= config.relax_projection
self.dense = nn.Linear(config.hidden_size, hid_size)
self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(
bert_model_embedding_weights.size(0)))
if hasattr(config, 'relax_projection') and (config.relax_projection > 1):
self.relax_projection = config.relax_projection
else:
self.relax_projection = 0
self.fp32_embedding = False #config.fp32_embedding
def convert_to_type(tensor):
if self.fp32_embedding:
return tensor.half()
else:
return tensor
self.type_converter = convert_to_type
self.converted = False
def forward(self, hidden_states, task_idx=None):
if not self.converted:
self.converted = True
if self.fp32_embedding:
self.transform.half()
hidden_states = self.transform(self.type_converter(hidden_states))
if self.relax_projection > 1:
num_batch = hidden_states.size(0)
num_pos = hidden_states.size(1)
# (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid)
hidden_states = hidden_states.view(
num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :]
if self.fp32_embedding:
hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter(
self.decoder.weight), self.type_converter(self.bias))
else:
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(
config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class MLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = ACT2FN['gelu'](x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights, with_cls=False, num_labels=2):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(
config, bert_model_embedding_weights)
if with_cls and num_labels > 0:
self.seq_relationship = nn.Linear(config.hidden_size, num_labels)
else:
self.seq_relationship = None
def forward(self, sequence_output, pooled_output=None, task_idx=None):
prediction_scores = self.predictions(sequence_output, task_idx)
if pooled_output is None or self.seq_relationship is None:
seq_relationship_score = None
else:
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def create_model(cls, *inputs, **kwargs ):
config_file = inputs[1]
config = BertConfig.from_json_file(config_file)
# define new type_vocab_size (there might be different numbers of segment ids)
if 'type_vocab_size' in kwargs:
config.type_vocab_size = kwargs['type_vocab_size']
# define new relax_projection
if ('relax_projection' in kwargs) and kwargs['relax_projection']:
config.relax_projection = kwargs['relax_projection']
# new position embedding
if ('new_pos_ids' in kwargs) and kwargs['new_pos_ids']:
config.new_pos_ids = kwargs['new_pos_ids']
# define new relax_projection
if ('task_idx' in kwargs) and kwargs['task_idx']:
config.task_idx = kwargs['task_idx']
# define new max position embedding for length expansion
if ('max_position_embeddings' in kwargs) and kwargs['max_position_embeddings']:
config.max_position_embeddings = kwargs['max_position_embeddings']
# use fp32 for embeddings
if ('fp32_embedding' in kwargs) and kwargs['fp32_embedding']:
config.fp32_embedding = kwargs['fp32_embedding']
# type of FFN in transformer blocks
if ('ffn_type' in kwargs) and kwargs['ffn_type']:
config.ffn_type = kwargs['ffn_type']
# label smoothing
if ('label_smoothing' in kwargs) and kwargs['label_smoothing']:
config.label_smoothing = kwargs['label_smoothing']
# dropout
if ('hidden_dropout_prob' in kwargs) and kwargs['hidden_dropout_prob']:
config.hidden_dropout_prob = kwargs['hidden_dropout_prob']
if ('attention_probs_dropout_prob' in kwargs) and kwargs['attention_probs_dropout_prob']:
config.attention_probs_dropout_prob = kwargs['attention_probs_dropout_prob']
# different QKV
if ('num_qkv' in kwargs) and kwargs['num_qkv']:
config.num_qkv = kwargs['num_qkv']
# segment embedding for self-attention
if ('seg_emb' in kwargs) and kwargs['seg_emb']:
config.seg_emb = kwargs['seg_emb']
# initialize word embeddings
# _word_emb_map = None
# if ('word_emb_map' in kwargs) and kwargs['word_emb_map']:
# _word_emb_map = kwargs['word_emb_map']
if 'local_debug' in kwargs and kwargs['local_debug']:
config.__setattr__('num_attention_heads', 1)
config.__setattr__("num_hidden_layers", 1)
if 'num_labels' in kwargs and kwargs['num_labels']:
config.num_labels = kwargs['num_labels']
logger.info("Model config {}".format(config))
# clean the arguments in kwargs
"""
model = BertForPreTrainingLossMask.from_pretrained(
args.bert_model, state_dict=_state_dict,
num_labels=cls_num_labels, num_rel=0, type_vocab_size=type_vocab_size,
config_path=args.config_path, task_idx=3, num_sentlvl_labels=num_sentlvl_labels,
max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing,
fp32_embedding=args.fp32_embedding, relax_projection=relax_projection, new_pos_ids=args.new_pos_ids,
ffn_type=args.ffn_type, hidden_dropout_prob=args.hidden_dropout_prob,
attention_probs_dropout_prob=args.attention_probs_dropout_prob, num_qkv=args.num_qkv,
seg_emb=args.seg_emb, local_debug=args.local_debug)
"""
for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection',
'new_pos_ids', 'task_idx', 'max_position_embeddings', 'fp32_embedding',
'ffn_type', 'label_smoothing', 'hidden_dropout_prob',
'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 'word_emb_map', 'local_debug'):
if arg_clean in kwargs:
del kwargs[arg_clean]
if kwargs.get('transformer_model_name', None):
model = cls(config, model_name=kwargs['transformer_model_name'])
# Instantiate model.
else:
model = cls(config, **kwargs)
if inputs[0]:
state_dict = torch.load(inputs[0], map_location='cpu')
try:
if 'Bert' in model.bert.__repr__() or model.bert.base_model_prefix == 'bert':
state_dict = {k.replace('roberta', 'bert'): v for k, v in state_dict.items()}
except:
pass
metadata = getattr(state_dict, '_metadata', None)
missing_keys = []
unexpected_keys = []
error_msgs = []
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model)
logger.warning(f'Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}, error_msgs: {error_msgs}')
# model.load_state_dict(state_dict)
return model
def from_pretrained(self, state_dict=None):
"""
Instantiate a PreTrainedBertModel from a a pytorch state dict.
Params:
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
# if state_dict is None:
# weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
# state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# initialize new segment embeddings
_k = 'bert.embeddings.token_type_embeddings.weight'
if (_k in state_dict) and (self.config.type_vocab_size != state_dict[_k].shape[0]):
logger.info("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
self.config.type_vocab_size, state_dict[_k].shape[0]))
if self.config.type_vocab_size > state_dict[_k].shape[0]:
# state_dict[_k].data = state_dict[_k].data.resize_(config.type_vocab_size, state_dict[_k].shape[1])
state_dict[_k].resize_(
self.config.type_vocab_size, state_dict[_k].shape[1])
# L2R
if self.config.type_vocab_size >= 3:
state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :])
# R2L
if self.config.type_vocab_size >= 4:
state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :])
# S2S
if self.config.type_vocab_size >= 6:
state_dict[_k].data[4, :].copy_(state_dict[_k].data[0, :])
state_dict[_k].data[5, :].copy_(state_dict[_k].data[1, :])
if self.config.type_vocab_size >= 7:
state_dict[_k].data[6, :].copy_(state_dict[_k].data[1, :])
elif self.config.type_vocab_size < state_dict[_k].shape[0]:
state_dict[_k].data = state_dict[_k].data[:self.config.type_vocab_size, :]
_k = 'bert.embeddings.position_embeddings.weight'
n_config_pos_emb = 4 if self.config.new_pos_ids else 1
if (_k in state_dict) and (n_config_pos_emb * self.config.hidden_size != state_dict[_k].shape[1]):
logger.info("n_config_pos_emb*config.hidden_size != state_dict[bert.embeddings.position_embeddings.weight] ({0}*{1} != {2})".format(
n_config_pos_emb, self.config.hidden_size, state_dict[_k].shape[1]))
assert state_dict[_k].shape[1] % self.config.hidden_size == 0
n_state_pos_emb = int(state_dict[_k].shape[1] / self.config.hidden_size)
assert (n_state_pos_emb == 1) != (n_config_pos_emb ==
1), "!!!!n_state_pos_emb == 1 xor n_config_pos_emb == 1!!!!"
if n_state_pos_emb == 1:
state_dict[_k].data = state_dict[_k].data.unsqueeze(1).repeat(
1, n_config_pos_emb, 1).reshape((self.config.max_position_embeddings, n_config_pos_emb * self.config.hidden_size))
elif n_config_pos_emb == 1:
if hasattr(self.config, 'task_idx') and (self.config.task_idx is not None) and (0 <= self.config.task_idx <= 3):
_task_idx = self.config.task_idx
else:
_task_idx = 0
state_dict[_k].data = state_dict[_k].data.view(
self.config.max_position_embeddings, n_state_pos_emb, self.config.hidden_size).select(1, _task_idx)
# initialize new position embeddings
_k = 'bert.embeddings.position_embeddings.weight'
if _k in state_dict and self.config.max_position_embeddings != state_dict[_k].shape[0]:
logger.info("config.max_position_embeddings != state_dict[bert.embeddings.position_embeddings.weight] ({0} - {1})".format(
self.config.max_position_embeddings, state_dict[_k].shape[0]))
if self.config.max_position_embeddings > state_dict[_k].shape[0]:
old_size = state_dict[_k].shape[0]
# state_dict[_k].data = state_dict[_k].data.resize_(config.max_position_embeddings, state_dict[_k].shape[1])
state_dict[_k].resize_(
self.config.max_position_embeddings, state_dict[_k].shape[1])
start = old_size
while start < self.config.max_position_embeddings:
chunk_size = min(
old_size, self.config.max_position_embeddings - start)
state_dict[_k].data[start:start+chunk_size,
:].copy_(state_dict[_k].data[:chunk_size, :])
start += chunk_size
elif self.config.max_position_embeddings < state_dict[_k].shape[0]:
state_dict[_k].data = state_dict[_k].data[:self.config.max_position_embeddings, :]
# initialize relax projection
_k = 'cls.predictions.transform.dense.weight'
n_config_relax = 1 if (self.config.relax_projection <
1) else self.config.relax_projection
if (_k in state_dict) and (n_config_relax * self.config.hidden_size != state_dict[_k].shape[0]):
logger.info("n_config_relax*config.hidden_size != state_dict[cls.predictions.transform.dense.weight] ({0}*{1} != {2})".format(
n_config_relax, self.config.hidden_size, state_dict[_k].shape[0]))
assert state_dict[_k].shape[0] % self.config.hidden_size == 0
n_state_relax = int(state_dict[_k].shape[0] / self.config.hidden_size)
assert (n_state_relax == 1) != (n_config_relax ==
1), "!!!!n_state_relax == 1 xor n_config_relax == 1!!!!"
if n_state_relax == 1:
_k = 'cls.predictions.transform.dense.weight'
state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat(
n_config_relax, 1, 1).reshape((n_config_relax * self.config.hidden_size, self.config.hidden_size))
for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'):
state_dict[_k].data = state_dict[_k].data.unsqueeze(
0).repeat(n_config_relax, 1).view(-1)
elif n_config_relax == 1:
if hasattr(self.config, 'task_idx') and (self.config.task_idx is not None) and (0 <= self.config.task_idx <= 3):
_task_idx = self.config.task_idx
else:
_task_idx = 0
_k = 'cls.predictions.transform.dense.weight'
state_dict[_k].data = state_dict[_k].data.view(
n_state_relax, self.config.hidden_size, self.config.hidden_size).select(0, _task_idx)
for _k in ('cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'):
state_dict[_k].data = state_dict[_k].data.view(
n_state_relax, self.config.hidden_size).select(0, _task_idx)
# initialize QKV
_all_head_size = self.config.num_attention_heads * \
int(self.config.hidden_size / self.config.num_attention_heads)
n_config_num_qkv = 1 if (self.config.num_qkv < 1) else self.config.num_qkv
for qkv_name in ('query', 'key', 'value'):
_k = 'bert.encoder.layer.0.attention.self.{0}.weight'.format(
qkv_name)
if (_k in state_dict) and (n_config_num_qkv*_all_head_size != state_dict[_k].shape[0]):
logger.info("n_config_num_qkv*_all_head_size != state_dict[_k] ({0}*{1} != {2})".format(
n_config_num_qkv, _all_head_size, state_dict[_k].shape[0]))
for layer_idx in range(self.config.num_hidden_layers):
_k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format(
layer_idx, qkv_name)
assert state_dict[_k].shape[0] % _all_head_size == 0
n_state_qkv = int(state_dict[_k].shape[0]/_all_head_size)
assert (n_state_qkv == 1) != (n_config_num_qkv ==
1), "!!!!n_state_qkv == 1 xor n_config_num_qkv == 1!!!!"
if n_state_qkv == 1:
_k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format(
layer_idx, qkv_name)
state_dict[_k].data = state_dict[_k].data.unsqueeze(0).repeat(
n_config_num_qkv, 1, 1).reshape((n_config_num_qkv*_all_head_size, _all_head_size))
_k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format(
layer_idx, qkv_name)
state_dict[_k].data = state_dict[_k].data.unsqueeze(
0).repeat(n_config_num_qkv, 1).view(-1)
elif n_config_num_qkv == 1:
if hasattr(self.config, 'task_idx') and (self.config.task_idx is not None) and (0 <= self.config.task_idx <= 3):
_task_idx = self.config.task_idx
else:
_task_idx = 0
assert _task_idx != 3, "[INVALID] _task_idx=3: n_config_num_qkv=1 (should be 2)"
if _task_idx == 0:
_qkv_idx = 0
else:
_qkv_idx = 1
_k = 'bert.encoder.layer.{0}.attention.self.{1}.weight'.format(
layer_idx, qkv_name)
state_dict[_k].data = state_dict[_k].data.view(
n_state_qkv, _all_head_size, _all_head_size).select(0, _qkv_idx)
_k = 'bert.encoder.layer.{0}.attention.self.{1}.bias'.format(
layer_idx, qkv_name)
state_dict[_k].data = state_dict[_k].data.view(
n_state_qkv, _all_head_size).select(0, _qkv_idx)
# if _word_emb_map:
# _k = 'bert.embeddings.word_embeddings.weight'
# for _tgt, _src in _word_emb_map:
# state_dict[_k].data[_tgt, :].copy_(
# state_dict[_k].data[_src, :])
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self.model, prefix='' if hasattr(self.model, 'bert') else 'bert.')
self.model.missing_keys = missing_keys
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
self.model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
self.model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
logger.info('\n'.join(error_msgs))
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def rescale_some_parameters(self):
for layer_id, layer in enumerate(self.encoder.layer):
layer.attention.output.dense.weight.data.div_(
math.sqrt(2.0*(layer_id + 1)))
layer.output.dense.weight.data.div_(math.sqrt(2.0*(layer_id + 1)))
def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
elif attention_mask.dim() == 4:
extended_attention_mask = attention_mask
else:
raise NotImplementedError
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# extended_attention_mask = extended_attention_mask.to(
# dtype=next(self.parameters()).dtype) # fp16 compatibility
# extended_attention_mask = (1.0 - extended_attention_mask) * -100000.0
return extended_attention_mask.byte()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, mask_qkv=None, task_idx=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids, task_idx=task_idx)
encoded_layers = self.encoder(embedding_output, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids)
sequence_output = encoded_layers[-1]
if self.pooler is not None:
pooled_output = self.pooler(sequence_output)
else:
pooled_output = None
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertModelIncr(BertModel):
def __init__(self, config):
super(BertModelIncr, self).__init__(config)
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, prev_embedding=None,
prev_encoded_layers=None, mask_qkv=None, task_idx=None):
extended_attention_mask = self.get_extended_attention_mask(
input_ids, token_type_ids, attention_mask)
embedding_output = self.embeddings(
input_ids, token_type_ids, position_ids, task_idx=task_idx)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, seg_ids=token_type_ids)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return embedding_output, encoded_layers, pooled_output
class BertForPreTraining(PreTrainedBertModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, mask_qkv=None, task_idx=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
prediction_scores, seq_relationship_score = self.cls(
sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(
seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
class BertPreTrainingPairTransform(nn.Module):
def __init__(self, config):
super(BertPreTrainingPairTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size*2, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
# self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5)
def forward(self, pair_x, pair_y):
hidden_states = torch.cat([pair_x, pair_y], dim=-1)
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
# hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertPreTrainingPairRel(nn.Module):
def __init__(self, config, num_rel=0):
super(BertPreTrainingPairRel, self).__init__()
self.R_xy = BertPreTrainingPairTransform(config)
self.rel_emb = nn.Embedding(num_rel, config.hidden_size)
def forward(self, pair_x, pair_y, pair_r, pair_pos_neg_mask):
# (batch, num_pair, hidden)
xy = self.R_xy(pair_x, pair_y)
r = self.rel_emb(pair_r)
_batch, _num_pair, _hidden = xy.size()
pair_score = (xy * r).sum(-1)
# torch.bmm(xy.view(-1, 1, _hidden),r.view(-1, _hidden, 1)).view(_batch, _num_pair)
# .mul_(-1.0): objective to loss
return F.logsigmoid(pair_score * pair_pos_neg_mask.type_as(pair_score)).mul_(-1.0)
class BertForPreTrainingLossMask(PreTrainedBertModel):
"""refer to BertForPreTraining"""
def __init__(self, config, num_labels=2, num_rel=0, num_sentlvl_labels=0, no_nsp=False):
super(BertForPreTrainingLossMask, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight, with_cls=True, num_labels=num_labels)
self.cls_ar = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight, with_cls=False, num_labels=num_labels)
self.cls_seq2seq = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight, with_cls=False, num_labels=num_labels)
self.num_sentlvl_labels = num_sentlvl_labels
self.cls2 = None
if self.num_sentlvl_labels > 0:
self.secondary_pred_proj = nn.Embedding(
num_sentlvl_labels, config.hidden_size)
self.cls2 = BertPreTrainingHeads(
config, self.secondary_pred_proj.weight, num_labels=num_sentlvl_labels)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
if no_nsp:
self.crit_next_sent = None
else:
self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
self.num_labels = num_labels
self.num_rel = num_rel
if self.num_rel > 0:
self.crit_pair_rel = BertPreTrainingPairRel(
config, num_rel=num_rel)
if hasattr(config, 'label_smoothing') and config.label_smoothing:
self.crit_mask_lm_smoothed = LabelSmoothingLoss(
config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none')
else:
self.crit_mask_lm_smoothed = None
self.apply(self.init_bert_weights)
self.bert.rescale_some_parameters()
# def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
# next_sentence_label=None, masked_pos=None, masked_weights=None, task_idx=None, pair_x=None,
# pair_x_mask=None, pair_y=None, pair_y_mask=None, pair_r=None, pair_pos_neg_mask=None,
# pair_loss_mask=None, masked_pos_2=None, masked_weights_2=None, masked_labels_2=None,
# num_tokens_a=None, num_tokens_b=None, mask_qkv=None):
def forward(self, input_ids, token_type_ids, attention_mask, masked_lm_labels, masked_pos,
task_idx, num_tokens_a=None, num_tokens_b=None, masked_weights=None,
next_sentence_label= None, mask_qkv=None,
target=None, feed_dict=None):
# feed_instances = {k: v for v, k in zip(input, feed_dict)}
# input_ids = feed_instances.get('input_ids', None)
# token_type_ids = feed_instances.get('token_type_ids', None)
# attention_mask = feed_instances.get('attention_mask', None)
# masked_lm_labels = feed_instances.get('masked_lm_labels', None).to(torch.long)
masked_pos = masked_pos.to(torch.int64)
task_idx = task_idx.to(torch.int64)
# num_tokens_a = feed_instances.get('num_tokens_a', None)
# num_tokens_b = feed_instances.get('num_tokens_b', None)
masked_weights = masked_weights.to(torch.int64)
next_sentence_label = next_sentence_label.to(torch.long)
# mask_qkv = feed_instances.get('mask_qkv', None)
if token_type_ids is None and attention_mask is None:
task_0 = (task_idx == 0)
task_1 = (task_idx == 1)
task_2 = (task_idx == 2)
task_3 = (task_idx == 3)
sequence_length = input_ids.shape[-1]
index_matrix = torch.arange(sequence_length).view(
1, sequence_length).to(input_ids.device)
num_tokens = num_tokens_a + num_tokens_b
base_mask = (index_matrix < num_tokens.view(-1, 1)).type_as(input_ids)
segment_a_mask = (index_matrix < num_tokens_a.view(-1, 1)).type_as(input_ids)
token_type_ids = (task_idx + 1 + task_3.type_as(task_idx)).view(-1, 1) * base_mask
token_type_ids = token_type_ids - segment_a_mask * \
(task_0 | task_3).type_as(segment_a_mask).view(-1, 1)
index_matrix = index_matrix.view(1, 1, sequence_length)
index_matrix_t = index_matrix.view(1, sequence_length, 1)
tril = index_matrix <= index_matrix_t
attention_mask_task_0 = (
index_matrix < num_tokens.view(-1, 1, 1)) & (index_matrix_t < num_tokens.view(-1, 1, 1))
attention_mask_task_1 = tril & attention_mask_task_0
attention_mask_task_2 = torch.transpose(
tril, dim0=-2, dim1=-1) & attention_mask_task_0
attention_mask_task_3 = ((index_matrix < num_tokens_a.view(-1, 1, 1)) | tril) & attention_mask_task_0
attention_mask = (attention_mask_task_0 & task_0.view(-1, 1, 1)) | \
(attention_mask_task_1 & task_1.view(-1, 1, 1)) | \
(attention_mask_task_2 & task_2.view(-1, 1, 1)) | \
(attention_mask_task_3 & task_3.view(-1, 1, 1))
attention_mask = attention_mask.byte()#type_as(input_ids)
sequence_output, pooled_output = self.bert(
input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
def gather_seq_out_by_pos(seq, pos):
return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)).to(torch.int64))
def gather_seq_out_by_task_idx(seq, task_idx, idx):
task_mask = task_idx == idx
return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
def gather_seq_out_by_pos_average(seq, pos, mask):
# pos/mask: (batch, num_pair, max_token_num)
batch_size, max_token_num = pos.size(0), pos.size(-1)
# (batch, num_pair, max_token_num, seq.size(-1))
pos_vec = torch.gather(seq, 1, pos.view(batch_size, -1).unsqueeze(
2).expand(-1, -1, seq.size(-1))).view(batch_size, -1, max_token_num, seq.size(-1))
# (batch, num_pair, seq.size(-1))
mask = mask.type_as(pos_vec)
pos_vec_masked_sum = (
pos_vec * mask.unsqueeze(3).expand_as(pos_vec)).sum(2)
return pos_vec_masked_sum / mask.sum(2, keepdim=True).expand_as(pos_vec_masked_sum)
def loss_mask_and_normalize(loss, mask):
mask = mask.type_as(loss)
loss = loss * mask
denominator = torch.sum(mask) + 1e-5
return (loss / denominator).sum()
def cal_acc(score, target, mask=None):
score = torch.argmax(score, -1)
cmp = score == target
label_num = cmp.size(0) * cmp.size(1)
if mask is not None:
cmp = cmp * mask
label_num = torch.sum(mask)
t = torch.sum(cmp)
return t / (label_num + 1e-5)
if masked_lm_labels is None:
if masked_pos is None:
prediction_scores, seq_relationship_score = self.cls(
sequence_output, pooled_output, task_idx=task_idx)
else:
sequence_output_masked = gather_seq_out_by_pos(
sequence_output, masked_pos)
prediction_scores, seq_relationship_score = self.cls(
sequence_output_masked, pooled_output, task_idx=task_idx)
return prediction_scores, seq_relationship_score
else:
masked_lm_labels = masked_lm_labels.to(torch.long)
# prediction_scores, seq_relationship_score = self.cls(
# sequence_output, pooled_output, task_idx=task_idx)
# masked lm
# mlm_mask_0 =
# mlm_mask_2 = task_idx == 2
mlm_mask = (task_idx == 0) | (task_idx == 2) | (task_idx == 3)
sequence_output_masked = sequence_output[mlm_mask]
if sequence_output_masked.size(0) > 0:
masked_pos = masked_pos[mlm_mask]
sequence_output_masked = gather_seq_out_by_pos(
sequence_output_masked, masked_pos)
prediction_scores_masked, seq_relationship_score = self.cls(
sequence_output_masked, pooled_output, task_idx=task_idx)
if self.crit_mask_lm_smoothed:
masked_lm_loss = self.crit_mask_lm_smoothed(
F.log_softmax(prediction_scores_masked.type_as(sequence_output_masked), dim=-1), masked_lm_labels[mlm_mask])
else:
masked_lm_loss = self.crit_mask_lm(
prediction_scores_masked.transpose(1, 2).type_as(sequence_output_masked), masked_lm_labels[mlm_mask])
masked_lm_loss = loss_mask_and_normalize(
masked_lm_loss.type_as(sequence_output_masked), masked_weights[mlm_mask])
mlm_acc = cal_acc(prediction_scores_masked, masked_lm_labels[mlm_mask], masked_weights[mlm_mask])
logits = prediction_scores_masked
else:
masked_lm_loss = torch.tensor(0.0).type_as(sequence_output_masked).to(sequence_output.device)
mlm_acc = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
logits = torch.tensor([]).type_as(sequence_output_masked).to(sequence_output_masked.device)
# lm loss, only AR implemeted
ar_mask = task_idx == 1
sequence_output_lm = sequence_output[ar_mask]
if sequence_output_lm.size(0) > 0:
prediction_scores_sequence_output_lm, _ = self.cls_ar(
sequence_output_lm, None, task_idx=task_idx)
prediction_scores_sequence_output_lm = prediction_scores_sequence_output_lm[:, : -1, :]
ar_label = input_ids[ar_mask]
ar_label = ar_label[:, 1:].to(torch.long)
if self.crit_mask_lm_smoothed:
lm_loss = self.crit_mask_lm_smoothed(
F.log_softmax(prediction_scores_sequence_output_lm.type_as(sequence_output), dim=-1), ar_label)
else:
lm_loss = self.crit_mask_lm(
prediction_scores_sequence_output_lm.transpose(1, 2).type_as(sequence_output), ar_label)
ar_mask_loss = torch.cumsum(torch.flip(input_ids, [-1]), -1)
ar_mask_loss = ar_mask_loss > 0
ar_mask_loss = torch.flip(ar_mask_loss, [-1])[:, : -1]
lm_loss = loss_mask_and_normalize(
lm_loss.type_as(sequence_output), ar_mask_loss[ar_mask])
lm_acc = cal_acc(prediction_scores_sequence_output_lm, ar_label, ar_mask_loss[ar_mask])
else:
lm_loss = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
lm_acc = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
# seq2seq loss
seq2seq_mask = task_idx == 4
sequence_output_seq2seq = sequence_output[seq2seq_mask]
if sequence_output_seq2seq.size(0) > 0:
prediction_scores_sequence_output_seq2seq, _ = self.cls_seq2seq(
sequence_output_seq2seq, None, task_idx=task_idx)
prediction_scores_sequence_output_seq2seq = prediction_scores_sequence_output_seq2seq[:, : -1, :]
seq2seq_label = input_ids[seq2seq_mask]
seq2seq_label = seq2seq_label[:, 1:].to(torch.long)
if self.crit_mask_lm_smoothed:
if sum(seq2seq_mask) >= 4:
print('big batch')
seq2seq_loss = self.crit_mask_lm_smoothed(
F.log_softmax(prediction_scores_sequence_output_seq2seq.type_as(sequence_output), dim=-1),
seq2seq_label)
else:
seq2seq_loss = self.crit_mask_lm(
prediction_scores_sequence_output_seq2seq.transpose(1, 2).type_as(sequence_output), seq2seq_label)
seq2seq_loss = loss_mask_and_normalize(
seq2seq_loss.type_as(sequence_output), token_type_ids[seq2seq_mask][:, 1:])
seq2seq_acc = cal_acc(prediction_scores_sequence_output_seq2seq, seq2seq_label, token_type_ids[seq2seq_mask][:, 1:])
else:
seq2seq_loss = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
seq2seq_acc = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
# logger.info(f'mlm_acc: {mlm_acc}, lm_acc: {lm_acc}, seq2seq_acc: {seq2seq_acc}')
# next sentence
if self.crit_next_sent is None or next_sentence_label is None or sequence_output_masked.size(0) == 0:
next_sentence_loss = torch.tensor(0.0).type_as(masked_lm_loss).to(masked_lm_loss.device)
else:
# next_sentence_loss = self.crit_next_sent(
# seq_relationship_score.view(-1, self.num_labels).float(), next_sentence_label.view(-1))
next_sentence_loss = self.crit_next_sent(seq_relationship_score[task_idx == 0], next_sentence_label[task_idx == 0])
next_sentence_mask = torch.where(task_idx != 0, torch.full_like(task_idx, 0), torch.full_like(task_idx, 1)).type_as(next_sentence_loss)
denominator = torch.sum(next_sentence_mask).type_as(sequence_output) + 1e-5
if not denominator.item():
# print('tast_idx 全是 -1')
# print(task_idx)
# print(next_sentence_loss)
next_sentence_loss = next_sentence_loss.sum()
else:
next_sentence_loss = (next_sentence_loss / denominator).sum()
loss = {'loss': masked_lm_loss + next_sentence_loss * 10 + lm_loss + seq2seq_loss,
'masked_lm_loss': masked_lm_loss,
'next_sentence_loss': next_sentence_loss,
'lm_loss': lm_loss,
'seq2seq_loss': seq2seq_loss,
'batch_size': input_ids.size(0),
'mlm_acc': mlm_acc,
'lm_acc': lm_acc,
'seq2seq_acc': seq2seq_acc,
'logits': logits,
'labels': masked_lm_labels}
# logger.info(loss)
return loss
# if self.cls2 is not None and masked_pos_2 is not None:
# sequence_output_masked_2 = gather_seq_out_by_pos(
# sequence_output, masked_pos_2)
# prediction_scores_masked_2, _ = self.cls2(
# sequence_output_masked_2, None)
# masked_lm_loss_2 = self.crit_mask_lm(
# prediction_scores_masked_2.transpose(1, 2).float(), masked_labels_2)
# masked_lm_loss_2 = loss_mask_and_normalize(
# masked_lm_loss_2.float(), masked_weights_2)
# masked_lm_loss = masked_lm_loss + masked_lm_loss_2
# if pair_x is None or pair_y is None or pair_r is None or pair_pos_neg_mask is None or pair_loss_mask is None:
# return masked_lm_loss, next_sentence_loss
# # pair and relation
# if pair_x_mask is None or pair_y_mask is None:
# pair_x_output_masked = gather_seq_out_by_pos(
# sequence_output, pair_x)
# pair_y_output_masked = gather_seq_out_by_pos(
# sequence_output, pair_y)
# else:
# pair_x_output_masked = gather_seq_out_by_pos_average(
# sequence_output, pair_x, pair_x_mask)
# pair_y_output_masked = gather_seq_out_by_pos_average(
# sequence_output, pair_y, pair_y_mask)
# pair_loss = self.crit_pair_rel(
# pair_x_output_masked, pair_y_output_masked, pair_r, pair_pos_neg_mask)
# pair_loss = loss_mask_and_normalize(
# pair_loss.float(), pair_loss_mask)
# return masked_lm_loss, next_sentence_loss, pair_loss
class BertForExtractiveSummarization(PreTrainedBertModel):
"""refer to BertForPreTraining"""
def __init__(self, config):
super(BertForExtractiveSummarization, self).__init__(config)
self.bert = BertModel(config)
self.secondary_pred_proj = nn.Embedding(2, config.hidden_size)
self.cls2 = BertPreTrainingHeads(
config, self.secondary_pred_proj.weight, num_labels=2)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_pos_2=None, masked_weights_2=None, task_idx=None, mask_qkv=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
def gather_seq_out_by_pos(seq, pos):
return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
sequence_output_masked_2 = gather_seq_out_by_pos(
sequence_output, masked_pos_2)
prediction_scores_masked_2, _ = self.cls2(
sequence_output_masked_2, None, task_idx=task_idx)
predicted_probs = torch.nn.functional.softmax(
prediction_scores_masked_2, dim=-1)
return predicted_probs, masked_pos_2, masked_weights_2
class BertForSeq2SeqDecoder(PreTrainedBertModel):
"""refer to BertForPreTraining"""
def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0,
search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0,
forbid_duplicate_ngrams=False, forbid_ignore_set=None, not_predict_set=None,
ngram_size=3, min_len=0, mode="s2s", pos_shift=False, use_rule=False, rule=0):
super(BertForSeq2SeqDecoder, self).__init__(config)
self.bert = BertModelIncr(config)
self.cls = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels)
self.apply(self.init_bert_weights)
self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none')
self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1)
self.mask_word_id = mask_word_id
self.num_labels = num_labels
self.num_rel = num_rel
if self.num_rel > 0:
self.crit_pair_rel = BertPreTrainingPairRel(
config, num_rel=num_rel)
self.search_beam_size = search_beam_size
self.length_penalty = length_penalty
self.eos_id = eos_id
self.sos_id = sos_id
self.forbid_duplicate_ngrams = forbid_duplicate_ngrams
self.forbid_ignore_set = forbid_ignore_set
self.not_predict_set = not_predict_set
self.ngram_size = ngram_size
self.min_len = min_len
assert mode in ("s2s", "l2r")
self.mode = mode
self.pos_shift = pos_shift
self.use_rule = use_rule
self.rule = rule
def get_prediction_scores(self,input_ids, token_type_ids, position_ids, attention_mask, decode_mask, task_idx=None, mask_qkv=None):
#print(input_ids.shape)
input_ids = torch.cat( (input_ids, torch.tensor([[self.mask_word_id]]).type_as(input_ids)), dim=-1)
#print(input_ids.shape)
input_shape = list(input_ids.size())
batch_size = input_shape[0]
assert batch_size == 1
input_length = input_shape[1]
# output_shape = list(token_type_ids.size())
# output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
# curr_ids = input_ids
# mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
next_pos = input_length - 1
# if self.pos_shift:
# sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id)
#
# while next_pos < output_length:
# curr_length = list(curr_ids.size())[1]
#
# if self.pos_shift:
# if next_pos == input_length:
# x_input_ids = torch.cat((curr_ids, sos_ids), dim=1)
# start_pos = 0
# else:
# x_input_ids = curr_ids
# start_pos = next_pos
# else:
# start_pos = next_pos - curr_length
# x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
start_pos = 0
curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos+1, :next_pos+1]
curr_position_ids = position_ids[:, start_pos:next_pos+1]
#print(curr_attention_mask.shape)
#print(curr_token_type_ids.shape)
#print(curr_position_ids.shape)
new_embedding, new_encoded_layers, _ = \
self.bert(input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(
last_hidden, None, task_idx=task_idx)
if self.not_predict_set:
for token_id in self.not_predict_set:
prediction_scores[:, :, token_id].fill_(-100000.0)
# 加入decode_mask
prediction_scores = prediction_scores + decode_mask.unsqueeze(1) * -100000.0
log_prob = torch.nn.functional.log_softmax(prediction_scores, dim=-1)
return log_prob
def forward(self, input_ids, token_type_ids, position_ids, attention_mask, decode_mask, task_idx=None, mask_qkv=None):
if self.rule != 3:
return self.get_prediction_scores(input_ids, token_type_ids, position_ids, attention_mask, decode_mask, task_idx=None, mask_qkv=None)
if self.search_beam_size > 1:
return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, decode_mask, task_idx=task_idx, mask_qkv=mask_qkv)
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
next_pos = input_length
if self.pos_shift:
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id)
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
if self.pos_shift:
if next_pos == input_length:
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1)
start_pos = 0
else:
x_input_ids = curr_ids
start_pos = next_pos
else:
start_pos = next_pos - curr_length
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos+1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos+1, :next_pos+1]
curr_position_ids = position_ids[:, start_pos:next_pos+1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding,
prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(
last_hidden, None, task_idx=task_idx)
if self.not_predict_set:
for token_id in self.not_predict_set:
prediction_scores[:, :, token_id].fill_(-100000.0)
# 加入decode_mask
prediction_scores = prediction_scores + decode_mask.unsqueeze(1) * -100000.0
_, max_ids = torch.max(prediction_scores, dim=-1)
output_ids.append(max_ids)
if self.pos_shift:
if prev_embedding is None:
prev_embedding = new_embedding
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding), dim=1)
if prev_encoded_layers is None:
prev_encoded_layers = [x for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip(
prev_encoded_layers, new_encoded_layers)]
else:
if prev_embedding is None:
prev_embedding = new_embedding[:, :-1, :]
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
if prev_encoded_layers is None:
prev_encoded_layers = [x[:, :-1, :]
for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
curr_ids = max_ids
next_pos += 1
return torch.cat(output_ids, dim=1)
def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, decode_mask, task_idx=None, mask_qkv=None):
input_shape = list(input_ids.size())
batch_size = input_shape[0]
input_length = input_shape[1]
output_shape = list(token_type_ids.size())
output_length = output_shape[1]
vocab_len = len(decode_mask.tolist()[0])
output_ids = []
prev_embedding = None
prev_encoded_layers = None
curr_ids = input_ids
mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id)
next_pos = input_length
if self.pos_shift:
sos_ids = input_ids.new(batch_size, 1).fill_(self.sos_id)
K = self.search_beam_size
total_scores = []
beam_masks = []
step_ids = []
step_back_ptrs = []
partial_seqs = []
forbid_word_mask = None
buf_matrix = None
partial_seqs_1 = []
while next_pos < output_length:
curr_length = list(curr_ids.size())[1]
is_first = (prev_embedding is None)
if self.pos_shift:
if next_pos == input_length:
x_input_ids = torch.cat((curr_ids, sos_ids), dim=1)
start_pos = 0
else:
x_input_ids = curr_ids
start_pos = next_pos
else:
start_pos = next_pos - curr_length
x_input_ids = torch.cat((curr_ids, mask_ids), dim=1)
curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1]
curr_attention_mask = attention_mask[:,
start_pos:next_pos + 1, :next_pos + 1]
curr_position_ids = position_ids[:, start_pos:next_pos + 1]
new_embedding, new_encoded_layers, _ = \
self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask,
output_all_encoded_layers=True, prev_embedding=prev_embedding, prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv)
last_hidden = new_encoded_layers[-1][:, -1:, :]
prediction_scores, _ = self.cls(
last_hidden, None, task_idx=task_idx)
log_scores = torch.nn.functional.log_softmax(
prediction_scores, dim=-1)
if forbid_word_mask is not None:
log_scores += (forbid_word_mask * -100000.0)
if self.min_len and (next_pos-input_length+1 <= self.min_len):
log_scores[:, :, self.eos_id].fill_(-100000.0)
if self.not_predict_set:
for token_id in self.not_predict_set:
log_scores[:, :, token_id].fill_(-100000.0)
# 加入decode_mask
log_scores = log_scores + decode_mask.unsqueeze(1) * -100000.0
# mask上文
if not is_first:
log_scores = log_scores.squeeze(dim=1)
mask_left = np.zeros((batch_size*K, vocab_len), dtype=np.float)
for l, seq in enumerate(partial_seqs_1):
for h in seq:
mask_left[l][h] = -100000.0
mask_left = torch.tensor(mask_left).type_as(log_scores)
log_scores = log_scores + mask_left
log_scores = log_scores.unsqueeze(dim=1)
kk_scores, kk_ids = torch.topk(log_scores, k=K)
if len(total_scores) == 0:
k_ids = torch.reshape(kk_ids, [batch_size, K])
back_ptrs = torch.zeros(batch_size, K, dtype=torch.long)
k_scores = torch.reshape(kk_scores, [batch_size, K])
else:
last_eos = torch.reshape(
beam_masks[-1], [batch_size * K, 1, 1])
last_seq_scores = torch.reshape(
total_scores[-1], [batch_size * K, 1, 1])
kk_scores += last_eos * (-100000.0) + last_seq_scores
kk_scores = torch.reshape(kk_scores, [batch_size, K * K])
k_scores, k_ids = torch.topk(kk_scores, k=K)
back_ptrs = torch.div(k_ids, K)
kk_ids = torch.reshape(kk_ids, [batch_size, K * K])
k_ids = torch.gather(kk_ids, 1, k_ids)
step_back_ptrs.append(back_ptrs)
step_ids.append(k_ids)
beam_masks.append(torch.eq(k_ids, self.eos_id).float())
total_scores.append(k_scores)
# 求出上文
wids = step_ids[-1].tolist()
ptrs = step_back_ptrs[-1].tolist()
if is_first:
partial_seqs_1 = []
for b in range(batch_size):
for k in range(K):
partial_seqs_1.append([wids[b][k]])
else:
new_partial_seqs_1 = []
for b in range(batch_size):
for k in range(K):
new_partial_seqs_1.append(
partial_seqs_1[ptrs[b][k] + b * K] + [wids[b][k]])
partial_seqs_1 = new_partial_seqs_1
def first_expand(x):
input_shape = list(x.size())
expanded_shape = input_shape[:1] + [1] + input_shape[1:]
x = torch.reshape(x, expanded_shape)
repeat_count = [1, K] + [1] * (len(input_shape) - 1)
x = x.repeat(*repeat_count)
x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:])
return x
def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
y = torch.gather(x, 1, ids)
y = torch.reshape(y, x_shape)
return y
if self.pos_shift:
if prev_embedding is None:
prev_embedding = first_expand(new_embedding)
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding), dim=1)
prev_embedding = select_beam_items(
prev_embedding, back_ptrs)
if prev_encoded_layers is None:
prev_encoded_layers = [first_expand(
x) for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip(
prev_encoded_layers, new_encoded_layers)]
prev_encoded_layers = [select_beam_items(
x, back_ptrs) for x in prev_encoded_layers]
else:
if prev_embedding is None:
prev_embedding = first_expand(new_embedding[:, :-1, :])
else:
prev_embedding = torch.cat(
(prev_embedding, new_embedding[:, :-1, :]), dim=1)
prev_embedding = select_beam_items(
prev_embedding, back_ptrs)
if prev_encoded_layers is None:
prev_encoded_layers = [first_expand(
x[:, :-1, :]) for x in new_encoded_layers]
else:
prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1)
for x in zip(prev_encoded_layers, new_encoded_layers)]
prev_encoded_layers = [select_beam_items(
x, back_ptrs) for x in prev_encoded_layers]
curr_ids = torch.reshape(k_ids, [batch_size * K, 1])
if is_first:
token_type_ids = first_expand(token_type_ids)
# 扩展维度
decode_mask = first_expand(decode_mask)
position_ids = first_expand(position_ids)
attention_mask = first_expand(attention_mask)
mask_ids = first_expand(mask_ids)
if mask_qkv is not None:
mask_qkv = first_expand(mask_qkv)
if self.forbid_duplicate_ngrams:
wids = step_ids[-1].tolist()
ptrs = step_back_ptrs[-1].tolist()
if is_first:
partial_seqs = []
for b in range(batch_size):
for k in range(K):
partial_seqs.append([wids[b][k]])
else:
new_partial_seqs = []
for b in range(batch_size):
for k in range(K):
new_partial_seqs.append(
partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]])
partial_seqs = new_partial_seqs
def get_dup_ngram_candidates(seq, n):
cands = set()
if len(seq) < n:
return []
tail = seq[-(n-1):]
if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail):
return []
for i in range(len(seq) - (n - 1)):
mismatch = False
for j in range(n - 1):
if tail[j] != seq[i + j]:
mismatch = True
break
if (not mismatch) and not(self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)):
cands.add(seq[i + n - 1])
return list(sorted(cands))
if len(partial_seqs[0]) >= self.ngram_size:
dup_cands = []
for seq in partial_seqs:
dup_cands.append(
get_dup_ngram_candidates(seq, self.ngram_size))
if max(len(x) for x in dup_cands) > 0:
if buf_matrix is None:
vocab_size = list(log_scores.size())[-1]
buf_matrix = np.zeros(
(batch_size * K, vocab_size), dtype=float)
else:
buf_matrix.fill(0)
for bk, cands in enumerate(dup_cands):
for i, wid in enumerate(cands):
buf_matrix[bk, wid] = 1.0
forbid_word_mask = torch.tensor(
buf_matrix, dtype=log_scores.dtype)
forbid_word_mask = torch.reshape(
forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda()
else:
forbid_word_mask = None
next_pos += 1
# [(batch, beam)]
total_scores = [x.tolist() for x in total_scores]
step_ids = [x.tolist() for x in step_ids]
step_back_ptrs = [x.tolist() for x in step_back_ptrs]
# back tracking
traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []}
for b in range(batch_size):
# [(beam,)]
scores = [x[b] for x in total_scores]
wids_list = [x[b] for x in step_ids]
ptrs = [x[b] for x in step_back_ptrs]
traces['scores'].append(scores)
traces['wids'].append(wids_list)
traces['ptrs'].append(ptrs)
# first we need to find the eos frame where all symbols are eos
# any frames after the eos frame are invalid
last_frame_id = len(scores) - 1
for i, wids in enumerate(wids_list):
if all(wid == self.eos_id for wid in wids):
last_frame_id = i
break
max_score = -math.inf
frame_id = -1
pos_in_frame = -1
for fid in range(last_frame_id + 1):
for i, wid in enumerate(wids_list[fid]):
if wid == self.eos_id or fid == last_frame_id:
s = scores[fid][i]
if self.length_penalty > 0:
s /= math.pow((5 + fid + 1) / 6.0,
self.length_penalty)
if s > max_score:
max_score = s
frame_id = fid
pos_in_frame = i
if frame_id == -1:
traces['pred_seq'].append([0])
else:
seq = [wids_list[frame_id][pos_in_frame]]
for fid in range(frame_id, 0, -1):
pos_in_frame = ptrs[fid][pos_in_frame]
seq.append(wids_list[fid - 1][pos_in_frame])
seq.reverse()
traces['pred_seq'].append(seq)
def _pad_sequence(sequences, max_len, padding_value=0):
trailing_dims = sequences[0].size()[1:]
out_dims = (len(sequences), max_len) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
out_tensor[i, :length, ...] = tensor
return out_tensor
# convert to tensors for DataParallel
for k in ('pred_seq', 'scores', 'wids', 'ptrs'):
ts_list = traces[k]
if not isinstance(ts_list[0], torch.Tensor):
dt = torch.float if k == 'scores' else torch.long
ts_list = [torch.tensor(it, dtype=dt) for it in ts_list]
traces[k] = _pad_sequence(
ts_list, output_length, padding_value=0).to(input_ids.device)
return traces
class BertForMaskedLM(PreTrainedBertModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.lm_head = MLMHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, mask_qkv=None, task_idx=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
prediction_scores = self.lm_head(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
class BertForNextSentencePrediction(PreTrainedBertModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `next_sentence_label` is not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, mask_qkv=None, task_idx=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
seq_relationship_score = self.cls(pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(
seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
class BertForSequenceClassification(PreTrainedBertModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = StableDropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None):
_, pooled_output = self.bert(
input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
if labels.dtype == torch.long:
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1))
elif labels.dtype == torch.half or labels.dtype == torch.float:
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
print('unkown labels.dtype')
loss = None
return loss
else:
return logits
class BertForMultipleChoice(PreTrainedBertModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = StableDropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(
flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForTokenClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = StableDropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, mask_qkv=None, task_idx=None):
sequence_output, _ = self.bert(
input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, mask_qkv=mask_qkv, task_idx=task_idx)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForQuestionAnswering(PreTrainedBertModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: either
- a BertConfig class instance with the configuration to build a new model, or
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-base-chinese`
The pre-trained model will be downloaded and cached if needed.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# self.dropout = StableDropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, task_idx=None):
sequence_output, _ = self.bert(
input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, task_idx=task_idx)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
class EnhancedMaskDecoder(torch.nn.Module):
def __init__(self, config, vocab_size):
super().__init__()
self.config = config
self.position_biased_input = getattr(config, 'position_biased_input', True)
self.lm_head = BertLMPredictionHead(config, vocab_size)
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None):
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
lm_loss = torch.tensor(0).to(ctx_layers[-1])
arlm_loss = torch.tensor(0).to(ctx_layers[-1])
ctx_layer = mlm_ctx_layers[-1]
lm_logits = self.lm_head(ctx_layer, ebd_weight).float()
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_labels = target_ids.view(-1)
label_index = (target_ids.view(-1)>0).nonzero().view(-1)
lm_labels = lm_labels.index_select(0, label_index)
lm_loss = loss_fct(lm_logits, lm_labels.long())
return lm_logits, lm_labels, lm_loss
def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=None):
if attention_mask.dim()<=2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
att_mask = extended_attention_mask.byte()
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
elif attention_mask.dim()==3:
attention_mask = attention_mask.unsqueeze(1)
target_mask = target_ids>0
hidden_states = encoder_layers[-2]
if not self.position_biased_input:
layers = [encoder.layer[-1] for _ in range(2)]
z_states += hidden_states
query_states = z_states
query_mask = attention_mask
outputs = []
rel_embeddings = encoder.get_rel_embedding()
for layer in layers:
# TODO: pass relative pos ids
output = layer(hidden_states, query_mask, return_att=False, query_states = query_states, relative_pos=relative_pos, rel_embeddings = rel_embeddings)
query_states = output
outputs.append(query_states)
else:
outputs = [encoder_layers[-1]]
_mask_index = (target_ids>0).view(-1).nonzero().view(-1)
def flatten_states(q_states):
q_states = q_states.view((-1, q_states.size(-1)))
q_states = q_states.index_select(0, _mask_index)
return q_states
return [flatten_states(q) for q in outputs]