Spaces:
Runtime error
Runtime error
spitzc32
commited on
Commit
·
24d0437
1
Parent(s):
d13f1d3
Added initial structure of the model
Browse files- Dockerfile +11 -0
- app.py +26 -0
- model/.DS_Store +0 -0
- model/__init__.py +2 -0
- model/embedding/__init__.py +63 -0
- model/layer/__init__.py +652 -0
- model/layer/bioes.py +62 -0
- model/layer/crf.py +47 -0
- model/layer/lstm.py +47 -0
- model/layer/span.py +211 -0
- model/layer/viterbi.py +241 -0
- part/__init__.py +2 -0
- part/data.py +142 -0
- part/dropout.py +60 -0
- requirements.txt +9 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.layer import Bi_LSTM_CRF
|
2 |
+
from flair.data import Sentence
|
3 |
+
|
4 |
+
tagger = Bi_LSTM_CRF.load("checkpoints/best-model.pt")
|
5 |
+
|
6 |
+
def model(word: str):
|
7 |
+
"""
|
8 |
+
An function for serving the model for the PHI classification.
|
9 |
+
:param word: list of word tokens in a paragraph.
|
10 |
+
|
11 |
+
:returns: dict that contains labeled
|
12 |
+
tags their respective classification.
|
13 |
+
"""
|
14 |
+
txt = Sentence(word)
|
15 |
+
tagger.predict(txt)
|
16 |
+
labels, tags = [], []
|
17 |
+
|
18 |
+
for entity in txt.get_spans('ner'):
|
19 |
+
labels.append(entity.text)
|
20 |
+
tags.append(entity.get_label("ner").value)
|
21 |
+
|
22 |
+
return {
|
23 |
+
"labels": labels,
|
24 |
+
"tags": tags
|
25 |
+
}
|
26 |
+
|
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import model.embedding
|
2 |
+
import model.layer
|
model/embedding/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from flair.embeddings import (
|
3 |
+
TransformerWordEmbeddings,
|
4 |
+
FlairEmbeddings,
|
5 |
+
CharacterEmbeddings,
|
6 |
+
StackedEmbeddings,
|
7 |
+
OneHotEmbeddings
|
8 |
+
)
|
9 |
+
from flair.data import Sentence
|
10 |
+
|
11 |
+
|
12 |
+
class PretrainedEmbeddings():
|
13 |
+
"""
|
14 |
+
This is the implmentation of the PretrainedEmbeddings we will use to embed our own
|
15 |
+
corpus for the purpose of generating a Tensor(the pre_word_embeds) that we will pass
|
16 |
+
to the model
|
17 |
+
|
18 |
+
Plan:
|
19 |
+
* Word-level Embeddings: We will utilize BERT Based transformer word embeddings
|
20 |
+
in order to achieve more functionality
|
21 |
+
* Context-level Embeddings: We will stick to Flair Embeddings first then go check if
|
22 |
+
pooled flair is better than FlairEmbeddings
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
word_embedding: str,
|
27 |
+
forward_embedding: str,
|
28 |
+
backward_embedding: str
|
29 |
+
) -> None:
|
30 |
+
self.word_embedding = word_embedding,
|
31 |
+
self.forward_embedding = forward_embedding
|
32 |
+
self.backward_embedding = backward_embedding
|
33 |
+
|
34 |
+
|
35 |
+
def forward(self):
|
36 |
+
# Firstly, we need to call out all pretrained embeddings accessible in
|
37 |
+
# Flair for our requirement
|
38 |
+
flair_forward_embedding = FlairEmbeddings(self.forward_embedding)
|
39 |
+
flair_backward_embedding = FlairEmbeddings(self.backward_embedding)
|
40 |
+
|
41 |
+
bert_embedding = TransformerWordEmbeddings(model=self.word_embedding,
|
42 |
+
fine_tune=True,
|
43 |
+
use_context=True,)
|
44 |
+
|
45 |
+
# Next Concatenate all embeddings above
|
46 |
+
stacked_embeddings = StackedEmbeddings(
|
47 |
+
embeddings=[
|
48 |
+
flair_forward_embedding,
|
49 |
+
flair_backward_embedding,
|
50 |
+
bert_embedding,
|
51 |
+
])
|
52 |
+
|
53 |
+
return stacked_embeddings
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
model/layer/__init__.py
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import flair.nn
|
11 |
+
from part import *
|
12 |
+
from flair.data import Dictionary, Sentence
|
13 |
+
from flair.datasets import DataLoader, FlairDatapointDataset
|
14 |
+
from flair.embeddings import TokenEmbeddings
|
15 |
+
from flair.file_utils import cached_path
|
16 |
+
from flair.training_utils import store_embeddings
|
17 |
+
|
18 |
+
from model.layer.bioes import get_spans_from_bio
|
19 |
+
from model.layer.lstm import LSTM
|
20 |
+
from model.layer.crf import CRF
|
21 |
+
from model.layer.viterbi import ViterbiDecoder, ViterbiLoss
|
22 |
+
|
23 |
+
log = logging.getLogger("flair")
|
24 |
+
|
25 |
+
|
26 |
+
class Bi_LSTM_CRF(flair.nn.Classifier[Sentence]):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
embeddings: TokenEmbeddings,
|
30 |
+
tag_dictionary: Dictionary,
|
31 |
+
tag_type: str,
|
32 |
+
rnn: Optional[torch.nn.RNN] = None,
|
33 |
+
tag_format: str = "BIOES",
|
34 |
+
hidden_size: int = 256,
|
35 |
+
rnn_layers: int = 1,
|
36 |
+
bidirectional: bool = True,
|
37 |
+
use_crf: bool = True,
|
38 |
+
ave_embeddings: bool = True,
|
39 |
+
dropout: float = 0.0,
|
40 |
+
word_dropout: float = 0.05,
|
41 |
+
locked_dropout: float = 0.5,
|
42 |
+
loss_weights: Dict[str, float] = None,
|
43 |
+
init_from_state_dict: bool = False,
|
44 |
+
allow_unk_predictions: bool = False,
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
BiLSTM Span CRF class for predicting labels for single tokens. Can be parameterized by several attributes.
|
48 |
+
Span prediction is utilized if there are nested entities such as Address and Organization. Since the researchers
|
49 |
+
observed that the token are have different length for a given dataset, we made the Span useful by incorporating it
|
50 |
+
only if the data needs it.
|
51 |
+
|
52 |
+
:param embeddings: Embeddings to use during training and prediction
|
53 |
+
:param tag_dictionary: Dictionary containing all tags from corpus which can be predicted
|
54 |
+
:param tag_type: type of tag which is going to be predicted in case a corpus has multiple annotations
|
55 |
+
:param rnn: (Optional) Takes a torch.nn.Module as parameter by which you can pass a shared RNN between
|
56 |
+
different tasks.
|
57 |
+
:param hidden_size: Hidden size of RNN layer
|
58 |
+
:param rnn_layers: number of RNN layers
|
59 |
+
:param bidirectional: If True, RNN becomes bidirectional
|
60 |
+
:param use_crf: If True, use a Conditional Random Field for prediction, else linear map to tag space.
|
61 |
+
:param ave_embeddings: If True, add a linear layer on top of embeddings, if you want to imitate
|
62 |
+
fine tune non-trainable embeddings.
|
63 |
+
:param dropout: If > 0, then use dropout.
|
64 |
+
:param word_dropout: If > 0, then use word dropout.
|
65 |
+
:param locked_dropout: If > 0, then use locked dropout.
|
66 |
+
:param loss_weights: Dictionary of weights for labels for the loss function
|
67 |
+
(if any label's weight is unspecified it will default to 1.0)
|
68 |
+
:param init_from_state_dict: Indicator whether we are loading a model from state dict
|
69 |
+
since we need to transform previous models' weights into CRF instance weights
|
70 |
+
"""
|
71 |
+
super(Bi_LSTM_CRF, self).__init__()
|
72 |
+
|
73 |
+
# ----- Create the internal tag dictionary -----
|
74 |
+
self.tag_type = tag_type
|
75 |
+
self.tag_format = tag_format.upper()
|
76 |
+
if init_from_state_dict:
|
77 |
+
self.label_dictionary = tag_dictionary
|
78 |
+
else:
|
79 |
+
# span-labels need special encoding (BIO or BIOES)
|
80 |
+
if tag_dictionary.span_labels:
|
81 |
+
# the big question is whether the label dictionary should contain an UNK or not
|
82 |
+
# without UNK, we cannot evaluate on data that contains labels not seen in test
|
83 |
+
# with UNK, the model learns less well if there are no UNK examples
|
84 |
+
self.label_dictionary = Dictionary(add_unk=allow_unk_predictions)
|
85 |
+
assert self.tag_format in ["BIOES", "BIO"]
|
86 |
+
for label in tag_dictionary.get_items():
|
87 |
+
if label == "<unk>":
|
88 |
+
continue
|
89 |
+
self.label_dictionary.add_item("O")
|
90 |
+
if self.tag_format == "BIOES":
|
91 |
+
self.label_dictionary.add_item("S-" + label)
|
92 |
+
self.label_dictionary.add_item("B-" + label)
|
93 |
+
self.label_dictionary.add_item("E-" + label)
|
94 |
+
self.label_dictionary.add_item("I-" + label)
|
95 |
+
if self.tag_format == "BIO":
|
96 |
+
self.label_dictionary.add_item("B-" + label)
|
97 |
+
self.label_dictionary.add_item("I-" + label)
|
98 |
+
else:
|
99 |
+
self.label_dictionary = tag_dictionary
|
100 |
+
|
101 |
+
# is this a span prediction problem?
|
102 |
+
self.predict_spans = self._determine_if_span_prediction_problem(self.label_dictionary)
|
103 |
+
|
104 |
+
self.tagset_size = len(self.label_dictionary)
|
105 |
+
log.info(f"SequenceTagger predicts: {self.label_dictionary}")
|
106 |
+
|
107 |
+
# ----- Embeddings -----
|
108 |
+
# We set the first initial embeddings gathered from Flair
|
109 |
+
# Stacked and concatenated then ave. using Linear
|
110 |
+
self.embeddings = embeddings
|
111 |
+
embedding_dim: int = embeddings.embedding_length
|
112 |
+
|
113 |
+
# ----- Initial loss weights parameters -----
|
114 |
+
# This is for reiteration process of training.
|
115 |
+
# Initially we don't have any loss weights, but as we proceed to training,
|
116 |
+
# we get loss computations from the evaluation stage.
|
117 |
+
self.weight_dict = loss_weights
|
118 |
+
self.loss_weights = self._init_loss_weights(loss_weights) if loss_weights else None
|
119 |
+
|
120 |
+
# ----- RNN specific parameters -----
|
121 |
+
# These parameters are for setting up the self.RNN
|
122 |
+
self.hidden_size = hidden_size if not rnn else rnn.hidden_size
|
123 |
+
self.rnn_layers = rnn_layers if not rnn else rnn.num_layers
|
124 |
+
self.bidirectional = bidirectional if not rnn else rnn.bidirectional
|
125 |
+
|
126 |
+
# ----- Conditional Random Field parameters -----
|
127 |
+
self.use_crf = use_crf
|
128 |
+
# Previously trained models have been trained without an explicit CRF, thus it is required to check
|
129 |
+
# whether we are loading a model from state dict in order to skip or add START and STOP token
|
130 |
+
if use_crf and not init_from_state_dict and not self.label_dictionary.start_stop_tags_are_set():
|
131 |
+
self.label_dictionary.set_start_stop_tags()
|
132 |
+
self.tagset_size += 2
|
133 |
+
|
134 |
+
# ----- Dropout parameters -----
|
135 |
+
# dropouts
|
136 |
+
self.use_dropout: float = dropout
|
137 |
+
self.use_word_dropout: float = word_dropout
|
138 |
+
self.use_locked_dropout: float = locked_dropout
|
139 |
+
|
140 |
+
if dropout > 0.0:
|
141 |
+
self.dropout = torch.nn.Dropout(dropout)
|
142 |
+
|
143 |
+
if word_dropout > 0.0:
|
144 |
+
self.word_dropout = flair.nn.WordDropout(word_dropout)
|
145 |
+
|
146 |
+
if locked_dropout > 0.0:
|
147 |
+
self.locked_dropout = flair.nn.LockedDropout(locked_dropout)
|
148 |
+
|
149 |
+
# ----- Model layers -----
|
150 |
+
# Initialize Embedding Linear Dim for the purpose of ave them
|
151 |
+
self.ave_embeddings = ave_embeddings
|
152 |
+
if self.ave_embeddings:
|
153 |
+
self.embedding2nn = torch.nn.Linear(embedding_dim, embedding_dim)
|
154 |
+
|
155 |
+
# ----- RNN layer -----
|
156 |
+
# If shared RNN provided, else create one for model
|
157 |
+
self.rnn: torch.nn.RNN = (
|
158 |
+
rnn
|
159 |
+
if rnn
|
160 |
+
else LSTM(
|
161 |
+
rnn_layers,
|
162 |
+
hidden_size,
|
163 |
+
bidirectional,
|
164 |
+
rnn_input_dim=embedding_dim,
|
165 |
+
)
|
166 |
+
)
|
167 |
+
|
168 |
+
num_directions = 2 if self.bidirectional else 1
|
169 |
+
hidden_output_dim = self.rnn.hidden_size * num_directions
|
170 |
+
|
171 |
+
|
172 |
+
# final linear map to tag space
|
173 |
+
self.linear = torch.nn.Linear(hidden_output_dim, len(self.label_dictionary))
|
174 |
+
|
175 |
+
|
176 |
+
# the loss function is Viterbi if using CRF, else regular Cross Entropy Loss
|
177 |
+
self.loss_function = (
|
178 |
+
ViterbiLoss(self.label_dictionary)
|
179 |
+
)
|
180 |
+
|
181 |
+
# if using CRF, we also require a CRF and a Viterbi decoder
|
182 |
+
if use_crf:
|
183 |
+
self.crf = CRF(self.label_dictionary, self.tagset_size, init_from_state_dict)
|
184 |
+
self.viterbi_decoder = ViterbiDecoder(self.label_dictionary)
|
185 |
+
|
186 |
+
self.to(flair.device)
|
187 |
+
|
188 |
+
@property
|
189 |
+
def label_type(self):
|
190 |
+
return self.tag_type
|
191 |
+
|
192 |
+
def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor:
|
193 |
+
"""
|
194 |
+
Intializes the loss weights based on given dictionary:
|
195 |
+
:param loss_weights: dictionary - contains loss weights
|
196 |
+
"""
|
197 |
+
n_classes = len(self.label_dictionary)
|
198 |
+
weight_list = [1.0 for _ in range(n_classes)]
|
199 |
+
for i, tag in enumerate(self.label_dictionary.get_items()):
|
200 |
+
if tag in loss_weights.keys():
|
201 |
+
weight_list[i] = loss_weights[tag]
|
202 |
+
|
203 |
+
return torch.tensor(weight_list).to(flair.device)
|
204 |
+
|
205 |
+
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
|
206 |
+
"""
|
207 |
+
Calculates the loss of the forward propagation of the model
|
208 |
+
:param sentences: either a listof sentence or just a sentence
|
209 |
+
"""
|
210 |
+
# if there are no sentences, there is no loss
|
211 |
+
if len(sentences) == 0:
|
212 |
+
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
|
213 |
+
|
214 |
+
# forward pass to get scores
|
215 |
+
scores, gold_labels = self.forward(sentences) # type: ignore
|
216 |
+
|
217 |
+
# calculate loss given scores and labels
|
218 |
+
return self._calculate_loss(scores, gold_labels)
|
219 |
+
|
220 |
+
def forward(self, sentences: Union[List[Sentence], Sentence]):
|
221 |
+
"""
|
222 |
+
Forward propagation through network. Returns gold labels of batch in addition.
|
223 |
+
:param sentences: Batch of current sentences
|
224 |
+
"""
|
225 |
+
if not isinstance(sentences, list):
|
226 |
+
sentences = [sentences]
|
227 |
+
self.embeddings.embed(sentences)
|
228 |
+
|
229 |
+
# make a zero-padded tensor for the whole sentence
|
230 |
+
lengths, sentence_tensor = self._make_padded_tensor_for_batch(sentences)
|
231 |
+
|
232 |
+
# sort tensor in decreasing order based on lengths of sentences in batch
|
233 |
+
sorted_lengths, length_indices = lengths.sort(dim=0, descending=True)
|
234 |
+
sentences = [sentences[i] for i in length_indices]
|
235 |
+
sentence_tensor = sentence_tensor[length_indices]
|
236 |
+
|
237 |
+
# ----- Forward Propagation -----
|
238 |
+
# we get the dropout we initialize for th regularization
|
239 |
+
# of our inputs
|
240 |
+
if self.use_dropout:
|
241 |
+
sentence_tensor = self.dropout(sentence_tensor)
|
242 |
+
if self.use_word_dropout:
|
243 |
+
sentence_tensor = self.word_dropout(sentence_tensor)
|
244 |
+
if self.use_locked_dropout:
|
245 |
+
sentence_tensor = self.locked_dropout(sentence_tensor)
|
246 |
+
|
247 |
+
# Average the embeddings using Linear Transform
|
248 |
+
if self.ave_embeddings:
|
249 |
+
sentence_tensor = self.embedding2nn(sentence_tensor)
|
250 |
+
|
251 |
+
# This packs our Sentence tensor form, the process for weighting
|
252 |
+
# our LSTM model
|
253 |
+
sentence_tensor, output_lengths = self.rnn(sentence_tensor, sorted_lengths)
|
254 |
+
|
255 |
+
# Regularize our computed sentence tensor form the LSTM model
|
256 |
+
if self.use_dropout:
|
257 |
+
sentence_tensor = self.dropout(sentence_tensor)
|
258 |
+
if self.use_locked_dropout:
|
259 |
+
sentence_tensor = self.locked_dropout(sentence_tensor)
|
260 |
+
|
261 |
+
# linear map to tag space
|
262 |
+
features = self.linear(sentence_tensor)
|
263 |
+
|
264 |
+
# Depending on whether we are using CRF or a linear layer, scores is either:
|
265 |
+
# -- A tensor of shape (batch size, sequence length, tagset size, tagset size) for CRF
|
266 |
+
# -- A tensor of shape (aggregated sequence length for all sentences in batch, tagset size) for linear layer
|
267 |
+
if self.use_crf:
|
268 |
+
features = self.crf(features)
|
269 |
+
scores = (features, sorted_lengths, self.crf.transitions)
|
270 |
+
else:
|
271 |
+
scores = self._get_scores_from_features(features, sorted_lengths)
|
272 |
+
|
273 |
+
# get the gold labels
|
274 |
+
gold_labels = self._get_gold_labels(sentences)
|
275 |
+
|
276 |
+
return scores, gold_labels
|
277 |
+
|
278 |
+
def _calculate_loss(self, scores, labels) -> Tuple[torch.Tensor, int]:
|
279 |
+
|
280 |
+
if not any(labels):
|
281 |
+
return torch.tensor(0.0, requires_grad=True, device=flair.device), 1
|
282 |
+
|
283 |
+
labels = torch.tensor(
|
284 |
+
[
|
285 |
+
self.label_dictionary.get_idx_for_item(label[0])
|
286 |
+
if len(label) > 0
|
287 |
+
else self.label_dictionary.get_idx_for_item("O")
|
288 |
+
for label in labels
|
289 |
+
],
|
290 |
+
dtype=torch.long,
|
291 |
+
device=flair.device,
|
292 |
+
)
|
293 |
+
|
294 |
+
return self.loss_function(scores, labels), len(labels)
|
295 |
+
|
296 |
+
def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]:
|
297 |
+
"""
|
298 |
+
makes zero padded tensors in the shape of the max longest sentence and the embedding_length to match
|
299 |
+
the shape of the embedding in feeding to our LSTM model.
|
300 |
+
:param sentences: Batch of current sentences
|
301 |
+
"""
|
302 |
+
names = self.embeddings.get_names()
|
303 |
+
tok_lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
|
304 |
+
longest_token_sequence_in_batch: int = max(tok_lengths)
|
305 |
+
zero_tensor = torch.zeros(
|
306 |
+
self.embeddings.embedding_length * longest_token_sequence_in_batch,
|
307 |
+
dtype=torch.float,
|
308 |
+
device=flair.device,
|
309 |
+
)
|
310 |
+
all_embs = list()
|
311 |
+
for sentence in sentences:
|
312 |
+
all_embs += [emb for token in sentence for emb in token.get_each_embedding(names)]
|
313 |
+
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
|
314 |
+
|
315 |
+
if nb_padding_tokens > 0:
|
316 |
+
t = zero_tensor[: self.embeddings.embedding_length * nb_padding_tokens]
|
317 |
+
all_embs.append(t)
|
318 |
+
|
319 |
+
sentence_tensor = torch.cat(all_embs).view(
|
320 |
+
[
|
321 |
+
len(sentences),
|
322 |
+
longest_token_sequence_in_batch,
|
323 |
+
self.embeddings.embedding_length,
|
324 |
+
]
|
325 |
+
)
|
326 |
+
return torch.tensor(tok_lengths, dtype=torch.long), sentence_tensor
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor):
|
330 |
+
"""
|
331 |
+
Trims current batch tensor in shape (batch size, sequence length, tagset size) in such a way that all
|
332 |
+
pads are going to be removed.
|
333 |
+
:param features: torch.tensor containing all features from forward propagation
|
334 |
+
:param lengths: length from each sentence in batch in order to trim padding tokens
|
335 |
+
"""
|
336 |
+
features_formatted = []
|
337 |
+
for feat, lens in zip(features, lengths):
|
338 |
+
features_formatted.append(feat[:lens])
|
339 |
+
scores = torch.cat(features_formatted)
|
340 |
+
|
341 |
+
return scores
|
342 |
+
|
343 |
+
def _get_gold_labels(self, sentences: Union[List[Sentence], Sentence]):
|
344 |
+
"""
|
345 |
+
Extracts gold labels from each sentence.
|
346 |
+
:param sentences: List of sentences in batch
|
347 |
+
"""
|
348 |
+
# spans need to be encoded as token-level predictions
|
349 |
+
if self.predict_spans:
|
350 |
+
all_sentence_labels = []
|
351 |
+
for sentence in sentences:
|
352 |
+
sentence_labels = ["O"] * len(sentence)
|
353 |
+
for label in sentence.get_labels(self.label_type):
|
354 |
+
span: Span = label.data_point
|
355 |
+
if self.tag_format == "BIOES":
|
356 |
+
if len(span) == 1:
|
357 |
+
sentence_labels[span[0].idx - 1] = "S-" + label.value
|
358 |
+
else:
|
359 |
+
sentence_labels[span[0].idx - 1] = "B-" + label.value
|
360 |
+
sentence_labels[span[-1].idx - 1] = "E-" + label.value
|
361 |
+
for i in range(span[0].idx, span[-1].idx - 1):
|
362 |
+
sentence_labels[i] = "I-" + label.value
|
363 |
+
else:
|
364 |
+
sentence_labels[span[0].idx - 1] = "B-" + label.value
|
365 |
+
for i in range(span[0].idx, span[-1].idx):
|
366 |
+
sentence_labels[i] = "I-" + label.value
|
367 |
+
all_sentence_labels.extend(sentence_labels)
|
368 |
+
labels = [[label] for label in all_sentence_labels]
|
369 |
+
|
370 |
+
# all others are regular labels for each token
|
371 |
+
else:
|
372 |
+
labels = [[token.get_label(self.label_type, "O").value] for sentence in sentences for token in sentence]
|
373 |
+
|
374 |
+
return labels
|
375 |
+
|
376 |
+
def predict(
|
377 |
+
self,
|
378 |
+
sentences: Union[List[Sentence], Sentence],
|
379 |
+
mini_batch_size: int = 32,
|
380 |
+
return_probabilities_for_all_classes: bool = False,
|
381 |
+
verbose: bool = False,
|
382 |
+
label_name: Optional[str] = None,
|
383 |
+
return_loss=False,
|
384 |
+
embedding_storage_mode="none",
|
385 |
+
force_token_predictions: bool = False,
|
386 |
+
): # type: ignore
|
387 |
+
"""
|
388 |
+
Predicts labels for current batch with CRF.
|
389 |
+
:param sentences: List of sentences in batch
|
390 |
+
:param mini_batch_size: batch size for test data
|
391 |
+
:param return_probabilities_for_all_classes: Whether to return probabilites for all classes
|
392 |
+
:param verbose: whether to use progress bar
|
393 |
+
:param label_name: which label to predict
|
394 |
+
:param return_loss: whether to return loss value
|
395 |
+
:param embedding_storage_mode: determines where to store embeddings - can be "gpu", "cpu" or None.
|
396 |
+
"""
|
397 |
+
if label_name is None:
|
398 |
+
label_name = self.tag_type
|
399 |
+
|
400 |
+
with torch.no_grad():
|
401 |
+
if not sentences:
|
402 |
+
return sentences
|
403 |
+
|
404 |
+
# make sure its a list
|
405 |
+
if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset):
|
406 |
+
sentences = [sentences]
|
407 |
+
|
408 |
+
# filter empty sentences
|
409 |
+
sentences = [sentence for sentence in sentences if len(sentence) > 0]
|
410 |
+
|
411 |
+
# reverse sort all sequences by their length
|
412 |
+
reordered_sentences = sorted(sentences, key=lambda s: len(s), reverse=True)
|
413 |
+
|
414 |
+
if len(reordered_sentences) == 0:
|
415 |
+
return sentences
|
416 |
+
|
417 |
+
dataloader = DataLoader(
|
418 |
+
dataset=FlairDatapointDataset(reordered_sentences),
|
419 |
+
batch_size=mini_batch_size,
|
420 |
+
)
|
421 |
+
# progress bar for verbosity
|
422 |
+
if verbose:
|
423 |
+
dataloader = tqdm(dataloader, desc="Batch inference")
|
424 |
+
|
425 |
+
overall_loss = torch.zeros(1, device=flair.device)
|
426 |
+
batch_no = 0
|
427 |
+
label_count = 0
|
428 |
+
for batch in dataloader:
|
429 |
+
|
430 |
+
batch_no += 1
|
431 |
+
|
432 |
+
# stop if all sentences are empty
|
433 |
+
if not batch:
|
434 |
+
continue
|
435 |
+
|
436 |
+
# get features from forward propagation
|
437 |
+
features, gold_labels = self.forward(batch)
|
438 |
+
|
439 |
+
# remove previously predicted labels of this type
|
440 |
+
for sentence in batch:
|
441 |
+
sentence.remove_labels(label_name)
|
442 |
+
|
443 |
+
# if return_loss, get loss value
|
444 |
+
if return_loss:
|
445 |
+
loss = self._calculate_loss(features, gold_labels)
|
446 |
+
overall_loss += loss[0]
|
447 |
+
label_count += loss[1]
|
448 |
+
|
449 |
+
# Sort batch in same way as forward propagation
|
450 |
+
lengths = torch.LongTensor([len(sentence) for sentence in batch])
|
451 |
+
_, sort_indices = lengths.sort(dim=0, descending=True)
|
452 |
+
batch = [batch[i] for i in sort_indices]
|
453 |
+
|
454 |
+
# make predictions
|
455 |
+
if self.use_crf:
|
456 |
+
predictions, all_tags = self.viterbi_decoder.decode(
|
457 |
+
features, return_probabilities_for_all_classes, batch
|
458 |
+
)
|
459 |
+
else:
|
460 |
+
predictions, all_tags = self._standard_inference(
|
461 |
+
features, batch, return_probabilities_for_all_classes
|
462 |
+
)
|
463 |
+
|
464 |
+
# add predictions to Sentence
|
465 |
+
for sentence, sentence_predictions in zip(batch, predictions):
|
466 |
+
|
467 |
+
# BIOES-labels need to be converted to spans
|
468 |
+
if self.predict_spans and not force_token_predictions:
|
469 |
+
sentence_tags = [label[0] for label in sentence_predictions]
|
470 |
+
sentence_scores = [label[1] for label in sentence_predictions]
|
471 |
+
predicted_spans = get_spans_from_bio(sentence_tags, sentence_scores)
|
472 |
+
for predicted_span in predicted_spans:
|
473 |
+
span: Span = sentence[predicted_span[0][0] : predicted_span[0][-1] + 1]
|
474 |
+
span.add_label(label_name, value=predicted_span[2], score=predicted_span[1])
|
475 |
+
|
476 |
+
# token-labels can be added directly ("O" and legacy "_" predictions are skipped)
|
477 |
+
else:
|
478 |
+
for token, label in zip(sentence.tokens, sentence_predictions):
|
479 |
+
if label[0] in ["O", "_"]:
|
480 |
+
continue
|
481 |
+
token.add_label(typename=label_name, value=label[0], score=label[1])
|
482 |
+
|
483 |
+
# all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
|
484 |
+
for (sentence, sent_all_tags) in zip(batch, all_tags):
|
485 |
+
for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags):
|
486 |
+
token.add_tags_proba_dist(label_name, token_all_tags)
|
487 |
+
|
488 |
+
store_embeddings(sentences, storage_mode=embedding_storage_mode)
|
489 |
+
|
490 |
+
if return_loss:
|
491 |
+
return overall_loss, label_count
|
492 |
+
|
493 |
+
def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool):
|
494 |
+
"""
|
495 |
+
Softmax over emission scores from forward propagation.
|
496 |
+
:param features: sentence tensor from forward propagation
|
497 |
+
:param batch: list of sentence
|
498 |
+
:param probabilities_for_all_classes: whether to return score for each tag in tag dictionary
|
499 |
+
"""
|
500 |
+
softmax_batch = F.softmax(features, dim=1).cpu()
|
501 |
+
scores_batch, prediction_batch = torch.max(softmax_batch, dim=1)
|
502 |
+
predictions = []
|
503 |
+
all_tags = []
|
504 |
+
|
505 |
+
for sentence in batch:
|
506 |
+
scores = scores_batch[: len(sentence)]
|
507 |
+
predictions_for_sentence = prediction_batch[: len(sentence)]
|
508 |
+
predictions.append(
|
509 |
+
[
|
510 |
+
(self.label_dictionary.get_item_for_index(prediction), score.item())
|
511 |
+
for token, score, prediction in zip(sentence, scores, predictions_for_sentence)
|
512 |
+
]
|
513 |
+
)
|
514 |
+
scores_batch = scores_batch[len(sentence) :]
|
515 |
+
prediction_batch = prediction_batch[len(sentence) :]
|
516 |
+
|
517 |
+
if probabilities_for_all_classes:
|
518 |
+
lengths = [len(sentence) for sentence in batch]
|
519 |
+
all_tags = self._all_scores_for_token(batch, softmax_batch, lengths)
|
520 |
+
|
521 |
+
return predictions, all_tags
|
522 |
+
|
523 |
+
def _all_scores_for_token(self, sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
|
524 |
+
"""
|
525 |
+
Returns all scores for each tag in tag dictionary.
|
526 |
+
:param scores: Scores for current sentence.
|
527 |
+
"""
|
528 |
+
scores = scores.numpy()
|
529 |
+
tokens = [token for sentence in sentences for token in sentence]
|
530 |
+
prob_all_tags = [
|
531 |
+
[
|
532 |
+
Label(token, self.label_dictionary.get_item_for_index(score_id), score)
|
533 |
+
for score_id, score in enumerate(score_dist)
|
534 |
+
]
|
535 |
+
for score_dist, token in zip(scores, tokens)
|
536 |
+
]
|
537 |
+
|
538 |
+
prob_tags_per_sentence = []
|
539 |
+
previous = 0
|
540 |
+
for length in lengths:
|
541 |
+
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
|
542 |
+
previous = length
|
543 |
+
return prob_tags_per_sentence
|
544 |
+
|
545 |
+
def _get_state_dict(self):
|
546 |
+
"""Returns the state dictionary for this model."""
|
547 |
+
model_state = {
|
548 |
+
**super()._get_state_dict(),
|
549 |
+
"embeddings": self.embeddings,
|
550 |
+
"hidden_size": self.hidden_size,
|
551 |
+
"tag_dictionary": self.label_dictionary,
|
552 |
+
"tag_format": self.tag_format,
|
553 |
+
"tag_type": self.tag_type,
|
554 |
+
"use_crf": self.use_crf,
|
555 |
+
"rnn_layers": self.rnn_layers,
|
556 |
+
"use_dropout": self.use_dropout,
|
557 |
+
"use_word_dropout": self.use_word_dropout,
|
558 |
+
"use_locked_dropout": self.use_locked_dropout,
|
559 |
+
"ave_embeddings": self.ave_embeddings,
|
560 |
+
"weight_dict": self.weight_dict,
|
561 |
+
}
|
562 |
+
|
563 |
+
return model_state
|
564 |
+
|
565 |
+
@classmethod
|
566 |
+
def _init_model_with_state_dict(cls, state, **kwargs):
|
567 |
+
|
568 |
+
if state["use_crf"]:
|
569 |
+
if "transitions" in state["state_dict"]:
|
570 |
+
state["state_dict"]["crf.transitions"] = state["state_dict"]["transitions"]
|
571 |
+
del state["state_dict"]["transitions"]
|
572 |
+
|
573 |
+
return super()._init_model_with_state_dict(
|
574 |
+
state,
|
575 |
+
embeddings=state.get("embeddings"),
|
576 |
+
tag_dictionary=state.get("tag_dictionary"),
|
577 |
+
tag_format=state.get("tag_format", "BIOES"),
|
578 |
+
tag_type=state.get("tag_type"),
|
579 |
+
use_crf=state.get("use_crf"),
|
580 |
+
rnn_layers=state.get("rnn_layers"),
|
581 |
+
hidden_size=state.get("hidden_size"),
|
582 |
+
dropout=state.get("use_dropout", 0.0),
|
583 |
+
word_dropout=state.get("use_word_dropout", 0.0),
|
584 |
+
locked_dropout=state.get("use_locked_dropout", 0.0),
|
585 |
+
ave_embeddings=state.get("ave_embeddings", True),
|
586 |
+
loss_weights=state.get("weight_dict"),
|
587 |
+
init_from_state_dict=True,
|
588 |
+
**kwargs,
|
589 |
+
)
|
590 |
+
|
591 |
+
@staticmethod
|
592 |
+
def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
|
593 |
+
filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
|
594 |
+
if len(sentences) != len(filtered_sentences):
|
595 |
+
log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.")
|
596 |
+
return filtered_sentences
|
597 |
+
|
598 |
+
def _determine_if_span_prediction_problem(self, dictionary: Dictionary) -> bool:
|
599 |
+
for item in dictionary.get_items():
|
600 |
+
if item.startswith("B-") or item.startswith("S-") or item.startswith("I-"):
|
601 |
+
return True
|
602 |
+
return False
|
603 |
+
|
604 |
+
def _print_predictions(self, batch, gold_label_type):
|
605 |
+
|
606 |
+
lines = []
|
607 |
+
if self.predict_spans:
|
608 |
+
for datapoint in batch:
|
609 |
+
# all labels default to "O"
|
610 |
+
for token in datapoint:
|
611 |
+
token.set_label("gold_bio", "O")
|
612 |
+
token.set_label("predicted_bio", "O")
|
613 |
+
|
614 |
+
# set gold token-level
|
615 |
+
for gold_label in datapoint.get_labels(gold_label_type):
|
616 |
+
gold_span: Span = gold_label.data_point
|
617 |
+
prefix = "B-"
|
618 |
+
for token in gold_span:
|
619 |
+
token.set_label("gold_bio", prefix + gold_label.value)
|
620 |
+
prefix = "I-"
|
621 |
+
|
622 |
+
# set predicted token-level
|
623 |
+
for predicted_label in datapoint.get_labels("predicted"):
|
624 |
+
predicted_span: Span = predicted_label.data_point
|
625 |
+
prefix = "B-"
|
626 |
+
for token in predicted_span:
|
627 |
+
token.set_label("predicted_bio", prefix + predicted_label.value)
|
628 |
+
prefix = "I-"
|
629 |
+
|
630 |
+
# now print labels in CoNLL format
|
631 |
+
for token in datapoint:
|
632 |
+
eval_line = (
|
633 |
+
f"{token.text} "
|
634 |
+
f"{token.get_label('gold_bio').value} "
|
635 |
+
f"{token.get_label('predicted_bio').value}\n"
|
636 |
+
)
|
637 |
+
lines.append(eval_line)
|
638 |
+
lines.append("\n")
|
639 |
+
|
640 |
+
else:
|
641 |
+
for datapoint in batch:
|
642 |
+
# print labels in CoNLL format
|
643 |
+
for token in datapoint:
|
644 |
+
eval_line = (
|
645 |
+
f"{token.text} "
|
646 |
+
f"{token.get_label(gold_label_type).value} "
|
647 |
+
f"{token.get_label('predicted').value}\n"
|
648 |
+
)
|
649 |
+
lines.append(eval_line)
|
650 |
+
lines.append("\n")
|
651 |
+
return lines
|
652 |
+
|
model/layer/bioes.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
|
5 |
+
def get_spans_from_bio(bioes_tags, bioes_scores=None):
|
6 |
+
# add a dummy "O" to close final prediction
|
7 |
+
bioes_tags.append("O")
|
8 |
+
# return complex list
|
9 |
+
found_spans = []
|
10 |
+
# internal variables
|
11 |
+
current_tag_weights: Dict[str, float] = defaultdict(lambda: 0.0)
|
12 |
+
previous_tag = "O-"
|
13 |
+
current_span = []
|
14 |
+
current_span_scores = []
|
15 |
+
for idx, bioes_tag in enumerate(bioes_tags):
|
16 |
+
|
17 |
+
# non-set tags are OUT tags
|
18 |
+
if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_":
|
19 |
+
bioes_tag = "O-"
|
20 |
+
|
21 |
+
# anything that is not OUT is IN
|
22 |
+
in_span = False if bioes_tag == "O-" else True
|
23 |
+
|
24 |
+
# does this prediction start a new span?
|
25 |
+
starts_new_span = False
|
26 |
+
|
27 |
+
# begin and single tags start new spans
|
28 |
+
if bioes_tag[0:2] in ["B-", "S-"]:
|
29 |
+
starts_new_span = True
|
30 |
+
|
31 |
+
# in IOB format, an I tag starts a span if it follows an O or is a different span
|
32 |
+
if bioes_tag[0:2] == "I-" and previous_tag[2:] != bioes_tag[2:]:
|
33 |
+
starts_new_span = True
|
34 |
+
|
35 |
+
# single tags that change prediction start new spans
|
36 |
+
if bioes_tag[0:2] in ["S-"] and previous_tag[2:] != bioes_tag[2:]:
|
37 |
+
starts_new_span = True
|
38 |
+
|
39 |
+
# if an existing span is ended (either by reaching O or starting a new span)
|
40 |
+
if (starts_new_span or not in_span) and len(current_span) > 0:
|
41 |
+
# determine score and value
|
42 |
+
span_score = sum(current_span_scores) / len(current_span_scores)
|
43 |
+
span_value = sorted(current_tag_weights.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]
|
44 |
+
|
45 |
+
# append to result list
|
46 |
+
found_spans.append((current_span, span_score, span_value))
|
47 |
+
|
48 |
+
# reset for-loop variables for new span
|
49 |
+
current_span = []
|
50 |
+
current_span_scores = []
|
51 |
+
current_tag_weights = defaultdict(lambda: 0.0)
|
52 |
+
|
53 |
+
if in_span:
|
54 |
+
current_span.append(idx)
|
55 |
+
current_span_scores.append(bioes_scores[idx] if bioes_scores else 1.0)
|
56 |
+
weight = 1.1 if starts_new_span else 1.0
|
57 |
+
current_tag_weights[bioes_tag[2:]] += weight
|
58 |
+
|
59 |
+
# remember previous tag
|
60 |
+
previous_tag = bioes_tag
|
61 |
+
|
62 |
+
return found_spans
|
model/layer/crf.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import flair
|
4 |
+
|
5 |
+
START_TAG: str = "<START>"
|
6 |
+
STOP_TAG: str = "<STOP>"
|
7 |
+
|
8 |
+
|
9 |
+
class CRF(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Conditional Random Field Implementation according to sgrvinod and modified to not
|
12 |
+
only look at the current word, but also on the previously seen annotation.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, tag_dictionary, tagset_size: int, init_from_state_dict: bool):
|
16 |
+
"""
|
17 |
+
:param tag_dictionary: tag dictionary in order to find ID for start and stop tags
|
18 |
+
:param tagset_size: number of tag from tag dictionary
|
19 |
+
:param init_from_state_dict: whether we load pretrained model from state dict
|
20 |
+
"""
|
21 |
+
super(CRF, self).__init__()
|
22 |
+
|
23 |
+
self.tagset_size = tagset_size
|
24 |
+
# Transitions are used in the following way: transitions[to, from].
|
25 |
+
self.transitions = torch.nn.Parameter(torch.randn(tagset_size, tagset_size))
|
26 |
+
# If we are not using a pretrained model and train a fresh one, we need to set transitions from any tag
|
27 |
+
# to START-tag and from STOP-tag to any other tag to -10000.
|
28 |
+
if not init_from_state_dict:
|
29 |
+
self.transitions.detach()[tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
|
30 |
+
|
31 |
+
self.transitions.detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000
|
32 |
+
self.to(flair.device)
|
33 |
+
|
34 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
35 |
+
"""
|
36 |
+
Forward propagation of Conditional Random Field.
|
37 |
+
:param features: output from LSTM Layer in shape (batch size, seq len, hidden size)
|
38 |
+
:return: CRF scores (emission scores for each token + transitions prob from previous state) in
|
39 |
+
shape (batch_size, seq len, tagset size, tagset size)
|
40 |
+
"""
|
41 |
+
batch_size, seq_len = features.size()[:2]
|
42 |
+
|
43 |
+
emission_scores = features
|
44 |
+
emission_scores = emission_scores.unsqueeze(-1).expand(batch_size, seq_len, self.tagset_size, self.tagset_size)
|
45 |
+
|
46 |
+
crf_scores = emission_scores + self.transitions.unsqueeze(0).unsqueeze(0)
|
47 |
+
return crf_scores
|
model/layer/lstm.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
3 |
+
import flair
|
4 |
+
|
5 |
+
class LSTM(torch.nn.Module):
|
6 |
+
"""
|
7 |
+
Simple LSTM Implementation that returns the features used for (1)CRF and (2)Span Classifier
|
8 |
+
|
9 |
+
"""
|
10 |
+
def __init__(self, rnn_layers: int, hidden_size: int, bidirectional: bool, rnn_input_dim: int,):
|
11 |
+
"""
|
12 |
+
:param rnn_layers: number of rnn layers to be used, default 1
|
13 |
+
:param hidden_size: hidden size of the LSTM layer
|
14 |
+
:param bidirectional: whether we use biderectional lstm or not, default True
|
15 |
+
:param rnn_input_dim: the shape of our max sentence token and embeddings
|
16 |
+
"""
|
17 |
+
super(LSTM, self).__init__()
|
18 |
+
|
19 |
+
self.hidden_size = hidden_size
|
20 |
+
self.rnn_input_dim = rnn_input_dim
|
21 |
+
self.num_layers = rnn_layers
|
22 |
+
self.dropout = 0.0 if rnn_layers == 1 else 0.5
|
23 |
+
self.bidirectional = bidirectional
|
24 |
+
self.batch_first = True
|
25 |
+
self.lstm = torch.nn.LSTM(
|
26 |
+
self.rnn_input_dim,
|
27 |
+
self.hidden_size,
|
28 |
+
num_layers=self.num_layers,
|
29 |
+
dropout=self.dropout,
|
30 |
+
bidirectional=self.bidirectional,
|
31 |
+
batch_first=self.batch_first,
|
32 |
+
)
|
33 |
+
|
34 |
+
self.to(flair.device)
|
35 |
+
|
36 |
+
def forward(self, sentence_tensor: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor:
|
37 |
+
"""
|
38 |
+
Forward propagation of LSTM Model by packing the tensors.
|
39 |
+
:param features: output from RNN / Linear layer in shape (batch size, seq len, hidden size)
|
40 |
+
:return: CRF scores (emission scores for each token + transitions prob from previous state) in
|
41 |
+
shape (batch_size, seq len, tagset size, tagset size)
|
42 |
+
"""
|
43 |
+
packed = pack_padded_sequence(sentence_tensor, sorted_lengths, batch_first=True, enforce_sorted=False)
|
44 |
+
rnn_output, hidden = self.lstm(packed)
|
45 |
+
sentence_tensor, output_lengths = pad_packed_sequence(rnn_output, batch_first=True)
|
46 |
+
|
47 |
+
return sentence_tensor, output_lengths
|
model/layer/span.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from itertools import chain
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def enumerate_spans(n):
|
11 |
+
for i in range(n):
|
12 |
+
for j in range(i, n):
|
13 |
+
yield (i, j)
|
14 |
+
|
15 |
+
@lru_cache # type: ignore
|
16 |
+
def get_all_spans(n: int) -> torch.Tensor:
|
17 |
+
return torch.tensor(list(enumerate_spans(n)), dtype=torch.long)
|
18 |
+
|
19 |
+
|
20 |
+
class SpanClassifier(nn.Module):
|
21 |
+
num_additional_labels = 1
|
22 |
+
|
23 |
+
def __init__(self, encoder, scorer: "SpanScorer"):
|
24 |
+
super().__init__()
|
25 |
+
self.encoder = encoder
|
26 |
+
self.scorer = scorer
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self, *input_ids: Sequence[torch.Tensor]
|
30 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
31 |
+
hs, lengths = self.encoder(*input_ids)
|
32 |
+
spans = list(map(get_all_spans, lengths))
|
33 |
+
scores = self.scorer(hs, spans)
|
34 |
+
return spans, scores
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def decode(
|
38 |
+
self,
|
39 |
+
spans: Sequence[torch.Tensor],
|
40 |
+
scores: Sequence[torch.Tensor],
|
41 |
+
) -> List[List[Tuple[int, int, int]]]:
|
42 |
+
spans_flatten = torch.cat(spans)
|
43 |
+
scores_flatten = torch.cat(scores)
|
44 |
+
assert len(spans_flatten) == len(scores_flatten)
|
45 |
+
labels_flatten = scores_flatten.argmax(dim=1).cpu()
|
46 |
+
mask = labels_flatten < self.scorer.num_labels - 1
|
47 |
+
mentions = torch.hstack((spans_flatten[mask], labels_flatten[mask, None]))
|
48 |
+
|
49 |
+
output = []
|
50 |
+
offset = 0
|
51 |
+
sizes = [m.sum() for m in torch.split(mask, [len(idxs) for idxs in spans])]
|
52 |
+
for size in sizes:
|
53 |
+
output.append([tuple(m) for m in mentions[offset : offset + size].tolist()])
|
54 |
+
offset += size
|
55 |
+
return output # type: ignore
|
56 |
+
|
57 |
+
def compute_metrics(
|
58 |
+
self,
|
59 |
+
spans: Sequence[torch.Tensor],
|
60 |
+
scores: Sequence[torch.Tensor],
|
61 |
+
true_mentions: Sequence[Sequence[Tuple[int, int, int]]],
|
62 |
+
decode=True,
|
63 |
+
) -> Dict[str, Any]:
|
64 |
+
assert len(spans) == len(scores) == len(true_mentions)
|
65 |
+
num_labels = self.scorer.num_labels
|
66 |
+
true_labels = []
|
67 |
+
for spans_i, scores_i, true_mentions_i in zip(spans, scores, true_mentions):
|
68 |
+
assert len(spans_i) == len(scores_i)
|
69 |
+
span2idx = {tuple(s): idx for idx, s in enumerate(spans_i.tolist())}
|
70 |
+
labels_i = torch.full((len(spans_i),), fill_value=num_labels - 1)
|
71 |
+
for (start, end, label) in true_mentions_i:
|
72 |
+
idx = span2idx.get((start, end))
|
73 |
+
if idx is not None:
|
74 |
+
labels_i[idx] = label
|
75 |
+
true_labels.append(labels_i)
|
76 |
+
|
77 |
+
scores_flatten = torch.cat(scores)
|
78 |
+
true_labels_flatten = torch.cat(true_labels).to(scores_flatten.device)
|
79 |
+
assert len(scores_flatten) == len(true_labels_flatten)
|
80 |
+
loss = F.cross_entropy(scores_flatten, true_labels_flatten)
|
81 |
+
accuracy = categorical_accuracy(scores_flatten, true_labels_flatten)
|
82 |
+
result = {"loss": loss, "accuracy": accuracy}
|
83 |
+
|
84 |
+
if decode:
|
85 |
+
pred_mentions = self.decode(spans, scores)
|
86 |
+
tp, fn, fp = 0, 0, 0
|
87 |
+
for pred_mentions_i, true_mentions_i in zip(pred_mentions, true_mentions):
|
88 |
+
pred, gold = set(pred_mentions_i), set(true_mentions_i)
|
89 |
+
tp += len(gold & pred)
|
90 |
+
fn += len(gold - pred)
|
91 |
+
fp += len(pred - gold)
|
92 |
+
result["precision"] = (tp, tp + fp)
|
93 |
+
result["recall"] = (tp, tp + fn)
|
94 |
+
result["mentions"] = pred_mentions
|
95 |
+
|
96 |
+
return result
|
97 |
+
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def categorical_accuracy(
|
101 |
+
y: torch.Tensor, t: torch.Tensor, ignore_index: Optional[int] = None
|
102 |
+
) -> Tuple[int, int]:
|
103 |
+
pred = y.argmax(dim=1)
|
104 |
+
if ignore_index is not None:
|
105 |
+
mask = t == ignore_index
|
106 |
+
ignore_cnt = mask.sum()
|
107 |
+
pred.masked_fill_(mask, ignore_index)
|
108 |
+
count = ((pred == t).sum() - ignore_cnt).item()
|
109 |
+
total = (t.numel() - ignore_cnt).item()
|
110 |
+
else:
|
111 |
+
count = (pred == t).sum().item()
|
112 |
+
total = t.numel()
|
113 |
+
return count, total
|
114 |
+
|
115 |
+
|
116 |
+
class SpanScorer(torch.nn.Module):
|
117 |
+
def __init__(self, num_labels: int):
|
118 |
+
super().__init__()
|
119 |
+
self.num_labels = num_labels
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
|
123 |
+
):
|
124 |
+
raise NotImplementedError
|
125 |
+
|
126 |
+
|
127 |
+
class BaselineSpanScorer(SpanScorer):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
input_size: int,
|
131 |
+
num_labels: int,
|
132 |
+
mlp_units: Union[int, Sequence[int]] = 150,
|
133 |
+
mlp_dropout: float = 0.0,
|
134 |
+
feature="concat",
|
135 |
+
):
|
136 |
+
super().__init__(num_labels)
|
137 |
+
input_size *= 2 if feature == "concat" else 1
|
138 |
+
self.mlp = MLP(input_size, num_labels, mlp_units, F.relu, mlp_dropout)
|
139 |
+
self.feature = feature
|
140 |
+
|
141 |
+
def forward(
|
142 |
+
self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
|
143 |
+
):
|
144 |
+
max_length = xs.size(1)
|
145 |
+
xs_flatten = xs.reshape(-1, xs.size(-1))
|
146 |
+
spans_flatten = torch.cat([idxs + max_length * i for i, idxs in enumerate(spans)])
|
147 |
+
features = self._compute_feature(xs_flatten, spans_flatten)
|
148 |
+
scores = self.mlp(features)
|
149 |
+
return torch.split(scores, [len(idxs) for idxs in spans])
|
150 |
+
|
151 |
+
def _compute_feature(self, xs, spans):
|
152 |
+
if self.feature == "concat":
|
153 |
+
return xs[spans.ravel()].view(len(spans), -1)
|
154 |
+
elif self.feature == "minus":
|
155 |
+
begins, ends = spans.T
|
156 |
+
return xs[ends] - xs[begins]
|
157 |
+
else:
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
|
161 |
+
class MLP(nn.Sequential):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
in_features: int,
|
165 |
+
out_features: Optional[int],
|
166 |
+
units: Optional[Union[int, Sequence[int]]] = None,
|
167 |
+
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
168 |
+
dropout: float = 0.0,
|
169 |
+
bias: bool = True,
|
170 |
+
):
|
171 |
+
units = [units] if isinstance(units, int) else units
|
172 |
+
if not units and out_features is None:
|
173 |
+
raise ValueError("'out_features' or 'units' must be specified")
|
174 |
+
layers = []
|
175 |
+
for u in units or []:
|
176 |
+
layers.append(MLP.Layer(in_features, u, activate, dropout, bias))
|
177 |
+
in_features = u
|
178 |
+
if out_features is not None:
|
179 |
+
layers.append(MLP.Layer(in_features, out_features, None, 0.0, bias))
|
180 |
+
super().__init__(*layers)
|
181 |
+
|
182 |
+
class Layer(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_features: int,
|
186 |
+
out_features: int,
|
187 |
+
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
188 |
+
dropout: float = 0.0,
|
189 |
+
bias: bool = True,
|
190 |
+
):
|
191 |
+
super().__init__()
|
192 |
+
if activate is not None and not callable(activate):
|
193 |
+
raise TypeError("activate must be callable: type={}".format(type(activate)))
|
194 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
195 |
+
self.activate = activate
|
196 |
+
self.dropout = nn.Dropout(dropout)
|
197 |
+
|
198 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
199 |
+
h = self.linear(x)
|
200 |
+
if self.activate is not None:
|
201 |
+
h = self.activate(h)
|
202 |
+
return self.dropout(h)
|
203 |
+
|
204 |
+
def extra_repr(self) -> str:
|
205 |
+
return "{}, activate={}, dropout={}".format(
|
206 |
+
self.linear.extra_repr(), self.activate, self.dropout.p
|
207 |
+
)
|
208 |
+
|
209 |
+
def __repr__(self):
|
210 |
+
return "{}.{}({})".format(MLP.__name__, self._get_name(), self.extra_repr())
|
211 |
+
|
model/layer/viterbi.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn
|
6 |
+
from torch.nn.functional import softmax
|
7 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
8 |
+
|
9 |
+
import flair
|
10 |
+
from flair.data import Dictionary, Label, List, Sentence
|
11 |
+
|
12 |
+
START_TAG: str = "<START>"
|
13 |
+
STOP_TAG: str = "<STOP>"
|
14 |
+
|
15 |
+
|
16 |
+
class ViterbiLoss(torch.nn.Module):
|
17 |
+
"""
|
18 |
+
Calculates the loss for each sequence up to its length t.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, tag_dictionary: Dictionary):
|
22 |
+
"""
|
23 |
+
:param tag_dictionary: tag_dictionary of task
|
24 |
+
"""
|
25 |
+
super(ViterbiLoss, self).__init__()
|
26 |
+
self.tag_dictionary = tag_dictionary
|
27 |
+
self.tagset_size = len(tag_dictionary)
|
28 |
+
self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
|
29 |
+
self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)
|
30 |
+
|
31 |
+
def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor:
|
32 |
+
"""
|
33 |
+
Forward propagation of Viterbi Loss
|
34 |
+
:param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
|
35 |
+
lengths of sentences in batch, transitions from CRF
|
36 |
+
:param targets: true tags for sentences which will be converted to matrix indices.
|
37 |
+
:return: average Viterbi Loss over batch size
|
38 |
+
"""
|
39 |
+
features, lengths, transitions = features_tuple
|
40 |
+
|
41 |
+
batch_size = features.size(0)
|
42 |
+
seq_len = features.size(1)
|
43 |
+
|
44 |
+
targets, targets_matrix_indices = self._format_targets(targets, lengths)
|
45 |
+
targets_matrix_indices = torch.tensor(targets_matrix_indices, dtype=torch.long).unsqueeze(2).to(flair.device)
|
46 |
+
|
47 |
+
# scores_at_targets[range(features.shape[0]), lengths.values -1]
|
48 |
+
# Squeeze crf scores matrices in 1-dim shape and gather scores at targets by matrix indices
|
49 |
+
scores_at_targets = torch.gather(features.view(batch_size, seq_len, -1), 2, targets_matrix_indices)
|
50 |
+
scores_at_targets = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)[0]
|
51 |
+
transitions_to_stop = transitions[
|
52 |
+
np.repeat(self.stop_tag, features.shape[0]),
|
53 |
+
[target[length - 1] for target, length in zip(targets, lengths)],
|
54 |
+
]
|
55 |
+
gold_score = scores_at_targets.sum() + transitions_to_stop.sum()
|
56 |
+
|
57 |
+
scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device)
|
58 |
+
|
59 |
+
for t in range(max(lengths)):
|
60 |
+
batch_size_t = sum(
|
61 |
+
[length > t for length in lengths]
|
62 |
+
) # since batch is ordered, we can save computation time by reducing our effective batch_size
|
63 |
+
|
64 |
+
if t == 0:
|
65 |
+
# Initially, get scores from <start> tag to all other tags
|
66 |
+
scores_upto_t[:batch_size_t] = (
|
67 |
+
scores_upto_t[:batch_size_t] + features[:batch_size_t, t, :, self.start_tag]
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
# We add scores at current timestep to scores accumulated up to previous timestep, and log-sum-exp
|
71 |
+
# Remember, the cur_tag of the previous timestep is the prev_tag of this timestep
|
72 |
+
scores_upto_t[:batch_size_t] = self._log_sum_exp(
|
73 |
+
features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(1), dim=2
|
74 |
+
)
|
75 |
+
|
76 |
+
all_paths_scores = self._log_sum_exp(scores_upto_t + transitions[self.stop_tag].unsqueeze(0), dim=1).sum()
|
77 |
+
|
78 |
+
viterbi_loss = all_paths_scores - gold_score
|
79 |
+
|
80 |
+
return viterbi_loss
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def _log_sum_exp(tensor, dim):
|
84 |
+
"""
|
85 |
+
Calculates the log-sum-exponent of a tensor's dimension in a numerically stable way.
|
86 |
+
:param tensor: tensor
|
87 |
+
:param dim: dimension to calculate log-sum-exp of
|
88 |
+
:return: log-sum-exp
|
89 |
+
"""
|
90 |
+
m, _ = torch.max(tensor, dim)
|
91 |
+
m_expanded = m.unsqueeze(dim).expand_as(tensor)
|
92 |
+
return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim))
|
93 |
+
|
94 |
+
def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor):
|
95 |
+
"""
|
96 |
+
Formats targets into matrix indices.
|
97 |
+
CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for
|
98 |
+
token j + transition prob from previous token i. Means, if we think of our rows as "to tag" and our columns
|
99 |
+
as "from tag", the matrix in cell [10,5] would contain the emission score for tag 10 + transition score
|
100 |
+
from previous tag 5 and could directly be addressed through the 1-dim indices (10 + tagset_size * 5) = 70,
|
101 |
+
if our tagset consists of 12 tags.
|
102 |
+
:param targets: targets as in tag dictionary
|
103 |
+
:param lengths: lengths of sentences in batch
|
104 |
+
"""
|
105 |
+
targets_per_sentence = []
|
106 |
+
|
107 |
+
targets_list = targets.tolist()
|
108 |
+
for cut in lengths:
|
109 |
+
targets_per_sentence.append(targets_list[:cut])
|
110 |
+
targets_list = targets_list[cut:]
|
111 |
+
|
112 |
+
for t in targets_per_sentence:
|
113 |
+
t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t))
|
114 |
+
|
115 |
+
matrix_indices = list(
|
116 |
+
map(
|
117 |
+
lambda s: [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)]
|
118 |
+
+ [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)],
|
119 |
+
targets_per_sentence,
|
120 |
+
)
|
121 |
+
)
|
122 |
+
|
123 |
+
return targets_per_sentence, matrix_indices
|
124 |
+
|
125 |
+
|
126 |
+
class ViterbiDecoder:
|
127 |
+
"""
|
128 |
+
Decodes a given sequence using the Viterbi algorithm.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, tag_dictionary: Dictionary):
|
132 |
+
"""
|
133 |
+
:param tag_dictionary: Dictionary of tags for sequence labeling task
|
134 |
+
"""
|
135 |
+
self.tag_dictionary = tag_dictionary
|
136 |
+
self.tagset_size = len(tag_dictionary)
|
137 |
+
self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
|
138 |
+
self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)
|
139 |
+
|
140 |
+
def decode(
|
141 |
+
self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence]
|
142 |
+
) -> Tuple[List, List]:
|
143 |
+
"""
|
144 |
+
Decoding function returning the most likely sequence of tags.
|
145 |
+
:param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
|
146 |
+
lengths of sentence in batch, transitions of CRF
|
147 |
+
:param probabilities_for_all_classes: whether to return probabilities for all tags
|
148 |
+
:return: decoded sequences
|
149 |
+
"""
|
150 |
+
features, lengths, transitions = features_tuple
|
151 |
+
all_tags = []
|
152 |
+
|
153 |
+
batch_size = features.size(0)
|
154 |
+
seq_len = features.size(1)
|
155 |
+
|
156 |
+
# Create a tensor to hold accumulated sequence scores at each current tag
|
157 |
+
scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device)
|
158 |
+
# Create a tensor to hold back-pointers
|
159 |
+
# i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
|
160 |
+
# Let pads be the <end> tag index, since that was the last tag in the decoded sequence
|
161 |
+
backpointers = (
|
162 |
+
torch.ones((batch_size, seq_len + 1, self.tagset_size), dtype=torch.long, device=flair.device)
|
163 |
+
* self.stop_tag
|
164 |
+
)
|
165 |
+
|
166 |
+
for t in range(seq_len):
|
167 |
+
batch_size_t = sum([length > t for length in lengths]) # effective batch size (sans pads) at this timestep
|
168 |
+
terminates = [i for i, length in enumerate(lengths) if length == t + 1]
|
169 |
+
|
170 |
+
if t == 0:
|
171 |
+
scores_upto_t[:batch_size_t, t] = features[:batch_size_t, t, :, self.start_tag]
|
172 |
+
backpointers[:batch_size_t, t, :] = (
|
173 |
+
torch.ones((batch_size_t, self.tagset_size), dtype=torch.long) * self.start_tag
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
# We add scores at current timestep to scores accumulated up to previous timestep, and
|
177 |
+
# choose the previous timestep that corresponds to the max. accumulated score for each current timestep
|
178 |
+
scores_upto_t[:batch_size_t, t], backpointers[:batch_size_t, t, :] = torch.max(
|
179 |
+
features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t, t - 1].unsqueeze(1), dim=2
|
180 |
+
)
|
181 |
+
|
182 |
+
# If sentence is over, add transition to STOP-tag
|
183 |
+
if terminates:
|
184 |
+
scores_upto_t[terminates, t + 1], backpointers[terminates, t + 1, :] = torch.max(
|
185 |
+
scores_upto_t[terminates, t].unsqueeze(1) + transitions[self.stop_tag].unsqueeze(0), dim=2
|
186 |
+
)
|
187 |
+
|
188 |
+
# Decode/trace best path backwards
|
189 |
+
decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long, device=flair.device)
|
190 |
+
pointer = torch.ones((batch_size, 1), dtype=torch.long, device=flair.device) * self.stop_tag
|
191 |
+
|
192 |
+
for t in list(reversed(range(backpointers.size(1)))):
|
193 |
+
decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1)
|
194 |
+
pointer = decoded[:, t].unsqueeze(1)
|
195 |
+
|
196 |
+
# Sanity check
|
197 |
+
assert torch.equal(
|
198 |
+
decoded[:, 0], torch.ones((batch_size), dtype=torch.long, device=flair.device) * self.start_tag
|
199 |
+
)
|
200 |
+
|
201 |
+
# remove start-tag and backscore to stop-tag
|
202 |
+
scores_upto_t = scores_upto_t[:, :-1, :]
|
203 |
+
decoded = decoded[:, 1:]
|
204 |
+
|
205 |
+
# Max + Softmax to get confidence score for predicted label and append label to each token
|
206 |
+
scores = softmax(scores_upto_t, dim=2)
|
207 |
+
confidences = torch.max(scores, dim=2)
|
208 |
+
|
209 |
+
tags = []
|
210 |
+
for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths):
|
211 |
+
tags.append(
|
212 |
+
[
|
213 |
+
(self.tag_dictionary.get_item_for_index(tag), conf.item())
|
214 |
+
for tag, conf in list(zip(tag_seq, tag_seq_conf))[:length_seq]
|
215 |
+
]
|
216 |
+
)
|
217 |
+
|
218 |
+
if probabilities_for_all_classes:
|
219 |
+
all_tags = self._all_scores_for_token(scores.cpu(), lengths, sentences)
|
220 |
+
|
221 |
+
return tags, all_tags
|
222 |
+
|
223 |
+
def _all_scores_for_token(self, scores: torch.Tensor, lengths: torch.IntTensor, sentences: List[Sentence]):
|
224 |
+
"""
|
225 |
+
Returns all scores for each tag in tag dictionary.
|
226 |
+
:param scores: Scores for current sentence.
|
227 |
+
"""
|
228 |
+
scores = scores.numpy()
|
229 |
+
prob_tags_per_sentence = []
|
230 |
+
for scores_sentence, length, sentence in zip(scores, lengths, sentences):
|
231 |
+
scores_sentence = scores_sentence[:length]
|
232 |
+
prob_tags_per_sentence.append(
|
233 |
+
[
|
234 |
+
[
|
235 |
+
Label(token, self.tag_dictionary.get_item_for_index(score_id), score)
|
236 |
+
for score_id, score in enumerate(score_dist)
|
237 |
+
]
|
238 |
+
for score_dist, token in zip(scores_sentence, sentence)
|
239 |
+
]
|
240 |
+
)
|
241 |
+
return prob_tags_per_sentence
|
part/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from part.data import *
|
2 |
+
from part.dropout import *
|
part/data.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
from flair.data import _PartOfSentence, DataPoint, Label
|
3 |
+
|
4 |
+
class Token(_PartOfSentence):
|
5 |
+
"""
|
6 |
+
This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point
|
7 |
+
to its head in a dependency tree.
|
8 |
+
|
9 |
+
:param text: Single text(Token) from the sequence
|
10 |
+
:param head_id: the location of the text (For Document)
|
11 |
+
:param whitespace_after: if token has whitespace
|
12 |
+
:param start_position: what character number in document does this token start?
|
13 |
+
:param sentence: If token belongs to sentence, indicate here which var it belongs to
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
text: str,
|
19 |
+
head_id: int = None,
|
20 |
+
whitespace_after: int = 1,
|
21 |
+
start_position: int = 0,
|
22 |
+
sentence=None,
|
23 |
+
):
|
24 |
+
super().__init__(sentence=sentence)
|
25 |
+
|
26 |
+
self.form: str = text
|
27 |
+
self._internal_index: Optional[int] = None
|
28 |
+
self.head_id: Optional[int] = head_id
|
29 |
+
self.whitespace_after: int = whitespace_after
|
30 |
+
|
31 |
+
self.start_pos = start_position
|
32 |
+
self.end_pos = start_position + len(text)
|
33 |
+
|
34 |
+
self._embeddings: Dict = {}
|
35 |
+
self.tags_proba_dist: Dict[str, List[Label]] = {}
|
36 |
+
|
37 |
+
@property
|
38 |
+
def idx(self) -> int:
|
39 |
+
if isinstance(self._internal_index, int):
|
40 |
+
return self._internal_index
|
41 |
+
else:
|
42 |
+
raise ValueError
|
43 |
+
|
44 |
+
@property
|
45 |
+
def text(self):
|
46 |
+
return self.form
|
47 |
+
|
48 |
+
@property
|
49 |
+
def unlabeled_identifier(self) -> str:
|
50 |
+
return f'Token[{self.idx-1}]: "{self.text}"'
|
51 |
+
|
52 |
+
def add_tags_proba_dist(self, tag_type: str, tags: List[Label]):
|
53 |
+
self.tags_proba_dist[tag_type] = tags
|
54 |
+
|
55 |
+
def get_tags_proba_dist(self, tag_type: str) -> List[Label]:
|
56 |
+
if tag_type in self.tags_proba_dist:
|
57 |
+
return self.tags_proba_dist[tag_type]
|
58 |
+
return []
|
59 |
+
|
60 |
+
def get_head(self):
|
61 |
+
return self.sentence.get_token(self.head_id)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def start_position(self) -> int:
|
65 |
+
return self.start_pos
|
66 |
+
|
67 |
+
@property
|
68 |
+
def end_position(self) -> int:
|
69 |
+
return self.end_pos
|
70 |
+
|
71 |
+
@property
|
72 |
+
def embedding(self):
|
73 |
+
return self.get_embedding()
|
74 |
+
|
75 |
+
def __repr__(self):
|
76 |
+
return self.__str__()
|
77 |
+
|
78 |
+
def add_label(self, typename: str, value: str, score: float = 1.0):
|
79 |
+
"""
|
80 |
+
The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
|
81 |
+
Therefore, labels get added only to the Sentence if it exists
|
82 |
+
"""
|
83 |
+
if self.sentence:
|
84 |
+
super().add_label(typename=typename, value=value, score=score)
|
85 |
+
else:
|
86 |
+
DataPoint.add_label(self, typename=typename, value=value, score=score)
|
87 |
+
|
88 |
+
def set_label(self, typename: str, value: str, score: float = 1.0):
|
89 |
+
"""
|
90 |
+
The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
|
91 |
+
Therefore, labels get set only to the Sentence if it exists
|
92 |
+
"""
|
93 |
+
if self.sentence:
|
94 |
+
super().set_label(typename=typename, value=value, score=score)
|
95 |
+
else:
|
96 |
+
DataPoint.set_label(self, typename=typename, value=value, score=score)
|
97 |
+
|
98 |
+
|
99 |
+
class Span(_PartOfSentence):
|
100 |
+
"""
|
101 |
+
This class represents one textual span consisting of Tokens. It may be used for the instance that the
|
102 |
+
tokens form in a nested nature, meaning the tokens combined together forms a long phrase.
|
103 |
+
|
104 |
+
:param tokens: List of tokens in the span
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, tokens: List[Token]):
|
108 |
+
super().__init__(tokens[0].sentence)
|
109 |
+
self.tokens = tokens
|
110 |
+
super()._init_labels()
|
111 |
+
|
112 |
+
@property
|
113 |
+
def start_position(self) -> int:
|
114 |
+
return self.tokens[0].start_position
|
115 |
+
|
116 |
+
@property
|
117 |
+
def end_position(self) -> int:
|
118 |
+
return self.tokens[-1].end_position
|
119 |
+
|
120 |
+
@property
|
121 |
+
def text(self) -> str:
|
122 |
+
return " ".join([t.text for t in self.tokens])
|
123 |
+
|
124 |
+
@property
|
125 |
+
def unlabeled_identifier(self) -> str:
|
126 |
+
return f'Span[{self.tokens[0].idx -1}:{self.tokens[-1].idx}]: "{self.text}"'
|
127 |
+
|
128 |
+
def __repr__(self):
|
129 |
+
return self.__str__()
|
130 |
+
|
131 |
+
def __getitem__(self, idx: int) -> Token:
|
132 |
+
return self.tokens[idx]
|
133 |
+
|
134 |
+
def __iter__(self):
|
135 |
+
return iter(self.tokens)
|
136 |
+
|
137 |
+
def __len__(self) -> int:
|
138 |
+
return len(self.tokens)
|
139 |
+
|
140 |
+
@property
|
141 |
+
def embedding(self):
|
142 |
+
pass
|
part/dropout.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class LockedDropout(torch.nn.Module):
|
5 |
+
"""
|
6 |
+
Implementation of locked (or variational) dropout.
|
7 |
+
Randomly drops out entire parameters in embedding space.
|
8 |
+
|
9 |
+
:param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1.
|
10 |
+
:param batch_first: represent if the drop will perform in an ascending manner
|
11 |
+
:param inplace:
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False):
|
15 |
+
super(LockedDropout, self).__init__()
|
16 |
+
self.dropout_rate = dropout_rate
|
17 |
+
self.batch_first = batch_first
|
18 |
+
self.inplace = inplace
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
if not self.training or not self.dropout_rate:
|
22 |
+
return x
|
23 |
+
|
24 |
+
if not self.batch_first:
|
25 |
+
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
|
26 |
+
else:
|
27 |
+
m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate)
|
28 |
+
|
29 |
+
mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
|
30 |
+
mask = mask.expand_as(x)
|
31 |
+
return mask * x
|
32 |
+
|
33 |
+
def extra_repr(self):
|
34 |
+
inplace_str = ", inplace" if self.inplace else ""
|
35 |
+
return "p={}{}".format(self.dropout_rate, inplace_str)
|
36 |
+
|
37 |
+
|
38 |
+
class WordDropout(torch.nn.Module):
|
39 |
+
"""
|
40 |
+
Implementation of word dropout. Randomly drops out entire words
|
41 |
+
(or characters) in embedding space.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, dropout_rate=0.05, inplace=False):
|
45 |
+
super(WordDropout, self).__init__()
|
46 |
+
self.dropout_rate = dropout_rate
|
47 |
+
self.inplace = inplace
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
if not self.training or not self.dropout_rate:
|
51 |
+
return x
|
52 |
+
|
53 |
+
m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate)
|
54 |
+
|
55 |
+
mask = torch.autograd.Variable(m, requires_grad=False)
|
56 |
+
return mask * x
|
57 |
+
|
58 |
+
def extra_repr(self):
|
59 |
+
inplace_str = ", inplace" if self.inplace else ""
|
60 |
+
return "p={}{}".format(self.dropout_rate, inplace_str)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
flair
|
5 |
+
numpy
|
6 |
+
pandas
|
7 |
+
nltk
|
8 |
+
panel
|
9 |
+
hvplot
|