spitzc32 commited on
Commit
24d0437
·
1 Parent(s): d13f1d3

Added initial structure of the model

Browse files
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.layer import Bi_LSTM_CRF
2
+ from flair.data import Sentence
3
+
4
+ tagger = Bi_LSTM_CRF.load("checkpoints/best-model.pt")
5
+
6
+ def model(word: str):
7
+ """
8
+ An function for serving the model for the PHI classification.
9
+ :param word: list of word tokens in a paragraph.
10
+
11
+ :returns: dict that contains labeled
12
+ tags their respective classification.
13
+ """
14
+ txt = Sentence(word)
15
+ tagger.predict(txt)
16
+ labels, tags = [], []
17
+
18
+ for entity in txt.get_spans('ner'):
19
+ labels.append(entity.text)
20
+ tags.append(entity.get_label("ner").value)
21
+
22
+ return {
23
+ "labels": labels,
24
+ "tags": tags
25
+ }
26
+
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import model.embedding
2
+ import model.layer
model/embedding/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from flair.embeddings import (
3
+ TransformerWordEmbeddings,
4
+ FlairEmbeddings,
5
+ CharacterEmbeddings,
6
+ StackedEmbeddings,
7
+ OneHotEmbeddings
8
+ )
9
+ from flair.data import Sentence
10
+
11
+
12
+ class PretrainedEmbeddings():
13
+ """
14
+ This is the implmentation of the PretrainedEmbeddings we will use to embed our own
15
+ corpus for the purpose of generating a Tensor(the pre_word_embeds) that we will pass
16
+ to the model
17
+
18
+ Plan:
19
+ * Word-level Embeddings: We will utilize BERT Based transformer word embeddings
20
+ in order to achieve more functionality
21
+ * Context-level Embeddings: We will stick to Flair Embeddings first then go check if
22
+ pooled flair is better than FlairEmbeddings
23
+ """
24
+
25
+ def __init__(self,
26
+ word_embedding: str,
27
+ forward_embedding: str,
28
+ backward_embedding: str
29
+ ) -> None:
30
+ self.word_embedding = word_embedding,
31
+ self.forward_embedding = forward_embedding
32
+ self.backward_embedding = backward_embedding
33
+
34
+
35
+ def forward(self):
36
+ # Firstly, we need to call out all pretrained embeddings accessible in
37
+ # Flair for our requirement
38
+ flair_forward_embedding = FlairEmbeddings(self.forward_embedding)
39
+ flair_backward_embedding = FlairEmbeddings(self.backward_embedding)
40
+
41
+ bert_embedding = TransformerWordEmbeddings(model=self.word_embedding,
42
+ fine_tune=True,
43
+ use_context=True,)
44
+
45
+ # Next Concatenate all embeddings above
46
+ stacked_embeddings = StackedEmbeddings(
47
+ embeddings=[
48
+ flair_forward_embedding,
49
+ flair_backward_embedding,
50
+ bert_embedding,
51
+ ])
52
+
53
+ return stacked_embeddings
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
model/layer/__init__.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
+ from tqdm import tqdm
9
+
10
+ import flair.nn
11
+ from part import *
12
+ from flair.data import Dictionary, Sentence
13
+ from flair.datasets import DataLoader, FlairDatapointDataset
14
+ from flair.embeddings import TokenEmbeddings
15
+ from flair.file_utils import cached_path
16
+ from flair.training_utils import store_embeddings
17
+
18
+ from model.layer.bioes import get_spans_from_bio
19
+ from model.layer.lstm import LSTM
20
+ from model.layer.crf import CRF
21
+ from model.layer.viterbi import ViterbiDecoder, ViterbiLoss
22
+
23
+ log = logging.getLogger("flair")
24
+
25
+
26
+ class Bi_LSTM_CRF(flair.nn.Classifier[Sentence]):
27
+ def __init__(
28
+ self,
29
+ embeddings: TokenEmbeddings,
30
+ tag_dictionary: Dictionary,
31
+ tag_type: str,
32
+ rnn: Optional[torch.nn.RNN] = None,
33
+ tag_format: str = "BIOES",
34
+ hidden_size: int = 256,
35
+ rnn_layers: int = 1,
36
+ bidirectional: bool = True,
37
+ use_crf: bool = True,
38
+ ave_embeddings: bool = True,
39
+ dropout: float = 0.0,
40
+ word_dropout: float = 0.05,
41
+ locked_dropout: float = 0.5,
42
+ loss_weights: Dict[str, float] = None,
43
+ init_from_state_dict: bool = False,
44
+ allow_unk_predictions: bool = False,
45
+ ):
46
+ """
47
+ BiLSTM Span CRF class for predicting labels for single tokens. Can be parameterized by several attributes.
48
+ Span prediction is utilized if there are nested entities such as Address and Organization. Since the researchers
49
+ observed that the token are have different length for a given dataset, we made the Span useful by incorporating it
50
+ only if the data needs it.
51
+
52
+ :param embeddings: Embeddings to use during training and prediction
53
+ :param tag_dictionary: Dictionary containing all tags from corpus which can be predicted
54
+ :param tag_type: type of tag which is going to be predicted in case a corpus has multiple annotations
55
+ :param rnn: (Optional) Takes a torch.nn.Module as parameter by which you can pass a shared RNN between
56
+ different tasks.
57
+ :param hidden_size: Hidden size of RNN layer
58
+ :param rnn_layers: number of RNN layers
59
+ :param bidirectional: If True, RNN becomes bidirectional
60
+ :param use_crf: If True, use a Conditional Random Field for prediction, else linear map to tag space.
61
+ :param ave_embeddings: If True, add a linear layer on top of embeddings, if you want to imitate
62
+ fine tune non-trainable embeddings.
63
+ :param dropout: If > 0, then use dropout.
64
+ :param word_dropout: If > 0, then use word dropout.
65
+ :param locked_dropout: If > 0, then use locked dropout.
66
+ :param loss_weights: Dictionary of weights for labels for the loss function
67
+ (if any label's weight is unspecified it will default to 1.0)
68
+ :param init_from_state_dict: Indicator whether we are loading a model from state dict
69
+ since we need to transform previous models' weights into CRF instance weights
70
+ """
71
+ super(Bi_LSTM_CRF, self).__init__()
72
+
73
+ # ----- Create the internal tag dictionary -----
74
+ self.tag_type = tag_type
75
+ self.tag_format = tag_format.upper()
76
+ if init_from_state_dict:
77
+ self.label_dictionary = tag_dictionary
78
+ else:
79
+ # span-labels need special encoding (BIO or BIOES)
80
+ if tag_dictionary.span_labels:
81
+ # the big question is whether the label dictionary should contain an UNK or not
82
+ # without UNK, we cannot evaluate on data that contains labels not seen in test
83
+ # with UNK, the model learns less well if there are no UNK examples
84
+ self.label_dictionary = Dictionary(add_unk=allow_unk_predictions)
85
+ assert self.tag_format in ["BIOES", "BIO"]
86
+ for label in tag_dictionary.get_items():
87
+ if label == "<unk>":
88
+ continue
89
+ self.label_dictionary.add_item("O")
90
+ if self.tag_format == "BIOES":
91
+ self.label_dictionary.add_item("S-" + label)
92
+ self.label_dictionary.add_item("B-" + label)
93
+ self.label_dictionary.add_item("E-" + label)
94
+ self.label_dictionary.add_item("I-" + label)
95
+ if self.tag_format == "BIO":
96
+ self.label_dictionary.add_item("B-" + label)
97
+ self.label_dictionary.add_item("I-" + label)
98
+ else:
99
+ self.label_dictionary = tag_dictionary
100
+
101
+ # is this a span prediction problem?
102
+ self.predict_spans = self._determine_if_span_prediction_problem(self.label_dictionary)
103
+
104
+ self.tagset_size = len(self.label_dictionary)
105
+ log.info(f"SequenceTagger predicts: {self.label_dictionary}")
106
+
107
+ # ----- Embeddings -----
108
+ # We set the first initial embeddings gathered from Flair
109
+ # Stacked and concatenated then ave. using Linear
110
+ self.embeddings = embeddings
111
+ embedding_dim: int = embeddings.embedding_length
112
+
113
+ # ----- Initial loss weights parameters -----
114
+ # This is for reiteration process of training.
115
+ # Initially we don't have any loss weights, but as we proceed to training,
116
+ # we get loss computations from the evaluation stage.
117
+ self.weight_dict = loss_weights
118
+ self.loss_weights = self._init_loss_weights(loss_weights) if loss_weights else None
119
+
120
+ # ----- RNN specific parameters -----
121
+ # These parameters are for setting up the self.RNN
122
+ self.hidden_size = hidden_size if not rnn else rnn.hidden_size
123
+ self.rnn_layers = rnn_layers if not rnn else rnn.num_layers
124
+ self.bidirectional = bidirectional if not rnn else rnn.bidirectional
125
+
126
+ # ----- Conditional Random Field parameters -----
127
+ self.use_crf = use_crf
128
+ # Previously trained models have been trained without an explicit CRF, thus it is required to check
129
+ # whether we are loading a model from state dict in order to skip or add START and STOP token
130
+ if use_crf and not init_from_state_dict and not self.label_dictionary.start_stop_tags_are_set():
131
+ self.label_dictionary.set_start_stop_tags()
132
+ self.tagset_size += 2
133
+
134
+ # ----- Dropout parameters -----
135
+ # dropouts
136
+ self.use_dropout: float = dropout
137
+ self.use_word_dropout: float = word_dropout
138
+ self.use_locked_dropout: float = locked_dropout
139
+
140
+ if dropout > 0.0:
141
+ self.dropout = torch.nn.Dropout(dropout)
142
+
143
+ if word_dropout > 0.0:
144
+ self.word_dropout = flair.nn.WordDropout(word_dropout)
145
+
146
+ if locked_dropout > 0.0:
147
+ self.locked_dropout = flair.nn.LockedDropout(locked_dropout)
148
+
149
+ # ----- Model layers -----
150
+ # Initialize Embedding Linear Dim for the purpose of ave them
151
+ self.ave_embeddings = ave_embeddings
152
+ if self.ave_embeddings:
153
+ self.embedding2nn = torch.nn.Linear(embedding_dim, embedding_dim)
154
+
155
+ # ----- RNN layer -----
156
+ # If shared RNN provided, else create one for model
157
+ self.rnn: torch.nn.RNN = (
158
+ rnn
159
+ if rnn
160
+ else LSTM(
161
+ rnn_layers,
162
+ hidden_size,
163
+ bidirectional,
164
+ rnn_input_dim=embedding_dim,
165
+ )
166
+ )
167
+
168
+ num_directions = 2 if self.bidirectional else 1
169
+ hidden_output_dim = self.rnn.hidden_size * num_directions
170
+
171
+
172
+ # final linear map to tag space
173
+ self.linear = torch.nn.Linear(hidden_output_dim, len(self.label_dictionary))
174
+
175
+
176
+ # the loss function is Viterbi if using CRF, else regular Cross Entropy Loss
177
+ self.loss_function = (
178
+ ViterbiLoss(self.label_dictionary)
179
+ )
180
+
181
+ # if using CRF, we also require a CRF and a Viterbi decoder
182
+ if use_crf:
183
+ self.crf = CRF(self.label_dictionary, self.tagset_size, init_from_state_dict)
184
+ self.viterbi_decoder = ViterbiDecoder(self.label_dictionary)
185
+
186
+ self.to(flair.device)
187
+
188
+ @property
189
+ def label_type(self):
190
+ return self.tag_type
191
+
192
+ def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor:
193
+ """
194
+ Intializes the loss weights based on given dictionary:
195
+ :param loss_weights: dictionary - contains loss weights
196
+ """
197
+ n_classes = len(self.label_dictionary)
198
+ weight_list = [1.0 for _ in range(n_classes)]
199
+ for i, tag in enumerate(self.label_dictionary.get_items()):
200
+ if tag in loss_weights.keys():
201
+ weight_list[i] = loss_weights[tag]
202
+
203
+ return torch.tensor(weight_list).to(flair.device)
204
+
205
+ def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
206
+ """
207
+ Calculates the loss of the forward propagation of the model
208
+ :param sentences: either a listof sentence or just a sentence
209
+ """
210
+ # if there are no sentences, there is no loss
211
+ if len(sentences) == 0:
212
+ return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
213
+
214
+ # forward pass to get scores
215
+ scores, gold_labels = self.forward(sentences) # type: ignore
216
+
217
+ # calculate loss given scores and labels
218
+ return self._calculate_loss(scores, gold_labels)
219
+
220
+ def forward(self, sentences: Union[List[Sentence], Sentence]):
221
+ """
222
+ Forward propagation through network. Returns gold labels of batch in addition.
223
+ :param sentences: Batch of current sentences
224
+ """
225
+ if not isinstance(sentences, list):
226
+ sentences = [sentences]
227
+ self.embeddings.embed(sentences)
228
+
229
+ # make a zero-padded tensor for the whole sentence
230
+ lengths, sentence_tensor = self._make_padded_tensor_for_batch(sentences)
231
+
232
+ # sort tensor in decreasing order based on lengths of sentences in batch
233
+ sorted_lengths, length_indices = lengths.sort(dim=0, descending=True)
234
+ sentences = [sentences[i] for i in length_indices]
235
+ sentence_tensor = sentence_tensor[length_indices]
236
+
237
+ # ----- Forward Propagation -----
238
+ # we get the dropout we initialize for th regularization
239
+ # of our inputs
240
+ if self.use_dropout:
241
+ sentence_tensor = self.dropout(sentence_tensor)
242
+ if self.use_word_dropout:
243
+ sentence_tensor = self.word_dropout(sentence_tensor)
244
+ if self.use_locked_dropout:
245
+ sentence_tensor = self.locked_dropout(sentence_tensor)
246
+
247
+ # Average the embeddings using Linear Transform
248
+ if self.ave_embeddings:
249
+ sentence_tensor = self.embedding2nn(sentence_tensor)
250
+
251
+ # This packs our Sentence tensor form, the process for weighting
252
+ # our LSTM model
253
+ sentence_tensor, output_lengths = self.rnn(sentence_tensor, sorted_lengths)
254
+
255
+ # Regularize our computed sentence tensor form the LSTM model
256
+ if self.use_dropout:
257
+ sentence_tensor = self.dropout(sentence_tensor)
258
+ if self.use_locked_dropout:
259
+ sentence_tensor = self.locked_dropout(sentence_tensor)
260
+
261
+ # linear map to tag space
262
+ features = self.linear(sentence_tensor)
263
+
264
+ # Depending on whether we are using CRF or a linear layer, scores is either:
265
+ # -- A tensor of shape (batch size, sequence length, tagset size, tagset size) for CRF
266
+ # -- A tensor of shape (aggregated sequence length for all sentences in batch, tagset size) for linear layer
267
+ if self.use_crf:
268
+ features = self.crf(features)
269
+ scores = (features, sorted_lengths, self.crf.transitions)
270
+ else:
271
+ scores = self._get_scores_from_features(features, sorted_lengths)
272
+
273
+ # get the gold labels
274
+ gold_labels = self._get_gold_labels(sentences)
275
+
276
+ return scores, gold_labels
277
+
278
+ def _calculate_loss(self, scores, labels) -> Tuple[torch.Tensor, int]:
279
+
280
+ if not any(labels):
281
+ return torch.tensor(0.0, requires_grad=True, device=flair.device), 1
282
+
283
+ labels = torch.tensor(
284
+ [
285
+ self.label_dictionary.get_idx_for_item(label[0])
286
+ if len(label) > 0
287
+ else self.label_dictionary.get_idx_for_item("O")
288
+ for label in labels
289
+ ],
290
+ dtype=torch.long,
291
+ device=flair.device,
292
+ )
293
+
294
+ return self.loss_function(scores, labels), len(labels)
295
+
296
+ def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]:
297
+ """
298
+ makes zero padded tensors in the shape of the max longest sentence and the embedding_length to match
299
+ the shape of the embedding in feeding to our LSTM model.
300
+ :param sentences: Batch of current sentences
301
+ """
302
+ names = self.embeddings.get_names()
303
+ tok_lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
304
+ longest_token_sequence_in_batch: int = max(tok_lengths)
305
+ zero_tensor = torch.zeros(
306
+ self.embeddings.embedding_length * longest_token_sequence_in_batch,
307
+ dtype=torch.float,
308
+ device=flair.device,
309
+ )
310
+ all_embs = list()
311
+ for sentence in sentences:
312
+ all_embs += [emb for token in sentence for emb in token.get_each_embedding(names)]
313
+ nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
314
+
315
+ if nb_padding_tokens > 0:
316
+ t = zero_tensor[: self.embeddings.embedding_length * nb_padding_tokens]
317
+ all_embs.append(t)
318
+
319
+ sentence_tensor = torch.cat(all_embs).view(
320
+ [
321
+ len(sentences),
322
+ longest_token_sequence_in_batch,
323
+ self.embeddings.embedding_length,
324
+ ]
325
+ )
326
+ return torch.tensor(tok_lengths, dtype=torch.long), sentence_tensor
327
+
328
+ @staticmethod
329
+ def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor):
330
+ """
331
+ Trims current batch tensor in shape (batch size, sequence length, tagset size) in such a way that all
332
+ pads are going to be removed.
333
+ :param features: torch.tensor containing all features from forward propagation
334
+ :param lengths: length from each sentence in batch in order to trim padding tokens
335
+ """
336
+ features_formatted = []
337
+ for feat, lens in zip(features, lengths):
338
+ features_formatted.append(feat[:lens])
339
+ scores = torch.cat(features_formatted)
340
+
341
+ return scores
342
+
343
+ def _get_gold_labels(self, sentences: Union[List[Sentence], Sentence]):
344
+ """
345
+ Extracts gold labels from each sentence.
346
+ :param sentences: List of sentences in batch
347
+ """
348
+ # spans need to be encoded as token-level predictions
349
+ if self.predict_spans:
350
+ all_sentence_labels = []
351
+ for sentence in sentences:
352
+ sentence_labels = ["O"] * len(sentence)
353
+ for label in sentence.get_labels(self.label_type):
354
+ span: Span = label.data_point
355
+ if self.tag_format == "BIOES":
356
+ if len(span) == 1:
357
+ sentence_labels[span[0].idx - 1] = "S-" + label.value
358
+ else:
359
+ sentence_labels[span[0].idx - 1] = "B-" + label.value
360
+ sentence_labels[span[-1].idx - 1] = "E-" + label.value
361
+ for i in range(span[0].idx, span[-1].idx - 1):
362
+ sentence_labels[i] = "I-" + label.value
363
+ else:
364
+ sentence_labels[span[0].idx - 1] = "B-" + label.value
365
+ for i in range(span[0].idx, span[-1].idx):
366
+ sentence_labels[i] = "I-" + label.value
367
+ all_sentence_labels.extend(sentence_labels)
368
+ labels = [[label] for label in all_sentence_labels]
369
+
370
+ # all others are regular labels for each token
371
+ else:
372
+ labels = [[token.get_label(self.label_type, "O").value] for sentence in sentences for token in sentence]
373
+
374
+ return labels
375
+
376
+ def predict(
377
+ self,
378
+ sentences: Union[List[Sentence], Sentence],
379
+ mini_batch_size: int = 32,
380
+ return_probabilities_for_all_classes: bool = False,
381
+ verbose: bool = False,
382
+ label_name: Optional[str] = None,
383
+ return_loss=False,
384
+ embedding_storage_mode="none",
385
+ force_token_predictions: bool = False,
386
+ ): # type: ignore
387
+ """
388
+ Predicts labels for current batch with CRF.
389
+ :param sentences: List of sentences in batch
390
+ :param mini_batch_size: batch size for test data
391
+ :param return_probabilities_for_all_classes: Whether to return probabilites for all classes
392
+ :param verbose: whether to use progress bar
393
+ :param label_name: which label to predict
394
+ :param return_loss: whether to return loss value
395
+ :param embedding_storage_mode: determines where to store embeddings - can be "gpu", "cpu" or None.
396
+ """
397
+ if label_name is None:
398
+ label_name = self.tag_type
399
+
400
+ with torch.no_grad():
401
+ if not sentences:
402
+ return sentences
403
+
404
+ # make sure its a list
405
+ if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset):
406
+ sentences = [sentences]
407
+
408
+ # filter empty sentences
409
+ sentences = [sentence for sentence in sentences if len(sentence) > 0]
410
+
411
+ # reverse sort all sequences by their length
412
+ reordered_sentences = sorted(sentences, key=lambda s: len(s), reverse=True)
413
+
414
+ if len(reordered_sentences) == 0:
415
+ return sentences
416
+
417
+ dataloader = DataLoader(
418
+ dataset=FlairDatapointDataset(reordered_sentences),
419
+ batch_size=mini_batch_size,
420
+ )
421
+ # progress bar for verbosity
422
+ if verbose:
423
+ dataloader = tqdm(dataloader, desc="Batch inference")
424
+
425
+ overall_loss = torch.zeros(1, device=flair.device)
426
+ batch_no = 0
427
+ label_count = 0
428
+ for batch in dataloader:
429
+
430
+ batch_no += 1
431
+
432
+ # stop if all sentences are empty
433
+ if not batch:
434
+ continue
435
+
436
+ # get features from forward propagation
437
+ features, gold_labels = self.forward(batch)
438
+
439
+ # remove previously predicted labels of this type
440
+ for sentence in batch:
441
+ sentence.remove_labels(label_name)
442
+
443
+ # if return_loss, get loss value
444
+ if return_loss:
445
+ loss = self._calculate_loss(features, gold_labels)
446
+ overall_loss += loss[0]
447
+ label_count += loss[1]
448
+
449
+ # Sort batch in same way as forward propagation
450
+ lengths = torch.LongTensor([len(sentence) for sentence in batch])
451
+ _, sort_indices = lengths.sort(dim=0, descending=True)
452
+ batch = [batch[i] for i in sort_indices]
453
+
454
+ # make predictions
455
+ if self.use_crf:
456
+ predictions, all_tags = self.viterbi_decoder.decode(
457
+ features, return_probabilities_for_all_classes, batch
458
+ )
459
+ else:
460
+ predictions, all_tags = self._standard_inference(
461
+ features, batch, return_probabilities_for_all_classes
462
+ )
463
+
464
+ # add predictions to Sentence
465
+ for sentence, sentence_predictions in zip(batch, predictions):
466
+
467
+ # BIOES-labels need to be converted to spans
468
+ if self.predict_spans and not force_token_predictions:
469
+ sentence_tags = [label[0] for label in sentence_predictions]
470
+ sentence_scores = [label[1] for label in sentence_predictions]
471
+ predicted_spans = get_spans_from_bio(sentence_tags, sentence_scores)
472
+ for predicted_span in predicted_spans:
473
+ span: Span = sentence[predicted_span[0][0] : predicted_span[0][-1] + 1]
474
+ span.add_label(label_name, value=predicted_span[2], score=predicted_span[1])
475
+
476
+ # token-labels can be added directly ("O" and legacy "_" predictions are skipped)
477
+ else:
478
+ for token, label in zip(sentence.tokens, sentence_predictions):
479
+ if label[0] in ["O", "_"]:
480
+ continue
481
+ token.add_label(typename=label_name, value=label[0], score=label[1])
482
+
483
+ # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided
484
+ for (sentence, sent_all_tags) in zip(batch, all_tags):
485
+ for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags):
486
+ token.add_tags_proba_dist(label_name, token_all_tags)
487
+
488
+ store_embeddings(sentences, storage_mode=embedding_storage_mode)
489
+
490
+ if return_loss:
491
+ return overall_loss, label_count
492
+
493
+ def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool):
494
+ """
495
+ Softmax over emission scores from forward propagation.
496
+ :param features: sentence tensor from forward propagation
497
+ :param batch: list of sentence
498
+ :param probabilities_for_all_classes: whether to return score for each tag in tag dictionary
499
+ """
500
+ softmax_batch = F.softmax(features, dim=1).cpu()
501
+ scores_batch, prediction_batch = torch.max(softmax_batch, dim=1)
502
+ predictions = []
503
+ all_tags = []
504
+
505
+ for sentence in batch:
506
+ scores = scores_batch[: len(sentence)]
507
+ predictions_for_sentence = prediction_batch[: len(sentence)]
508
+ predictions.append(
509
+ [
510
+ (self.label_dictionary.get_item_for_index(prediction), score.item())
511
+ for token, score, prediction in zip(sentence, scores, predictions_for_sentence)
512
+ ]
513
+ )
514
+ scores_batch = scores_batch[len(sentence) :]
515
+ prediction_batch = prediction_batch[len(sentence) :]
516
+
517
+ if probabilities_for_all_classes:
518
+ lengths = [len(sentence) for sentence in batch]
519
+ all_tags = self._all_scores_for_token(batch, softmax_batch, lengths)
520
+
521
+ return predictions, all_tags
522
+
523
+ def _all_scores_for_token(self, sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]):
524
+ """
525
+ Returns all scores for each tag in tag dictionary.
526
+ :param scores: Scores for current sentence.
527
+ """
528
+ scores = scores.numpy()
529
+ tokens = [token for sentence in sentences for token in sentence]
530
+ prob_all_tags = [
531
+ [
532
+ Label(token, self.label_dictionary.get_item_for_index(score_id), score)
533
+ for score_id, score in enumerate(score_dist)
534
+ ]
535
+ for score_dist, token in zip(scores, tokens)
536
+ ]
537
+
538
+ prob_tags_per_sentence = []
539
+ previous = 0
540
+ for length in lengths:
541
+ prob_tags_per_sentence.append(prob_all_tags[previous : previous + length])
542
+ previous = length
543
+ return prob_tags_per_sentence
544
+
545
+ def _get_state_dict(self):
546
+ """Returns the state dictionary for this model."""
547
+ model_state = {
548
+ **super()._get_state_dict(),
549
+ "embeddings": self.embeddings,
550
+ "hidden_size": self.hidden_size,
551
+ "tag_dictionary": self.label_dictionary,
552
+ "tag_format": self.tag_format,
553
+ "tag_type": self.tag_type,
554
+ "use_crf": self.use_crf,
555
+ "rnn_layers": self.rnn_layers,
556
+ "use_dropout": self.use_dropout,
557
+ "use_word_dropout": self.use_word_dropout,
558
+ "use_locked_dropout": self.use_locked_dropout,
559
+ "ave_embeddings": self.ave_embeddings,
560
+ "weight_dict": self.weight_dict,
561
+ }
562
+
563
+ return model_state
564
+
565
+ @classmethod
566
+ def _init_model_with_state_dict(cls, state, **kwargs):
567
+
568
+ if state["use_crf"]:
569
+ if "transitions" in state["state_dict"]:
570
+ state["state_dict"]["crf.transitions"] = state["state_dict"]["transitions"]
571
+ del state["state_dict"]["transitions"]
572
+
573
+ return super()._init_model_with_state_dict(
574
+ state,
575
+ embeddings=state.get("embeddings"),
576
+ tag_dictionary=state.get("tag_dictionary"),
577
+ tag_format=state.get("tag_format", "BIOES"),
578
+ tag_type=state.get("tag_type"),
579
+ use_crf=state.get("use_crf"),
580
+ rnn_layers=state.get("rnn_layers"),
581
+ hidden_size=state.get("hidden_size"),
582
+ dropout=state.get("use_dropout", 0.0),
583
+ word_dropout=state.get("use_word_dropout", 0.0),
584
+ locked_dropout=state.get("use_locked_dropout", 0.0),
585
+ ave_embeddings=state.get("ave_embeddings", True),
586
+ loss_weights=state.get("weight_dict"),
587
+ init_from_state_dict=True,
588
+ **kwargs,
589
+ )
590
+
591
+ @staticmethod
592
+ def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
593
+ filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
594
+ if len(sentences) != len(filtered_sentences):
595
+ log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.")
596
+ return filtered_sentences
597
+
598
+ def _determine_if_span_prediction_problem(self, dictionary: Dictionary) -> bool:
599
+ for item in dictionary.get_items():
600
+ if item.startswith("B-") or item.startswith("S-") or item.startswith("I-"):
601
+ return True
602
+ return False
603
+
604
+ def _print_predictions(self, batch, gold_label_type):
605
+
606
+ lines = []
607
+ if self.predict_spans:
608
+ for datapoint in batch:
609
+ # all labels default to "O"
610
+ for token in datapoint:
611
+ token.set_label("gold_bio", "O")
612
+ token.set_label("predicted_bio", "O")
613
+
614
+ # set gold token-level
615
+ for gold_label in datapoint.get_labels(gold_label_type):
616
+ gold_span: Span = gold_label.data_point
617
+ prefix = "B-"
618
+ for token in gold_span:
619
+ token.set_label("gold_bio", prefix + gold_label.value)
620
+ prefix = "I-"
621
+
622
+ # set predicted token-level
623
+ for predicted_label in datapoint.get_labels("predicted"):
624
+ predicted_span: Span = predicted_label.data_point
625
+ prefix = "B-"
626
+ for token in predicted_span:
627
+ token.set_label("predicted_bio", prefix + predicted_label.value)
628
+ prefix = "I-"
629
+
630
+ # now print labels in CoNLL format
631
+ for token in datapoint:
632
+ eval_line = (
633
+ f"{token.text} "
634
+ f"{token.get_label('gold_bio').value} "
635
+ f"{token.get_label('predicted_bio').value}\n"
636
+ )
637
+ lines.append(eval_line)
638
+ lines.append("\n")
639
+
640
+ else:
641
+ for datapoint in batch:
642
+ # print labels in CoNLL format
643
+ for token in datapoint:
644
+ eval_line = (
645
+ f"{token.text} "
646
+ f"{token.get_label(gold_label_type).value} "
647
+ f"{token.get_label('predicted').value}\n"
648
+ )
649
+ lines.append(eval_line)
650
+ lines.append("\n")
651
+ return lines
652
+
model/layer/bioes.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict
3
+
4
+
5
+ def get_spans_from_bio(bioes_tags, bioes_scores=None):
6
+ # add a dummy "O" to close final prediction
7
+ bioes_tags.append("O")
8
+ # return complex list
9
+ found_spans = []
10
+ # internal variables
11
+ current_tag_weights: Dict[str, float] = defaultdict(lambda: 0.0)
12
+ previous_tag = "O-"
13
+ current_span = []
14
+ current_span_scores = []
15
+ for idx, bioes_tag in enumerate(bioes_tags):
16
+
17
+ # non-set tags are OUT tags
18
+ if bioes_tag == "" or bioes_tag == "O" or bioes_tag == "_":
19
+ bioes_tag = "O-"
20
+
21
+ # anything that is not OUT is IN
22
+ in_span = False if bioes_tag == "O-" else True
23
+
24
+ # does this prediction start a new span?
25
+ starts_new_span = False
26
+
27
+ # begin and single tags start new spans
28
+ if bioes_tag[0:2] in ["B-", "S-"]:
29
+ starts_new_span = True
30
+
31
+ # in IOB format, an I tag starts a span if it follows an O or is a different span
32
+ if bioes_tag[0:2] == "I-" and previous_tag[2:] != bioes_tag[2:]:
33
+ starts_new_span = True
34
+
35
+ # single tags that change prediction start new spans
36
+ if bioes_tag[0:2] in ["S-"] and previous_tag[2:] != bioes_tag[2:]:
37
+ starts_new_span = True
38
+
39
+ # if an existing span is ended (either by reaching O or starting a new span)
40
+ if (starts_new_span or not in_span) and len(current_span) > 0:
41
+ # determine score and value
42
+ span_score = sum(current_span_scores) / len(current_span_scores)
43
+ span_value = sorted(current_tag_weights.items(), key=lambda k_v: k_v[1], reverse=True)[0][0]
44
+
45
+ # append to result list
46
+ found_spans.append((current_span, span_score, span_value))
47
+
48
+ # reset for-loop variables for new span
49
+ current_span = []
50
+ current_span_scores = []
51
+ current_tag_weights = defaultdict(lambda: 0.0)
52
+
53
+ if in_span:
54
+ current_span.append(idx)
55
+ current_span_scores.append(bioes_scores[idx] if bioes_scores else 1.0)
56
+ weight = 1.1 if starts_new_span else 1.0
57
+ current_tag_weights[bioes_tag[2:]] += weight
58
+
59
+ # remember previous tag
60
+ previous_tag = bioes_tag
61
+
62
+ return found_spans
model/layer/crf.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import flair
4
+
5
+ START_TAG: str = "<START>"
6
+ STOP_TAG: str = "<STOP>"
7
+
8
+
9
+ class CRF(torch.nn.Module):
10
+ """
11
+ Conditional Random Field Implementation according to sgrvinod and modified to not
12
+ only look at the current word, but also on the previously seen annotation.
13
+ """
14
+
15
+ def __init__(self, tag_dictionary, tagset_size: int, init_from_state_dict: bool):
16
+ """
17
+ :param tag_dictionary: tag dictionary in order to find ID for start and stop tags
18
+ :param tagset_size: number of tag from tag dictionary
19
+ :param init_from_state_dict: whether we load pretrained model from state dict
20
+ """
21
+ super(CRF, self).__init__()
22
+
23
+ self.tagset_size = tagset_size
24
+ # Transitions are used in the following way: transitions[to, from].
25
+ self.transitions = torch.nn.Parameter(torch.randn(tagset_size, tagset_size))
26
+ # If we are not using a pretrained model and train a fresh one, we need to set transitions from any tag
27
+ # to START-tag and from STOP-tag to any other tag to -10000.
28
+ if not init_from_state_dict:
29
+ self.transitions.detach()[tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
30
+
31
+ self.transitions.detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000
32
+ self.to(flair.device)
33
+
34
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Forward propagation of Conditional Random Field.
37
+ :param features: output from LSTM Layer in shape (batch size, seq len, hidden size)
38
+ :return: CRF scores (emission scores for each token + transitions prob from previous state) in
39
+ shape (batch_size, seq len, tagset size, tagset size)
40
+ """
41
+ batch_size, seq_len = features.size()[:2]
42
+
43
+ emission_scores = features
44
+ emission_scores = emission_scores.unsqueeze(-1).expand(batch_size, seq_len, self.tagset_size, self.tagset_size)
45
+
46
+ crf_scores = emission_scores + self.transitions.unsqueeze(0).unsqueeze(0)
47
+ return crf_scores
model/layer/lstm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
3
+ import flair
4
+
5
+ class LSTM(torch.nn.Module):
6
+ """
7
+ Simple LSTM Implementation that returns the features used for (1)CRF and (2)Span Classifier
8
+
9
+ """
10
+ def __init__(self, rnn_layers: int, hidden_size: int, bidirectional: bool, rnn_input_dim: int,):
11
+ """
12
+ :param rnn_layers: number of rnn layers to be used, default 1
13
+ :param hidden_size: hidden size of the LSTM layer
14
+ :param bidirectional: whether we use biderectional lstm or not, default True
15
+ :param rnn_input_dim: the shape of our max sentence token and embeddings
16
+ """
17
+ super(LSTM, self).__init__()
18
+
19
+ self.hidden_size = hidden_size
20
+ self.rnn_input_dim = rnn_input_dim
21
+ self.num_layers = rnn_layers
22
+ self.dropout = 0.0 if rnn_layers == 1 else 0.5
23
+ self.bidirectional = bidirectional
24
+ self.batch_first = True
25
+ self.lstm = torch.nn.LSTM(
26
+ self.rnn_input_dim,
27
+ self.hidden_size,
28
+ num_layers=self.num_layers,
29
+ dropout=self.dropout,
30
+ bidirectional=self.bidirectional,
31
+ batch_first=self.batch_first,
32
+ )
33
+
34
+ self.to(flair.device)
35
+
36
+ def forward(self, sentence_tensor: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Forward propagation of LSTM Model by packing the tensors.
39
+ :param features: output from RNN / Linear layer in shape (batch size, seq len, hidden size)
40
+ :return: CRF scores (emission scores for each token + transitions prob from previous state) in
41
+ shape (batch_size, seq len, tagset size, tagset size)
42
+ """
43
+ packed = pack_padded_sequence(sentence_tensor, sorted_lengths, batch_first=True, enforce_sorted=False)
44
+ rnn_output, hidden = self.lstm(packed)
45
+ sentence_tensor, output_lengths = pad_packed_sequence(rnn_output, batch_first=True)
46
+
47
+ return sentence_tensor, output_lengths
model/layer/span.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from itertools import chain
3
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def enumerate_spans(n):
11
+ for i in range(n):
12
+ for j in range(i, n):
13
+ yield (i, j)
14
+
15
+ @lru_cache # type: ignore
16
+ def get_all_spans(n: int) -> torch.Tensor:
17
+ return torch.tensor(list(enumerate_spans(n)), dtype=torch.long)
18
+
19
+
20
+ class SpanClassifier(nn.Module):
21
+ num_additional_labels = 1
22
+
23
+ def __init__(self, encoder, scorer: "SpanScorer"):
24
+ super().__init__()
25
+ self.encoder = encoder
26
+ self.scorer = scorer
27
+
28
+ def forward(
29
+ self, *input_ids: Sequence[torch.Tensor]
30
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
31
+ hs, lengths = self.encoder(*input_ids)
32
+ spans = list(map(get_all_spans, lengths))
33
+ scores = self.scorer(hs, spans)
34
+ return spans, scores
35
+
36
+ @torch.no_grad()
37
+ def decode(
38
+ self,
39
+ spans: Sequence[torch.Tensor],
40
+ scores: Sequence[torch.Tensor],
41
+ ) -> List[List[Tuple[int, int, int]]]:
42
+ spans_flatten = torch.cat(spans)
43
+ scores_flatten = torch.cat(scores)
44
+ assert len(spans_flatten) == len(scores_flatten)
45
+ labels_flatten = scores_flatten.argmax(dim=1).cpu()
46
+ mask = labels_flatten < self.scorer.num_labels - 1
47
+ mentions = torch.hstack((spans_flatten[mask], labels_flatten[mask, None]))
48
+
49
+ output = []
50
+ offset = 0
51
+ sizes = [m.sum() for m in torch.split(mask, [len(idxs) for idxs in spans])]
52
+ for size in sizes:
53
+ output.append([tuple(m) for m in mentions[offset : offset + size].tolist()])
54
+ offset += size
55
+ return output # type: ignore
56
+
57
+ def compute_metrics(
58
+ self,
59
+ spans: Sequence[torch.Tensor],
60
+ scores: Sequence[torch.Tensor],
61
+ true_mentions: Sequence[Sequence[Tuple[int, int, int]]],
62
+ decode=True,
63
+ ) -> Dict[str, Any]:
64
+ assert len(spans) == len(scores) == len(true_mentions)
65
+ num_labels = self.scorer.num_labels
66
+ true_labels = []
67
+ for spans_i, scores_i, true_mentions_i in zip(spans, scores, true_mentions):
68
+ assert len(spans_i) == len(scores_i)
69
+ span2idx = {tuple(s): idx for idx, s in enumerate(spans_i.tolist())}
70
+ labels_i = torch.full((len(spans_i),), fill_value=num_labels - 1)
71
+ for (start, end, label) in true_mentions_i:
72
+ idx = span2idx.get((start, end))
73
+ if idx is not None:
74
+ labels_i[idx] = label
75
+ true_labels.append(labels_i)
76
+
77
+ scores_flatten = torch.cat(scores)
78
+ true_labels_flatten = torch.cat(true_labels).to(scores_flatten.device)
79
+ assert len(scores_flatten) == len(true_labels_flatten)
80
+ loss = F.cross_entropy(scores_flatten, true_labels_flatten)
81
+ accuracy = categorical_accuracy(scores_flatten, true_labels_flatten)
82
+ result = {"loss": loss, "accuracy": accuracy}
83
+
84
+ if decode:
85
+ pred_mentions = self.decode(spans, scores)
86
+ tp, fn, fp = 0, 0, 0
87
+ for pred_mentions_i, true_mentions_i in zip(pred_mentions, true_mentions):
88
+ pred, gold = set(pred_mentions_i), set(true_mentions_i)
89
+ tp += len(gold & pred)
90
+ fn += len(gold - pred)
91
+ fp += len(pred - gold)
92
+ result["precision"] = (tp, tp + fp)
93
+ result["recall"] = (tp, tp + fn)
94
+ result["mentions"] = pred_mentions
95
+
96
+ return result
97
+
98
+
99
+ @torch.no_grad()
100
+ def categorical_accuracy(
101
+ y: torch.Tensor, t: torch.Tensor, ignore_index: Optional[int] = None
102
+ ) -> Tuple[int, int]:
103
+ pred = y.argmax(dim=1)
104
+ if ignore_index is not None:
105
+ mask = t == ignore_index
106
+ ignore_cnt = mask.sum()
107
+ pred.masked_fill_(mask, ignore_index)
108
+ count = ((pred == t).sum() - ignore_cnt).item()
109
+ total = (t.numel() - ignore_cnt).item()
110
+ else:
111
+ count = (pred == t).sum().item()
112
+ total = t.numel()
113
+ return count, total
114
+
115
+
116
+ class SpanScorer(torch.nn.Module):
117
+ def __init__(self, num_labels: int):
118
+ super().__init__()
119
+ self.num_labels = num_labels
120
+
121
+ def forward(
122
+ self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
123
+ ):
124
+ raise NotImplementedError
125
+
126
+
127
+ class BaselineSpanScorer(SpanScorer):
128
+ def __init__(
129
+ self,
130
+ input_size: int,
131
+ num_labels: int,
132
+ mlp_units: Union[int, Sequence[int]] = 150,
133
+ mlp_dropout: float = 0.0,
134
+ feature="concat",
135
+ ):
136
+ super().__init__(num_labels)
137
+ input_size *= 2 if feature == "concat" else 1
138
+ self.mlp = MLP(input_size, num_labels, mlp_units, F.relu, mlp_dropout)
139
+ self.feature = feature
140
+
141
+ def forward(
142
+ self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
143
+ ):
144
+ max_length = xs.size(1)
145
+ xs_flatten = xs.reshape(-1, xs.size(-1))
146
+ spans_flatten = torch.cat([idxs + max_length * i for i, idxs in enumerate(spans)])
147
+ features = self._compute_feature(xs_flatten, spans_flatten)
148
+ scores = self.mlp(features)
149
+ return torch.split(scores, [len(idxs) for idxs in spans])
150
+
151
+ def _compute_feature(self, xs, spans):
152
+ if self.feature == "concat":
153
+ return xs[spans.ravel()].view(len(spans), -1)
154
+ elif self.feature == "minus":
155
+ begins, ends = spans.T
156
+ return xs[ends] - xs[begins]
157
+ else:
158
+ raise NotImplementedError
159
+
160
+
161
+ class MLP(nn.Sequential):
162
+ def __init__(
163
+ self,
164
+ in_features: int,
165
+ out_features: Optional[int],
166
+ units: Optional[Union[int, Sequence[int]]] = None,
167
+ activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
168
+ dropout: float = 0.0,
169
+ bias: bool = True,
170
+ ):
171
+ units = [units] if isinstance(units, int) else units
172
+ if not units and out_features is None:
173
+ raise ValueError("'out_features' or 'units' must be specified")
174
+ layers = []
175
+ for u in units or []:
176
+ layers.append(MLP.Layer(in_features, u, activate, dropout, bias))
177
+ in_features = u
178
+ if out_features is not None:
179
+ layers.append(MLP.Layer(in_features, out_features, None, 0.0, bias))
180
+ super().__init__(*layers)
181
+
182
+ class Layer(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_features: int,
186
+ out_features: int,
187
+ activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
188
+ dropout: float = 0.0,
189
+ bias: bool = True,
190
+ ):
191
+ super().__init__()
192
+ if activate is not None and not callable(activate):
193
+ raise TypeError("activate must be callable: type={}".format(type(activate)))
194
+ self.linear = nn.Linear(in_features, out_features, bias)
195
+ self.activate = activate
196
+ self.dropout = nn.Dropout(dropout)
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
199
+ h = self.linear(x)
200
+ if self.activate is not None:
201
+ h = self.activate(h)
202
+ return self.dropout(h)
203
+
204
+ def extra_repr(self) -> str:
205
+ return "{}, activate={}, dropout={}".format(
206
+ self.linear.extra_repr(), self.activate, self.dropout.p
207
+ )
208
+
209
+ def __repr__(self):
210
+ return "{}.{}({})".format(MLP.__name__, self._get_name(), self.extra_repr())
211
+
model/layer/viterbi.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn
6
+ from torch.nn.functional import softmax
7
+ from torch.nn.utils.rnn import pack_padded_sequence
8
+
9
+ import flair
10
+ from flair.data import Dictionary, Label, List, Sentence
11
+
12
+ START_TAG: str = "<START>"
13
+ STOP_TAG: str = "<STOP>"
14
+
15
+
16
+ class ViterbiLoss(torch.nn.Module):
17
+ """
18
+ Calculates the loss for each sequence up to its length t.
19
+ """
20
+
21
+ def __init__(self, tag_dictionary: Dictionary):
22
+ """
23
+ :param tag_dictionary: tag_dictionary of task
24
+ """
25
+ super(ViterbiLoss, self).__init__()
26
+ self.tag_dictionary = tag_dictionary
27
+ self.tagset_size = len(tag_dictionary)
28
+ self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
29
+ self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)
30
+
31
+ def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor:
32
+ """
33
+ Forward propagation of Viterbi Loss
34
+ :param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
35
+ lengths of sentences in batch, transitions from CRF
36
+ :param targets: true tags for sentences which will be converted to matrix indices.
37
+ :return: average Viterbi Loss over batch size
38
+ """
39
+ features, lengths, transitions = features_tuple
40
+
41
+ batch_size = features.size(0)
42
+ seq_len = features.size(1)
43
+
44
+ targets, targets_matrix_indices = self._format_targets(targets, lengths)
45
+ targets_matrix_indices = torch.tensor(targets_matrix_indices, dtype=torch.long).unsqueeze(2).to(flair.device)
46
+
47
+ # scores_at_targets[range(features.shape[0]), lengths.values -1]
48
+ # Squeeze crf scores matrices in 1-dim shape and gather scores at targets by matrix indices
49
+ scores_at_targets = torch.gather(features.view(batch_size, seq_len, -1), 2, targets_matrix_indices)
50
+ scores_at_targets = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)[0]
51
+ transitions_to_stop = transitions[
52
+ np.repeat(self.stop_tag, features.shape[0]),
53
+ [target[length - 1] for target, length in zip(targets, lengths)],
54
+ ]
55
+ gold_score = scores_at_targets.sum() + transitions_to_stop.sum()
56
+
57
+ scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device)
58
+
59
+ for t in range(max(lengths)):
60
+ batch_size_t = sum(
61
+ [length > t for length in lengths]
62
+ ) # since batch is ordered, we can save computation time by reducing our effective batch_size
63
+
64
+ if t == 0:
65
+ # Initially, get scores from <start> tag to all other tags
66
+ scores_upto_t[:batch_size_t] = (
67
+ scores_upto_t[:batch_size_t] + features[:batch_size_t, t, :, self.start_tag]
68
+ )
69
+ else:
70
+ # We add scores at current timestep to scores accumulated up to previous timestep, and log-sum-exp
71
+ # Remember, the cur_tag of the previous timestep is the prev_tag of this timestep
72
+ scores_upto_t[:batch_size_t] = self._log_sum_exp(
73
+ features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(1), dim=2
74
+ )
75
+
76
+ all_paths_scores = self._log_sum_exp(scores_upto_t + transitions[self.stop_tag].unsqueeze(0), dim=1).sum()
77
+
78
+ viterbi_loss = all_paths_scores - gold_score
79
+
80
+ return viterbi_loss
81
+
82
+ @staticmethod
83
+ def _log_sum_exp(tensor, dim):
84
+ """
85
+ Calculates the log-sum-exponent of a tensor's dimension in a numerically stable way.
86
+ :param tensor: tensor
87
+ :param dim: dimension to calculate log-sum-exp of
88
+ :return: log-sum-exp
89
+ """
90
+ m, _ = torch.max(tensor, dim)
91
+ m_expanded = m.unsqueeze(dim).expand_as(tensor)
92
+ return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim))
93
+
94
+ def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor):
95
+ """
96
+ Formats targets into matrix indices.
97
+ CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for
98
+ token j + transition prob from previous token i. Means, if we think of our rows as "to tag" and our columns
99
+ as "from tag", the matrix in cell [10,5] would contain the emission score for tag 10 + transition score
100
+ from previous tag 5 and could directly be addressed through the 1-dim indices (10 + tagset_size * 5) = 70,
101
+ if our tagset consists of 12 tags.
102
+ :param targets: targets as in tag dictionary
103
+ :param lengths: lengths of sentences in batch
104
+ """
105
+ targets_per_sentence = []
106
+
107
+ targets_list = targets.tolist()
108
+ for cut in lengths:
109
+ targets_per_sentence.append(targets_list[:cut])
110
+ targets_list = targets_list[cut:]
111
+
112
+ for t in targets_per_sentence:
113
+ t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t))
114
+
115
+ matrix_indices = list(
116
+ map(
117
+ lambda s: [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)]
118
+ + [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)],
119
+ targets_per_sentence,
120
+ )
121
+ )
122
+
123
+ return targets_per_sentence, matrix_indices
124
+
125
+
126
+ class ViterbiDecoder:
127
+ """
128
+ Decodes a given sequence using the Viterbi algorithm.
129
+ """
130
+
131
+ def __init__(self, tag_dictionary: Dictionary):
132
+ """
133
+ :param tag_dictionary: Dictionary of tags for sequence labeling task
134
+ """
135
+ self.tag_dictionary = tag_dictionary
136
+ self.tagset_size = len(tag_dictionary)
137
+ self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
138
+ self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)
139
+
140
+ def decode(
141
+ self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence]
142
+ ) -> Tuple[List, List]:
143
+ """
144
+ Decoding function returning the most likely sequence of tags.
145
+ :param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
146
+ lengths of sentence in batch, transitions of CRF
147
+ :param probabilities_for_all_classes: whether to return probabilities for all tags
148
+ :return: decoded sequences
149
+ """
150
+ features, lengths, transitions = features_tuple
151
+ all_tags = []
152
+
153
+ batch_size = features.size(0)
154
+ seq_len = features.size(1)
155
+
156
+ # Create a tensor to hold accumulated sequence scores at each current tag
157
+ scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device)
158
+ # Create a tensor to hold back-pointers
159
+ # i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
160
+ # Let pads be the <end> tag index, since that was the last tag in the decoded sequence
161
+ backpointers = (
162
+ torch.ones((batch_size, seq_len + 1, self.tagset_size), dtype=torch.long, device=flair.device)
163
+ * self.stop_tag
164
+ )
165
+
166
+ for t in range(seq_len):
167
+ batch_size_t = sum([length > t for length in lengths]) # effective batch size (sans pads) at this timestep
168
+ terminates = [i for i, length in enumerate(lengths) if length == t + 1]
169
+
170
+ if t == 0:
171
+ scores_upto_t[:batch_size_t, t] = features[:batch_size_t, t, :, self.start_tag]
172
+ backpointers[:batch_size_t, t, :] = (
173
+ torch.ones((batch_size_t, self.tagset_size), dtype=torch.long) * self.start_tag
174
+ )
175
+ else:
176
+ # We add scores at current timestep to scores accumulated up to previous timestep, and
177
+ # choose the previous timestep that corresponds to the max. accumulated score for each current timestep
178
+ scores_upto_t[:batch_size_t, t], backpointers[:batch_size_t, t, :] = torch.max(
179
+ features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t, t - 1].unsqueeze(1), dim=2
180
+ )
181
+
182
+ # If sentence is over, add transition to STOP-tag
183
+ if terminates:
184
+ scores_upto_t[terminates, t + 1], backpointers[terminates, t + 1, :] = torch.max(
185
+ scores_upto_t[terminates, t].unsqueeze(1) + transitions[self.stop_tag].unsqueeze(0), dim=2
186
+ )
187
+
188
+ # Decode/trace best path backwards
189
+ decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long, device=flair.device)
190
+ pointer = torch.ones((batch_size, 1), dtype=torch.long, device=flair.device) * self.stop_tag
191
+
192
+ for t in list(reversed(range(backpointers.size(1)))):
193
+ decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1)
194
+ pointer = decoded[:, t].unsqueeze(1)
195
+
196
+ # Sanity check
197
+ assert torch.equal(
198
+ decoded[:, 0], torch.ones((batch_size), dtype=torch.long, device=flair.device) * self.start_tag
199
+ )
200
+
201
+ # remove start-tag and backscore to stop-tag
202
+ scores_upto_t = scores_upto_t[:, :-1, :]
203
+ decoded = decoded[:, 1:]
204
+
205
+ # Max + Softmax to get confidence score for predicted label and append label to each token
206
+ scores = softmax(scores_upto_t, dim=2)
207
+ confidences = torch.max(scores, dim=2)
208
+
209
+ tags = []
210
+ for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths):
211
+ tags.append(
212
+ [
213
+ (self.tag_dictionary.get_item_for_index(tag), conf.item())
214
+ for tag, conf in list(zip(tag_seq, tag_seq_conf))[:length_seq]
215
+ ]
216
+ )
217
+
218
+ if probabilities_for_all_classes:
219
+ all_tags = self._all_scores_for_token(scores.cpu(), lengths, sentences)
220
+
221
+ return tags, all_tags
222
+
223
+ def _all_scores_for_token(self, scores: torch.Tensor, lengths: torch.IntTensor, sentences: List[Sentence]):
224
+ """
225
+ Returns all scores for each tag in tag dictionary.
226
+ :param scores: Scores for current sentence.
227
+ """
228
+ scores = scores.numpy()
229
+ prob_tags_per_sentence = []
230
+ for scores_sentence, length, sentence in zip(scores, lengths, sentences):
231
+ scores_sentence = scores_sentence[:length]
232
+ prob_tags_per_sentence.append(
233
+ [
234
+ [
235
+ Label(token, self.tag_dictionary.get_item_for_index(score_id), score)
236
+ for score_id, score in enumerate(score_dist)
237
+ ]
238
+ for score_dist, token in zip(scores_sentence, sentence)
239
+ ]
240
+ )
241
+ return prob_tags_per_sentence
part/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from part.data import *
2
+ from part.dropout import *
part/data.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from flair.data import _PartOfSentence, DataPoint, Label
3
+
4
+ class Token(_PartOfSentence):
5
+ """
6
+ This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point
7
+ to its head in a dependency tree.
8
+
9
+ :param text: Single text(Token) from the sequence
10
+ :param head_id: the location of the text (For Document)
11
+ :param whitespace_after: if token has whitespace
12
+ :param start_position: what character number in document does this token start?
13
+ :param sentence: If token belongs to sentence, indicate here which var it belongs to
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ text: str,
19
+ head_id: int = None,
20
+ whitespace_after: int = 1,
21
+ start_position: int = 0,
22
+ sentence=None,
23
+ ):
24
+ super().__init__(sentence=sentence)
25
+
26
+ self.form: str = text
27
+ self._internal_index: Optional[int] = None
28
+ self.head_id: Optional[int] = head_id
29
+ self.whitespace_after: int = whitespace_after
30
+
31
+ self.start_pos = start_position
32
+ self.end_pos = start_position + len(text)
33
+
34
+ self._embeddings: Dict = {}
35
+ self.tags_proba_dist: Dict[str, List[Label]] = {}
36
+
37
+ @property
38
+ def idx(self) -> int:
39
+ if isinstance(self._internal_index, int):
40
+ return self._internal_index
41
+ else:
42
+ raise ValueError
43
+
44
+ @property
45
+ def text(self):
46
+ return self.form
47
+
48
+ @property
49
+ def unlabeled_identifier(self) -> str:
50
+ return f'Token[{self.idx-1}]: "{self.text}"'
51
+
52
+ def add_tags_proba_dist(self, tag_type: str, tags: List[Label]):
53
+ self.tags_proba_dist[tag_type] = tags
54
+
55
+ def get_tags_proba_dist(self, tag_type: str) -> List[Label]:
56
+ if tag_type in self.tags_proba_dist:
57
+ return self.tags_proba_dist[tag_type]
58
+ return []
59
+
60
+ def get_head(self):
61
+ return self.sentence.get_token(self.head_id)
62
+
63
+ @property
64
+ def start_position(self) -> int:
65
+ return self.start_pos
66
+
67
+ @property
68
+ def end_position(self) -> int:
69
+ return self.end_pos
70
+
71
+ @property
72
+ def embedding(self):
73
+ return self.get_embedding()
74
+
75
+ def __repr__(self):
76
+ return self.__str__()
77
+
78
+ def add_label(self, typename: str, value: str, score: float = 1.0):
79
+ """
80
+ The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
81
+ Therefore, labels get added only to the Sentence if it exists
82
+ """
83
+ if self.sentence:
84
+ super().add_label(typename=typename, value=value, score=score)
85
+ else:
86
+ DataPoint.add_label(self, typename=typename, value=value, score=score)
87
+
88
+ def set_label(self, typename: str, value: str, score: float = 1.0):
89
+ """
90
+ The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
91
+ Therefore, labels get set only to the Sentence if it exists
92
+ """
93
+ if self.sentence:
94
+ super().set_label(typename=typename, value=value, score=score)
95
+ else:
96
+ DataPoint.set_label(self, typename=typename, value=value, score=score)
97
+
98
+
99
+ class Span(_PartOfSentence):
100
+ """
101
+ This class represents one textual span consisting of Tokens. It may be used for the instance that the
102
+ tokens form in a nested nature, meaning the tokens combined together forms a long phrase.
103
+
104
+ :param tokens: List of tokens in the span
105
+ """
106
+
107
+ def __init__(self, tokens: List[Token]):
108
+ super().__init__(tokens[0].sentence)
109
+ self.tokens = tokens
110
+ super()._init_labels()
111
+
112
+ @property
113
+ def start_position(self) -> int:
114
+ return self.tokens[0].start_position
115
+
116
+ @property
117
+ def end_position(self) -> int:
118
+ return self.tokens[-1].end_position
119
+
120
+ @property
121
+ def text(self) -> str:
122
+ return " ".join([t.text for t in self.tokens])
123
+
124
+ @property
125
+ def unlabeled_identifier(self) -> str:
126
+ return f'Span[{self.tokens[0].idx -1}:{self.tokens[-1].idx}]: "{self.text}"'
127
+
128
+ def __repr__(self):
129
+ return self.__str__()
130
+
131
+ def __getitem__(self, idx: int) -> Token:
132
+ return self.tokens[idx]
133
+
134
+ def __iter__(self):
135
+ return iter(self.tokens)
136
+
137
+ def __len__(self) -> int:
138
+ return len(self.tokens)
139
+
140
+ @property
141
+ def embedding(self):
142
+ pass
part/dropout.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class LockedDropout(torch.nn.Module):
5
+ """
6
+ Implementation of locked (or variational) dropout.
7
+ Randomly drops out entire parameters in embedding space.
8
+
9
+ :param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1.
10
+ :param batch_first: represent if the drop will perform in an ascending manner
11
+ :param inplace:
12
+ """
13
+
14
+ def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False):
15
+ super(LockedDropout, self).__init__()
16
+ self.dropout_rate = dropout_rate
17
+ self.batch_first = batch_first
18
+ self.inplace = inplace
19
+
20
+ def forward(self, x):
21
+ if not self.training or not self.dropout_rate:
22
+ return x
23
+
24
+ if not self.batch_first:
25
+ m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
26
+ else:
27
+ m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate)
28
+
29
+ mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
30
+ mask = mask.expand_as(x)
31
+ return mask * x
32
+
33
+ def extra_repr(self):
34
+ inplace_str = ", inplace" if self.inplace else ""
35
+ return "p={}{}".format(self.dropout_rate, inplace_str)
36
+
37
+
38
+ class WordDropout(torch.nn.Module):
39
+ """
40
+ Implementation of word dropout. Randomly drops out entire words
41
+ (or characters) in embedding space.
42
+ """
43
+
44
+ def __init__(self, dropout_rate=0.05, inplace=False):
45
+ super(WordDropout, self).__init__()
46
+ self.dropout_rate = dropout_rate
47
+ self.inplace = inplace
48
+
49
+ def forward(self, x):
50
+ if not self.training or not self.dropout_rate:
51
+ return x
52
+
53
+ m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate)
54
+
55
+ mask = torch.autograd.Variable(m, requires_grad=False)
56
+ return mask * x
57
+
58
+ def extra_repr(self):
59
+ inplace_str = ", inplace" if self.inplace else ""
60
+ return "p={}{}".format(self.dropout_rate, inplace_str)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ flair
5
+ numpy
6
+ pandas
7
+ nltk
8
+ panel
9
+ hvplot