omarmomen commited on
Commit
62a5ce9
·
1 Parent(s): dd4341f
Files changed (3) hide show
  1. config.json +37 -0
  2. modeling_structroberta.py +1267 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructRoberta"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_structroberta.StructRobertaConfig",
8
+ "AutoModelForMaskedLM": "modeling_structroberta.StructRoberta"
9
+ },
10
+ "bos_token_id": 0,
11
+ "classifier_dropout": null,
12
+ "conv_size": 9,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "layer_norm_eps": 1e-05,
20
+ "max_position_embeddings": 514,
21
+ "model_type": "roberta",
22
+ "n_parser_layers": 6,
23
+ "num_attention_heads": 12,
24
+ "num_hidden_layers": 12,
25
+ "pad_token_id": 1,
26
+ "position_embedding_type": "absolute",
27
+ "relations": [
28
+ "head",
29
+ "child"
30
+ ],
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.18.0",
33
+ "type_vocab_size": 1,
34
+ "use_cache": true,
35
+ "vocab_size": 32000,
36
+ "weight_act": "softmax"
37
+ }
modeling_structroberta.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN, gelu
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ MaskedLMOutput,
33
+ )
34
+ from transformers.modeling_utils import (
35
+ PreTrainedModel,
36
+ apply_chunking_to_forward,
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ )
46
+ from transformers import RobertaConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "roberta-base"
52
+ _CONFIG_FOR_DOC = "RobertaConfig"
53
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
54
+
55
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
56
+ "roberta-base",
57
+ "roberta-large",
58
+ "roberta-large-mnli",
59
+ "distilroberta-base",
60
+ "roberta-base-openai-detector",
61
+ "roberta-large-openai-detector",
62
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
63
+ ]
64
+
65
+
66
+ class StructRobertaConfig(RobertaConfig):
67
+ model_type = "roberta"
68
+
69
+ def __init__(
70
+ self,
71
+ n_parser_layers=4,
72
+ conv_size=9,
73
+ relations=('head', 'child'),
74
+ weight_act='softmax',
75
+ **kwargs,
76
+ ):
77
+ super().__init__(**kwargs)
78
+ self.n_parser_layers = n_parser_layers
79
+ self.conv_size = conv_size
80
+ self.relations = relations
81
+ self.weight_act = weight_act
82
+
83
+ class Conv1d(nn.Module):
84
+ """1D convolution layer."""
85
+
86
+ def __init__(self, hidden_size, kernel_size, dilation=1):
87
+ """Initialization.
88
+
89
+ Args:
90
+ hidden_size: dimension of input embeddings
91
+ kernel_size: convolution kernel size
92
+ dilation: the spacing between the kernel points
93
+ """
94
+ super(Conv1d, self).__init__()
95
+
96
+ if kernel_size % 2 == 0:
97
+ padding = (kernel_size // 2) * dilation
98
+ self.shift = True
99
+ else:
100
+ padding = ((kernel_size - 1) // 2) * dilation
101
+ self.shift = False
102
+ self.conv = nn.Conv1d(
103
+ hidden_size,
104
+ hidden_size,
105
+ kernel_size,
106
+ padding=padding,
107
+ dilation=dilation)
108
+
109
+ def forward(self, x):
110
+ """Compute convolution.
111
+
112
+ Args:
113
+ x: input embeddings
114
+ Returns:
115
+ conv_output: convolution results
116
+ """
117
+
118
+ if self.shift:
119
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
120
+ else:
121
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
122
+
123
+ class RobertaEmbeddings(nn.Module):
124
+ """
125
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
126
+ """
127
+
128
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
132
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
133
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
134
+
135
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
136
+ # any TensorFlow checkpoint file
137
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
139
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
140
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
141
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
142
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
143
+ self.register_buffer(
144
+ "token_type_ids",
145
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
146
+ persistent=False,
147
+ )
148
+
149
+ # End copy
150
+ self.padding_idx = config.pad_token_id
151
+ self.position_embeddings = nn.Embedding(
152
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
153
+ )
154
+
155
+ def forward(
156
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
157
+ ):
158
+ if position_ids is None:
159
+ if input_ids is not None:
160
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
161
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
162
+ else:
163
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
164
+
165
+ if input_ids is not None:
166
+ input_shape = input_ids.size()
167
+ else:
168
+ input_shape = inputs_embeds.size()[:-1]
169
+
170
+ seq_length = input_shape[1]
171
+
172
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
173
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
174
+ # issue #5664
175
+ if token_type_ids is None:
176
+ if hasattr(self, "token_type_ids"):
177
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
178
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
179
+ token_type_ids = buffered_token_type_ids_expanded
180
+ else:
181
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
182
+
183
+ if inputs_embeds is None:
184
+ inputs_embeds = self.word_embeddings(input_ids)
185
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
186
+
187
+ embeddings = inputs_embeds + token_type_embeddings
188
+ if self.position_embedding_type == "absolute":
189
+ position_embeddings = self.position_embeddings(position_ids)
190
+ embeddings += position_embeddings
191
+ embeddings = self.LayerNorm(embeddings)
192
+ embeddings = self.dropout(embeddings)
193
+ return embeddings
194
+
195
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
196
+ """
197
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
198
+
199
+ Args:
200
+ inputs_embeds: torch.Tensor
201
+
202
+ Returns: torch.Tensor
203
+ """
204
+ input_shape = inputs_embeds.size()[:-1]
205
+ sequence_length = input_shape[1]
206
+
207
+ position_ids = torch.arange(
208
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
209
+ )
210
+ return position_ids.unsqueeze(0).expand(input_shape)
211
+
212
+
213
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
214
+ class RobertaSelfAttention(nn.Module):
215
+ def __init__(self, config, position_embedding_type=None):
216
+ super().__init__()
217
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
218
+ raise ValueError(
219
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
220
+ f"heads ({config.num_attention_heads})"
221
+ )
222
+
223
+ self.num_attention_heads = config.num_attention_heads
224
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
225
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
226
+
227
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
228
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
229
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
230
+
231
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
232
+ self.position_embedding_type = position_embedding_type or getattr(
233
+ config, "position_embedding_type", "absolute"
234
+ )
235
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
236
+ self.max_position_embeddings = config.max_position_embeddings
237
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
238
+
239
+ self.is_decoder = config.is_decoder
240
+
241
+ def transpose_for_scores(self, x):
242
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
243
+ x = x.view(new_x_shape)
244
+ return x.permute(0, 2, 1, 3)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ attention_mask: Optional[torch.FloatTensor] = None,
250
+ head_mask: Optional[torch.FloatTensor] = None,
251
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
252
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
253
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
254
+ output_attentions: Optional[bool] = False,
255
+ parser_att_mask=None,
256
+ ) -> Tuple[torch.Tensor]:
257
+ mixed_query_layer = self.query(hidden_states)
258
+
259
+ # If this is instantiated as a cross-attention module, the keys
260
+ # and values come from an encoder; the attention mask needs to be
261
+ # such that the encoder's padding tokens are not attended to.
262
+ is_cross_attention = encoder_hidden_states is not None
263
+
264
+ if is_cross_attention and past_key_value is not None:
265
+ # reuse k,v, cross_attentions
266
+ key_layer = past_key_value[0]
267
+ value_layer = past_key_value[1]
268
+ attention_mask = encoder_attention_mask
269
+ elif is_cross_attention:
270
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
271
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
272
+ attention_mask = encoder_attention_mask
273
+ elif past_key_value is not None:
274
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
275
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
276
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
277
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
278
+ else:
279
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
280
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
281
+
282
+ query_layer = self.transpose_for_scores(mixed_query_layer)
283
+
284
+ if self.is_decoder:
285
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
286
+ # Further calls to cross_attention layer can then reuse all cross-attention
287
+ # key/value_states (first "if" case)
288
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
289
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
290
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
291
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
292
+ past_key_value = (key_layer, value_layer)
293
+
294
+ # Take the dot product between "query" and "key" to get the raw attention scores.
295
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
296
+
297
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
298
+ seq_length = hidden_states.size()[1]
299
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
300
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
301
+ distance = position_ids_l - position_ids_r
302
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
303
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
304
+
305
+ if self.position_embedding_type == "relative_key":
306
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
307
+ attention_scores = attention_scores + relative_position_scores
308
+ elif self.position_embedding_type == "relative_key_query":
309
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
310
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
311
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
312
+
313
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
314
+ if attention_mask is not None:
315
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
316
+ attention_scores = attention_scores + attention_mask
317
+
318
+
319
+ if parser_att_mask is None:
320
+ # Normalize the attention scores to probabilities.
321
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
322
+ else:
323
+ attention_probs = torch.sigmoid(attention_scores) * parser_att_mask
324
+
325
+ # This is actually dropping out entire tokens to attend to, which might
326
+ # seem a bit unusual, but is taken from the original Transformer paper.
327
+ attention_probs = self.dropout(attention_probs)
328
+
329
+ # Mask heads if we want to
330
+ if head_mask is not None:
331
+ attention_probs = attention_probs * head_mask
332
+
333
+ context_layer = torch.matmul(attention_probs, value_layer)
334
+
335
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
336
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
337
+ context_layer = context_layer.view(new_context_layer_shape)
338
+
339
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
340
+
341
+ if self.is_decoder:
342
+ outputs = outputs + (past_key_value,)
343
+ return outputs
344
+
345
+
346
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
347
+ class RobertaSelfOutput(nn.Module):
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
351
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
352
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
353
+
354
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
355
+ hidden_states = self.dense(hidden_states)
356
+ hidden_states = self.dropout(hidden_states)
357
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
358
+ return hidden_states
359
+
360
+
361
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
362
+ class RobertaAttention(nn.Module):
363
+ def __init__(self, config, position_embedding_type=None):
364
+ super().__init__()
365
+ self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
366
+ self.output = RobertaSelfOutput(config)
367
+ self.pruned_heads = set()
368
+
369
+ def prune_heads(self, heads):
370
+ if len(heads) == 0:
371
+ return
372
+ heads, index = find_pruneable_heads_and_indices(
373
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
374
+ )
375
+
376
+ # Prune linear layers
377
+ self.self.query = prune_linear_layer(self.self.query, index)
378
+ self.self.key = prune_linear_layer(self.self.key, index)
379
+ self.self.value = prune_linear_layer(self.self.value, index)
380
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
381
+
382
+ # Update hyper params and store pruned heads
383
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
384
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
385
+ self.pruned_heads = self.pruned_heads.union(heads)
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states: torch.Tensor,
390
+ attention_mask: Optional[torch.FloatTensor] = None,
391
+ head_mask: Optional[torch.FloatTensor] = None,
392
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
393
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
394
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
395
+ output_attentions: Optional[bool] = False,
396
+ parser_att_mask=None,
397
+ ) -> Tuple[torch.Tensor]:
398
+ self_outputs = self.self(
399
+ hidden_states,
400
+ attention_mask,
401
+ head_mask,
402
+ encoder_hidden_states,
403
+ encoder_attention_mask,
404
+ past_key_value,
405
+ output_attentions,
406
+ parser_att_mask=parser_att_mask,
407
+ )
408
+ attention_output = self.output(self_outputs[0], hidden_states)
409
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
410
+ return outputs
411
+
412
+
413
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
414
+ class RobertaIntermediate(nn.Module):
415
+ def __init__(self, config):
416
+ super().__init__()
417
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
418
+ if isinstance(config.hidden_act, str):
419
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
420
+ else:
421
+ self.intermediate_act_fn = config.hidden_act
422
+
423
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
424
+ hidden_states = self.dense(hidden_states)
425
+ hidden_states = self.intermediate_act_fn(hidden_states)
426
+ return hidden_states
427
+
428
+
429
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
430
+ class RobertaOutput(nn.Module):
431
+ def __init__(self, config):
432
+ super().__init__()
433
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
434
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
435
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
436
+
437
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
438
+ hidden_states = self.dense(hidden_states)
439
+ hidden_states = self.dropout(hidden_states)
440
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
441
+ return hidden_states
442
+
443
+
444
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
445
+ class RobertaLayer(nn.Module):
446
+ def __init__(self, config):
447
+ super().__init__()
448
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
449
+ self.seq_len_dim = 1
450
+ self.attention = RobertaAttention(config)
451
+ self.is_decoder = config.is_decoder
452
+ self.add_cross_attention = config.add_cross_attention
453
+ if self.add_cross_attention:
454
+ if not self.is_decoder:
455
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
456
+ self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
457
+ self.intermediate = RobertaIntermediate(config)
458
+ self.output = RobertaOutput(config)
459
+
460
+ def forward(
461
+ self,
462
+ hidden_states: torch.Tensor,
463
+ attention_mask: Optional[torch.FloatTensor] = None,
464
+ head_mask: Optional[torch.FloatTensor] = None,
465
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
466
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
467
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
468
+ output_attentions: Optional[bool] = False,
469
+ parser_att_mask=None,
470
+ ) -> Tuple[torch.Tensor]:
471
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
472
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
473
+ self_attention_outputs = self.attention(
474
+ hidden_states,
475
+ attention_mask,
476
+ head_mask,
477
+ output_attentions=output_attentions,
478
+ past_key_value=self_attn_past_key_value,
479
+ parser_att_mask=parser_att_mask,
480
+ )
481
+ attention_output = self_attention_outputs[0]
482
+
483
+ # if decoder, the last output is tuple of self-attn cache
484
+ if self.is_decoder:
485
+ outputs = self_attention_outputs[1:-1]
486
+ present_key_value = self_attention_outputs[-1]
487
+ else:
488
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
489
+
490
+ cross_attn_present_key_value = None
491
+ if self.is_decoder and encoder_hidden_states is not None:
492
+ if not hasattr(self, "crossattention"):
493
+ raise ValueError(
494
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
495
+ )
496
+
497
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
498
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
499
+ cross_attention_outputs = self.crossattention(
500
+ attention_output,
501
+ attention_mask,
502
+ head_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ cross_attn_past_key_value,
506
+ output_attentions,
507
+ )
508
+ attention_output = cross_attention_outputs[0]
509
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
510
+
511
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
512
+ cross_attn_present_key_value = cross_attention_outputs[-1]
513
+ present_key_value = present_key_value + cross_attn_present_key_value
514
+
515
+ layer_output = apply_chunking_to_forward(
516
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
517
+ )
518
+ outputs = (layer_output,) + outputs
519
+
520
+ # if decoder, return the attn key/values as the last output
521
+ if self.is_decoder:
522
+ outputs = outputs + (present_key_value,)
523
+
524
+ return outputs
525
+
526
+ def feed_forward_chunk(self, attention_output):
527
+ intermediate_output = self.intermediate(attention_output)
528
+ layer_output = self.output(intermediate_output, attention_output)
529
+ return layer_output
530
+
531
+
532
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
533
+ class RobertaEncoder(nn.Module):
534
+ def __init__(self, config):
535
+ super().__init__()
536
+ self.config = config
537
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
538
+ self.gradient_checkpointing = False
539
+
540
+ def forward(
541
+ self,
542
+ hidden_states: torch.Tensor,
543
+ attention_mask: Optional[torch.FloatTensor] = None,
544
+ head_mask: Optional[torch.FloatTensor] = None,
545
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
546
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
547
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
548
+ use_cache: Optional[bool] = None,
549
+ output_attentions: Optional[bool] = False,
550
+ output_hidden_states: Optional[bool] = False,
551
+ return_dict: Optional[bool] = True,
552
+ parser_att_mask=None,
553
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
554
+ all_hidden_states = () if output_hidden_states else None
555
+ all_self_attentions = () if output_attentions else None
556
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
557
+
558
+ next_decoder_cache = () if use_cache else None
559
+ for i, layer_module in enumerate(self.layer):
560
+ if output_hidden_states:
561
+ all_hidden_states = all_hidden_states + (hidden_states,)
562
+
563
+ layer_head_mask = head_mask[i] if head_mask is not None else None
564
+ past_key_value = past_key_values[i] if past_key_values is not None else None
565
+
566
+ if self.gradient_checkpointing and self.training:
567
+
568
+ if use_cache:
569
+ logger.warning(
570
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
571
+ )
572
+ use_cache = False
573
+
574
+ def create_custom_forward(module):
575
+ def custom_forward(*inputs):
576
+ return module(*inputs, past_key_value, output_attentions)
577
+
578
+ return custom_forward
579
+
580
+ layer_outputs = torch.utils.checkpoint.checkpoint(
581
+ create_custom_forward(layer_module),
582
+ hidden_states,
583
+ attention_mask,
584
+ layer_head_mask,
585
+ encoder_hidden_states,
586
+ encoder_attention_mask,
587
+ )
588
+ else:
589
+ layer_outputs = layer_module(
590
+ hidden_states,
591
+ attention_mask,
592
+ layer_head_mask,
593
+ encoder_hidden_states,
594
+ encoder_attention_mask,
595
+ past_key_value,
596
+ output_attentions,
597
+ parser_att_mask=parser_att_mask[i],
598
+ )
599
+
600
+ hidden_states = layer_outputs[0]
601
+ if use_cache:
602
+ next_decoder_cache += (layer_outputs[-1],)
603
+ if output_attentions:
604
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
605
+ if self.config.add_cross_attention:
606
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
607
+
608
+ if output_hidden_states:
609
+ all_hidden_states = all_hidden_states + (hidden_states,)
610
+
611
+ if not return_dict:
612
+ return tuple(
613
+ v
614
+ for v in [
615
+ hidden_states,
616
+ next_decoder_cache,
617
+ all_hidden_states,
618
+ all_self_attentions,
619
+ all_cross_attentions,
620
+ ]
621
+ if v is not None
622
+ )
623
+ return BaseModelOutputWithPastAndCrossAttentions(
624
+ last_hidden_state=hidden_states,
625
+ past_key_values=next_decoder_cache,
626
+ hidden_states=all_hidden_states,
627
+ attentions=all_self_attentions,
628
+ cross_attentions=all_cross_attentions,
629
+ )
630
+
631
+
632
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
633
+ class RobertaPooler(nn.Module):
634
+ def __init__(self, config):
635
+ super().__init__()
636
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
637
+ self.activation = nn.Tanh()
638
+
639
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
640
+ # We "pool" the model by simply taking the hidden state corresponding
641
+ # to the first token.
642
+ first_token_tensor = hidden_states[:, 0]
643
+ pooled_output = self.dense(first_token_tensor)
644
+ pooled_output = self.activation(pooled_output)
645
+ return pooled_output
646
+
647
+
648
+ class RobertaPreTrainedModel(PreTrainedModel):
649
+ """
650
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
651
+ models.
652
+ """
653
+
654
+ config_class = RobertaConfig
655
+ base_model_prefix = "roberta"
656
+ supports_gradient_checkpointing = True
657
+
658
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
659
+ def _init_weights(self, module):
660
+ """Initialize the weights"""
661
+ if isinstance(module, nn.Linear):
662
+ # Slightly different from the TF version which uses truncated_normal for initialization
663
+ # cf https://github.com/pytorch/pytorch/pull/5617
664
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
665
+ if module.bias is not None:
666
+ module.bias.data.zero_()
667
+ elif isinstance(module, nn.Embedding):
668
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
669
+ if module.padding_idx is not None:
670
+ module.weight.data[module.padding_idx].zero_()
671
+ elif isinstance(module, nn.LayerNorm):
672
+ if module.bias is not None:
673
+ module.bias.data.zero_()
674
+ module.weight.data.fill_(1.0)
675
+
676
+ def _set_gradient_checkpointing(self, module, value=False):
677
+ if isinstance(module, RobertaEncoder):
678
+ module.gradient_checkpointing = value
679
+
680
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
681
+ """Remove some keys from ignore list"""
682
+ if not config.tie_word_embeddings:
683
+ # must make a new list, or the class variable gets modified!
684
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
685
+ self._keys_to_ignore_on_load_missing = [
686
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
687
+ ]
688
+
689
+
690
+ ROBERTA_START_DOCSTRING = r"""
691
+
692
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
693
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
694
+ etc.)
695
+
696
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
697
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
698
+ and behavior.
699
+
700
+ Parameters:
701
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
702
+ model. Initializing with a config file does not load the weights associated with the model, only the
703
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
704
+ """
705
+
706
+
707
+ ROBERTA_INPUTS_DOCSTRING = r"""
708
+ Args:
709
+ input_ids (`torch.LongTensor` of shape `({0})`):
710
+ Indices of input sequence tokens in the vocabulary.
711
+
712
+ Indices can be obtained using [`RobertaTokenizer`]. See [`PreTrainedTokenizer.encode`] and
713
+ [`PreTrainedTokenizer.__call__`] for details.
714
+
715
+ [What are input IDs?](../glossary#input-ids)
716
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
717
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
718
+
719
+ - 1 for tokens that are **not masked**,
720
+ - 0 for tokens that are **masked**.
721
+
722
+ [What are attention masks?](../glossary#attention-mask)
723
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
724
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
725
+ 1]`:
726
+
727
+ - 0 corresponds to a *sentence A* token,
728
+ - 1 corresponds to a *sentence B* token.
729
+
730
+ [What are token type IDs?](../glossary#token-type-ids)
731
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
732
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
733
+ config.max_position_embeddings - 1]`.
734
+
735
+ [What are position IDs?](../glossary#position-ids)
736
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
737
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
738
+
739
+ - 1 indicates the head is **not masked**,
740
+ - 0 indicates the head is **masked**.
741
+
742
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
743
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
744
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
745
+ model's internal embedding lookup matrix.
746
+ output_attentions (`bool`, *optional*):
747
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
748
+ tensors for more detail.
749
+ output_hidden_states (`bool`, *optional*):
750
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
751
+ more detail.
752
+ return_dict (`bool`, *optional*):
753
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
754
+ """
755
+
756
+
757
+ class RobertaModel(RobertaPreTrainedModel):
758
+ """
759
+
760
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
761
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
762
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
763
+ Kaiser and Illia Polosukhin.
764
+
765
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
766
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
767
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
768
+
769
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
770
+
771
+ """
772
+
773
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
774
+
775
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
776
+ def __init__(self, config, add_pooling_layer=True):
777
+ super().__init__(config)
778
+ self.config = config
779
+
780
+ self.embeddings = RobertaEmbeddings(config)
781
+ self.encoder = RobertaEncoder(config)
782
+
783
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
784
+
785
+ # Initialize weights and apply final processing
786
+ self.post_init()
787
+
788
+ def get_input_embeddings(self):
789
+ return self.embeddings.word_embeddings
790
+
791
+ def set_input_embeddings(self, value):
792
+ self.embeddings.word_embeddings = value
793
+
794
+ def _prune_heads(self, heads_to_prune):
795
+ """
796
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
797
+ class PreTrainedModel
798
+ """
799
+ for layer, heads in heads_to_prune.items():
800
+ self.encoder.layer[layer].attention.prune_heads(heads)
801
+
802
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
803
+ def forward(
804
+ self,
805
+ input_ids: Optional[torch.Tensor] = None,
806
+ attention_mask: Optional[torch.Tensor] = None,
807
+ token_type_ids: Optional[torch.Tensor] = None,
808
+ position_ids: Optional[torch.Tensor] = None,
809
+ head_mask: Optional[torch.Tensor] = None,
810
+ inputs_embeds: Optional[torch.Tensor] = None,
811
+ encoder_hidden_states: Optional[torch.Tensor] = None,
812
+ encoder_attention_mask: Optional[torch.Tensor] = None,
813
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
814
+ use_cache: Optional[bool] = None,
815
+ output_attentions: Optional[bool] = None,
816
+ output_hidden_states: Optional[bool] = None,
817
+ return_dict: Optional[bool] = None,
818
+ parser_att_mask=None,
819
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
820
+ r"""
821
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
827
+
828
+ - 1 for tokens that are **not masked**,
829
+ - 0 for tokens that are **masked**.
830
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
831
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
832
+
833
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
834
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
835
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
836
+ use_cache (`bool`, *optional*):
837
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
838
+ `past_key_values`).
839
+ """
840
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
841
+ output_hidden_states = (
842
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
843
+ )
844
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
845
+
846
+ if self.config.is_decoder:
847
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
848
+ else:
849
+ use_cache = False
850
+
851
+ if input_ids is not None and inputs_embeds is not None:
852
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
853
+ elif input_ids is not None:
854
+ input_shape = input_ids.size()
855
+ elif inputs_embeds is not None:
856
+ input_shape = inputs_embeds.size()[:-1]
857
+ else:
858
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
859
+
860
+ batch_size, seq_length = input_shape
861
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
862
+
863
+ # past_key_values_length
864
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
865
+
866
+ if attention_mask is None:
867
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
868
+
869
+ if token_type_ids is None:
870
+ if hasattr(self.embeddings, "token_type_ids"):
871
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
872
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
873
+ token_type_ids = buffered_token_type_ids_expanded
874
+ else:
875
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
876
+
877
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
878
+ # ourselves in which case we just need to make it broadcastable to all heads.
879
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
880
+
881
+ # If a 2D or 3D attention mask is provided for the cross-attention
882
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
883
+ if self.config.is_decoder and encoder_hidden_states is not None:
884
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
885
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
886
+ if encoder_attention_mask is None:
887
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
888
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
889
+ else:
890
+ encoder_extended_attention_mask = None
891
+
892
+ # Prepare head mask if needed
893
+ # 1.0 in head_mask indicate we keep the head
894
+ # attention_probs has shape bsz x n_heads x N x N
895
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
896
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
897
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
898
+
899
+ embedding_output = self.embeddings(
900
+ input_ids=input_ids,
901
+ position_ids=position_ids,
902
+ token_type_ids=token_type_ids,
903
+ inputs_embeds=inputs_embeds,
904
+ past_key_values_length=past_key_values_length,
905
+ )
906
+ encoder_outputs = self.encoder(
907
+ embedding_output,
908
+ attention_mask=extended_attention_mask,
909
+ head_mask=head_mask,
910
+ encoder_hidden_states=encoder_hidden_states,
911
+ encoder_attention_mask=encoder_extended_attention_mask,
912
+ past_key_values=past_key_values,
913
+ use_cache=use_cache,
914
+ output_attentions=output_attentions,
915
+ output_hidden_states=output_hidden_states,
916
+ return_dict=return_dict,
917
+ parser_att_mask=parser_att_mask
918
+ )
919
+ sequence_output = encoder_outputs[0]
920
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
921
+
922
+ if not return_dict:
923
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
924
+
925
+ return BaseModelOutputWithPoolingAndCrossAttentions(
926
+ last_hidden_state=sequence_output,
927
+ pooler_output=pooled_output,
928
+ past_key_values=encoder_outputs.past_key_values,
929
+ hidden_states=encoder_outputs.hidden_states,
930
+ attentions=encoder_outputs.attentions,
931
+ cross_attentions=encoder_outputs.cross_attentions,
932
+ )
933
+
934
+
935
+ class StructRoberta(RobertaPreTrainedModel):
936
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
937
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
938
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
939
+
940
+ def __init__(self, config):
941
+ super().__init__(config)
942
+
943
+ if config.is_decoder:
944
+ logger.warning(
945
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
946
+ "bi-directional self-attention."
947
+ )
948
+
949
+ self.parser_layers = nn.ModuleList([
950
+ nn.Sequential(Conv1d(config.hidden_size, config.conv_size),
951
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
952
+ nn.Tanh()) for i in range(config.n_parser_layers)])
953
+
954
+ self.distance_ff = nn.Sequential(
955
+ Conv1d(config.hidden_size, 2),
956
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False), nn.Tanh(),
957
+ nn.Linear(config.hidden_size, 1))
958
+
959
+ self.height_ff = nn.Sequential(
960
+ nn.Linear(config.hidden_size, config.hidden_size),
961
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False), nn.Tanh(),
962
+ nn.Linear(config.hidden_size, 1))
963
+
964
+ n_rel = len(config.relations)
965
+ self._rel_weight = nn.Parameter(torch.zeros((config.num_hidden_layers, config.num_attention_heads, n_rel)))
966
+ self._rel_weight.data.normal_(0, 0.1)
967
+
968
+ self._scaler = nn.Parameter(torch.zeros(2))
969
+
970
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
971
+ self.lm_head = RobertaLMHead(config)
972
+
973
+ self.pad = config.pad_token_id
974
+
975
+ # The LM head weights require special treatment only when they are tied with the word embeddings
976
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
977
+
978
+ # Initialize weights and apply final processing
979
+ self.post_init()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.lm_head.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.lm_head.decoder = new_embeddings
986
+
987
+ @property
988
+ def scaler(self):
989
+ return self._scaler.exp()
990
+
991
+ @property
992
+ def rel_weight(self):
993
+ if self.config.weight_act == 'sigmoid':
994
+ return torch.sigmoid(self._rel_weight)
995
+ elif self.config.weight_act == 'softmax':
996
+ return torch.softmax(self._rel_weight, dim=-1)
997
+
998
+ def compute_block(self, distance, height):
999
+ """Compute constituents from distance and height."""
1000
+
1001
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
1002
+
1003
+ gamma = torch.sigmoid(-beta_logits)
1004
+ ones = torch.ones_like(gamma)
1005
+
1006
+ block_mask_left = cummin(
1007
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
1008
+ block_mask_left = block_mask_left - F.pad(
1009
+ block_mask_left[:, :, :-1], (1, 0), value=0)
1010
+ block_mask_left.tril_(0)
1011
+
1012
+ block_mask_right = cummin(
1013
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
1014
+ block_mask_right = block_mask_right - F.pad(
1015
+ block_mask_right[:, :, 1:], (0, 1), value=0)
1016
+ block_mask_right.triu_(0)
1017
+
1018
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
1019
+ block = cumsum(block_mask_left).tril(0) + cumsum(
1020
+ block_mask_right, reverse=True).triu(1)
1021
+
1022
+ return block_p, block
1023
+
1024
+ def compute_head(self, height):
1025
+ """Estimate head for each constituent."""
1026
+
1027
+ _, length = height.size()
1028
+ head_logits = height * self.scaler[1]
1029
+ index = torch.arange(length, device=height.device)
1030
+
1031
+ mask = (index[:, None, None] <= index[None, None, :]) * (
1032
+ index[None, None, :] <= index[None, :, None])
1033
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
1034
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
1035
+
1036
+ head_p = torch.softmax(head_logits, dim=-1)
1037
+
1038
+ return head_p
1039
+
1040
+ def parse(self, x):
1041
+ """Parse input sentence.
1042
+
1043
+ Args:
1044
+ x: input tokens (required).
1045
+ pos: position for each token (optional).
1046
+ Returns:
1047
+ distance: syntactic distance
1048
+ height: syntactic height
1049
+ """
1050
+
1051
+ mask = (x != self.pad)
1052
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
1053
+
1054
+ h = self.roberta.embeddings(x)
1055
+ for i in range(self.config.n_parser_layers):
1056
+ h = h.masked_fill(~mask[:, :, None], 0)
1057
+ h = self.parser_layers[i](h)
1058
+
1059
+ height = self.height_ff(h).squeeze(-1)
1060
+ height.masked_fill_(~mask, -1e9)
1061
+
1062
+ distance = self.distance_ff(h).squeeze(-1)
1063
+ distance.masked_fill_(~mask_shifted, 1e9)
1064
+
1065
+ # Calbrating the distance and height to the same level
1066
+ length = distance.size(1)
1067
+ height_max = height[:, None, :].expand(-1, length, -1)
1068
+ height_max = torch.cummax(
1069
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
1070
+ dim=-1)[0].triu(0)
1071
+
1072
+ margin_left = torch.relu(
1073
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
1074
+ margin_right = torch.relu(distance[:, None, :] - height_max)
1075
+ margin = torch.where(margin_left > margin_right, margin_right,
1076
+ margin_left).triu(0)
1077
+
1078
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
1079
+ margin.masked_fill_(~margin_mask, 0)
1080
+ margin = margin.max()
1081
+
1082
+ distance = distance - margin
1083
+
1084
+ return distance, height
1085
+
1086
+ def generate_mask(self, x, distance, height):
1087
+ """Compute head and cibling distribution for each token."""
1088
+
1089
+ bsz, length = x.size()
1090
+
1091
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
1092
+ eye = eye[None, :, :].expand((bsz, -1, -1))
1093
+
1094
+ block_p, block = self.compute_block(distance, height)
1095
+ head_p = self.compute_head(height)
1096
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
1097
+ head = head.masked_fill(eye, 0)
1098
+ child = head.transpose(1, 2)
1099
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
1100
+
1101
+ rel_list = []
1102
+ if 'head' in self.config.relations:
1103
+ rel_list.append(head)
1104
+ if 'child' in self.config.relations:
1105
+ rel_list.append(child)
1106
+ if 'cibling' in self.config.relations:
1107
+ rel_list.append(cibling)
1108
+
1109
+ rel = torch.stack(rel_list, dim=1)
1110
+
1111
+ rel_weight = self.rel_weight
1112
+
1113
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
1114
+ att_mask = dep.reshape(self.config.num_hidden_layers, bsz, self.config.num_attention_heads, length, length)
1115
+
1116
+ return att_mask, cibling, head, block
1117
+
1118
+ def forward(
1119
+ self,
1120
+ input_ids: Optional[torch.LongTensor] = None,
1121
+ attention_mask: Optional[torch.FloatTensor] = None,
1122
+ token_type_ids: Optional[torch.LongTensor] = None,
1123
+ position_ids: Optional[torch.LongTensor] = None,
1124
+ head_mask: Optional[torch.FloatTensor] = None,
1125
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1126
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1127
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1128
+ labels: Optional[torch.LongTensor] = None,
1129
+ output_attentions: Optional[bool] = None,
1130
+ output_hidden_states: Optional[bool] = None,
1131
+ return_dict: Optional[bool] = None,
1132
+ ) -> Union[Tuple, MaskedLMOutput]:
1133
+ r"""
1134
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1135
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1136
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1137
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1138
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1139
+ Used to hide legacy arguments that have been deprecated.
1140
+ """
1141
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1142
+
1143
+ distance, height = self.parse(input_ids)
1144
+ att_mask, cibling, head, block = self.generate_mask(input_ids, distance, height)
1145
+
1146
+ outputs = self.roberta(
1147
+ input_ids,
1148
+ attention_mask=attention_mask,
1149
+ token_type_ids=token_type_ids,
1150
+ position_ids=position_ids,
1151
+ head_mask=head_mask,
1152
+ inputs_embeds=inputs_embeds,
1153
+ encoder_hidden_states=encoder_hidden_states,
1154
+ encoder_attention_mask=encoder_attention_mask,
1155
+ output_attentions=output_attentions,
1156
+ output_hidden_states=output_hidden_states,
1157
+ return_dict=return_dict,
1158
+ parser_att_mask=att_mask,
1159
+ )
1160
+ sequence_output = outputs[0]
1161
+ prediction_scores = self.lm_head(sequence_output)
1162
+
1163
+ masked_lm_loss = None
1164
+ if labels is not None:
1165
+ loss_fct = CrossEntropyLoss()
1166
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1167
+
1168
+ if not return_dict:
1169
+ output = (prediction_scores,) + outputs[2:]
1170
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1171
+
1172
+ return MaskedLMOutput(
1173
+ loss=masked_lm_loss,
1174
+ logits=prediction_scores,
1175
+ hidden_states=outputs.hidden_states,
1176
+ attentions=outputs.attentions,
1177
+ )
1178
+
1179
+ class RobertaLMHead(nn.Module):
1180
+ """Roberta Head for masked language modeling."""
1181
+
1182
+ def __init__(self, config):
1183
+ super().__init__()
1184
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1185
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1186
+
1187
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1188
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1189
+ self.decoder.bias = self.bias
1190
+
1191
+ def forward(self, features, **kwargs):
1192
+ x = self.dense(features)
1193
+ x = gelu(x)
1194
+ x = self.layer_norm(x)
1195
+
1196
+ # project back to size of vocabulary with bias
1197
+ x = self.decoder(x)
1198
+
1199
+ return x
1200
+
1201
+ def _tie_weights(self):
1202
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1203
+ self.bias = self.decoder.bias
1204
+
1205
+
1206
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1207
+ """
1208
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1209
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1210
+
1211
+ Args:
1212
+ x: torch.Tensor x:
1213
+
1214
+ Returns: torch.Tensor
1215
+ """
1216
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1217
+ mask = input_ids.ne(padding_idx).int()
1218
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1219
+ return incremental_indices.long() + padding_idx
1220
+
1221
+
1222
+ def cumprod(x, reverse=False, exclusive=False):
1223
+ """cumulative product."""
1224
+ if reverse:
1225
+ x = x.flip([-1])
1226
+
1227
+ if exclusive:
1228
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
1229
+
1230
+ cx = x.cumprod(-1)
1231
+
1232
+ if reverse:
1233
+ cx = cx.flip([-1])
1234
+ return cx
1235
+
1236
+
1237
+ def cumsum(x, reverse=False, exclusive=False):
1238
+ """cumulative sum."""
1239
+ bsz, _, length = x.size()
1240
+ device = x.device
1241
+ if reverse:
1242
+ if exclusive:
1243
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
1244
+ else:
1245
+ w = torch.ones([bsz, length, length], device=device).tril(0)
1246
+ cx = torch.bmm(x, w)
1247
+ else:
1248
+ if exclusive:
1249
+ w = torch.ones([bsz, length, length], device=device).triu(1)
1250
+ else:
1251
+ w = torch.ones([bsz, length, length], device=device).triu(0)
1252
+ cx = torch.bmm(x, w)
1253
+ return cx
1254
+
1255
+
1256
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
1257
+ """cumulative min."""
1258
+ if reverse:
1259
+ if exclusive:
1260
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
1261
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
1262
+ else:
1263
+ if exclusive:
1264
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
1265
+ x = x.cummin(-1)[0]
1266
+ return x
1267
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16292cfcec36b9e1eff9d6a286a39946a32a2df5fbc2a2a5e0e28f07cf7e4137
3
+ size 577194047