Spaces:
Running
Running
"""Basic model. Predicts tags for every token""" | |
from typing import Dict, Optional, List, Any | |
import numpy | |
import torch | |
import torch.nn.functional as F | |
from allennlp.data import Vocabulary | |
from allennlp.models.model import Model | |
from allennlp.modules import TimeDistributed, TextFieldEmbedder | |
from allennlp.nn import InitializerApplicator, RegularizerApplicator | |
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits | |
from allennlp.training.metrics import CategoricalAccuracy | |
from overrides import overrides | |
from torch.nn.modules.linear import Linear | |
class Seq2Labels(Model): | |
""" | |
This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then | |
predicts a tag (or couple tags) for each token in the sequence. | |
Parameters | |
---------- | |
vocab : ``Vocabulary``, required | |
A Vocabulary, required in order to compute sizes for input/output projections. | |
text_field_embedder : ``TextFieldEmbedder``, required | |
Used to embed the ``tokens`` ``TextField`` we get as input to the model. | |
encoder : ``Seq2SeqEncoder`` | |
The encoder (with its own internal stacking) that we will use in between embedding tokens | |
and predicting output tags. | |
calculate_span_f1 : ``bool``, optional (default=``None``) | |
Calculate span-level F1 metrics during training. If this is ``True``, then | |
``label_encoding`` is required. If ``None`` and | |
label_encoding is specified, this is set to ``True``. | |
If ``None`` and label_encoding is not specified, it defaults | |
to ``False``. | |
label_encoding : ``str``, optional (default=``None``) | |
Label encoding to use when calculating span f1. | |
Valid options are "BIO", "BIOUL", "IOB1", "BMES". | |
Required if ``calculate_span_f1`` is true. | |
labels_namespace : ``str``, optional (default=``labels``) | |
This is needed to compute the SpanBasedF1Measure metric, if desired. | |
Unless you did something unusual, the default value should be what you want. | |
verbose_metrics : ``bool``, optional (default = False) | |
If true, metrics will be returned per label class in addition | |
to the overall statistics. | |
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) | |
Used to initialize the model parameters. | |
regularizer : ``RegularizerApplicator``, optional (default=``None``) | |
If provided, will be used to calculate the regularization penalty during training. | |
""" | |
def __init__(self, vocab: Vocabulary, | |
text_field_embedder: TextFieldEmbedder, | |
predictor_dropout=0.0, | |
labels_namespace: str = "labels", | |
detect_namespace: str = "d_tags", | |
verbose_metrics: bool = False, | |
label_smoothing: float = 0.0, | |
confidence: float = 0.0, | |
del_confidence: float = 0.0, | |
initializer: InitializerApplicator = InitializerApplicator(), | |
regularizer: Optional[RegularizerApplicator] = None) -> None: | |
super(Seq2Labels, self).__init__(vocab, regularizer) | |
self.label_namespaces = [labels_namespace, | |
detect_namespace] | |
self.text_field_embedder = text_field_embedder | |
self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) | |
self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) | |
self.label_smoothing = label_smoothing | |
self.confidence = confidence | |
self.del_conf = del_confidence | |
self.incorr_index = self.vocab.get_token_index("INCORRECT", | |
namespace=detect_namespace) | |
self._verbose_metrics = verbose_metrics | |
self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) | |
self.tag_labels_projection_layer = TimeDistributed( | |
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) | |
self.tag_detect_projection_layer = TimeDistributed( | |
Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) | |
self.metrics = {"accuracy": CategoricalAccuracy()} | |
initializer(self) | |
def forward(self, # type: ignore | |
tokens: Dict[str, torch.LongTensor], | |
labels: torch.LongTensor = None, | |
d_tags: torch.LongTensor = None, | |
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: | |
# pylint: disable=arguments-differ | |
""" | |
Parameters | |
---------- | |
tokens : Dict[str, torch.LongTensor], required | |
The output of ``TextField.as_array()``, which should typically be passed directly to a | |
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` | |
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": | |
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used | |
for the ``TokenIndexers`` when you created the ``TextField`` representing your | |
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, | |
which knows how to combine different word representations into a single vector per | |
token in your input. | |
labels : torch.LongTensor, optional (default = None) | |
A torch tensor representing the sequence of integer gold class labels of shape | |
``(batch_size, num_tokens)``. | |
d_tags : torch.LongTensor, optional (default = None) | |
A torch tensor representing the sequence of integer gold class labels of shape | |
``(batch_size, num_tokens)``. | |
metadata : ``List[Dict[str, Any]]``, optional, (default = None) | |
metadata containing the original words in the sentence to be tagged under a 'words' key. | |
Returns | |
------- | |
An output dictionary consisting of: | |
logits : torch.FloatTensor | |
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing | |
unnormalised log probabilities of the tag classes. | |
class_probabilities : torch.FloatTensor | |
A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing | |
a distribution of the tag classes per word. | |
loss : torch.FloatTensor, optional | |
A scalar loss to be optimised. | |
""" | |
encoded_text = self.text_field_embedder(tokens) | |
batch_size, sequence_length, _ = encoded_text.size() | |
mask = get_text_field_mask(tokens) | |
logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) | |
logits_d = self.tag_detect_projection_layer(encoded_text) | |
class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( | |
[batch_size, sequence_length, self.num_labels_classes]) | |
class_probabilities_d = F.softmax(logits_d, dim=-1).view( | |
[batch_size, sequence_length, self.num_detect_classes]) | |
error_probs = class_probabilities_d[:, :, self.incorr_index] * mask | |
incorr_prob = torch.max(error_probs, dim=-1)[0] | |
probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2) | |
class_probabilities_labels += torch.FloatTensor(probability_change).repeat( | |
(batch_size, sequence_length, 1)).to(class_probabilities_labels.device) | |
output_dict = {"logits_labels": logits_labels, | |
"logits_d_tags": logits_d, | |
"class_probabilities_labels": class_probabilities_labels, | |
"class_probabilities_d_tags": class_probabilities_d, | |
"max_error_probability": incorr_prob} | |
if labels is not None and d_tags is not None: | |
loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, | |
label_smoothing=self.label_smoothing) | |
loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) | |
for metric in self.metrics.values(): | |
metric(logits_labels, labels, mask.float()) | |
metric(logits_d, d_tags, mask.float()) | |
output_dict["loss"] = loss_labels + loss_d | |
if metadata is not None: | |
output_dict["words"] = [x["words"] for x in metadata] | |
return output_dict | |
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Does a simple position-wise argmax over each token, converts indices to string labels, and | |
adds a ``"tags"`` key to the dictionary with the result. | |
""" | |
for label_namespace in self.label_namespaces: | |
all_predictions = output_dict[f'class_probabilities_{label_namespace}'] | |
all_predictions = all_predictions.cpu().data.numpy() | |
if all_predictions.ndim == 3: | |
predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] | |
else: | |
predictions_list = [all_predictions] | |
all_tags = [] | |
for predictions in predictions_list: | |
argmax_indices = numpy.argmax(predictions, axis=-1) | |
tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) | |
for x in argmax_indices] | |
all_tags.append(tags) | |
output_dict[f'{label_namespace}'] = all_tags | |
return output_dict | |
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | |
metrics_to_return = {metric_name: metric.get_metric(reset) for | |
metric_name, metric in self.metrics.items()} | |
return metrics_to_return | |