|
from typing import Optional, List, Tuple, Any
|
|
from collections import OrderedDict
|
|
from transformers import logging, RobertaForTokenClassification
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
from torchcrf import CRF
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
import torch
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
|
|
class NerOutput(OrderedDict):
|
|
loss: Optional[torch.FloatTensor] = torch.FloatTensor([0.0])
|
|
tags: Optional[List[int]] = []
|
|
|
|
def __getitem__(self, k):
|
|
if isinstance(k, str):
|
|
inner_dict = {k: v for (k, v) in self.items()}
|
|
return inner_dict[k]
|
|
else:
|
|
return self.to_tuple()[k]
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in self.keys() and value is not None:
|
|
super().__setitem__(name, value)
|
|
super().__setattr__(name, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
super().__setitem__(key, value)
|
|
super().__setattr__(key, value)
|
|
|
|
def to_tuple(self) -> Tuple[Any]:
|
|
return tuple(self[k] for k in self.keys())
|
|
|
|
|
|
|
|
class PhoBertSoftmax(RobertaForTokenClassification):
|
|
def __init__(self, config, **kwargs):
|
|
super(PhoBertSoftmax, self).__init__(config=config, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
|
|
label_masks=None):
|
|
seq_output = self.roberta(input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask,
|
|
head_mask=None)[0]
|
|
seq_output = self.dropout(seq_output)
|
|
logits = self.classifier(seq_output)
|
|
probs = F.log_softmax(logits, dim=2)
|
|
label_masks = label_masks.view(-1) != 0
|
|
seq_tags = torch.masked_select(torch.argmax(probs, dim=2).view(-1), label_masks).tolist()
|
|
if labels is not None:
|
|
loss_func = nn.CrossEntropyLoss()
|
|
loss = loss_func(logits.view(-1, self.num_labels), labels.view(-1))
|
|
return NerOutput(loss=loss, tags=seq_tags)
|
|
else:
|
|
return NerOutput(tags=seq_tags)
|
|
|
|
|
|
class PhoBertCrf(RobertaForTokenClassification):
|
|
def __init__(self, config):
|
|
super(PhoBertCrf, self).__init__(config=config)
|
|
self.num_labels = config.num_labels
|
|
self.crf = CRF(config.num_labels, batch_first=True)
|
|
self.init_weights()
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
|
|
label_masks=None):
|
|
seq_outputs = self.roberta(input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask,
|
|
head_mask=None)[0]
|
|
|
|
batch_size, max_len, feat_dim = seq_outputs.shape
|
|
range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1)
|
|
seq_outputs = seq_outputs[range_vector, valid_ids]
|
|
seq_outputs = self.dropout(seq_outputs)
|
|
logits = self.classifier(seq_outputs)
|
|
seq_tags = self.crf.decode(logits, mask=label_masks != 0)
|
|
|
|
if labels is not None:
|
|
log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8))
|
|
return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags)
|
|
else:
|
|
return NerOutput(tags=seq_tags)
|
|
|
|
|
|
class PhoBertLstmCrf(RobertaForTokenClassification):
|
|
def __init__(self, config):
|
|
super(PhoBertLstmCrf, self).__init__(config=config)
|
|
self.num_labels = config.num_labels
|
|
self.lstm = nn.LSTM(input_size=config.hidden_size,
|
|
hidden_size=config.hidden_size // 2,
|
|
num_layers=1,
|
|
batch_first=True,
|
|
bidirectional=True)
|
|
self.crf = CRF(config.num_labels, batch_first=True)
|
|
|
|
@staticmethod
|
|
def sort_batch(src_tensor, lengths):
|
|
"""
|
|
Sort a minibatch by the length of the sequences with the longest sequences first
|
|
return the sorted batch targes and sequence lengths.
|
|
This way the output can be used by pack_padded_sequences(...)
|
|
"""
|
|
seq_lengths, perm_idx = lengths.sort(0, descending=True)
|
|
seq_tensor = src_tensor[perm_idx]
|
|
_, reversed_idx = perm_idx.sort(0, descending=False)
|
|
return seq_tensor, seq_lengths, reversed_idx
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
|
|
label_masks=None):
|
|
seq_outputs = self.roberta(input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask,
|
|
head_mask=None)[0]
|
|
|
|
batch_size, max_len, feat_dim = seq_outputs.shape
|
|
seq_lens = torch.sum(label_masks, dim=-1)
|
|
range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1)
|
|
seq_outputs = seq_outputs[range_vector, valid_ids]
|
|
|
|
sorted_seq_outputs, sorted_seq_lens, reversed_idx = self.sort_batch(src_tensor=seq_outputs,
|
|
lengths=seq_lens)
|
|
packed_words = pack_padded_sequence(sorted_seq_outputs, sorted_seq_lens.cpu(), True)
|
|
lstm_outs, _ = self.lstm(packed_words)
|
|
lstm_outs, _ = pad_packed_sequence(lstm_outs, batch_first=True, total_length=max_len)
|
|
seq_outputs = lstm_outs[reversed_idx]
|
|
|
|
seq_outputs = self.dropout(seq_outputs)
|
|
logits = self.classifier(seq_outputs)
|
|
seq_tags = self.crf.decode(logits, mask=label_masks != 0)
|
|
|
|
if labels is not None:
|
|
log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8))
|
|
return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags)
|
|
else:
|
|
return NerOutput(tags=seq_tags) |