|
|
|
|
|
""" |
|
@file : modeling_glycebert.py |
|
@author: zijun |
|
@contact : [email protected] |
|
@date : 2020/9/6 18:50 |
|
@version: 1.0 |
|
@desc : ChineseBert Model |
|
""" |
|
import json |
|
import os |
|
import shutil |
|
import time |
|
import warnings |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.file_download import http_user_agent |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss, MSELoss |
|
from torch.nn import functional as F |
|
|
|
try: |
|
from transformers.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, BertModel |
|
except: |
|
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertOnlyMLMHead, BertPreTrainedModel, \ |
|
BertModel |
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput, \ |
|
QuestionAnsweringModelOutput, TokenClassifierOutput |
|
|
|
cache_path = Path(os.path.abspath(__file__)).parent |
|
|
|
|
|
def download_file(filename: str, path: Path): |
|
if os.path.exists(cache_path / filename): |
|
return |
|
|
|
if os.path.exists(path / filename): |
|
shutil.copyfile(path / filename, cache_path / filename) |
|
return |
|
|
|
hf_hub_download( |
|
"iioSnail/ChineseBERT-base", |
|
filename, |
|
local_dir=cache_path, |
|
user_agent=http_user_agent(), |
|
) |
|
time.sleep(0.2) |
|
|
|
|
|
class GlyceBertModel(BertModel): |
|
r""" |
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: |
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` |
|
Sequence of hidden-states at the output of the last layer of the models. |
|
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` |
|
Last layer hidden-state of the first token of the sequence (classification token) |
|
further processed by a Linear layer and a Tanh activation function. The Linear |
|
layer weights are trained from the next sentence prediction (classification) |
|
objective during Bert pretraining. This output is usually *not* a good summary |
|
of the semantic content of the input, you're often better with averaging or pooling |
|
the sequence of hidden-states for the whole input sequence. |
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) |
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) |
|
of shape ``(batch_size, sequence_length, hidden_size)``: |
|
Hidden-states of the models at the output of each layer plus the initial embedding outputs. |
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``) |
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: |
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. |
|
|
|
Examples:: |
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
models = BertModel.from_pretrained('bert-base-uncased') |
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 |
|
outputs = models(input_ids) |
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
|
|
|
""" |
|
|
|
def __init__(self, config): |
|
super(GlyceBertModel, self).__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = FusionBertEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
self.pooler = BertPooler(config) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention |
|
if the models is configured as a decoder. |
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask |
|
is used in the cross-attention if the models is configured as a decoder. |
|
Mask values selected in ``[0, 1]``: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None: |
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids, |
|
inputs_embeds=inputs_embeds |
|
) |
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class GlyceBertForMaskedLM(BertPreTrainedModel): |
|
def __init__(self, config): |
|
super(GlyceBertForMaskedLM, self).__init__(config) |
|
|
|
self.bert = GlyceBertModel(config) |
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
self.init_weights() |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
**kwargs |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Labels for computing the masked language modeling loss. |
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) |
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels |
|
in ``[0, ..., config.vocab_size]`` |
|
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): |
|
Used to hide legacy arguments that have been deprecated. |
|
""" |
|
if "masked_lm_labels" in kwargs: |
|
warnings.warn( |
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", |
|
FutureWarning, |
|
) |
|
labels = kwargs.pop("masked_lm_labels") |
|
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task." |
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
pinyin_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
prediction_scores = self.cls(sequence_output) |
|
|
|
masked_lm_loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=masked_lm_loss, |
|
logits=prediction_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class GlyceBertForSequenceClassification(BertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = GlyceBertModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
|
Labels for computing the sequence classification/regression loss. |
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. |
|
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), |
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
pinyin_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.num_labels == 1: |
|
|
|
loss_fct = MSELoss() |
|
loss = loss_fct(logits.view(-1), labels.view(-1)) |
|
else: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class GlyceBertForQuestionAnswering(BertPreTrainedModel): |
|
"""BERT model for Question Answering (span extraction). |
|
This module is composed of the BERT model with a linear layer on top of |
|
the sequence output that computes start_logits and end_logits |
|
|
|
Params: |
|
`config`: a BertConfig class instance with the configuration to build a new model. |
|
|
|
Inputs: |
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts |
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`) |
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
|
a `sentence B` token (see BERT paper for more details). |
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
|
input sequence length in the current batch. It's the mask that we typically use for attention when |
|
a batch has varying length sentences. |
|
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. |
|
Positions are clamped to the length of the sequence and position outside of the sequence are not taken |
|
into account for computing the loss. |
|
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. |
|
Positions are clamped to the length of the sequence and position outside of the sequence are not taken |
|
into account for computing the loss. |
|
|
|
Outputs: |
|
if `start_positions` and `end_positions` are not `None`: |
|
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. |
|
if `start_positions` or `end_positions` is `None`: |
|
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end |
|
position tokens of shape [batch_size, sequence_length]. |
|
|
|
Example usage: |
|
```python |
|
# Already been converted into WordPiece token ids |
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) |
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) |
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) |
|
|
|
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, |
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) |
|
|
|
model = BertForQuestionAnswering(config) |
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) |
|
``` |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = GlyceBertModel(config) |
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
start_positions=None, |
|
end_positions=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
|
Labels for position (index) of the start of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (:obj:`sequence_length`). |
|
Position outside of the sequence are not taken into account for computing the loss. |
|
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
|
Labels for position (index) of the end of the labelled span for computing the token classification loss. |
|
Positions are clamped to the length of the sequence (:obj:`sequence_length`). |
|
Position outside of the sequence are not taken into account for computing the loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
pinyin_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1) |
|
end_logits = end_logits.squeeze(-1) |
|
|
|
total_loss = None |
|
if start_positions is not None and end_positions is not None: |
|
|
|
if len(start_positions.size()) > 1: |
|
start_positions = start_positions.squeeze(-1) |
|
if len(end_positions.size()) > 1: |
|
end_positions = end_positions.squeeze(-1) |
|
|
|
ignored_index = start_logits.size(1) |
|
start_positions.clamp_(0, ignored_index) |
|
end_positions.clamp_(0, ignored_index) |
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
|
start_loss = loss_fct(start_logits, start_positions) |
|
end_loss = loss_fct(end_logits, end_positions) |
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
if not return_dict: |
|
output = (start_logits, end_logits) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
return QuestionAnsweringModelOutput( |
|
loss=total_loss, |
|
start_logits=start_logits, |
|
end_logits=end_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class GlyceBertForTokenClassification(BertPreTrainedModel): |
|
def __init__(self, config, mlp=False): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.bert = GlyceBertModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
if mlp: |
|
self.classifier = BertMLP(config) |
|
else: |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, |
|
input_ids=None, |
|
pinyin_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Labels for computing the token classification loss. |
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
pinyin_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
|
|
if attention_mask is not None: |
|
active_loss = attention_mask.view(-1) == 1 |
|
active_logits = logits.view(-1, self.num_labels) |
|
active_labels = torch.where( |
|
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) |
|
) |
|
loss = loss_fct(active_logits, active_labels) |
|
else: |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class FusionBertEmbeddings(nn.Module): |
|
""" |
|
Construct the embeddings from word, position, glyph, pinyin and token_type embeddings. |
|
""" |
|
|
|
def __init__(self, config): |
|
super(FusionBertEmbeddings, self).__init__() |
|
self.path = Path(config._name_or_path) |
|
config_path = cache_path / 'config' |
|
if not os.path.exists(config_path): |
|
os.makedirs(config_path) |
|
|
|
font_files = [] |
|
download_file("config/STFANGSO.TTF24.npy", self.path) |
|
download_file("config/STXINGKA.TTF24.npy", self.path) |
|
download_file("config/方正古隶繁体.ttf24.npy", self.path) |
|
for file in os.listdir(config_path): |
|
if file.endswith(".npy"): |
|
font_files.append(str(config_path / file)) |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size, config=config) |
|
self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files) |
|
|
|
|
|
|
|
self.glyph_map = nn.Linear(1728, config.hidden_size) |
|
self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
|
|
|
def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
|
|
word_embeddings = inputs_embeds |
|
pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) |
|
glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) |
|
|
|
concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2) |
|
inputs_embeds = self.map_fc(concat_embeddings) |
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class PinyinEmbedding(nn.Module): |
|
|
|
def __init__(self, embedding_size: int, pinyin_out_dim: int, config): |
|
""" |
|
Pinyin Embedding Module |
|
Args: |
|
embedding_size: the size of each embedding vector |
|
pinyin_out_dim: kernel number of conv |
|
""" |
|
super(PinyinEmbedding, self).__init__() |
|
download_file('config/pinyin_map.json', Path(config._name_or_path)) |
|
with open(cache_path / 'config' / 'pinyin_map.json') as fin: |
|
pinyin_dict = json.load(fin) |
|
self.pinyin_out_dim = pinyin_out_dim |
|
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size) |
|
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2, |
|
stride=1, padding=0) |
|
|
|
def forward(self, pinyin_ids): |
|
""" |
|
Args: |
|
pinyin_ids: (bs*sentence_length*pinyin_locs) |
|
|
|
Returns: |
|
pinyin_embed: (bs,sentence_length,pinyin_out_dim) |
|
""" |
|
|
|
embed = self.embedding(pinyin_ids) |
|
bs, sentence_length, pinyin_locs, embed_size = embed.shape |
|
view_embed = embed.view(-1, pinyin_locs, embed_size) |
|
input_embed = view_embed.permute(0, 2, 1) |
|
|
|
pinyin_conv = self.conv(input_embed) |
|
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) |
|
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) |
|
|
|
|
|
class BertMLP(nn.Module): |
|
def __init__(self, config, ): |
|
super().__init__() |
|
self.dense_layer = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dense_to_labels_layer = nn.Linear(config.hidden_size, config.num_labels) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, sequence_hidden_states): |
|
sequence_output = self.dense_layer(sequence_hidden_states) |
|
sequence_output = self.activation(sequence_output) |
|
sequence_output = self.dense_to_labels_layer(sequence_output) |
|
return sequence_output |
|
|
|
|
|
class GlyphEmbedding(nn.Module): |
|
"""Glyph2Image Embedding""" |
|
|
|
def __init__(self, font_npy_files: List[str]): |
|
super(GlyphEmbedding, self).__init__() |
|
font_arrays = [ |
|
np.load(np_file).astype(np.float32) for np_file in font_npy_files |
|
] |
|
self.vocab_size = font_arrays[0].shape[0] |
|
self.font_num = len(font_arrays) |
|
self.font_size = font_arrays[0].shape[-1] |
|
|
|
font_array = np.stack(font_arrays, axis=1) |
|
self.embedding = nn.Embedding( |
|
num_embeddings=self.vocab_size, |
|
embedding_dim=self.font_size ** 2 * self.font_num, |
|
_weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1])) |
|
) |
|
|
|
def forward(self, input_ids): |
|
""" |
|
get glyph images for batch inputs |
|
Args: |
|
input_ids: [batch, sentence_length] |
|
Returns: |
|
images: [batch, sentence_length, self.font_num*self.font_size*self.font_size] |
|
""" |
|
|
|
return self.embedding(input_ids) |
|
|