Kowsher commited on
Commit
091e371
·
verified ·
1 Parent(s): 583f06a

Upload shared_model.py

Browse files
Files changed (1) hide show
  1. shared_model.py +2032 -0
shared_model.py ADDED
@@ -0,0 +1,2032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from packaging import version
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_attn_mask_utils import (
32
+ _prepare_4d_attention_mask_for_sdpa,
33
+ _prepare_4d_causal_attention_mask_for_sdpa,
34
+ )
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutputWithPastAndCrossAttentions,
37
+ BaseModelOutputWithPoolingAndCrossAttentions,
38
+ CausalLMOutputWithCrossAttentions,
39
+ MaskedLMOutput,
40
+ MultipleChoiceModelOutput,
41
+ NextSentencePredictorOutput,
42
+ QuestionAnsweringModelOutput,
43
+ SequenceClassifierOutput,
44
+ TokenClassifierOutput,
45
+ )
46
+ from transformers.modeling_utils import PreTrainedModel
47
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
48
+ from transformers.utils import (
49
+ ModelOutput,
50
+ add_code_sample_docstrings,
51
+ add_start_docstrings,
52
+ add_start_docstrings_to_model_forward,
53
+ get_torch_version,
54
+ logging,
55
+ replace_return_docstrings,
56
+ )
57
+ from transformers import BertConfig
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
63
+ _CONFIG_FOR_DOC = "BertConfig"
64
+
65
+ # TokenClassification docstring
66
+ _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
67
+ _TOKEN_CLASS_EXPECTED_OUTPUT = (
68
+ "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
69
+ )
70
+ _TOKEN_CLASS_EXPECTED_LOSS = 0.01
71
+
72
+ # QuestionAnswering docstring
73
+ _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
74
+ _QA_EXPECTED_OUTPUT = "'a nice puppet'"
75
+ _QA_EXPECTED_LOSS = 7.41
76
+ _QA_TARGET_START_INDEX = 14
77
+ _QA_TARGET_END_INDEX = 15
78
+
79
+ # SequenceClassification docstring
80
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
81
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
82
+ _SEQ_CLASS_EXPECTED_LOSS = 0.01
83
+
84
+
85
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
86
+ """Load tf checkpoints in a pytorch model."""
87
+ try:
88
+ import re
89
+
90
+ import numpy as np
91
+ import tensorflow as tf
92
+ except ImportError:
93
+ logger.error(
94
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
95
+ "https://www.tensorflow.org/install/ for installation instructions."
96
+ )
97
+ raise
98
+ tf_path = os.path.abspath(tf_checkpoint_path)
99
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
100
+ # Load weights from TF model
101
+ init_vars = tf.train.list_variables(tf_path)
102
+ names = []
103
+ arrays = []
104
+ for name, shape in init_vars:
105
+ logger.info(f"Loading TF weight {name} with shape {shape}")
106
+ array = tf.train.load_variable(tf_path, name)
107
+ names.append(name)
108
+ arrays.append(array)
109
+
110
+ for name, array in zip(names, arrays):
111
+ name = name.split("/")
112
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
113
+ # which are not required for using pretrained model
114
+ if any(
115
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
116
+ for n in name
117
+ ):
118
+ logger.info(f"Skipping {'/'.join(name)}")
119
+ continue
120
+ pointer = model
121
+ for m_name in name:
122
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
123
+ scope_names = re.split(r"_(\d+)", m_name)
124
+ else:
125
+ scope_names = [m_name]
126
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
127
+ pointer = getattr(pointer, "weight")
128
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
129
+ pointer = getattr(pointer, "bias")
130
+ elif scope_names[0] == "output_weights":
131
+ pointer = getattr(pointer, "weight")
132
+ elif scope_names[0] == "squad":
133
+ pointer = getattr(pointer, "classifier")
134
+ else:
135
+ try:
136
+ pointer = getattr(pointer, scope_names[0])
137
+ except AttributeError:
138
+ logger.info(f"Skipping {'/'.join(name)}")
139
+ continue
140
+ if len(scope_names) >= 2:
141
+ num = int(scope_names[1])
142
+ pointer = pointer[num]
143
+ if m_name[-11:] == "_embeddings":
144
+ pointer = getattr(pointer, "weight")
145
+ elif m_name == "kernel":
146
+ array = np.transpose(array)
147
+ try:
148
+ if pointer.shape != array.shape:
149
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
150
+ except ValueError as e:
151
+ e.args += (pointer.shape, array.shape)
152
+ raise
153
+ logger.info(f"Initialize PyTorch weight {name}")
154
+ pointer.data = torch.from_numpy(array)
155
+ return model
156
+
157
+
158
+ class BertEmbeddings(nn.Module):
159
+ """Construct the embeddings from word, position and token_type embeddings."""
160
+
161
+ def __init__(self, config):
162
+ super().__init__()
163
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
164
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
165
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
166
+
167
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
168
+ # any TensorFlow checkpoint file
169
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
170
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
171
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
172
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
173
+ self.register_buffer(
174
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
175
+ )
176
+ self.register_buffer(
177
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
178
+ )
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ token_type_ids: Optional[torch.LongTensor] = None,
184
+ position_ids: Optional[torch.LongTensor] = None,
185
+ inputs_embeds: Optional[torch.FloatTensor] = None,
186
+ past_key_values_length: int = 0,
187
+ ) -> torch.Tensor:
188
+ if input_ids is not None:
189
+ input_shape = input_ids.size()
190
+ else:
191
+ input_shape = inputs_embeds.size()[:-1]
192
+
193
+ seq_length = input_shape[1]
194
+
195
+ if position_ids is None:
196
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
197
+
198
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
199
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
200
+ # issue #5664
201
+ if token_type_ids is None:
202
+ if hasattr(self, "token_type_ids"):
203
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
204
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
205
+ token_type_ids = buffered_token_type_ids_expanded
206
+ else:
207
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.word_embeddings(input_ids)
211
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
212
+
213
+ embeddings = inputs_embeds + token_type_embeddings
214
+ if self.position_embedding_type == "absolute":
215
+ position_embeddings = self.position_embeddings(position_ids)
216
+ embeddings += position_embeddings
217
+ embeddings = self.LayerNorm(embeddings)
218
+ embeddings = self.dropout(embeddings)
219
+ return embeddings
220
+
221
+
222
+ class BertSelfAttention(nn.Module):
223
+ def __init__(self, config, position_embedding_type=None):
224
+ super().__init__()
225
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
226
+ raise ValueError(
227
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
228
+ f"heads ({config.num_attention_heads})"
229
+ )
230
+
231
+ self.num_attention_heads = config.num_attention_heads
232
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
233
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
234
+
235
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
236
+
237
+ self.query_pro = nn.Parameter(torch.ones(self.all_head_size))
238
+ self.value_pro = nn.Parameter(torch.ones(self.all_head_size))
239
+ self.key_pro = nn.Parameter(torch.ones(self.all_head_size))
240
+
241
+
242
+
243
+
244
+
245
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
246
+ self.position_embedding_type = position_embedding_type or getattr(
247
+ config, "position_embedding_type", "absolute"
248
+ )
249
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
250
+ self.max_position_embeddings = config.max_position_embeddings
251
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
252
+
253
+ self.is_decoder = config.is_decoder
254
+
255
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
256
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
257
+ x = x.view(new_x_shape)
258
+ return x.permute(0, 2, 1, 3)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ attention_mask: Optional[torch.FloatTensor] = None,
264
+ head_mask: Optional[torch.FloatTensor] = None,
265
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
266
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
267
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
268
+ output_attentions: Optional[bool] = False,
269
+ ) -> Tuple[torch.Tensor]:
270
+ mixed_query_layer = self.query(hidden_states)
271
+
272
+ # If this is instantiated as a cross-attention module, the keys
273
+ # and values come from an encoder; the attention mask needs to be
274
+ # such that the encoder's padding tokens are not attended to.
275
+ is_cross_attention = encoder_hidden_states is not None
276
+
277
+ if is_cross_attention and past_key_value is not None:
278
+ # reuse k,v, cross_attentions
279
+ key_layer = past_key_value[0]
280
+ value_layer = past_key_value[1]
281
+ attention_mask = encoder_attention_mask
282
+ elif is_cross_attention:
283
+ encoder_att = self.query(encoder_hidden_states)
284
+ key_layer = self.transpose_for_scores(encoder_att * self.key_pro)
285
+ value_layer = self.transpose_for_scores(encoder_att * self.value_pro)
286
+ attention_mask = encoder_attention_mask
287
+ elif past_key_value is not None:
288
+ key_layer = self.transpose_for_scores(mixed_query_layer * self.key_pro)
289
+ value_layer = self.transpose_for_scores(mixed_query_layer*self.value_pro)
290
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
291
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
292
+ else:
293
+ key_layer = self.transpose_for_scores(mixed_query_layer * self.key_pro)
294
+ value_layer = self.transpose_for_scores(mixed_query_layer * self.value_pro)
295
+
296
+ query_layer = self.transpose_for_scores(mixed_query_layer * self.query_pro )
297
+
298
+
299
+ use_cache = past_key_value is not None
300
+ if self.is_decoder:
301
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
302
+ # Further calls to cross_attention layer can then reuse all cross-attention
303
+ # key/value_states (first "if" case)
304
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
305
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
306
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
307
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
308
+ past_key_value = (key_layer, value_layer)
309
+
310
+ # Take the dot product between "query" and "key" to get the raw attention scores.
311
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
312
+
313
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
314
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
315
+ if use_cache:
316
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
317
+ -1, 1
318
+ )
319
+ else:
320
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
321
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
322
+ distance = position_ids_l - position_ids_r
323
+
324
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
325
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
326
+
327
+ if self.position_embedding_type == "relative_key":
328
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
329
+ attention_scores = attention_scores + relative_position_scores
330
+ elif self.position_embedding_type == "relative_key_query":
331
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
332
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
333
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
334
+
335
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
336
+ if attention_mask is not None:
337
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
338
+ attention_scores = attention_scores + attention_mask
339
+
340
+ # Normalize the attention scores to probabilities.
341
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
342
+
343
+ # This is actually dropping out entire tokens to attend to, which might
344
+ # seem a bit unusual, but is taken from the original Transformer paper.
345
+ attention_probs = self.dropout(attention_probs)
346
+
347
+ # Mask heads if we want to
348
+ if head_mask is not None:
349
+ attention_probs = attention_probs * head_mask
350
+
351
+ context_layer = torch.matmul(attention_probs, value_layer)
352
+
353
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
354
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
355
+ context_layer = context_layer.view(new_context_layer_shape)
356
+
357
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
358
+
359
+ if self.is_decoder:
360
+ outputs = outputs + (past_key_value,)
361
+ return outputs
362
+
363
+
364
+ class BertSdpaSelfAttention(BertSelfAttention):
365
+ def __init__(self, config, position_embedding_type=None):
366
+ super().__init__(config, position_embedding_type=position_embedding_type)
367
+ self.dropout_prob = config.attention_probs_dropout_prob
368
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
369
+
370
+ # Adapted from BertSelfAttention
371
+ def forward(
372
+ self,
373
+ hidden_states: torch.Tensor,
374
+ attention_mask: Optional[torch.Tensor] = None,
375
+ head_mask: Optional[torch.FloatTensor] = None,
376
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
377
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
378
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
379
+ output_attentions: Optional[bool] = False,
380
+ ) -> Tuple[torch.Tensor]:
381
+ if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
382
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
383
+ logger.warning_once(
384
+ "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
385
+ "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
386
+ "the manual attention implementation, but specifying the manual implementation will be required from "
387
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
388
+ '`attn_implementation="eager"` when loading the model.'
389
+ )
390
+ return super().forward(
391
+ hidden_states,
392
+ attention_mask,
393
+ head_mask,
394
+ encoder_hidden_states,
395
+ encoder_attention_mask,
396
+ past_key_value,
397
+ output_attentions,
398
+ )
399
+
400
+ bsz, tgt_len, _ = hidden_states.size()
401
+
402
+ query_layer = self.transpose_for_scores(self.query(hidden_states) * self.query_pro)
403
+
404
+ # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
405
+ # mask needs to be such that the encoder's padding tokens are not attended to.
406
+ is_cross_attention = encoder_hidden_states is not None
407
+
408
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
409
+ attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
410
+
411
+ # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
412
+ if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
413
+ key_layer, value_layer = past_key_value
414
+ else:
415
+ current_sent = self.query(current_states)
416
+ key_layer = self.transpose_for_scores(current_sent * self.key_pro)
417
+ value_layer = self.transpose_for_scores(current_sent * self.key_pro)
418
+ if past_key_value is not None and not is_cross_attention:
419
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
420
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
421
+
422
+
423
+ if self.is_decoder:
424
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
425
+ # Further calls to cross_attention layer can then reuse all cross-attention
426
+ # key/value_states (first "if" case)
427
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
428
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
429
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
430
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
431
+ past_key_value = (key_layer, value_layer)
432
+
433
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
434
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
435
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
436
+ if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
437
+ query_layer = query_layer.contiguous()
438
+ key_layer = key_layer.contiguous()
439
+ value_layer = value_layer.contiguous()
440
+
441
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
442
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
443
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
444
+ # a causal mask in case tgt_len == 1.
445
+ is_causal = (
446
+ True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
447
+ )
448
+
449
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
450
+ query_layer,
451
+ key_layer,
452
+ value_layer,
453
+ attn_mask=attention_mask,
454
+ dropout_p=self.dropout_prob if self.training else 0.0,
455
+ is_causal=is_causal,
456
+ )
457
+
458
+ attn_output = attn_output.transpose(1, 2)
459
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
460
+
461
+ outputs = (attn_output,)
462
+ if self.is_decoder:
463
+ outputs = outputs + (past_key_value,)
464
+ return outputs
465
+
466
+
467
+ class BertSelfOutput(nn.Module):
468
+ def __init__(self, config):
469
+ super().__init__()
470
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
471
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
472
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
473
+
474
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
475
+ hidden_states = self.dense(hidden_states)
476
+ hidden_states = self.dropout(hidden_states)
477
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
478
+ return hidden_states
479
+
480
+
481
+ BERT_SELF_ATTENTION_CLASSES = {
482
+ "eager": BertSelfAttention,
483
+ "sdpa": BertSdpaSelfAttention,
484
+ }
485
+
486
+
487
+ class BertAttention(nn.Module):
488
+ def __init__(self, config, position_embedding_type=None):
489
+ super().__init__()
490
+ self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
491
+ config, position_embedding_type=position_embedding_type
492
+ )
493
+ self.output = BertSelfOutput(config)
494
+ self.pruned_heads = set()
495
+
496
+ def prune_heads(self, heads):
497
+ if len(heads) == 0:
498
+ return
499
+ heads, index = find_pruneable_heads_and_indices(
500
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
501
+ )
502
+
503
+ # Prune linear layers
504
+ self.self.query = prune_linear_layer(self.self.query, index)
505
+
506
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
507
+
508
+ # Update hyper params and store pruned heads
509
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
510
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
511
+ self.pruned_heads = self.pruned_heads.union(heads)
512
+
513
+ def forward(
514
+ self,
515
+ hidden_states: torch.Tensor,
516
+ attention_mask: Optional[torch.FloatTensor] = None,
517
+ head_mask: Optional[torch.FloatTensor] = None,
518
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
519
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
520
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
521
+ output_attentions: Optional[bool] = False,
522
+ ) -> Tuple[torch.Tensor]:
523
+ self_outputs = self.self(
524
+ hidden_states,
525
+ attention_mask,
526
+ head_mask,
527
+ encoder_hidden_states,
528
+ encoder_attention_mask,
529
+ past_key_value,
530
+ output_attentions,
531
+ )
532
+ attention_output = self.output(self_outputs[0], hidden_states)
533
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
534
+ return outputs
535
+
536
+
537
+ class BertIntermediate(nn.Module):
538
+ def __init__(self, config):
539
+ super().__init__()
540
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
541
+ if isinstance(config.hidden_act, str):
542
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
543
+ else:
544
+ self.intermediate_act_fn = config.hidden_act
545
+
546
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
547
+ hidden_states = self.dense(hidden_states)
548
+ hidden_states = self.intermediate_act_fn(hidden_states)
549
+ return hidden_states
550
+
551
+
552
+ class BertOutput(nn.Module):
553
+ def __init__(self, config):
554
+ super().__init__()
555
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
556
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
557
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
558
+
559
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
560
+ hidden_states = self.dense(hidden_states)
561
+ hidden_states = self.dropout(hidden_states)
562
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
563
+ return hidden_states
564
+
565
+
566
+ class BertLayer(nn.Module):
567
+ def __init__(self, config):
568
+ super().__init__()
569
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
570
+ self.seq_len_dim = 1
571
+ self.attention = BertAttention(config)
572
+ self.is_decoder = config.is_decoder
573
+ self.add_cross_attention = config.add_cross_attention
574
+ if self.add_cross_attention:
575
+ if not self.is_decoder:
576
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
577
+ self.crossattention = BertAttention(config, position_embedding_type="absolute")
578
+ self.intermediate = BertIntermediate(config)
579
+ self.output = BertOutput(config)
580
+
581
+ def forward(
582
+ self,
583
+ hidden_states: torch.Tensor,
584
+ attention_mask: Optional[torch.FloatTensor] = None,
585
+ head_mask: Optional[torch.FloatTensor] = None,
586
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
587
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
588
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
589
+ output_attentions: Optional[bool] = False,
590
+ ) -> Tuple[torch.Tensor]:
591
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
592
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
593
+ self_attention_outputs = self.attention(
594
+ hidden_states,
595
+ attention_mask,
596
+ head_mask,
597
+ output_attentions=output_attentions,
598
+ past_key_value=self_attn_past_key_value,
599
+ )
600
+ attention_output = self_attention_outputs[0]
601
+
602
+ # if decoder, the last output is tuple of self-attn cache
603
+ if self.is_decoder:
604
+ outputs = self_attention_outputs[1:-1]
605
+ present_key_value = self_attention_outputs[-1]
606
+ else:
607
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
608
+
609
+ cross_attn_present_key_value = None
610
+ if self.is_decoder and encoder_hidden_states is not None:
611
+ if not hasattr(self, "crossattention"):
612
+ raise ValueError(
613
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
614
+ " by setting `config.add_cross_attention=True`"
615
+ )
616
+
617
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
618
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
619
+ cross_attention_outputs = self.crossattention(
620
+ attention_output,
621
+ attention_mask,
622
+ head_mask,
623
+ encoder_hidden_states,
624
+ encoder_attention_mask,
625
+ cross_attn_past_key_value,
626
+ output_attentions,
627
+ )
628
+ attention_output = cross_attention_outputs[0]
629
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
630
+
631
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
632
+ cross_attn_present_key_value = cross_attention_outputs[-1]
633
+ present_key_value = present_key_value + cross_attn_present_key_value
634
+
635
+ layer_output = apply_chunking_to_forward(
636
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
637
+ )
638
+ outputs = (layer_output,) + outputs
639
+
640
+ # if decoder, return the attn key/values as the last output
641
+ if self.is_decoder:
642
+ outputs = outputs + (present_key_value,)
643
+
644
+ return outputs
645
+
646
+ def feed_forward_chunk(self, attention_output):
647
+ intermediate_output = self.intermediate(attention_output)
648
+ layer_output = self.output(intermediate_output, attention_output)
649
+ return layer_output
650
+
651
+
652
+ class BertEncoder(nn.Module):
653
+ def __init__(self, config):
654
+ super().__init__()
655
+ self.config = config
656
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
657
+ self.gradient_checkpointing = False
658
+
659
+ def forward(
660
+ self,
661
+ hidden_states: torch.Tensor,
662
+ attention_mask: Optional[torch.FloatTensor] = None,
663
+ head_mask: Optional[torch.FloatTensor] = None,
664
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
665
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
666
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
667
+ use_cache: Optional[bool] = None,
668
+ output_attentions: Optional[bool] = False,
669
+ output_hidden_states: Optional[bool] = False,
670
+ return_dict: Optional[bool] = True,
671
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
672
+ all_hidden_states = () if output_hidden_states else None
673
+ all_self_attentions = () if output_attentions else None
674
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
675
+
676
+ if self.gradient_checkpointing and self.training:
677
+ if use_cache:
678
+ logger.warning_once(
679
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
680
+ )
681
+ use_cache = False
682
+
683
+ next_decoder_cache = () if use_cache else None
684
+ for i, layer_module in enumerate(self.layer):
685
+ if output_hidden_states:
686
+ all_hidden_states = all_hidden_states + (hidden_states,)
687
+
688
+ layer_head_mask = head_mask[i] if head_mask is not None else None
689
+ past_key_value = past_key_values[i] if past_key_values is not None else None
690
+
691
+ if self.gradient_checkpointing and self.training:
692
+ layer_outputs = self._gradient_checkpointing_func(
693
+ layer_module.__call__,
694
+ hidden_states,
695
+ attention_mask,
696
+ layer_head_mask,
697
+ encoder_hidden_states,
698
+ encoder_attention_mask,
699
+ past_key_value,
700
+ output_attentions,
701
+ )
702
+ else:
703
+ layer_outputs = layer_module(
704
+ hidden_states,
705
+ attention_mask,
706
+ layer_head_mask,
707
+ encoder_hidden_states,
708
+ encoder_attention_mask,
709
+ past_key_value,
710
+ output_attentions,
711
+ )
712
+
713
+ hidden_states = layer_outputs[0]
714
+ if use_cache:
715
+ next_decoder_cache += (layer_outputs[-1],)
716
+ if output_attentions:
717
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
718
+ if self.config.add_cross_attention:
719
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
720
+
721
+ if output_hidden_states:
722
+ all_hidden_states = all_hidden_states + (hidden_states,)
723
+
724
+ if not return_dict:
725
+ return tuple(
726
+ v
727
+ for v in [
728
+ hidden_states,
729
+ next_decoder_cache,
730
+ all_hidden_states,
731
+ all_self_attentions,
732
+ all_cross_attentions,
733
+ ]
734
+ if v is not None
735
+ )
736
+ return BaseModelOutputWithPastAndCrossAttentions(
737
+ last_hidden_state=hidden_states,
738
+ past_key_values=next_decoder_cache,
739
+ hidden_states=all_hidden_states,
740
+ attentions=all_self_attentions,
741
+ cross_attentions=all_cross_attentions,
742
+ )
743
+
744
+
745
+ class BertPooler(nn.Module):
746
+ def __init__(self, config):
747
+ super().__init__()
748
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
749
+ self.activation = nn.Tanh()
750
+
751
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
752
+ # We "pool" the model by simply taking the hidden state corresponding
753
+ # to the first token.
754
+ first_token_tensor = hidden_states[:, 0]
755
+ pooled_output = self.dense(first_token_tensor)
756
+ pooled_output = self.activation(pooled_output)
757
+ return pooled_output
758
+
759
+
760
+ class BertPredictionHeadTransform(nn.Module):
761
+ def __init__(self, config):
762
+ super().__init__()
763
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
764
+ if isinstance(config.hidden_act, str):
765
+ self.transform_act_fn = ACT2FN[config.hidden_act]
766
+ else:
767
+ self.transform_act_fn = config.hidden_act
768
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
769
+
770
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
771
+ hidden_states = self.dense(hidden_states)
772
+ hidden_states = self.transform_act_fn(hidden_states)
773
+ hidden_states = self.LayerNorm(hidden_states)
774
+ return hidden_states
775
+
776
+
777
+ class BertLMPredictionHead(nn.Module):
778
+ def __init__(self, config):
779
+ super().__init__()
780
+ self.transform = BertPredictionHeadTransform(config)
781
+
782
+ # The output weights are the same as the input embeddings, but there is
783
+ # an output-only bias for each token.
784
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
785
+
786
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
787
+
788
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
789
+ self.decoder.bias = self.bias
790
+
791
+ def _tie_weights(self):
792
+ self.decoder.bias = self.bias
793
+
794
+ def forward(self, hidden_states):
795
+ hidden_states = self.transform(hidden_states)
796
+ hidden_states = self.decoder(hidden_states)
797
+ return hidden_states
798
+
799
+
800
+ class BertOnlyMLMHead(nn.Module):
801
+ def __init__(self, config):
802
+ super().__init__()
803
+ self.predictions = BertLMPredictionHead(config)
804
+
805
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
806
+ prediction_scores = self.predictions(sequence_output)
807
+ return prediction_scores
808
+
809
+
810
+ class BertOnlyNSPHead(nn.Module):
811
+ def __init__(self, config):
812
+ super().__init__()
813
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
814
+
815
+ def forward(self, pooled_output):
816
+ seq_relationship_score = self.seq_relationship(pooled_output)
817
+ return seq_relationship_score
818
+
819
+
820
+ class BertPreTrainingHeads(nn.Module):
821
+ def __init__(self, config):
822
+ super().__init__()
823
+ self.predictions = BertLMPredictionHead(config)
824
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
825
+
826
+ def forward(self, sequence_output, pooled_output):
827
+ prediction_scores = self.predictions(sequence_output)
828
+ seq_relationship_score = self.seq_relationship(pooled_output)
829
+ return prediction_scores, seq_relationship_score
830
+
831
+
832
+ class BertPreTrainedModel(PreTrainedModel):
833
+ """
834
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
835
+ models.
836
+ """
837
+
838
+ config_class = BertConfig
839
+ load_tf_weights = load_tf_weights_in_bert
840
+ base_model_prefix = "bert"
841
+ supports_gradient_checkpointing = True
842
+ _supports_sdpa = True
843
+
844
+ def _init_weights(self, module):
845
+ """Initialize the weights"""
846
+ if isinstance(module, nn.Linear):
847
+ # Slightly different from the TF version which uses truncated_normal for initialization
848
+ # cf https://github.com/pytorch/pytorch/pull/5617
849
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
850
+ if module.bias is not None:
851
+ module.bias.data.zero_()
852
+ elif isinstance(module, nn.Embedding):
853
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
854
+ if module.padding_idx is not None:
855
+ module.weight.data[module.padding_idx].zero_()
856
+ elif isinstance(module, nn.LayerNorm):
857
+ module.bias.data.zero_()
858
+ module.weight.data.fill_(1.0)
859
+
860
+
861
+ @dataclass
862
+ class BertForPreTrainingOutput(ModelOutput):
863
+ """
864
+ Output type of [`BertForPreTraining`].
865
+
866
+ Args:
867
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
868
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
869
+ (classification) loss.
870
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
871
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
872
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
873
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
874
+ before SoftMax).
875
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
876
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
877
+ shape `(batch_size, sequence_length, hidden_size)`.
878
+
879
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
880
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
881
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
882
+ sequence_length)`.
883
+
884
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
885
+ heads.
886
+ """
887
+
888
+ loss: Optional[torch.FloatTensor] = None
889
+ prediction_logits: torch.FloatTensor = None
890
+ seq_relationship_logits: torch.FloatTensor = None
891
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
892
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
893
+
894
+
895
+ BERT_START_DOCSTRING = r"""
896
+
897
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
898
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
899
+ etc.)
900
+
901
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
902
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
903
+ and behavior.
904
+
905
+ Parameters:
906
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
907
+ Initializing with a config file does not load the weights associated with the model, only the
908
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
909
+ """
910
+
911
+ BERT_INPUTS_DOCSTRING = r"""
912
+ Args:
913
+ input_ids (`torch.LongTensor` of shape `({0})`):
914
+ Indices of input sequence tokens in the vocabulary.
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ [What are input IDs?](../glossary#input-ids)
920
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
921
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
922
+
923
+ - 1 for tokens that are **not masked**,
924
+ - 0 for tokens that are **masked**.
925
+
926
+ [What are attention masks?](../glossary#attention-mask)
927
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
928
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
929
+ 1]`:
930
+
931
+ - 0 corresponds to a *sentence A* token,
932
+ - 1 corresponds to a *sentence B* token.
933
+
934
+ [What are token type IDs?](../glossary#token-type-ids)
935
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
936
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
937
+ config.max_position_embeddings - 1]`.
938
+
939
+ [What are position IDs?](../glossary#position-ids)
940
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
941
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
942
+
943
+ - 1 indicates the head is **not masked**,
944
+ - 0 indicates the head is **masked**.
945
+
946
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
947
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
948
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
949
+ model's internal embedding lookup matrix.
950
+ output_attentions (`bool`, *optional*):
951
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
952
+ tensors for more detail.
953
+ output_hidden_states (`bool`, *optional*):
954
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
955
+ more detail.
956
+ return_dict (`bool`, *optional*):
957
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
958
+ """
959
+
960
+
961
+ @add_start_docstrings(
962
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
963
+ BERT_START_DOCSTRING,
964
+ )
965
+ class BertModel(BertPreTrainedModel):
966
+ """
967
+
968
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
969
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
970
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
971
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
972
+
973
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
974
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
975
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
976
+ """
977
+
978
+ _no_split_modules = ["BertEmbeddings", "BertLayer"]
979
+
980
+ def __init__(self, config, add_pooling_layer=True):
981
+ super().__init__(config)
982
+ self.config = config
983
+
984
+ self.embeddings = BertEmbeddings(config)
985
+ self.encoder = BertEncoder(config)
986
+
987
+ self.pooler = BertPooler(config) if add_pooling_layer else None
988
+
989
+ self.attn_implementation = config._attn_implementation
990
+ self.position_embedding_type = config.position_embedding_type
991
+
992
+ # Initialize weights and apply final processing
993
+ self.post_init()
994
+
995
+ def get_input_embeddings(self):
996
+ return self.embeddings.word_embeddings
997
+
998
+ def set_input_embeddings(self, value):
999
+ self.embeddings.word_embeddings = value
1000
+
1001
+ def _prune_heads(self, heads_to_prune):
1002
+ """
1003
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1004
+ class PreTrainedModel
1005
+ """
1006
+ for layer, heads in heads_to_prune.items():
1007
+ self.encoder.layer[layer].attention.prune_heads(heads)
1008
+
1009
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1010
+ @add_code_sample_docstrings(
1011
+ checkpoint=_CHECKPOINT_FOR_DOC,
1012
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
1013
+ config_class=_CONFIG_FOR_DOC,
1014
+ )
1015
+ def forward(
1016
+ self,
1017
+ input_ids: Optional[torch.Tensor] = None,
1018
+ attention_mask: Optional[torch.Tensor] = None,
1019
+ token_type_ids: Optional[torch.Tensor] = None,
1020
+ position_ids: Optional[torch.Tensor] = None,
1021
+ head_mask: Optional[torch.Tensor] = None,
1022
+ inputs_embeds: Optional[torch.Tensor] = None,
1023
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1024
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1025
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1026
+ use_cache: Optional[bool] = None,
1027
+ output_attentions: Optional[bool] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1031
+ r"""
1032
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1033
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1034
+ the model is configured as a decoder.
1035
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1036
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1037
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1038
+
1039
+ - 1 for tokens that are **not masked**,
1040
+ - 0 for tokens that are **masked**.
1041
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1042
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1043
+
1044
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1045
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1046
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1047
+ use_cache (`bool`, *optional*):
1048
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1049
+ `past_key_values`).
1050
+ """
1051
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1052
+ output_hidden_states = (
1053
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1054
+ )
1055
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1056
+
1057
+ if self.config.is_decoder:
1058
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1059
+ else:
1060
+ use_cache = False
1061
+
1062
+ if input_ids is not None and inputs_embeds is not None:
1063
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1064
+ elif input_ids is not None:
1065
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1066
+ input_shape = input_ids.size()
1067
+ elif inputs_embeds is not None:
1068
+ input_shape = inputs_embeds.size()[:-1]
1069
+ else:
1070
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1071
+
1072
+ batch_size, seq_length = input_shape
1073
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1074
+
1075
+ # past_key_values_length
1076
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1077
+
1078
+ if token_type_ids is None:
1079
+ if hasattr(self.embeddings, "token_type_ids"):
1080
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1081
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1082
+ token_type_ids = buffered_token_type_ids_expanded
1083
+ else:
1084
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1085
+
1086
+ embedding_output = self.embeddings(
1087
+ input_ids=input_ids,
1088
+ position_ids=position_ids,
1089
+ token_type_ids=token_type_ids,
1090
+ inputs_embeds=inputs_embeds,
1091
+ past_key_values_length=past_key_values_length,
1092
+ )
1093
+
1094
+ if attention_mask is None:
1095
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
1096
+
1097
+ use_sdpa_attention_masks = (
1098
+ self.attn_implementation == "sdpa"
1099
+ and self.position_embedding_type == "absolute"
1100
+ and head_mask is None
1101
+ and not output_attentions
1102
+ )
1103
+
1104
+ # Expand the attention mask
1105
+ if use_sdpa_attention_masks:
1106
+ # Expand the attention mask for SDPA.
1107
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
1108
+ if self.config.is_decoder:
1109
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1110
+ attention_mask,
1111
+ input_shape,
1112
+ embedding_output,
1113
+ past_key_values_length,
1114
+ )
1115
+ else:
1116
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1117
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
1118
+ )
1119
+ else:
1120
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1121
+ # ourselves in which case we just need to make it broadcastable to all heads.
1122
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
1123
+
1124
+ # If a 2D or 3D attention mask is provided for the cross-attention
1125
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1126
+ if self.config.is_decoder and encoder_hidden_states is not None:
1127
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1128
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1129
+ if encoder_attention_mask is None:
1130
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1131
+
1132
+ if use_sdpa_attention_masks:
1133
+ # Expand the attention mask for SDPA.
1134
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
1135
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1136
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
1137
+ )
1138
+ else:
1139
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1140
+ else:
1141
+ encoder_extended_attention_mask = None
1142
+
1143
+ # Prepare head mask if needed
1144
+ # 1.0 in head_mask indicate we keep the head
1145
+ # attention_probs has shape bsz x n_heads x N x N
1146
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1147
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1148
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1149
+
1150
+ encoder_outputs = self.encoder(
1151
+ embedding_output,
1152
+ attention_mask=extended_attention_mask,
1153
+ head_mask=head_mask,
1154
+ encoder_hidden_states=encoder_hidden_states,
1155
+ encoder_attention_mask=encoder_extended_attention_mask,
1156
+ past_key_values=past_key_values,
1157
+ use_cache=use_cache,
1158
+ output_attentions=output_attentions,
1159
+ output_hidden_states=output_hidden_states,
1160
+ return_dict=return_dict,
1161
+ )
1162
+ sequence_output = encoder_outputs[0]
1163
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1164
+
1165
+ if not return_dict:
1166
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1167
+
1168
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1169
+ last_hidden_state=sequence_output,
1170
+ pooler_output=pooled_output,
1171
+ past_key_values=encoder_outputs.past_key_values,
1172
+ hidden_states=encoder_outputs.hidden_states,
1173
+ attentions=encoder_outputs.attentions,
1174
+ cross_attentions=encoder_outputs.cross_attentions,
1175
+ )
1176
+
1177
+
1178
+ @add_start_docstrings(
1179
+ """
1180
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1181
+ sentence prediction (classification)` head.
1182
+ """,
1183
+ BERT_START_DOCSTRING,
1184
+ )
1185
+ class BertForPreTraining(BertPreTrainedModel):
1186
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1187
+
1188
+ def __init__(self, config):
1189
+ super().__init__(config)
1190
+
1191
+ self.bert = BertModel(config)
1192
+ self.cls = BertPreTrainingHeads(config)
1193
+
1194
+ # Initialize weights and apply final processing
1195
+ self.post_init()
1196
+
1197
+ def get_output_embeddings(self):
1198
+ return self.cls.predictions.decoder
1199
+
1200
+ def set_output_embeddings(self, new_embeddings):
1201
+ self.cls.predictions.decoder = new_embeddings
1202
+ self.cls.predictions.bias = new_embeddings.bias
1203
+
1204
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1205
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1206
+ def forward(
1207
+ self,
1208
+ input_ids: Optional[torch.Tensor] = None,
1209
+ attention_mask: Optional[torch.Tensor] = None,
1210
+ token_type_ids: Optional[torch.Tensor] = None,
1211
+ position_ids: Optional[torch.Tensor] = None,
1212
+ head_mask: Optional[torch.Tensor] = None,
1213
+ inputs_embeds: Optional[torch.Tensor] = None,
1214
+ labels: Optional[torch.Tensor] = None,
1215
+ next_sentence_label: Optional[torch.Tensor] = None,
1216
+ output_attentions: Optional[bool] = None,
1217
+ output_hidden_states: Optional[bool] = None,
1218
+ return_dict: Optional[bool] = None,
1219
+ ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
1220
+ r"""
1221
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1222
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1223
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1224
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1225
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1226
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1227
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1228
+
1229
+ - 0 indicates sequence B is a continuation of sequence A,
1230
+ - 1 indicates sequence B is a random sequence.
1231
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1232
+ Used to hide legacy arguments that have been deprecated.
1233
+
1234
+ Returns:
1235
+
1236
+ Example:
1237
+
1238
+ ```python
1239
+ >>> from transformers import AutoTokenizer, BertForPreTraining
1240
+ >>> import torch
1241
+
1242
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1243
+ >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
1244
+
1245
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1246
+ >>> outputs = model(**inputs)
1247
+
1248
+ >>> prediction_logits = outputs.prediction_logits
1249
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1250
+ ```
1251
+ """
1252
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1253
+
1254
+ outputs = self.bert(
1255
+ input_ids,
1256
+ attention_mask=attention_mask,
1257
+ token_type_ids=token_type_ids,
1258
+ position_ids=position_ids,
1259
+ head_mask=head_mask,
1260
+ inputs_embeds=inputs_embeds,
1261
+ output_attentions=output_attentions,
1262
+ output_hidden_states=output_hidden_states,
1263
+ return_dict=return_dict,
1264
+ )
1265
+
1266
+ sequence_output, pooled_output = outputs[:2]
1267
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1268
+
1269
+ total_loss = None
1270
+ if labels is not None and next_sentence_label is not None:
1271
+ loss_fct = CrossEntropyLoss()
1272
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1273
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1274
+ total_loss = masked_lm_loss + next_sentence_loss
1275
+
1276
+ if not return_dict:
1277
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1278
+ return ((total_loss,) + output) if total_loss is not None else output
1279
+
1280
+ return BertForPreTrainingOutput(
1281
+ loss=total_loss,
1282
+ prediction_logits=prediction_scores,
1283
+ seq_relationship_logits=seq_relationship_score,
1284
+ hidden_states=outputs.hidden_states,
1285
+ attentions=outputs.attentions,
1286
+ )
1287
+
1288
+
1289
+ @add_start_docstrings(
1290
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
1291
+ )
1292
+ class BertLMHeadModel(BertPreTrainedModel):
1293
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1294
+
1295
+ def __init__(self, config):
1296
+ super().__init__(config)
1297
+
1298
+ if not config.is_decoder:
1299
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1300
+
1301
+ self.bert = BertModel(config, add_pooling_layer=False)
1302
+ self.cls = BertOnlyMLMHead(config)
1303
+
1304
+ # Initialize weights and apply final processing
1305
+ self.post_init()
1306
+
1307
+ def get_output_embeddings(self):
1308
+ return self.cls.predictions.decoder
1309
+
1310
+ def set_output_embeddings(self, new_embeddings):
1311
+ self.cls.predictions.decoder = new_embeddings
1312
+ self.cls.predictions.bias = new_embeddings.bias
1313
+
1314
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1315
+ @add_code_sample_docstrings(
1316
+ checkpoint=_CHECKPOINT_FOR_DOC,
1317
+ output_type=CausalLMOutputWithCrossAttentions,
1318
+ config_class=_CONFIG_FOR_DOC,
1319
+ )
1320
+ def forward(
1321
+ self,
1322
+ input_ids: Optional[torch.Tensor] = None,
1323
+ attention_mask: Optional[torch.Tensor] = None,
1324
+ token_type_ids: Optional[torch.Tensor] = None,
1325
+ position_ids: Optional[torch.Tensor] = None,
1326
+ head_mask: Optional[torch.Tensor] = None,
1327
+ inputs_embeds: Optional[torch.Tensor] = None,
1328
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1329
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1330
+ labels: Optional[torch.Tensor] = None,
1331
+ past_key_values: Optional[List[torch.Tensor]] = None,
1332
+ use_cache: Optional[bool] = None,
1333
+ output_attentions: Optional[bool] = None,
1334
+ output_hidden_states: Optional[bool] = None,
1335
+ return_dict: Optional[bool] = None,
1336
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1337
+ r"""
1338
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1339
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1340
+ the model is configured as a decoder.
1341
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1342
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1343
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1344
+
1345
+ - 1 for tokens that are **not masked**,
1346
+ - 0 for tokens that are **masked**.
1347
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1348
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1349
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1350
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1351
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1352
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1353
+
1354
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1355
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1356
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1357
+ use_cache (`bool`, *optional*):
1358
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1359
+ `past_key_values`).
1360
+ """
1361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362
+ if labels is not None:
1363
+ use_cache = False
1364
+
1365
+ outputs = self.bert(
1366
+ input_ids,
1367
+ attention_mask=attention_mask,
1368
+ token_type_ids=token_type_ids,
1369
+ position_ids=position_ids,
1370
+ head_mask=head_mask,
1371
+ inputs_embeds=inputs_embeds,
1372
+ encoder_hidden_states=encoder_hidden_states,
1373
+ encoder_attention_mask=encoder_attention_mask,
1374
+ past_key_values=past_key_values,
1375
+ use_cache=use_cache,
1376
+ output_attentions=output_attentions,
1377
+ output_hidden_states=output_hidden_states,
1378
+ return_dict=return_dict,
1379
+ )
1380
+
1381
+ sequence_output = outputs[0]
1382
+ prediction_scores = self.cls(sequence_output)
1383
+
1384
+ lm_loss = None
1385
+ if labels is not None:
1386
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1387
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1388
+ labels = labels[:, 1:].contiguous()
1389
+ loss_fct = CrossEntropyLoss()
1390
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1391
+
1392
+ if not return_dict:
1393
+ output = (prediction_scores,) + outputs[2:]
1394
+ return ((lm_loss,) + output) if lm_loss is not None else output
1395
+
1396
+ return CausalLMOutputWithCrossAttentions(
1397
+ loss=lm_loss,
1398
+ logits=prediction_scores,
1399
+ past_key_values=outputs.past_key_values,
1400
+ hidden_states=outputs.hidden_states,
1401
+ attentions=outputs.attentions,
1402
+ cross_attentions=outputs.cross_attentions,
1403
+ )
1404
+
1405
+ def prepare_inputs_for_generation(
1406
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
1407
+ ):
1408
+ input_shape = input_ids.shape
1409
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1410
+ if attention_mask is None:
1411
+ attention_mask = input_ids.new_ones(input_shape)
1412
+
1413
+ # cut decoder_input_ids if past_key_values is used
1414
+ if past_key_values is not None:
1415
+ past_length = past_key_values[0][0].shape[2]
1416
+
1417
+ # Some generation methods already pass only the last input ID
1418
+ if input_ids.shape[1] > past_length:
1419
+ remove_prefix_length = past_length
1420
+ else:
1421
+ # Default to old behavior: keep only final ID
1422
+ remove_prefix_length = input_ids.shape[1] - 1
1423
+
1424
+ input_ids = input_ids[:, remove_prefix_length:]
1425
+
1426
+ return {
1427
+ "input_ids": input_ids,
1428
+ "attention_mask": attention_mask,
1429
+ "past_key_values": past_key_values,
1430
+ "use_cache": use_cache,
1431
+ }
1432
+
1433
+ def _reorder_cache(self, past_key_values, beam_idx):
1434
+ reordered_past = ()
1435
+ for layer_past in past_key_values:
1436
+ reordered_past += (
1437
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1438
+ )
1439
+ return reordered_past
1440
+
1441
+
1442
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1443
+ class BertForMaskedLM(BertPreTrainedModel):
1444
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1445
+
1446
+ def __init__(self, config):
1447
+ super().__init__(config)
1448
+
1449
+ if config.is_decoder:
1450
+ logger.warning(
1451
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1452
+ "bi-directional self-attention."
1453
+ )
1454
+
1455
+ self.bert = BertModel(config, add_pooling_layer=False)
1456
+ self.cls = BertOnlyMLMHead(config)
1457
+
1458
+ # Initialize weights and apply final processing
1459
+ self.post_init()
1460
+
1461
+ def get_output_embeddings(self):
1462
+ return self.cls.predictions.decoder
1463
+
1464
+ def set_output_embeddings(self, new_embeddings):
1465
+ self.cls.predictions.decoder = new_embeddings
1466
+ self.cls.predictions.bias = new_embeddings.bias
1467
+
1468
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1469
+ @add_code_sample_docstrings(
1470
+ checkpoint=_CHECKPOINT_FOR_DOC,
1471
+ output_type=MaskedLMOutput,
1472
+ config_class=_CONFIG_FOR_DOC,
1473
+ expected_output="'paris'",
1474
+ expected_loss=0.88,
1475
+ )
1476
+ def forward(
1477
+ self,
1478
+ input_ids: Optional[torch.Tensor] = None,
1479
+ attention_mask: Optional[torch.Tensor] = None,
1480
+ token_type_ids: Optional[torch.Tensor] = None,
1481
+ position_ids: Optional[torch.Tensor] = None,
1482
+ head_mask: Optional[torch.Tensor] = None,
1483
+ inputs_embeds: Optional[torch.Tensor] = None,
1484
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1485
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1486
+ labels: Optional[torch.Tensor] = None,
1487
+ output_attentions: Optional[bool] = None,
1488
+ output_hidden_states: Optional[bool] = None,
1489
+ return_dict: Optional[bool] = None,
1490
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1491
+ r"""
1492
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1493
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1494
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1495
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1496
+ """
1497
+
1498
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1499
+
1500
+ outputs = self.bert(
1501
+ input_ids,
1502
+ attention_mask=attention_mask,
1503
+ token_type_ids=token_type_ids,
1504
+ position_ids=position_ids,
1505
+ head_mask=head_mask,
1506
+ inputs_embeds=inputs_embeds,
1507
+ encoder_hidden_states=encoder_hidden_states,
1508
+ encoder_attention_mask=encoder_attention_mask,
1509
+ output_attentions=output_attentions,
1510
+ output_hidden_states=output_hidden_states,
1511
+ return_dict=return_dict,
1512
+ )
1513
+
1514
+ sequence_output = outputs[0]
1515
+ prediction_scores = self.cls(sequence_output)
1516
+
1517
+ masked_lm_loss = None
1518
+ if labels is not None:
1519
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1520
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1521
+
1522
+ if not return_dict:
1523
+ output = (prediction_scores,) + outputs[2:]
1524
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1525
+
1526
+ return MaskedLMOutput(
1527
+ loss=masked_lm_loss,
1528
+ logits=prediction_scores,
1529
+ hidden_states=outputs.hidden_states,
1530
+ attentions=outputs.attentions,
1531
+ )
1532
+
1533
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1534
+ input_shape = input_ids.shape
1535
+ effective_batch_size = input_shape[0]
1536
+
1537
+ # add a dummy token
1538
+ if self.config.pad_token_id is None:
1539
+ raise ValueError("The PAD token should be defined for generation")
1540
+
1541
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1542
+ dummy_token = torch.full(
1543
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1544
+ )
1545
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1546
+
1547
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1548
+
1549
+
1550
+ @add_start_docstrings(
1551
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1552
+ BERT_START_DOCSTRING,
1553
+ )
1554
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1555
+ def __init__(self, config):
1556
+ super().__init__(config)
1557
+
1558
+ self.bert = BertModel(config)
1559
+ self.cls = BertOnlyNSPHead(config)
1560
+
1561
+ # Initialize weights and apply final processing
1562
+ self.post_init()
1563
+
1564
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1565
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1566
+ def forward(
1567
+ self,
1568
+ input_ids: Optional[torch.Tensor] = None,
1569
+ attention_mask: Optional[torch.Tensor] = None,
1570
+ token_type_ids: Optional[torch.Tensor] = None,
1571
+ position_ids: Optional[torch.Tensor] = None,
1572
+ head_mask: Optional[torch.Tensor] = None,
1573
+ inputs_embeds: Optional[torch.Tensor] = None,
1574
+ labels: Optional[torch.Tensor] = None,
1575
+ output_attentions: Optional[bool] = None,
1576
+ output_hidden_states: Optional[bool] = None,
1577
+ return_dict: Optional[bool] = None,
1578
+ **kwargs,
1579
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1580
+ r"""
1581
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1582
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1583
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1584
+
1585
+ - 0 indicates sequence B is a continuation of sequence A,
1586
+ - 1 indicates sequence B is a random sequence.
1587
+
1588
+ Returns:
1589
+
1590
+ Example:
1591
+
1592
+ ```python
1593
+ >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
1594
+ >>> import torch
1595
+
1596
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1597
+ >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
1598
+
1599
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1600
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1601
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1602
+
1603
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1604
+ >>> logits = outputs.logits
1605
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1606
+ ```
1607
+ """
1608
+
1609
+ if "next_sentence_label" in kwargs:
1610
+ warnings.warn(
1611
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1612
+ " `labels` instead.",
1613
+ FutureWarning,
1614
+ )
1615
+ labels = kwargs.pop("next_sentence_label")
1616
+
1617
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1618
+
1619
+ outputs = self.bert(
1620
+ input_ids,
1621
+ attention_mask=attention_mask,
1622
+ token_type_ids=token_type_ids,
1623
+ position_ids=position_ids,
1624
+ head_mask=head_mask,
1625
+ inputs_embeds=inputs_embeds,
1626
+ output_attentions=output_attentions,
1627
+ output_hidden_states=output_hidden_states,
1628
+ return_dict=return_dict,
1629
+ )
1630
+
1631
+ pooled_output = outputs[1]
1632
+
1633
+ seq_relationship_scores = self.cls(pooled_output)
1634
+
1635
+ next_sentence_loss = None
1636
+ if labels is not None:
1637
+ loss_fct = CrossEntropyLoss()
1638
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1639
+
1640
+ if not return_dict:
1641
+ output = (seq_relationship_scores,) + outputs[2:]
1642
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1643
+
1644
+ return NextSentencePredictorOutput(
1645
+ loss=next_sentence_loss,
1646
+ logits=seq_relationship_scores,
1647
+ hidden_states=outputs.hidden_states,
1648
+ attentions=outputs.attentions,
1649
+ )
1650
+
1651
+
1652
+ @add_start_docstrings(
1653
+ """
1654
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1655
+ output) e.g. for GLUE tasks.
1656
+ """,
1657
+ BERT_START_DOCSTRING,
1658
+ )
1659
+ class BertForSequenceClassification(BertPreTrainedModel):
1660
+ def __init__(self, config):
1661
+ super().__init__(config)
1662
+ self.num_labels = config.num_labels
1663
+ self.config = config
1664
+
1665
+ self.bert = BertModel(config)
1666
+ classifier_dropout = (
1667
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1668
+ )
1669
+ self.dropout = nn.Dropout(classifier_dropout)
1670
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1671
+
1672
+ # Initialize weights and apply final processing
1673
+ self.post_init()
1674
+
1675
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1676
+ @add_code_sample_docstrings(
1677
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1678
+ output_type=SequenceClassifierOutput,
1679
+ config_class=_CONFIG_FOR_DOC,
1680
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1681
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1682
+ )
1683
+ def forward(
1684
+ self,
1685
+ input_ids: Optional[torch.Tensor] = None,
1686
+ attention_mask: Optional[torch.Tensor] = None,
1687
+ token_type_ids: Optional[torch.Tensor] = None,
1688
+ position_ids: Optional[torch.Tensor] = None,
1689
+ head_mask: Optional[torch.Tensor] = None,
1690
+ inputs_embeds: Optional[torch.Tensor] = None,
1691
+ labels: Optional[torch.Tensor] = None,
1692
+ output_attentions: Optional[bool] = None,
1693
+ output_hidden_states: Optional[bool] = None,
1694
+ return_dict: Optional[bool] = None,
1695
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1696
+ r"""
1697
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1698
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1699
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1700
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1701
+ """
1702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1703
+
1704
+ outputs = self.bert(
1705
+ input_ids,
1706
+ attention_mask=attention_mask,
1707
+ token_type_ids=token_type_ids,
1708
+ position_ids=position_ids,
1709
+ head_mask=head_mask,
1710
+ inputs_embeds=inputs_embeds,
1711
+ output_attentions=output_attentions,
1712
+ output_hidden_states=output_hidden_states,
1713
+ return_dict=return_dict,
1714
+ )
1715
+
1716
+ pooled_output = outputs[1]
1717
+
1718
+ pooled_output = self.dropout(pooled_output)
1719
+ logits = self.classifier(pooled_output)
1720
+
1721
+ loss = None
1722
+ if labels is not None:
1723
+ if self.config.problem_type is None:
1724
+ if self.num_labels == 1:
1725
+ self.config.problem_type = "regression"
1726
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1727
+ self.config.problem_type = "single_label_classification"
1728
+ else:
1729
+ self.config.problem_type = "multi_label_classification"
1730
+
1731
+ if self.config.problem_type == "regression":
1732
+ loss_fct = MSELoss()
1733
+ if self.num_labels == 1:
1734
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1735
+ else:
1736
+ loss = loss_fct(logits, labels)
1737
+ elif self.config.problem_type == "single_label_classification":
1738
+ loss_fct = CrossEntropyLoss()
1739
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1740
+ elif self.config.problem_type == "multi_label_classification":
1741
+ loss_fct = BCEWithLogitsLoss()
1742
+ loss = loss_fct(logits, labels)
1743
+ if not return_dict:
1744
+ output = (logits,) + outputs[2:]
1745
+ return ((loss,) + output) if loss is not None else output
1746
+
1747
+ return SequenceClassifierOutput(
1748
+ loss=loss,
1749
+ logits=logits,
1750
+ hidden_states=outputs.hidden_states,
1751
+ attentions=outputs.attentions,
1752
+ )
1753
+
1754
+
1755
+ @add_start_docstrings(
1756
+ """
1757
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1758
+ softmax) e.g. for RocStories/SWAG tasks.
1759
+ """,
1760
+ BERT_START_DOCSTRING,
1761
+ )
1762
+ class BertForMultipleChoice(BertPreTrainedModel):
1763
+ def __init__(self, config):
1764
+ super().__init__(config)
1765
+
1766
+ self.bert = BertModel(config)
1767
+ classifier_dropout = (
1768
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1769
+ )
1770
+ self.dropout = nn.Dropout(classifier_dropout)
1771
+ self.classifier = nn.Linear(config.hidden_size, 1)
1772
+
1773
+ # Initialize weights and apply final processing
1774
+ self.post_init()
1775
+
1776
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1777
+ @add_code_sample_docstrings(
1778
+ checkpoint=_CHECKPOINT_FOR_DOC,
1779
+ output_type=MultipleChoiceModelOutput,
1780
+ config_class=_CONFIG_FOR_DOC,
1781
+ )
1782
+ def forward(
1783
+ self,
1784
+ input_ids: Optional[torch.Tensor] = None,
1785
+ attention_mask: Optional[torch.Tensor] = None,
1786
+ token_type_ids: Optional[torch.Tensor] = None,
1787
+ position_ids: Optional[torch.Tensor] = None,
1788
+ head_mask: Optional[torch.Tensor] = None,
1789
+ inputs_embeds: Optional[torch.Tensor] = None,
1790
+ labels: Optional[torch.Tensor] = None,
1791
+ output_attentions: Optional[bool] = None,
1792
+ output_hidden_states: Optional[bool] = None,
1793
+ return_dict: Optional[bool] = None,
1794
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1795
+ r"""
1796
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1797
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1798
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1799
+ `input_ids` above)
1800
+ """
1801
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1802
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1803
+
1804
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1805
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1806
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1807
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1808
+ inputs_embeds = (
1809
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1810
+ if inputs_embeds is not None
1811
+ else None
1812
+ )
1813
+
1814
+ outputs = self.bert(
1815
+ input_ids,
1816
+ attention_mask=attention_mask,
1817
+ token_type_ids=token_type_ids,
1818
+ position_ids=position_ids,
1819
+ head_mask=head_mask,
1820
+ inputs_embeds=inputs_embeds,
1821
+ output_attentions=output_attentions,
1822
+ output_hidden_states=output_hidden_states,
1823
+ return_dict=return_dict,
1824
+ )
1825
+
1826
+ pooled_output = outputs[1]
1827
+
1828
+ pooled_output = self.dropout(pooled_output)
1829
+ logits = self.classifier(pooled_output)
1830
+ reshaped_logits = logits.view(-1, num_choices)
1831
+
1832
+ loss = None
1833
+ if labels is not None:
1834
+ loss_fct = CrossEntropyLoss()
1835
+ loss = loss_fct(reshaped_logits, labels)
1836
+
1837
+ if not return_dict:
1838
+ output = (reshaped_logits,) + outputs[2:]
1839
+ return ((loss,) + output) if loss is not None else output
1840
+
1841
+ return MultipleChoiceModelOutput(
1842
+ loss=loss,
1843
+ logits=reshaped_logits,
1844
+ hidden_states=outputs.hidden_states,
1845
+ attentions=outputs.attentions,
1846
+ )
1847
+
1848
+
1849
+ @add_start_docstrings(
1850
+ """
1851
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1852
+ Named-Entity-Recognition (NER) tasks.
1853
+ """,
1854
+ BERT_START_DOCSTRING,
1855
+ )
1856
+ class BertForTokenClassification(BertPreTrainedModel):
1857
+ def __init__(self, config):
1858
+ super().__init__(config)
1859
+ self.num_labels = config.num_labels
1860
+
1861
+ self.bert = BertModel(config, add_pooling_layer=False)
1862
+ classifier_dropout = (
1863
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1864
+ )
1865
+ self.dropout = nn.Dropout(classifier_dropout)
1866
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1867
+
1868
+ # Initialize weights and apply final processing
1869
+ self.post_init()
1870
+
1871
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1872
+ @add_code_sample_docstrings(
1873
+ checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
1874
+ output_type=TokenClassifierOutput,
1875
+ config_class=_CONFIG_FOR_DOC,
1876
+ expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
1877
+ expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
1878
+ )
1879
+ def forward(
1880
+ self,
1881
+ input_ids: Optional[torch.Tensor] = None,
1882
+ attention_mask: Optional[torch.Tensor] = None,
1883
+ token_type_ids: Optional[torch.Tensor] = None,
1884
+ position_ids: Optional[torch.Tensor] = None,
1885
+ head_mask: Optional[torch.Tensor] = None,
1886
+ inputs_embeds: Optional[torch.Tensor] = None,
1887
+ labels: Optional[torch.Tensor] = None,
1888
+ output_attentions: Optional[bool] = None,
1889
+ output_hidden_states: Optional[bool] = None,
1890
+ return_dict: Optional[bool] = None,
1891
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1892
+ r"""
1893
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1894
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1895
+ """
1896
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1897
+
1898
+ outputs = self.bert(
1899
+ input_ids,
1900
+ attention_mask=attention_mask,
1901
+ token_type_ids=token_type_ids,
1902
+ position_ids=position_ids,
1903
+ head_mask=head_mask,
1904
+ inputs_embeds=inputs_embeds,
1905
+ output_attentions=output_attentions,
1906
+ output_hidden_states=output_hidden_states,
1907
+ return_dict=return_dict,
1908
+ )
1909
+
1910
+ sequence_output = outputs[0]
1911
+
1912
+ sequence_output = self.dropout(sequence_output)
1913
+ logits = self.classifier(sequence_output)
1914
+
1915
+ loss = None
1916
+ if labels is not None:
1917
+ loss_fct = CrossEntropyLoss()
1918
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1919
+
1920
+ if not return_dict:
1921
+ output = (logits,) + outputs[2:]
1922
+ return ((loss,) + output) if loss is not None else output
1923
+
1924
+ return TokenClassifierOutput(
1925
+ loss=loss,
1926
+ logits=logits,
1927
+ hidden_states=outputs.hidden_states,
1928
+ attentions=outputs.attentions,
1929
+ )
1930
+
1931
+
1932
+ @add_start_docstrings(
1933
+ """
1934
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1935
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1936
+ """,
1937
+ BERT_START_DOCSTRING,
1938
+ )
1939
+ class BertForQuestionAnswering(BertPreTrainedModel):
1940
+ def __init__(self, config):
1941
+ super().__init__(config)
1942
+ self.num_labels = config.num_labels
1943
+
1944
+ self.bert = BertModel(config, add_pooling_layer=False)
1945
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1946
+
1947
+ # Initialize weights and apply final processing
1948
+ self.post_init()
1949
+
1950
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1951
+ @add_code_sample_docstrings(
1952
+ checkpoint=_CHECKPOINT_FOR_QA,
1953
+ output_type=QuestionAnsweringModelOutput,
1954
+ config_class=_CONFIG_FOR_DOC,
1955
+ qa_target_start_index=_QA_TARGET_START_INDEX,
1956
+ qa_target_end_index=_QA_TARGET_END_INDEX,
1957
+ expected_output=_QA_EXPECTED_OUTPUT,
1958
+ expected_loss=_QA_EXPECTED_LOSS,
1959
+ )
1960
+ def forward(
1961
+ self,
1962
+ input_ids: Optional[torch.Tensor] = None,
1963
+ attention_mask: Optional[torch.Tensor] = None,
1964
+ token_type_ids: Optional[torch.Tensor] = None,
1965
+ position_ids: Optional[torch.Tensor] = None,
1966
+ head_mask: Optional[torch.Tensor] = None,
1967
+ inputs_embeds: Optional[torch.Tensor] = None,
1968
+ start_positions: Optional[torch.Tensor] = None,
1969
+ end_positions: Optional[torch.Tensor] = None,
1970
+ output_attentions: Optional[bool] = None,
1971
+ output_hidden_states: Optional[bool] = None,
1972
+ return_dict: Optional[bool] = None,
1973
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1974
+ r"""
1975
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1976
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1977
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1978
+ are not taken into account for computing the loss.
1979
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1980
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1981
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1982
+ are not taken into account for computing the loss.
1983
+ """
1984
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1985
+
1986
+ outputs = self.bert(
1987
+ input_ids,
1988
+ attention_mask=attention_mask,
1989
+ token_type_ids=token_type_ids,
1990
+ position_ids=position_ids,
1991
+ head_mask=head_mask,
1992
+ inputs_embeds=inputs_embeds,
1993
+ output_attentions=output_attentions,
1994
+ output_hidden_states=output_hidden_states,
1995
+ return_dict=return_dict,
1996
+ )
1997
+
1998
+ sequence_output = outputs[0]
1999
+
2000
+ logits = self.qa_outputs(sequence_output)
2001
+ start_logits, end_logits = logits.split(1, dim=-1)
2002
+ start_logits = start_logits.squeeze(-1).contiguous()
2003
+ end_logits = end_logits.squeeze(-1).contiguous()
2004
+
2005
+ total_loss = None
2006
+ if start_positions is not None and end_positions is not None:
2007
+ # If we are on multi-GPU, split add a dimension
2008
+ if len(start_positions.size()) > 1:
2009
+ start_positions = start_positions.squeeze(-1)
2010
+ if len(end_positions.size()) > 1:
2011
+ end_positions = end_positions.squeeze(-1)
2012
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2013
+ ignored_index = start_logits.size(1)
2014
+ start_positions = start_positions.clamp(0, ignored_index)
2015
+ end_positions = end_positions.clamp(0, ignored_index)
2016
+
2017
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2018
+ start_loss = loss_fct(start_logits, start_positions)
2019
+ end_loss = loss_fct(end_logits, end_positions)
2020
+ total_loss = (start_loss + end_loss) / 2
2021
+
2022
+ if not return_dict:
2023
+ output = (start_logits, end_logits) + outputs[2:]
2024
+ return ((total_loss,) + output) if total_loss is not None else output
2025
+
2026
+ return QuestionAnsweringModelOutput(
2027
+ loss=total_loss,
2028
+ start_logits=start_logits,
2029
+ end_logits=end_logits,
2030
+ hidden_states=outputs.hidden_states,
2031
+ attentions=outputs.attentions,
2032
+ )