|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import torch |
|
|
import os |
|
|
import random |
|
|
|
|
|
import json |
|
|
from .ops import * |
|
|
from .bert import * |
|
|
from .bert import BertLayer |
|
|
from .config import ModelConfig |
|
|
from .cache_utils import load_model_state |
|
|
from .nnmodule import NNModule |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['WywLM'] |
|
|
|
|
|
def flatten_states(q_states, mask_index): |
|
|
q_states = q_states.reshape((-1, q_states.size(-1))) |
|
|
q_states = q_states.index_select(0, mask_index) |
|
|
return q_states |
|
|
|
|
|
|
|
|
class UGDecoder(torch.nn.Module): |
|
|
def __init__(self, config, vocab_size): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.position_biased_input = getattr(config, 'position_biased_input', True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, ctx_layers, word_embedding, input_ids, z_states, attention_mask, \ |
|
|
encoder, target_ids=None, relative_pos=None, decode=False, s2s_idx=None): |
|
|
causal_outputs, lm_outputs = self.emd_context_layer(ctx_layers, z_states, attention_mask, |
|
|
encoder, target_ids, input_ids, |
|
|
relative_pos=relative_pos, decode=decode, |
|
|
word_embedding=word_embedding, s2s_idx=s2s_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return causal_outputs[-1], lm_outputs[-1] |
|
|
|
|
|
def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids,\ |
|
|
relative_pos=None, decode=False, word_embedding=None, s2s_idx=None): |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask.dim()<=2: |
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
att_mask = extended_attention_mask.byte() |
|
|
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1) |
|
|
elif attention_mask.dim()==3: |
|
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
if not self.position_biased_input: |
|
|
|
|
|
|
|
|
lm_outputs = [] |
|
|
|
|
|
hidden_states = encoder_layers[-2] |
|
|
layers = [encoder.layer[-1] for _ in range(2)] |
|
|
z_states += hidden_states |
|
|
query_states = z_states |
|
|
query_mask = attention_mask |
|
|
rel_embeddings = encoder.get_rel_embedding() |
|
|
for layer in layers: |
|
|
|
|
|
output = layer(hidden_states, query_mask, return_att=False, |
|
|
query_states=query_states, relative_pos=relative_pos, |
|
|
rel_embeddings=rel_embeddings) |
|
|
query_states = output |
|
|
lm_outputs.append(query_states) |
|
|
|
|
|
|
|
|
attention_mask = torch.tril(torch.ones((input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])), |
|
|
diagonal=0).to(input_ids.device) |
|
|
causal_outputs = [] |
|
|
|
|
|
target_embd = word_embedding(target_ids) |
|
|
|
|
|
target_embd += z_states.detach() |
|
|
|
|
|
output = layers[-2](target_embd, attention_mask, return_att=False, |
|
|
query_states=target_embd, relative_pos=relative_pos, |
|
|
rel_embeddings=encoder.get_rel_embedding()) |
|
|
causal_outputs.append(output) |
|
|
|
|
|
output = layers[-1](output, attention_mask, return_att=False, |
|
|
query_states=query_states, relative_pos=relative_pos, |
|
|
rel_embeddings=encoder.get_rel_embedding()) |
|
|
causal_outputs.append(output) |
|
|
|
|
|
else: |
|
|
causal_outputs = [encoder_layers[-1]] |
|
|
lm_outputs = [encoder_layers[-1]] |
|
|
return causal_outputs, lm_outputs |
|
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
|
""" |
|
|
Shift input ids one token to the right. |
|
|
""" |
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
|
|
if pad_token_id is None: |
|
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
|
|
return shifted_input_ids |
|
|
|
|
|
|
|
|
class WywLMLoss(torch.nn.Module): |
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean') |
|
|
hidden_size = getattr(config, 'embedding_size', config.hidden_size) |
|
|
self.compare = torch.nn.Linear(hidden_size * 3, 2) |
|
|
|
|
|
self.lm_head = BertLMPredictionHead(config, config.vocab_size) |
|
|
|
|
|
def forward(self, logits, lm_logits, target_ids, dict_pos, input_ids, target_ids_s2s, decode=False, ebd_weight=None, task=0): |
|
|
loss_compare = torch.tensor(0).to(logits).float() |
|
|
mlm_loss = torch.tensor(0).to(logits).float() |
|
|
lm_loss = torch.tensor(0).to(logits).float() |
|
|
|
|
|
|
|
|
if task == 1: |
|
|
compare_logits = [] |
|
|
compare_labels = [] |
|
|
for bi, sampel_pos in enumerate(dict_pos): |
|
|
num_pos = int((sampel_pos > 0).sum().detach().cpu().numpy() / 4) - 1 |
|
|
if num_pos <= 1: |
|
|
continue |
|
|
for pi in range(num_pos): |
|
|
pos = sampel_pos[pi] |
|
|
entry_logits = logits[bi][pos[0]: pos[1]] |
|
|
desc_logits = logits[bi][pos[2]: pos[3]] |
|
|
neg_num = random.randint(0, num_pos) |
|
|
ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]] |
|
|
ids_pos = input_ids[bi][pos[0]: pos[1]] |
|
|
if neg_num == pi or (ids_neg.shape == ids_pos.shape and torch.all(ids_neg == ids_pos)): |
|
|
neg_num = -1 |
|
|
for ni in range(num_pos): |
|
|
neg_num = random.randint(0, num_pos) |
|
|
ids_neg = input_ids[bi][sampel_pos[neg_num][0]: sampel_pos[neg_num][1]] |
|
|
if neg_num != pi and (ids_neg.shape != ids_pos.shape or not torch.all(ids_neg == ids_pos)): |
|
|
break |
|
|
else: |
|
|
neg_num = -1 |
|
|
if neg_num == -1: |
|
|
continue |
|
|
neg_desc_logits = logits[bi][sampel_pos[neg_num][2]: sampel_pos[neg_num][3]] |
|
|
if torch.any(torch.isnan(neg_desc_logits)): |
|
|
print('error') |
|
|
entry_logits = entry_logits.mean(dim=0, keepdim=True).float() |
|
|
desc_logits = desc_logits.mean(dim=0, keepdim=True).float() |
|
|
neg_desc_logits = neg_desc_logits.mean(dim=0, keepdim=True).float() |
|
|
compare_logits.append(torch.concat([entry_logits, desc_logits, entry_logits - desc_logits], dim=1)) |
|
|
compare_logits.append(torch.concat([entry_logits, neg_desc_logits, entry_logits - neg_desc_logits], dim=1)) |
|
|
compare_labels += [1, 0] |
|
|
if len(compare_logits) > 0: |
|
|
compare_logits = torch.concat(compare_logits, dim=0).to(logits.dtype) |
|
|
compare_pred = self.compare(compare_logits) |
|
|
loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean() |
|
|
|
|
|
if torch.all(loss_compare == 0): |
|
|
entry_logits = logits[0][0].unsqueeze(0) |
|
|
compare_logits = torch.concat([entry_logits, entry_logits, entry_logits - entry_logits], dim=1) |
|
|
compare_pred = self.compare(compare_logits) |
|
|
compare_labels = [1] |
|
|
loss_compare = self.loss_fn(compare_pred, torch.tensor(compare_labels, dtype=torch.long, device=compare_logits.device)).mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if task == 0: |
|
|
_mask_index = (target_ids_s2s > 0).view(-1).nonzero().view(-1) |
|
|
lm_logits_ = flatten_states(lm_logits, _mask_index) |
|
|
lm_pred = self.lm_head(lm_logits_, ebd_weight).float() |
|
|
lm_labels = target_ids_s2s.clone().reshape(-1) |
|
|
lm_labels = lm_labels.index_select(0, _mask_index) |
|
|
|
|
|
|
|
|
lm_loss = self.loss_fn(lm_pred, lm_labels.long()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_mask_index = (target_ids > 0).view(-1).nonzero().view(-1) |
|
|
mlm_logits = flatten_states(logits, _mask_index) |
|
|
mlm_pred = self.lm_head(mlm_logits, ebd_weight).float() |
|
|
mlm_labels = target_ids.view(-1) |
|
|
mlm_labels = mlm_labels.index_select(0, _mask_index) |
|
|
mlm_loss = self.loss_fn(mlm_pred, mlm_labels.long()) |
|
|
return loss_compare, mlm_loss, lm_loss |
|
|
|
|
|
class WywLM(torch.nn.Module): |
|
|
""" DeBERTa encoder |
|
|
This module is composed of the input embedding layer with stacked transformer layers with disentangled attention. |
|
|
|
|
|
Parameters: |
|
|
config: |
|
|
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ |
|
|
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` |
|
|
|
|
|
pre_trained: |
|
|
The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \ |
|
|
i.e. [**base, large, base_mnli, large_mnli**] |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, config=None, pre_trained=None): |
|
|
super().__init__() |
|
|
state = None |
|
|
if pre_trained is not None: |
|
|
state, model_config = load_model_state(pre_trained) |
|
|
if config is not None and model_config is not None: |
|
|
for k in config.__dict__: |
|
|
if k not in ['hidden_size', |
|
|
'intermediate_size', |
|
|
'num_attention_heads', |
|
|
'num_hidden_layers', |
|
|
'vocab_size', |
|
|
'max_position_embeddings']: |
|
|
model_config.__dict__[k] = config.__dict__[k] |
|
|
config = copy.copy(model_config) |
|
|
self.embeddings = BertEmbeddings(config) |
|
|
self.encoder = BertEncoder(config) |
|
|
self.config = config |
|
|
self.pre_trained = pre_trained |
|
|
self.apply_state(state) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, output_all_encoded_layers=True, position_ids = None, return_att = False): |
|
|
""" |
|
|
Args: |
|
|
input_ids: |
|
|
a torch.LongTensor of shape [batch_size, sequence_length] \ |
|
|
with the word token indices in the vocabulary |
|
|
|
|
|
attention_mask: |
|
|
an optional parameter for input mask or attention mask. |
|
|
|
|
|
- If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \ |
|
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \ |
|
|
input sequence length in the current batch. It's the mask that we typically use for attention when \ |
|
|
a batch has varying length sentences. |
|
|
|
|
|
- If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \ |
|
|
In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence. |
|
|
|
|
|
token_type_ids: |
|
|
an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \ |
|
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \ |
|
|
a `sentence B` token (see BERT paper for more details). |
|
|
|
|
|
output_all_encoded_layers: |
|
|
whether to output results of all encoder layers, default, True |
|
|
|
|
|
Returns: |
|
|
|
|
|
- The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \ |
|
|
the last layer of stacked transformer layers |
|
|
|
|
|
- Attention matrix of self-attention layers if `return_att=True` |
|
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
# Batch of wordPiece token ids. |
|
|
# Each sample was padded with zero to the maxium length of the batch |
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) |
|
|
# Mask of valid input ids |
|
|
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) |
|
|
|
|
|
# DeBERTa model initialized with pretrained base model |
|
|
bert = DeBERTa(pre_trained='base') |
|
|
|
|
|
encoder_layers = bert(input_ids, attention_mask=attention_mask) |
|
|
|
|
|
""" |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
if token_type_ids is None: |
|
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
token_mask = torch.ones_like(input_ids) |
|
|
else: |
|
|
idxs = torch.flip(torch.cumsum(torch.flip(token_type_ids, [-1]), axis=1), [-1]) |
|
|
token_mask = idxs > 0 |
|
|
token_mask = token_mask.byte() |
|
|
ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, token_mask) |
|
|
embedding_output = ebd_output['embeddings'] |
|
|
encoder_output = self.encoder(embedding_output, |
|
|
attention_mask, |
|
|
output_all_encoded_layers=output_all_encoded_layers, return_att = return_att) |
|
|
encoder_output.update(ebd_output) |
|
|
return encoder_output |
|
|
|
|
|
def apply_state(self, state = None): |
|
|
""" Load state from previous loaded model state dictionary. |
|
|
|
|
|
Args: |
|
|
state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \ |
|
|
If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \ |
|
|
the `DeBERTa` model |
|
|
""" |
|
|
if self.pre_trained is None and state is None: |
|
|
return |
|
|
if state is None: |
|
|
state, config = load_model_state(self.pre_trained) |
|
|
self.config = config |
|
|
|
|
|
prefix = '' |
|
|
for k in state: |
|
|
if 'embeddings.' in k: |
|
|
if not k.startswith('embeddings.'): |
|
|
prefix = k[:k.index('embeddings.')] |
|
|
break |
|
|
|
|
|
missing_keys = [] |
|
|
unexpected_keys = [] |
|
|
error_msgs = [] |
|
|
self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs) |
|
|
|
|
|
|
|
|
class MaskedLanguageModel(NNModule): |
|
|
""" Masked language model |
|
|
""" |
|
|
def __init__(self, config, *wargs, **kwargs): |
|
|
super().__init__(config) |
|
|
self.backbone = WywLM(config) |
|
|
|
|
|
self.max_relative_positions = getattr(config, 'max_relative_positions', -1) |
|
|
self.position_buckets = getattr(config, 'position_buckets', -1) |
|
|
if self.max_relative_positions <1: |
|
|
self.max_relative_positions = config.max_position_embeddings |
|
|
|
|
|
self.lm_predictions = UGDecoder(self.backbone.config, self.backbone.embeddings.word_embeddings.weight.size(0)) |
|
|
self.device = None |
|
|
self.loss = WywLMLoss(config) |
|
|
|
|
|
self.apply(self.init_weights) |
|
|
|
|
|
def forward(self, samples, position_ids=None): |
|
|
task = samples['task'] |
|
|
if task == 0: |
|
|
input_ids = samples['s2s_input_ids'] |
|
|
type_ids = samples['s2s_token_type_ids'] |
|
|
attention_mask = samples['s2s_attention_mask'] |
|
|
labels = samples['s2s_masked_lm_labels'] |
|
|
dict_pos = samples['dict_pos'] |
|
|
s2s_label = samples['s2s_label'] |
|
|
else: |
|
|
input_ids = samples['input_ids'] |
|
|
type_ids = samples['token_type_ids'] |
|
|
attention_mask = samples['attention_mask'] |
|
|
labels = samples['masked_lm_labels'] |
|
|
dict_pos = samples['dict_pos'] |
|
|
s2s_label = samples['s2s_label'] |
|
|
|
|
|
if self.device is None: |
|
|
self.device = list(self.parameters())[0].device |
|
|
|
|
|
input_ids = input_ids.to(self.device) |
|
|
|
|
|
type_ids = None |
|
|
lm_labels = labels.to(self.device) |
|
|
s2s_label = s2s_label.to(self.device) |
|
|
attention_mask = attention_mask.to(self.device) |
|
|
|
|
|
encoder_output = self.backbone(input_ids, attention_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids) |
|
|
encoder_layers = encoder_output['hidden_states'] |
|
|
z_states = encoder_output['position_embeddings'] |
|
|
ctx_layer = encoder_layers[-1] |
|
|
mlm_loss = torch.tensor(0).to(ctx_layer).float() |
|
|
lm_loss = torch.tensor(0).to(ctx_layer).float() |
|
|
lm_logits = None |
|
|
label_inputs = None |
|
|
loss = torch.tensor(0).to(ctx_layer).float() |
|
|
loss_compare = torch.tensor(0).to(ctx_layer).float() |
|
|
|
|
|
ebd_weight = self.backbone.embeddings.word_embeddings.weight |
|
|
lm_logits, mlm_logits = self.lm_predictions(encoder_layers, self.backbone.embeddings.word_embeddings, |
|
|
input_ids, z_states, |
|
|
attention_mask, self.backbone.encoder, |
|
|
target_ids=lm_labels) |
|
|
|
|
|
loss_compare, mlm_loss, lm_loss = self.loss(mlm_logits, |
|
|
lm_logits, |
|
|
lm_labels, |
|
|
dict_pos, |
|
|
target_ids_s2s=s2s_label, |
|
|
decode=False, |
|
|
ebd_weight=ebd_weight, |
|
|
input_ids=input_ids, |
|
|
task=task) |
|
|
loss = loss_compare * 10 + mlm_loss + lm_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
'logits' : lm_logits, |
|
|
'labels' : lm_labels, |
|
|
's2s_label': s2s_label, |
|
|
'loss' : loss.float(), |
|
|
'loss_compare': loss_compare.float(), |
|
|
'lm_loss': lm_loss.float(), |
|
|
'mlm_loss': mlm_loss.float() |
|
|
} |
|
|
|