thehonestbob commited on
Commit
259a0a5
1 Parent(s): daf313f

Upload 10 files

Browse files
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "share_encoder_decoder_embeddings": true,
8
+ "encoder_normalize_embedding": false,
9
+ "decoder_normalize_embedding": false,
10
+ "architectures": [
11
+ "BartModel"
12
+ ],
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_bat.BartConfig",
15
+ "AutoModel": "modeling_bat.BartForConditionalGeneration",
16
+ "AutoModelForSeq2SeqLM": "modeling_bat.BartForConditionalGeneration"
17
+ },
18
+ "attention_dropout": 0.0,
19
+ "bos_token_id": 0,
20
+ "d_model": 1024,
21
+ "decoder_attention_heads": 16,
22
+ "decoder_ffn_dim": 4096,
23
+ "decoder_layerdrop": 0.05,
24
+ "decoder_layers": 6,
25
+ "decoder_start_token_id": 2,
26
+ "dropout": 0.1,
27
+ "early_stopping": true,
28
+ "encoder_attention_heads": 16,
29
+ "encoder_ffn_dim": 4096,
30
+ "encoder_layerdrop": 0.05,
31
+ "encoder_layers": 6,
32
+ "eos_token_id": 2,
33
+ "gradient_checkpointing": false,
34
+ "init_std": 0.02,
35
+ "is_encoder_decoder": true,
36
+ "max_length": 300,
37
+ "max_position_embeddings": 300,
38
+ "model_type": "bart",
39
+ "num_beams": 5,
40
+ "num_hidden_layers": 12,
41
+ "pad_token_id": 1,
42
+ "scale_embedding": true,
43
+ "transformers_version": "4.4.0.dev0",
44
+ "use_cache": true,
45
+ "forced_bos_token_id": 64870,
46
+ "vocab_size": 64871
47
+ }
configuration_bat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:cb
4
+ @contact:[email protected]
5
+ @time:2023/6/6 13:25
6
+ @filename:modeling.py
7
+ @software:PyCharm
8
+ @description:根据bart进行改写
9
+ """
10
+ from transformers.models.bart.configuration_bart import BartConfig
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_bat.py ADDED
@@ -0,0 +1,1951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:cb
4
+ @contact:[email protected]
5
+ @time:2023/6/6 13:25
6
+ @filename:modeling.py
7
+ @software:PyCharm
8
+ @description:根据bart进行改写
9
+ """
10
+ import copy
11
+ import math
12
+ import random
13
+ import warnings
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ from torch import nn
19
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
20
+
21
+ from transformers.activations import ACT2FN
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutput,
24
+ BaseModelOutputWithPastAndCrossAttentions,
25
+ CausalLMOutputWithCrossAttentions,
26
+ Seq2SeqLMOutput,
27
+ Seq2SeqModelOutput,
28
+ Seq2SeqQuestionAnsweringModelOutput,
29
+ Seq2SeqSequenceClassifierOutput,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (
33
+ add_code_sample_docstrings,
34
+ add_end_docstrings,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.models.bart.configuration_bart import BartConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "facebook/bart-base"
46
+ _CONFIG_FOR_DOC = "BartConfig"
47
+
48
+ # Base model docstring
49
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
50
+
51
+ # SequenceClassification docstring
52
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
53
+ _SEQ_CLASS_EXPECTED_LOSS = 0.0
54
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
55
+
56
+ # QuestionAsnwering docstring
57
+ _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
58
+ _QA_EXPECTED_LOSS = 0.59
59
+ _QA_EXPECTED_OUTPUT = "' nice puppet'"
60
+
61
+
62
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/bart-large",
64
+ # see all BART models at https://huggingface.co/models?filter=bart
65
+ ]
66
+
67
+
68
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
69
+ """
70
+ Shift input ids one token to the right.
71
+ """
72
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
73
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
74
+ shifted_input_ids[:, 0] = decoder_start_token_id
75
+
76
+ if pad_token_id is None:
77
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
78
+ # replace possible -100 values in labels by `pad_token_id`
79
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
80
+
81
+ return shifted_input_ids
82
+
83
+
84
+ def _make_causal_mask(
85
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
86
+ ):
87
+ """
88
+ Make causal mask used for bi-directional self-attention.
89
+ """
90
+ bsz, tgt_len = input_ids_shape
91
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
92
+ mask_cond = torch.arange(mask.size(-1), device=device)
93
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
94
+ mask = mask.to(dtype)
95
+
96
+ if past_key_values_length > 0:
97
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
98
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
99
+
100
+
101
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
102
+ """
103
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
104
+ """
105
+ bsz, src_len = mask.size()
106
+ tgt_len = tgt_len if tgt_len is not None else src_len
107
+
108
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
109
+
110
+ inverted_mask = 1.0 - expanded_mask
111
+
112
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
113
+
114
+
115
+ class BartLearnedPositionalEmbedding(nn.Embedding):
116
+ """
117
+ This module learns positional embeddings up to a fixed maximum size.
118
+ """
119
+
120
+ def __init__(self, num_embeddings: int, embedding_dim: int):
121
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
122
+ # and adjust num_embeddings appropriately. Other models don't have this hack
123
+ self.offset = 2
124
+ super().__init__(num_embeddings + self.offset, embedding_dim)
125
+
126
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
127
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
128
+
129
+ bsz, seq_len = input_ids.shape[:2]
130
+ positions = torch.arange(
131
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
132
+ ).expand(bsz, -1)
133
+
134
+ return super().forward(positions + self.offset)
135
+
136
+
137
+ class BartAttention(nn.Module):
138
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
139
+
140
+ def __init__(
141
+ self,
142
+ embed_dim: int,
143
+ num_heads: int,
144
+ dropout: float = 0.0,
145
+ is_decoder: bool = False,
146
+ bias: bool = True,
147
+ ):
148
+ super().__init__()
149
+ self.embed_dim = embed_dim
150
+ self.num_heads = num_heads
151
+ self.dropout = dropout
152
+ self.head_dim = embed_dim // num_heads
153
+
154
+ if (self.head_dim * num_heads) != self.embed_dim:
155
+ raise ValueError(
156
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
157
+ f" and `num_heads`: {num_heads})."
158
+ )
159
+ self.scaling = self.head_dim**-0.5
160
+ self.is_decoder = is_decoder
161
+
162
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
163
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
164
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
165
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
166
+
167
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
168
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ key_value_states: Optional[torch.Tensor] = None,
174
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ layer_head_mask: Optional[torch.Tensor] = None,
177
+ output_attentions: bool = False,
178
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
179
+ """Input shape: Batch x Time x Channel"""
180
+
181
+ # if key_value_states are provided this layer is used as a cross-attention layer
182
+ # for the decoder
183
+ is_cross_attention = key_value_states is not None
184
+
185
+ bsz, tgt_len, _ = hidden_states.size()
186
+
187
+ # get query proj
188
+ query_states = self.q_proj(hidden_states) * self.scaling
189
+ # get key, value proj
190
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
191
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
192
+ # the provided `key_value_states` to support prefix tuning
193
+ if (
194
+ is_cross_attention
195
+ and past_key_value is not None
196
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
197
+ ):
198
+ # reuse k,v, cross_attentions
199
+ key_states = past_key_value[0]
200
+ value_states = past_key_value[1]
201
+ elif is_cross_attention:
202
+ # cross_attentions
203
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
204
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
205
+ elif past_key_value is not None:
206
+ # reuse k, v, self_attention
207
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
208
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
209
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
+ else:
212
+ # self_attention
213
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
214
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
215
+
216
+ if self.is_decoder:
217
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
218
+ # Further calls to cross_attention layer can then reuse all cross-attention
219
+ # key/value_states (first "if" case)
220
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
221
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
222
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
223
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
224
+ past_key_value = (key_states, value_states)
225
+
226
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
227
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
228
+ key_states = key_states.reshape(*proj_shape)
229
+ value_states = value_states.reshape(*proj_shape)
230
+
231
+ src_len = key_states.size(1)
232
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
233
+
234
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
235
+ raise ValueError(
236
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
237
+ f" {attn_weights.size()}"
238
+ )
239
+
240
+ if attention_mask is not None:
241
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
242
+ raise ValueError(
243
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
244
+ )
245
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
246
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
247
+
248
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
249
+
250
+ if layer_head_mask is not None:
251
+ if layer_head_mask.size() != (self.num_heads,):
252
+ raise ValueError(
253
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
254
+ f" {layer_head_mask.size()}"
255
+ )
256
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
257
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
258
+
259
+ if output_attentions:
260
+ # this operation is a bit awkward, but it's required to
261
+ # make sure that attn_weights keeps its gradient.
262
+ # In order to do so, attn_weights have to be reshaped
263
+ # twice and have to be reused in the following
264
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
265
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
266
+ else:
267
+ attn_weights_reshaped = None
268
+
269
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
270
+
271
+ attn_output = torch.bmm(attn_probs, value_states)
272
+
273
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
274
+ raise ValueError(
275
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
276
+ f" {attn_output.size()}"
277
+ )
278
+
279
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
280
+ attn_output = attn_output.transpose(1, 2)
281
+
282
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
283
+ # partitioned across GPUs when using tensor-parallelism.
284
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
285
+
286
+ attn_output = self.out_proj(attn_output)
287
+
288
+ return attn_output, attn_weights_reshaped, past_key_value
289
+
290
+
291
+ class BartEncoderLayer(nn.Module):
292
+ def __init__(self, config: BartConfig):
293
+ super().__init__()
294
+ self.embed_dim = config.d_model
295
+ self.self_attn = BartAttention(
296
+ embed_dim=self.embed_dim,
297
+ num_heads=config.encoder_attention_heads,
298
+ dropout=config.attention_dropout,
299
+ )
300
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
301
+ self.dropout = config.dropout
302
+ self.activation_fn = ACT2FN[config.activation_function]
303
+ self.activation_dropout = config.activation_dropout
304
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
305
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
306
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states: torch.FloatTensor,
311
+ attention_mask: torch.FloatTensor,
312
+ layer_head_mask: torch.FloatTensor,
313
+ output_attentions: Optional[bool] = False,
314
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
315
+ """
316
+ Args:
317
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
318
+ attention_mask (`torch.FloatTensor`): attention mask of size
319
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
320
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
321
+ `(encoder_attention_heads,)`.
322
+ output_attentions (`bool`, *optional*):
323
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
324
+ returned tensors for more detail.
325
+ """
326
+ residual = hidden_states
327
+ hidden_states, attn_weights, _ = self.self_attn(
328
+ hidden_states=hidden_states,
329
+ attention_mask=attention_mask,
330
+ layer_head_mask=layer_head_mask,
331
+ output_attentions=output_attentions,
332
+ )
333
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
334
+ hidden_states = residual + hidden_states
335
+ hidden_states = self.self_attn_layer_norm(hidden_states)
336
+
337
+ residual = hidden_states
338
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
339
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
340
+ hidden_states = self.fc2(hidden_states)
341
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
342
+ hidden_states = residual + hidden_states
343
+ hidden_states = self.final_layer_norm(hidden_states)
344
+
345
+ if hidden_states.dtype == torch.float16 and (
346
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
347
+ ):
348
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
349
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
350
+
351
+ outputs = (hidden_states,)
352
+
353
+ if output_attentions:
354
+ outputs += (attn_weights,)
355
+
356
+ return outputs
357
+
358
+
359
+ class BartDecoderLayer(nn.Module):
360
+ def __init__(self, config: BartConfig):
361
+ super().__init__()
362
+ self.embed_dim = config.d_model
363
+
364
+ self.self_attn = BartAttention(
365
+ embed_dim=self.embed_dim,
366
+ num_heads=config.decoder_attention_heads,
367
+ dropout=config.attention_dropout,
368
+ is_decoder=True,
369
+ )
370
+ self.dropout = config.dropout
371
+ self.activation_fn = ACT2FN[config.activation_function]
372
+ self.activation_dropout = config.activation_dropout
373
+
374
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
375
+ self.encoder_attn = BartAttention(
376
+ self.embed_dim,
377
+ config.decoder_attention_heads,
378
+ dropout=config.attention_dropout,
379
+ is_decoder=True,
380
+ )
381
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
382
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
383
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
384
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ encoder_hidden_states: Optional[torch.Tensor] = None,
391
+ encoder_attention_mask: Optional[torch.Tensor] = None,
392
+ layer_head_mask: Optional[torch.Tensor] = None,
393
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
394
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
395
+ output_attentions: Optional[bool] = False,
396
+ use_cache: Optional[bool] = True,
397
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
398
+ """
399
+ Args:
400
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
401
+ attention_mask (`torch.FloatTensor`): attention mask of size
402
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
403
+ encoder_hidden_states (`torch.FloatTensor`):
404
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
405
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
406
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
407
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
408
+ `(encoder_attention_heads,)`.
409
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
410
+ size `(decoder_attention_heads,)`.
411
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
412
+ output_attentions (`bool`, *optional*):
413
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
414
+ returned tensors for more detail.
415
+ """
416
+ residual = hidden_states
417
+
418
+ # Self Attention
419
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
420
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
421
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
422
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
423
+ hidden_states=hidden_states,
424
+ past_key_value=self_attn_past_key_value,
425
+ attention_mask=attention_mask,
426
+ layer_head_mask=layer_head_mask,
427
+ output_attentions=output_attentions,
428
+ )
429
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
430
+ hidden_states = residual + hidden_states
431
+ hidden_states = self.self_attn_layer_norm(hidden_states)
432
+
433
+ # Cross-Attention Block
434
+ cross_attn_present_key_value = None
435
+ cross_attn_weights = None
436
+ if encoder_hidden_states is not None:
437
+ residual = hidden_states
438
+
439
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
440
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
441
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
442
+ hidden_states=hidden_states,
443
+ key_value_states=encoder_hidden_states,
444
+ attention_mask=encoder_attention_mask,
445
+ layer_head_mask=cross_attn_layer_head_mask,
446
+ past_key_value=cross_attn_past_key_value,
447
+ output_attentions=output_attentions,
448
+ )
449
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
450
+ hidden_states = residual + hidden_states
451
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
452
+
453
+ # add cross-attn to positions 3,4 of present_key_value tuple
454
+ present_key_value = present_key_value + cross_attn_present_key_value
455
+
456
+ # Fully Connected
457
+ residual = hidden_states
458
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
459
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
460
+ hidden_states = self.fc2(hidden_states)
461
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
462
+ hidden_states = residual + hidden_states
463
+ hidden_states = self.final_layer_norm(hidden_states)
464
+
465
+ outputs = (hidden_states,)
466
+
467
+ if output_attentions:
468
+ outputs += (self_attn_weights, cross_attn_weights)
469
+
470
+ if use_cache:
471
+ outputs += (present_key_value,)
472
+
473
+ return outputs
474
+
475
+
476
+ class BartClassificationHead(nn.Module):
477
+ """Head for sentence-level classification tasks."""
478
+
479
+ def __init__(
480
+ self,
481
+ input_dim: int,
482
+ inner_dim: int,
483
+ num_classes: int,
484
+ pooler_dropout: float,
485
+ ):
486
+ super().__init__()
487
+ self.dense = nn.Linear(input_dim, inner_dim)
488
+ self.dropout = nn.Dropout(p=pooler_dropout)
489
+ self.out_proj = nn.Linear(inner_dim, num_classes)
490
+
491
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
492
+ hidden_states = self.dropout(hidden_states)
493
+ hidden_states = self.dense(hidden_states)
494
+ hidden_states = torch.tanh(hidden_states)
495
+ hidden_states = self.dropout(hidden_states)
496
+ hidden_states = self.out_proj(hidden_states)
497
+ return hidden_states
498
+
499
+
500
+ class BartPretrainedModel(PreTrainedModel):
501
+ config_class = BartConfig
502
+ base_model_prefix = "model"
503
+ supports_gradient_checkpointing = True
504
+ _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
505
+ _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
506
+
507
+ def _init_weights(self, module):
508
+ std = self.config.init_std
509
+ if isinstance(module, nn.Linear):
510
+ module.weight.data.normal_(mean=0.0, std=std)
511
+ if module.bias is not None:
512
+ module.bias.data.zero_()
513
+ elif isinstance(module, nn.Embedding):
514
+ module.weight.data.normal_(mean=0.0, std=std)
515
+ if module.padding_idx is not None:
516
+ module.weight.data[module.padding_idx].zero_()
517
+
518
+ def _set_gradient_checkpointing(self, module, value=False):
519
+ if isinstance(module, (BartDecoder, BartEncoder)):
520
+ module.gradient_checkpointing = value
521
+
522
+ @property
523
+ def dummy_inputs(self):
524
+ pad_token = self.config.pad_token_id
525
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
526
+ dummy_inputs = {
527
+ "attention_mask": input_ids.ne(pad_token),
528
+ "input_ids": input_ids,
529
+ }
530
+ return dummy_inputs
531
+
532
+
533
+ class PretrainedBartModel(BartPretrainedModel):
534
+ def __init_subclass__(self):
535
+ warnings.warn(
536
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
537
+ FutureWarning,
538
+ )
539
+
540
+
541
+ BART_START_DOCSTRING = r"""
542
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
543
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
544
+ etc.)
545
+
546
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
547
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
548
+ and behavior.
549
+
550
+ Parameters:
551
+ config ([`BartConfig`]):
552
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
553
+ load the weights associated with the model, only the configuration. Check out the
554
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
555
+ """
556
+
557
+ BART_GENERATION_EXAMPLE = r"""
558
+ Summarization example:
559
+
560
+ ```python
561
+ >>> from transformers import AutoTokenizer, BartForConditionalGeneration
562
+
563
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
564
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
565
+
566
+ >>> ARTICLE_TO_SUMMARIZE = (
567
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
568
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
569
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
570
+ ... )
571
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
572
+
573
+ >>> # Generate Summary
574
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
575
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
576
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
577
+ ```
578
+
579
+ Mask filling example:
580
+
581
+ ```python
582
+ >>> from transformers import AutoTokenizer, BartForConditionalGeneration
583
+
584
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
585
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
586
+
587
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
588
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
589
+ >>> logits = model(input_ids).logits
590
+
591
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
592
+ >>> probs = logits[0, masked_index].softmax(dim=0)
593
+ >>> values, predictions = probs.topk(5)
594
+
595
+ >>> tokenizer.decode(predictions).split()
596
+ ['not', 'good', 'healthy', 'great', 'very']
597
+ ```
598
+ """
599
+
600
+ BART_INPUTS_DOCSTRING = r"""
601
+ Args:
602
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
603
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
604
+ it.
605
+
606
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
607
+ [`PreTrainedTokenizer.__call__`] for details.
608
+
609
+ [What are input IDs?](../glossary#input-ids)
610
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
611
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
612
+
613
+ - 1 for tokens that are **not masked**,
614
+ - 0 for tokens that are **masked**.
615
+
616
+ [What are attention masks?](../glossary#attention-mask)
617
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
618
+ Indices of decoder input sequence tokens in the vocabulary.
619
+
620
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
621
+ [`PreTrainedTokenizer.__call__`] for details.
622
+
623
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
624
+
625
+ Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
626
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
627
+
628
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
629
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
630
+ for denoising pre-training following the paper.
631
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
632
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
633
+ be used by default.
634
+
635
+ If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
636
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
637
+ information on the default strategy.
638
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
639
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
640
+
641
+ - 1 indicates the head is **not masked**,
642
+ - 0 indicates the head is **masked**.
643
+
644
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
645
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
646
+
647
+ - 1 indicates the head is **not masked**,
648
+ - 0 indicates the head is **masked**.
649
+
650
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
651
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
652
+ 1]`:
653
+
654
+ - 1 indicates the head is **not masked**,
655
+ - 0 indicates the head is **masked**.
656
+
657
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
658
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
659
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
660
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
661
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
662
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
663
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
664
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
665
+
666
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
667
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
668
+
669
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
670
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
671
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
672
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
673
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
674
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
675
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
676
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
677
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
678
+ input (see `past_key_values`). This is useful if you want more control over how to convert
679
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
680
+
681
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
682
+ of `inputs_embeds`.
683
+ use_cache (`bool`, *optional*):
684
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
685
+ `past_key_values`).
686
+ output_attentions (`bool`, *optional*):
687
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
688
+ tensors for more detail.
689
+ output_hidden_states (`bool`, *optional*):
690
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
691
+ more detail.
692
+ return_dict (`bool`, *optional*):
693
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
694
+ """
695
+
696
+
697
+ class BartEncoder(BartPretrainedModel):
698
+ """
699
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
700
+ [`BartEncoderLayer`].
701
+
702
+ Args:
703
+ config: BartConfig
704
+ embed_tokens (nn.Embedding): output embedding
705
+ """
706
+
707
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
708
+ super().__init__(config)
709
+
710
+ self.dropout = config.dropout
711
+ self.layerdrop = config.encoder_layerdrop
712
+
713
+ embed_dim = config.d_model
714
+ self.padding_idx = config.pad_token_id
715
+ self.max_source_positions = config.max_position_embeddings
716
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
717
+
718
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
719
+
720
+ if embed_tokens is not None:
721
+ self.embed_tokens.weight = embed_tokens.weight
722
+
723
+ self.embed_positions = BartLearnedPositionalEmbedding(
724
+ config.max_position_embeddings,
725
+ embed_dim,
726
+ )
727
+ self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
728
+ if config.encoder_normalize_embedding:
729
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
730
+ if config.add_final_layer_norm:
731
+ self.layer_norm = nn.LayerNorm(embed_dim)
732
+ self.gradient_checkpointing = False
733
+ # Initialize weights and apply final processing
734
+ self.post_init()
735
+
736
+ def get_input_embeddings(self):
737
+ return self.embed_tokens
738
+
739
+ def set_input_embeddings(self, value):
740
+ self.embed_tokens = value
741
+
742
+ def forward(
743
+ self,
744
+ input_ids: torch.LongTensor = None,
745
+ attention_mask: Optional[torch.Tensor] = None,
746
+ head_mask: Optional[torch.Tensor] = None,
747
+ inputs_embeds: Optional[torch.FloatTensor] = None,
748
+ output_attentions: Optional[bool] = None,
749
+ output_hidden_states: Optional[bool] = None,
750
+ return_dict: Optional[bool] = None,
751
+ ) -> Union[Tuple, BaseModelOutput]:
752
+ r"""
753
+ Args:
754
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
755
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
756
+ provide it.
757
+
758
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
759
+ [`PreTrainedTokenizer.__call__`] for details.
760
+
761
+ [What are input IDs?](../glossary#input-ids)
762
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
763
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
764
+
765
+ - 1 for tokens that are **not masked**,
766
+ - 0 for tokens that are **masked**.
767
+
768
+ [What are attention masks?](../glossary#attention-mask)
769
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
770
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
771
+
772
+ - 1 indicates the head is **not masked**,
773
+ - 0 indicates the head is **masked**.
774
+
775
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
776
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
777
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
778
+ than the model's internal embedding lookup matrix.
779
+ output_attentions (`bool`, *optional*):
780
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
781
+ returned tensors for more detail.
782
+ output_hidden_states (`bool`, *optional*):
783
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
784
+ for more detail.
785
+ return_dict (`bool`, *optional*):
786
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
787
+ """
788
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
789
+ output_hidden_states = (
790
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
791
+ )
792
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
793
+
794
+ # retrieve input_ids and inputs_embeds
795
+ if input_ids is not None and inputs_embeds is not None:
796
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
797
+ elif input_ids is not None:
798
+ input = input_ids
799
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
800
+ elif inputs_embeds is not None:
801
+ input = inputs_embeds[:, :, -1]
802
+ else:
803
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
804
+
805
+ if inputs_embeds is None:
806
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
807
+
808
+ embed_pos = self.embed_positions(input)
809
+ embed_pos = embed_pos.to(inputs_embeds.device)
810
+
811
+ hidden_states = inputs_embeds + embed_pos
812
+ if self.config.encoder_normalize_embedding:
813
+ hidden_states = self.layernorm_embedding(hidden_states)
814
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
815
+
816
+ # expand attention_mask
817
+ if attention_mask is not None:
818
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
819
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
820
+
821
+ encoder_states = () if output_hidden_states else None
822
+ all_attentions = () if output_attentions else None
823
+
824
+ # check if head_mask has a correct number of layers specified if desired
825
+ if head_mask is not None:
826
+ if head_mask.size()[0] != (len(self.layers)):
827
+ raise ValueError(
828
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
829
+ f" {head_mask.size()[0]}."
830
+ )
831
+
832
+ for idx, encoder_layer in enumerate(self.layers):
833
+ if output_hidden_states:
834
+ encoder_states = encoder_states + (hidden_states,)
835
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
836
+ dropout_probability = random.uniform(0, 1)
837
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
838
+ layer_outputs = (None, None)
839
+ else:
840
+ if self.gradient_checkpointing and self.training:
841
+
842
+ def create_custom_forward(module):
843
+ def custom_forward(*inputs):
844
+ return module(*inputs, output_attentions)
845
+
846
+ return custom_forward
847
+
848
+ layer_outputs = torch.utils.checkpoint.checkpoint(
849
+ create_custom_forward(encoder_layer),
850
+ hidden_states,
851
+ attention_mask,
852
+ (head_mask[idx] if head_mask is not None else None),
853
+ )
854
+ else:
855
+ layer_outputs = encoder_layer(
856
+ hidden_states,
857
+ attention_mask,
858
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
859
+ output_attentions=output_attentions,
860
+ )
861
+
862
+ hidden_states = layer_outputs[0]
863
+
864
+ if output_attentions:
865
+ all_attentions = all_attentions + (layer_outputs[1],)
866
+
867
+ if self.config.add_final_layer_norm:
868
+ hidden_states = self.layer_norm(hidden_states)
869
+
870
+ if output_hidden_states:
871
+ encoder_states = encoder_states + (hidden_states,)
872
+
873
+ if not return_dict:
874
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
875
+ return BaseModelOutput(
876
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
877
+ )
878
+
879
+
880
+ class BartDecoder(BartPretrainedModel):
881
+ """
882
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
883
+
884
+ Args:
885
+ config: BartConfig
886
+ embed_tokens (nn.Embedding): output embedding
887
+ """
888
+
889
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
890
+ super().__init__(config)
891
+ self.dropout = config.dropout
892
+ self.layerdrop = config.decoder_layerdrop
893
+ self.padding_idx = config.pad_token_id
894
+ self.max_target_positions = config.max_position_embeddings
895
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
896
+
897
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
898
+
899
+ if embed_tokens is not None:
900
+ self.embed_tokens.weight = embed_tokens.weight
901
+
902
+ self.embed_positions = BartLearnedPositionalEmbedding(
903
+ config.max_position_embeddings,
904
+ config.d_model,
905
+ )
906
+ self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
907
+ if config.decoder_normalize_embedding:
908
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
909
+ if config.add_final_layer_norm:
910
+ self.layer_norm = nn.LayerNorm(config.d_model)
911
+ self.gradient_checkpointing = False
912
+ # Initialize weights and apply final processing
913
+ self.post_init()
914
+
915
+ def get_input_embeddings(self):
916
+ return self.embed_tokens
917
+
918
+ def set_input_embeddings(self, value):
919
+ self.embed_tokens = value
920
+
921
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
922
+ # create causal mask
923
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
924
+ combined_attention_mask = None
925
+ if input_shape[-1] > 1:
926
+ combined_attention_mask = _make_causal_mask(
927
+ input_shape,
928
+ inputs_embeds.dtype,
929
+ device=inputs_embeds.device,
930
+ past_key_values_length=past_key_values_length,
931
+ )
932
+
933
+ if attention_mask is not None:
934
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
935
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
936
+ inputs_embeds.device
937
+ )
938
+ combined_attention_mask = (
939
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
940
+ )
941
+
942
+ return combined_attention_mask
943
+
944
+ def forward(
945
+ self,
946
+ input_ids: torch.LongTensor = None,
947
+ attention_mask: Optional[torch.Tensor] = None,
948
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
949
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
950
+ head_mask: Optional[torch.Tensor] = None,
951
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
952
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
953
+ inputs_embeds: Optional[torch.FloatTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ return_dict: Optional[bool] = None,
958
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
959
+ r"""
960
+ Args:
961
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
962
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
963
+ provide it.
964
+
965
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
966
+ [`PreTrainedTokenizer.__call__`] for details.
967
+
968
+ [What are input IDs?](../glossary#input-ids)
969
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
970
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
971
+
972
+ - 1 for tokens that are **not masked**,
973
+ - 0 for tokens that are **masked**.
974
+
975
+ [What are attention masks?](../glossary#attention-mask)
976
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
977
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
978
+ of the decoder.
979
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
980
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
981
+ selected in `[0, 1]`:
982
+
983
+ - 1 for tokens that are **not masked**,
984
+ - 0 for tokens that are **masked**.
985
+
986
+ [What are attention masks?](../glossary#attention-mask)
987
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
988
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
989
+
990
+ - 1 indicates the head is **not masked**,
991
+ - 0 indicates the head is **masked**.
992
+
993
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
994
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
995
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
996
+
997
+ - 1 indicates the head is **not masked**,
998
+ - 0 indicates the head is **masked**.
999
+
1000
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1001
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1002
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1003
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1004
+
1005
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1006
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1007
+
1008
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1009
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1010
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1011
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1012
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1013
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1014
+ embedding lookup matrix.
1015
+ output_attentions (`bool`, *optional*):
1016
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1017
+ returned tensors for more detail.
1018
+ output_hidden_states (`bool`, *optional*):
1019
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1020
+ for more detail.
1021
+ return_dict (`bool`, *optional*):
1022
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1023
+ """
1024
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1025
+ output_hidden_states = (
1026
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1027
+ )
1028
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1029
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1030
+
1031
+ # retrieve input_ids and inputs_embeds
1032
+ if input_ids is not None and inputs_embeds is not None:
1033
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1034
+ elif input_ids is not None:
1035
+ input = input_ids
1036
+ input_shape = input.shape
1037
+ input_ids = input_ids.view(-1, input_shape[-1])
1038
+ elif inputs_embeds is not None:
1039
+ input_shape = inputs_embeds.size()[:-1]
1040
+ input = inputs_embeds[:, :, -1]
1041
+ else:
1042
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1043
+
1044
+ # past_key_values_length
1045
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1046
+
1047
+ if inputs_embeds is None:
1048
+ inputs_embeds = self.embed_tokens(input) * self.embed_scale
1049
+
1050
+ attention_mask = self._prepare_decoder_attention_mask(
1051
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1052
+ )
1053
+
1054
+ # expand encoder attention mask
1055
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1056
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1057
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1058
+
1059
+ # embed positions
1060
+ positions = self.embed_positions(input, past_key_values_length)
1061
+ positions = positions.to(inputs_embeds.device)
1062
+
1063
+ hidden_states = inputs_embeds + positions
1064
+ if self.config.decoder_normalize_embedding:
1065
+ hidden_states = self.layernorm_embedding(hidden_states)
1066
+
1067
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1068
+
1069
+ if self.gradient_checkpointing and self.training:
1070
+ if use_cache:
1071
+ logger.warning_once(
1072
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1073
+ )
1074
+ use_cache = False
1075
+
1076
+ # decoder layers
1077
+ all_hidden_states = () if output_hidden_states else None
1078
+ all_self_attns = () if output_attentions else None
1079
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1080
+ next_decoder_cache = () if use_cache else None
1081
+
1082
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1083
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1084
+ if attn_mask is not None:
1085
+ if attn_mask.size()[0] != (len(self.layers)):
1086
+ raise ValueError(
1087
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1088
+ f" {head_mask.size()[0]}."
1089
+ )
1090
+
1091
+ for idx, decoder_layer in enumerate(self.layers):
1092
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+ dropout_probability = random.uniform(0, 1)
1096
+ if self.training and (dropout_probability < self.layerdrop):
1097
+ continue
1098
+
1099
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1100
+
1101
+ if self.gradient_checkpointing and self.training:
1102
+
1103
+ def create_custom_forward(module):
1104
+ def custom_forward(*inputs):
1105
+ # None for past_key_value
1106
+ return module(*inputs, output_attentions, use_cache)
1107
+
1108
+ return custom_forward
1109
+
1110
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1111
+ create_custom_forward(decoder_layer),
1112
+ hidden_states,
1113
+ attention_mask,
1114
+ encoder_hidden_states,
1115
+ encoder_attention_mask,
1116
+ head_mask[idx] if head_mask is not None else None,
1117
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1118
+ None,
1119
+ )
1120
+ else:
1121
+ layer_outputs = decoder_layer(
1122
+ hidden_states,
1123
+ attention_mask=attention_mask,
1124
+ encoder_hidden_states=encoder_hidden_states,
1125
+ encoder_attention_mask=encoder_attention_mask,
1126
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1127
+ cross_attn_layer_head_mask=(
1128
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1129
+ ),
1130
+ past_key_value=past_key_value,
1131
+ output_attentions=output_attentions,
1132
+ use_cache=use_cache,
1133
+ )
1134
+ hidden_states = layer_outputs[0]
1135
+
1136
+ if use_cache:
1137
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1138
+
1139
+ if output_attentions:
1140
+ all_self_attns += (layer_outputs[1],)
1141
+
1142
+ if encoder_hidden_states is not None:
1143
+ all_cross_attentions += (layer_outputs[2],)
1144
+ if self.config.add_final_layer_norm:
1145
+ hidden_states = self.layer_norm(hidden_states)
1146
+ # add hidden states from the last decoder layer
1147
+ if output_hidden_states:
1148
+ all_hidden_states += (hidden_states,)
1149
+
1150
+ next_cache = next_decoder_cache if use_cache else None
1151
+ if not return_dict:
1152
+ return tuple(
1153
+ v
1154
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1155
+ if v is not None
1156
+ )
1157
+ return BaseModelOutputWithPastAndCrossAttentions(
1158
+ last_hidden_state=hidden_states,
1159
+ past_key_values=next_cache,
1160
+ hidden_states=all_hidden_states,
1161
+ attentions=all_self_attns,
1162
+ cross_attentions=all_cross_attentions,
1163
+ )
1164
+
1165
+
1166
+ @add_start_docstrings(
1167
+ "The bare BART Model outputting raw hidden-states without any specific head on top.",
1168
+ BART_START_DOCSTRING,
1169
+ )
1170
+ class BartModel(BartPretrainedModel):
1171
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1172
+
1173
+ def __init__(self, config: BartConfig):
1174
+ super().__init__(config)
1175
+
1176
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1177
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1178
+ if self.config.share_encoder_decoder_embeddings:
1179
+ encoder_embed_tokens = decoder_embed_tokens = self.shared
1180
+ else:
1181
+ # Since the embeddings are not shared, deepcopy the embeddings here for encoder
1182
+ # and decoder to make sure they are not tied.
1183
+ encoder_embed_tokens = copy.deepcopy(self.shared)
1184
+ decoder_embed_tokens = copy.deepcopy(self.shared)
1185
+ self.shared = None
1186
+ self.encoder = BartEncoder(config, encoder_embed_tokens)
1187
+ self.decoder = BartDecoder(config, decoder_embed_tokens)
1188
+
1189
+ # Initialize weights and apply final processing
1190
+ self.post_init()
1191
+
1192
+ def get_input_embeddings(self):
1193
+ return self.shared
1194
+
1195
+ def set_input_embeddings(self, value):
1196
+ self.shared = value
1197
+ self.encoder.embed_tokens = self.shared
1198
+ self.decoder.embed_tokens = self.shared
1199
+
1200
+ def get_encoder(self):
1201
+ return self.encoder
1202
+
1203
+ def get_decoder(self):
1204
+ return self.decoder
1205
+
1206
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1207
+ @add_code_sample_docstrings(
1208
+ checkpoint=_CHECKPOINT_FOR_DOC,
1209
+ output_type=Seq2SeqModelOutput,
1210
+ config_class=_CONFIG_FOR_DOC,
1211
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1212
+ )
1213
+ def forward(
1214
+ self,
1215
+ input_ids: torch.LongTensor = None,
1216
+ attention_mask: Optional[torch.Tensor] = None,
1217
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1218
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1219
+ head_mask: Optional[torch.Tensor] = None,
1220
+ decoder_head_mask: Optional[torch.Tensor] = None,
1221
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1222
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1225
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1226
+ use_cache: Optional[bool] = None,
1227
+ output_attentions: Optional[bool] = None,
1228
+ output_hidden_states: Optional[bool] = None,
1229
+ return_dict: Optional[bool] = None,
1230
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
1231
+ # different to other models, Bart automatically creates decoder_input_ids from
1232
+ # input_ids if no decoder_input_ids are provided
1233
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1234
+ if input_ids is None:
1235
+ raise ValueError(
1236
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1237
+ "passed, `input_ids` cannot be `None`. Please pass either "
1238
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1239
+ )
1240
+
1241
+ decoder_input_ids = shift_tokens_right(
1242
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1243
+ )
1244
+
1245
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1246
+ output_hidden_states = (
1247
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1248
+ )
1249
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1250
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1251
+
1252
+ if encoder_outputs is None:
1253
+ encoder_outputs = self.encoder(
1254
+ input_ids=input_ids,
1255
+ attention_mask=attention_mask,
1256
+ head_mask=head_mask,
1257
+ inputs_embeds=inputs_embeds,
1258
+ output_attentions=output_attentions,
1259
+ output_hidden_states=output_hidden_states,
1260
+ return_dict=return_dict,
1261
+ )
1262
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1263
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1264
+ encoder_outputs = BaseModelOutput(
1265
+ last_hidden_state=encoder_outputs[0],
1266
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1267
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1268
+ )
1269
+
1270
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1271
+ decoder_outputs = self.decoder(
1272
+ input_ids=decoder_input_ids,
1273
+ attention_mask=decoder_attention_mask,
1274
+ encoder_hidden_states=encoder_outputs[0],
1275
+ encoder_attention_mask=attention_mask,
1276
+ head_mask=decoder_head_mask,
1277
+ cross_attn_head_mask=cross_attn_head_mask,
1278
+ past_key_values=past_key_values,
1279
+ inputs_embeds=decoder_inputs_embeds,
1280
+ use_cache=use_cache,
1281
+ output_attentions=output_attentions,
1282
+ output_hidden_states=output_hidden_states,
1283
+ return_dict=return_dict,
1284
+ )
1285
+
1286
+ if not return_dict:
1287
+ return decoder_outputs + encoder_outputs
1288
+
1289
+ return Seq2SeqModelOutput(
1290
+ last_hidden_state=decoder_outputs.last_hidden_state,
1291
+ past_key_values=decoder_outputs.past_key_values,
1292
+ decoder_hidden_states=decoder_outputs.hidden_states,
1293
+ decoder_attentions=decoder_outputs.attentions,
1294
+ cross_attentions=decoder_outputs.cross_attentions,
1295
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1296
+ encoder_hidden_states=encoder_outputs.hidden_states,
1297
+ encoder_attentions=encoder_outputs.attentions,
1298
+ )
1299
+
1300
+
1301
+ @add_start_docstrings(
1302
+ "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
1303
+ )
1304
+ class BartForConditionalGeneration(BartPretrainedModel):
1305
+ base_model_prefix = "model"
1306
+ _keys_to_ignore_on_load_missing = [
1307
+ r"final_logits_bias",
1308
+ r"lm_head.weight",
1309
+ "encoder.embed_tokens.weight",
1310
+ "decoder.embed_tokens.weight",
1311
+ ]
1312
+
1313
+ def __init__(self, config: BartConfig):
1314
+ super().__init__(config)
1315
+ self.model = BartModel(config)
1316
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1317
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1318
+
1319
+ # Initialize weights and apply final processing
1320
+ self.post_init()
1321
+
1322
+ def get_encoder(self):
1323
+ return self.model.get_encoder()
1324
+
1325
+ def get_decoder(self):
1326
+ return self.model.get_decoder()
1327
+
1328
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1329
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1330
+ self._resize_final_logits_bias(new_num_tokens)
1331
+ return new_embeddings
1332
+
1333
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1334
+ old_num_tokens = self.final_logits_bias.shape[-1]
1335
+ if new_num_tokens <= old_num_tokens:
1336
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1337
+ else:
1338
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1339
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1340
+ self.register_buffer("final_logits_bias", new_bias)
1341
+
1342
+ def get_output_embeddings(self):
1343
+ return self.lm_head
1344
+
1345
+ def set_output_embeddings(self, new_embeddings):
1346
+ self.lm_head = new_embeddings
1347
+
1348
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1349
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1350
+ @add_end_docstrings(BART_GENERATION_EXAMPLE)
1351
+ def forward(
1352
+ self,
1353
+ input_ids: torch.LongTensor = None,
1354
+ attention_mask: Optional[torch.Tensor] = None,
1355
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1356
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1357
+ head_mask: Optional[torch.Tensor] = None,
1358
+ decoder_head_mask: Optional[torch.Tensor] = None,
1359
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1360
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1361
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1362
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1363
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1364
+ labels: Optional[torch.LongTensor] = None,
1365
+ use_cache: Optional[bool] = None,
1366
+ output_attentions: Optional[bool] = None,
1367
+ output_hidden_states: Optional[bool] = None,
1368
+ return_dict: Optional[bool] = None,
1369
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
1370
+ r"""
1371
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1372
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1373
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1374
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1375
+
1376
+ Returns:
1377
+ """
1378
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1379
+
1380
+ if labels is not None:
1381
+ if use_cache:
1382
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1383
+ use_cache = False
1384
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1385
+ decoder_input_ids = shift_tokens_right(
1386
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1387
+ )
1388
+
1389
+ outputs = self.model(
1390
+ input_ids,
1391
+ attention_mask=attention_mask,
1392
+ decoder_input_ids=decoder_input_ids,
1393
+ encoder_outputs=encoder_outputs,
1394
+ decoder_attention_mask=decoder_attention_mask,
1395
+ head_mask=head_mask,
1396
+ decoder_head_mask=decoder_head_mask,
1397
+ cross_attn_head_mask=cross_attn_head_mask,
1398
+ past_key_values=past_key_values,
1399
+ inputs_embeds=inputs_embeds,
1400
+ decoder_inputs_embeds=decoder_inputs_embeds,
1401
+ use_cache=use_cache,
1402
+ output_attentions=output_attentions,
1403
+ output_hidden_states=output_hidden_states,
1404
+ return_dict=return_dict,
1405
+ )
1406
+
1407
+ lm_logits = self.lm_head(outputs[0])
1408
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
1409
+
1410
+ masked_lm_loss = None
1411
+ if labels is not None:
1412
+ labels = labels.to(lm_logits.device)
1413
+ loss_fct = CrossEntropyLoss()
1414
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1415
+
1416
+ if not return_dict:
1417
+ output = (lm_logits,) + outputs[1:]
1418
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1419
+
1420
+ return Seq2SeqLMOutput(
1421
+ loss=masked_lm_loss,
1422
+ logits=lm_logits,
1423
+ past_key_values=outputs.past_key_values,
1424
+ decoder_hidden_states=outputs.decoder_hidden_states,
1425
+ decoder_attentions=outputs.decoder_attentions,
1426
+ cross_attentions=outputs.cross_attentions,
1427
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1428
+ encoder_hidden_states=outputs.encoder_hidden_states,
1429
+ encoder_attentions=outputs.encoder_attentions,
1430
+ )
1431
+
1432
+ def prepare_inputs_for_generation(
1433
+ self,
1434
+ decoder_input_ids,
1435
+ past_key_values=None,
1436
+ attention_mask=None,
1437
+ decoder_attention_mask=None,
1438
+ head_mask=None,
1439
+ decoder_head_mask=None,
1440
+ cross_attn_head_mask=None,
1441
+ use_cache=None,
1442
+ encoder_outputs=None,
1443
+ **kwargs,
1444
+ ):
1445
+ # cut decoder_input_ids if past_key_values is used
1446
+ if past_key_values is not None:
1447
+ decoder_input_ids = decoder_input_ids[:, -1:]
1448
+
1449
+ return {
1450
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1451
+ "encoder_outputs": encoder_outputs,
1452
+ "past_key_values": past_key_values,
1453
+ "decoder_input_ids": decoder_input_ids,
1454
+ "attention_mask": attention_mask,
1455
+ "decoder_attention_mask": decoder_attention_mask,
1456
+ "head_mask": head_mask,
1457
+ "decoder_head_mask": decoder_head_mask,
1458
+ "cross_attn_head_mask": cross_attn_head_mask,
1459
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1460
+ }
1461
+
1462
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1463
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1464
+
1465
+ @staticmethod
1466
+ def _reorder_cache(past_key_values, beam_idx):
1467
+ reordered_past = ()
1468
+ for layer_past in past_key_values:
1469
+ # cached cross_attention states don't have to be reordered -> they are always the same
1470
+ reordered_past += (
1471
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1472
+ )
1473
+ return reordered_past
1474
+
1475
+
1476
+ @add_start_docstrings(
1477
+ """
1478
+ Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
1479
+ tasks.
1480
+ """,
1481
+ BART_START_DOCSTRING,
1482
+ )
1483
+ class BartForSequenceClassification(BartPretrainedModel):
1484
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1485
+
1486
+ def __init__(self, config: BartConfig, **kwargs):
1487
+ super().__init__(config, **kwargs)
1488
+ self.model = BartModel(config)
1489
+ self.classification_head = BartClassificationHead(
1490
+ config.d_model,
1491
+ config.d_model,
1492
+ config.num_labels,
1493
+ config.classifier_dropout,
1494
+ )
1495
+
1496
+ # Initialize weights and apply final processing
1497
+ self.post_init()
1498
+
1499
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1500
+ @add_code_sample_docstrings(
1501
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1502
+ output_type=Seq2SeqSequenceClassifierOutput,
1503
+ config_class=_CONFIG_FOR_DOC,
1504
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1505
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1506
+ )
1507
+ def forward(
1508
+ self,
1509
+ input_ids: torch.LongTensor = None,
1510
+ attention_mask: Optional[torch.Tensor] = None,
1511
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1512
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1513
+ head_mask: Optional[torch.Tensor] = None,
1514
+ decoder_head_mask: Optional[torch.Tensor] = None,
1515
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1516
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1517
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1518
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1519
+ labels: Optional[torch.LongTensor] = None,
1520
+ use_cache: Optional[bool] = None,
1521
+ output_attentions: Optional[bool] = None,
1522
+ output_hidden_states: Optional[bool] = None,
1523
+ return_dict: Optional[bool] = None,
1524
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
1525
+ r"""
1526
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1527
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1528
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1529
+ """
1530
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1531
+ if labels is not None:
1532
+ use_cache = False
1533
+
1534
+ if input_ids is None and inputs_embeds is not None:
1535
+ raise NotImplementedError(
1536
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1537
+ )
1538
+
1539
+ outputs = self.model(
1540
+ input_ids,
1541
+ attention_mask=attention_mask,
1542
+ decoder_input_ids=decoder_input_ids,
1543
+ decoder_attention_mask=decoder_attention_mask,
1544
+ head_mask=head_mask,
1545
+ decoder_head_mask=decoder_head_mask,
1546
+ cross_attn_head_mask=cross_attn_head_mask,
1547
+ encoder_outputs=encoder_outputs,
1548
+ inputs_embeds=inputs_embeds,
1549
+ decoder_inputs_embeds=decoder_inputs_embeds,
1550
+ use_cache=use_cache,
1551
+ output_attentions=output_attentions,
1552
+ output_hidden_states=output_hidden_states,
1553
+ return_dict=return_dict,
1554
+ )
1555
+ hidden_states = outputs[0] # last hidden state
1556
+
1557
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
1558
+
1559
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1560
+ raise ValueError("All examples must have the same number of <eos> tokens.")
1561
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1562
+ :, -1, :
1563
+ ]
1564
+ logits = self.classification_head(sentence_representation)
1565
+
1566
+ loss = None
1567
+ if labels is not None:
1568
+ labels = labels.to(logits.device)
1569
+ if self.config.problem_type is None:
1570
+ if self.config.num_labels == 1:
1571
+ self.config.problem_type = "regression"
1572
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1573
+ self.config.problem_type = "single_label_classification"
1574
+ else:
1575
+ self.config.problem_type = "multi_label_classification"
1576
+
1577
+ if self.config.problem_type == "regression":
1578
+ loss_fct = MSELoss()
1579
+ if self.config.num_labels == 1:
1580
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1581
+ else:
1582
+ loss = loss_fct(logits, labels)
1583
+ elif self.config.problem_type == "single_label_classification":
1584
+ loss_fct = CrossEntropyLoss()
1585
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1586
+ elif self.config.problem_type == "multi_label_classification":
1587
+ loss_fct = BCEWithLogitsLoss()
1588
+ loss = loss_fct(logits, labels)
1589
+ if not return_dict:
1590
+ output = (logits,) + outputs[1:]
1591
+ return ((loss,) + output) if loss is not None else output
1592
+
1593
+ return Seq2SeqSequenceClassifierOutput(
1594
+ loss=loss,
1595
+ logits=logits,
1596
+ past_key_values=outputs.past_key_values,
1597
+ decoder_hidden_states=outputs.decoder_hidden_states,
1598
+ decoder_attentions=outputs.decoder_attentions,
1599
+ cross_attentions=outputs.cross_attentions,
1600
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1601
+ encoder_hidden_states=outputs.encoder_hidden_states,
1602
+ encoder_attentions=outputs.encoder_attentions,
1603
+ )
1604
+
1605
+
1606
+ @add_start_docstrings(
1607
+ """
1608
+ BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1609
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1610
+ """,
1611
+ BART_START_DOCSTRING,
1612
+ )
1613
+ class BartForQuestionAnswering(BartPretrainedModel):
1614
+ _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1615
+
1616
+ def __init__(self, config):
1617
+ super().__init__(config)
1618
+
1619
+ config.num_labels = 2
1620
+ self.num_labels = config.num_labels
1621
+
1622
+ self.model = BartModel(config)
1623
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1624
+
1625
+ # Initialize weights and apply final processing
1626
+ self.post_init()
1627
+
1628
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1629
+ @add_code_sample_docstrings(
1630
+ checkpoint=_CHECKPOINT_FOR_QA,
1631
+ output_type=Seq2SeqQuestionAnsweringModelOutput,
1632
+ config_class=_CONFIG_FOR_DOC,
1633
+ expected_loss=_QA_EXPECTED_LOSS,
1634
+ expected_output=_QA_EXPECTED_OUTPUT,
1635
+ )
1636
+ def forward(
1637
+ self,
1638
+ input_ids: torch.Tensor = None,
1639
+ attention_mask: Optional[torch.Tensor] = None,
1640
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1641
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1642
+ head_mask: Optional[torch.Tensor] = None,
1643
+ decoder_head_mask: Optional[torch.Tensor] = None,
1644
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1645
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1646
+ start_positions: Optional[torch.LongTensor] = None,
1647
+ end_positions: Optional[torch.LongTensor] = None,
1648
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1649
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1650
+ use_cache: Optional[bool] = None,
1651
+ output_attentions: Optional[bool] = None,
1652
+ output_hidden_states: Optional[bool] = None,
1653
+ return_dict: Optional[bool] = None,
1654
+ ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
1655
+ r"""
1656
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1657
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1658
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1659
+ are not taken into account for computing the loss.
1660
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1661
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1662
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1663
+ are not taken into account for computing the loss.
1664
+ """
1665
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1666
+ if start_positions is not None and end_positions is not None:
1667
+ use_cache = False
1668
+
1669
+ outputs = self.model(
1670
+ input_ids,
1671
+ attention_mask=attention_mask,
1672
+ decoder_input_ids=decoder_input_ids,
1673
+ decoder_attention_mask=decoder_attention_mask,
1674
+ head_mask=head_mask,
1675
+ decoder_head_mask=decoder_head_mask,
1676
+ cross_attn_head_mask=cross_attn_head_mask,
1677
+ encoder_outputs=encoder_outputs,
1678
+ inputs_embeds=inputs_embeds,
1679
+ decoder_inputs_embeds=decoder_inputs_embeds,
1680
+ use_cache=use_cache,
1681
+ output_attentions=output_attentions,
1682
+ output_hidden_states=output_hidden_states,
1683
+ return_dict=return_dict,
1684
+ )
1685
+
1686
+ sequence_output = outputs[0]
1687
+
1688
+ logits = self.qa_outputs(sequence_output)
1689
+ start_logits, end_logits = logits.split(1, dim=-1)
1690
+ start_logits = start_logits.squeeze(-1).contiguous()
1691
+ end_logits = end_logits.squeeze(-1).contiguous()
1692
+
1693
+ total_loss = None
1694
+ if start_positions is not None and end_positions is not None:
1695
+ # If we are on multi-GPU, split add a dimension
1696
+ if len(start_positions.size()) > 1:
1697
+ start_positions = start_positions.squeeze(-1)
1698
+ if len(end_positions.size()) > 1:
1699
+ end_positions = end_positions.squeeze(-1)
1700
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1701
+ ignored_index = start_logits.size(1)
1702
+ start_positions = start_positions.clamp(0, ignored_index)
1703
+ end_positions = end_positions.clamp(0, ignored_index)
1704
+
1705
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1706
+ start_loss = loss_fct(start_logits, start_positions)
1707
+ end_loss = loss_fct(end_logits, end_positions)
1708
+ total_loss = (start_loss + end_loss) / 2
1709
+
1710
+ if not return_dict:
1711
+ output = (
1712
+ start_logits,
1713
+ end_logits,
1714
+ ) + outputs[1:]
1715
+ return ((total_loss,) + output) if total_loss is not None else output
1716
+
1717
+ return Seq2SeqQuestionAnsweringModelOutput(
1718
+ loss=total_loss,
1719
+ start_logits=start_logits,
1720
+ end_logits=end_logits,
1721
+ past_key_values=outputs.past_key_values,
1722
+ decoder_hidden_states=outputs.decoder_hidden_states,
1723
+ decoder_attentions=outputs.decoder_attentions,
1724
+ cross_attentions=outputs.cross_attentions,
1725
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1726
+ encoder_hidden_states=outputs.encoder_hidden_states,
1727
+ encoder_attentions=outputs.encoder_attentions,
1728
+ )
1729
+
1730
+
1731
+ class BartDecoderWrapper(BartPretrainedModel):
1732
+ """
1733
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1734
+ used in combination with the [`EncoderDecoderModel`] framework.
1735
+ """
1736
+
1737
+ def __init__(self, config):
1738
+ super().__init__(config)
1739
+ self.decoder = BartDecoder(config)
1740
+
1741
+ def forward(self, *args, **kwargs):
1742
+ return self.decoder(*args, **kwargs)
1743
+
1744
+
1745
+ @add_start_docstrings(
1746
+ """
1747
+ BART decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings).
1748
+ """,
1749
+ BART_START_DOCSTRING,
1750
+ )
1751
+ class BartForCausalLM(BartPretrainedModel):
1752
+ _keys_to_ignore_on_load_missing = ["lm_head.weight"]
1753
+
1754
+ def __init__(self, config):
1755
+ config = copy.deepcopy(config)
1756
+ config.is_decoder = True
1757
+ config.is_encoder_decoder = False
1758
+ super().__init__(config)
1759
+ self.model = BartDecoderWrapper(config)
1760
+
1761
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1762
+
1763
+ # Initialize weights and apply final processing
1764
+ self.post_init()
1765
+
1766
+ def get_input_embeddings(self):
1767
+ return self.model.decoder.embed_tokens
1768
+
1769
+ def set_input_embeddings(self, value):
1770
+ self.model.decoder.embed_tokens = value
1771
+
1772
+ def get_output_embeddings(self):
1773
+ return self.lm_head
1774
+
1775
+ def set_output_embeddings(self, new_embeddings):
1776
+ self.lm_head = new_embeddings
1777
+
1778
+ def set_decoder(self, decoder):
1779
+ self.model.decoder = decoder
1780
+
1781
+ def get_decoder(self):
1782
+ return self.model.decoder
1783
+
1784
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1785
+ def forward(
1786
+ self,
1787
+ input_ids: torch.LongTensor = None,
1788
+ attention_mask: Optional[torch.Tensor] = None,
1789
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1790
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1791
+ head_mask: Optional[torch.Tensor] = None,
1792
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1793
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1794
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1795
+ labels: Optional[torch.LongTensor] = None,
1796
+ use_cache: Optional[bool] = None,
1797
+ output_attentions: Optional[bool] = None,
1798
+ output_hidden_states: Optional[bool] = None,
1799
+ return_dict: Optional[bool] = None,
1800
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1801
+ r"""
1802
+ Args:
1803
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1804
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1805
+ provide it.
1806
+
1807
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1808
+ [`PreTrainedTokenizer.__call__`] for details.
1809
+
1810
+ [What are input IDs?](../glossary#input-ids)
1811
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1812
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1813
+
1814
+ - 1 for tokens that are **not masked**,
1815
+ - 0 for tokens that are **masked**.
1816
+
1817
+ [What are attention masks?](../glossary#attention-mask)
1818
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1819
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1820
+ if the model is configured as a decoder.
1821
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1822
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1823
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1824
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1825
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1826
+
1827
+ - 1 indicates the head is **not masked**,
1828
+ - 0 indicates the head is **masked**.
1829
+
1830
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1831
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1832
+
1833
+ - 1 indicates the head is **not masked**,
1834
+ - 0 indicates the head is **masked**.
1835
+
1836
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1837
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1838
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1839
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1840
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1841
+
1842
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1843
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1844
+
1845
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1846
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1847
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1848
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1849
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1850
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1851
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1852
+ use_cache (`bool`, *optional*):
1853
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1854
+ (see `past_key_values`).
1855
+
1856
+ - 1 for tokens that are **not masked**,
1857
+ - 0 for tokens that are **masked**.
1858
+ output_attentions (`bool`, *optional*):
1859
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1860
+ returned tensors for more detail.
1861
+ output_hidden_states (`bool`, *optional*):
1862
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1863
+ for more detail.
1864
+ return_dict (`bool`, *optional*):
1865
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1866
+
1867
+ Returns:
1868
+
1869
+ Example:
1870
+
1871
+ ```python
1872
+ >>> from transformers import AutoTokenizer, BartForCausalLM
1873
+
1874
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
1875
+ >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
1876
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1877
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1878
+ >>> outputs = model(**inputs)
1879
+
1880
+ >>> logits = outputs.logits
1881
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1882
+ >>> list(logits.shape) == expected_shape
1883
+ True
1884
+ ```"""
1885
+
1886
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1887
+ output_hidden_states = (
1888
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1889
+ )
1890
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1891
+
1892
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1893
+ outputs = self.model.decoder(
1894
+ input_ids=input_ids,
1895
+ attention_mask=attention_mask,
1896
+ encoder_hidden_states=encoder_hidden_states,
1897
+ encoder_attention_mask=encoder_attention_mask,
1898
+ head_mask=head_mask,
1899
+ cross_attn_head_mask=cross_attn_head_mask,
1900
+ past_key_values=past_key_values,
1901
+ inputs_embeds=inputs_embeds,
1902
+ use_cache=use_cache,
1903
+ output_attentions=output_attentions,
1904
+ output_hidden_states=output_hidden_states,
1905
+ return_dict=return_dict,
1906
+ )
1907
+
1908
+ logits = self.lm_head(outputs[0])
1909
+
1910
+ loss = None
1911
+ if labels is not None:
1912
+ labels = labels.to(logits.device)
1913
+ loss_fct = CrossEntropyLoss()
1914
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1915
+
1916
+ if not return_dict:
1917
+ output = (logits,) + outputs[1:]
1918
+ return (loss,) + output if loss is not None else output
1919
+
1920
+ return CausalLMOutputWithCrossAttentions(
1921
+ loss=loss,
1922
+ logits=logits,
1923
+ past_key_values=outputs.past_key_values,
1924
+ hidden_states=outputs.hidden_states,
1925
+ attentions=outputs.attentions,
1926
+ cross_attentions=outputs.cross_attentions,
1927
+ )
1928
+
1929
+ def prepare_inputs_for_generation(
1930
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1931
+ ):
1932
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1933
+ if attention_mask is None:
1934
+ attention_mask = input_ids.new_ones(input_ids.shape)
1935
+
1936
+ if past_key_values:
1937
+ input_ids = input_ids[:, -1:]
1938
+ # first step, decoder_cached_states are empty
1939
+ return {
1940
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1941
+ "attention_mask": attention_mask,
1942
+ "past_key_values": past_key_values,
1943
+ "use_cache": use_cache,
1944
+ }
1945
+
1946
+ @staticmethod
1947
+ def _reorder_cache(past_key_values, beam_idx):
1948
+ reordered_past = ()
1949
+ for layer_past in past_key_values:
1950
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1951
+ return reordered_past
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63d8c31787dc3bd02602e3490a447a9e4b8dd902bb7053b5970a55e2ac9b55af
3
+ size 973699715
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"additional_special_tokens": ["LANG_TOK_EN", "LANG_TOK_AF", "LANG_TOK_AR", "LANG_TOK_AZ", "LANG_TOK_BE", "LANG_TOK_BG", "LANG_TOK_BN", "LANG_TOK_BS", "LANG_TOK_CS", "LANG_TOK_DA", "LANG_TOK_DE", "LANG_TOK_EL", "LANG_TOK_EO", "LANG_TOK_ES", "LANG_TOK_ET", "LANG_TOK_EU", "LANG_TOK_FA", "LANG_TOK_FI", "LANG_TOK_FR", "LANG_TOK_GL", "LANG_TOK_GU", "LANG_TOK_HE", "LANG_TOK_HI", "LANG_TOK_HR", "LANG_TOK_HU", "LANG_TOK_HY", "LANG_TOK_ID", "LANG_TOK_IT", "LANG_TOK_JA", "LANG_TOK_KA", "LANG_TOK_KK", "LANG_TOK_KO", "LANG_TOK_KU", "LANG_TOK_LT", "LANG_TOK_LV", "LANG_TOK_MK", "LANG_TOK_MN", "LANG_TOK_MR", "LANG_TOK_MS", "LANG_TOK_MT", "LANG_TOK_MY", "LANG_TOK_NB", "LANG_TOK_NL", "LANG_TOK_PL", "LANG_TOK_PT", "LANG_TOK_RO", "LANG_TOK_RU", "LANG_TOK_SK", "LANG_TOK_SL", "LANG_TOK_SQ", "LANG_TOK_SR", "LANG_TOK_SV", "LANG_TOK_TA", "LANG_TOK_TH", "LANG_TOK_TR", "LANG_TOK_UK", "LANG_TOK_UR", "LANG_TOK_VI", "LANG_TOK_ZH"]}
tokenization_bat.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:cb
4
+ @contact:[email protected]
5
+ @time:2023/5/30 14:21
6
+ @filename:tokenization.py
7
+ @software:PyCharm
8
+ @description:
9
+ """
10
+ import re
11
+ from transformers import FSMTTokenizer as fsmt
12
+
13
+
14
+ class FSMTTokenizer(fsmt):
15
+ def __init__(self, *args, **kwargs):
16
+ super(FSMTTokenizer, self).__init__(*args, **kwargs)
17
+ self.space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*')
18
+ self.reversal = False
19
+
20
+ def moses_tokenize(self, text, lang):
21
+ if lang not in self.cache_moses_tokenizer:
22
+ moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
23
+ self.cache_moses_tokenizer[lang] = moses_tokenizer
24
+ return self.cache_moses_tokenizer[lang].tokenize(
25
+ text, aggressive_dash_splits=True, return_str=False, escape=False
26
+ )
27
+
28
+ def _switch_to_input_mode(self):
29
+ if self.reversal:
30
+ self.lang_prefix, self.lang_prefix_id = 'zh', 64870
31
+ else:
32
+ self.lang_prefix, self.lang_prefix_id = 'en', 64812
33
+
34
+ def _switch_to_target_mode(self):
35
+ if self.reversal:
36
+ self.lang_prefix, self.lang_prefix_id = 'en', 64812
37
+ else:
38
+ self.lang_prefix, self.lang_prefix_id = 'zh', 64870
39
+
40
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
41
+ """
42
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
43
+ adding special tokens. A FAIRSEQ Transformer sequence has the following format:
44
+
45
+ - single sequence: `<s> X </s>`
46
+ - pair of sequences: `<s> A </s> B </s>`
47
+
48
+ Args:
49
+ token_ids_0 (`List[int]`):
50
+ List of IDs to which the special tokens will be added.
51
+ token_ids_1 (`List[int]`, *optional*):
52
+ Optional second list of IDs for sequence pairs.
53
+
54
+ Returns:
55
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
56
+ """
57
+ sep = [self.sep_token_id]
58
+ token_ids_0 = [self.lang_prefix_id] + token_ids_0
59
+ # no bos used in fairseq
60
+ if token_ids_1 is None:
61
+ return token_ids_0 + sep
62
+ return token_ids_0 + sep + token_ids_1 + sep
63
+
64
+ def moses_pipeline(self, text, lang):
65
+ text = self.moses_punct_norm(text, lang)
66
+ return text
67
+
68
+ def _tokenize(self, text, lang="en", bypass_tokenizer=False):
69
+ """
70
+ 原版FSMTTokenizer会把中文标点英文化,故重写
71
+ :param text:
72
+ :param lang:
73
+ :param bypass_tokenizer:
74
+ :return:
75
+ """
76
+ if self.do_lower_case:
77
+ text = text.lower()
78
+ if bypass_tokenizer:
79
+ text = text.split()
80
+ else:
81
+ text = self.moses_pipeline(text, lang=self.lang_prefix)
82
+ text = self.moses_tokenize(text, lang=self.lang_prefix)
83
+
84
+ split_tokens = []
85
+ for token in text:
86
+ if token:
87
+ split_tokens.extend(list(self.bpe(token).split(" ")))
88
+
89
+ return split_tokens
90
+
91
+ def convert_tokens_to_string(self, tokens):
92
+ """
93
+ 删除非英文字母前后的空格,业务上处理更合适
94
+ :param tokens:
95
+ :return:
96
+ """
97
+ tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens)
98
+ tokens = self.space_re.sub('', tokens)
99
+ return tokens
100
+
101
+
102
+ if __name__ == '__main__':
103
+ tokenizer = FSMTTokenizer.from_pretrained(r'./')
104
+ r = tokenizer(['hello'], text_target=['你好朋友'])
105
+ print(r)
106
+ tokenizer.reversal = True
107
+ r = tokenizer(['你好朋友'], text_target=['hello'])
108
+ # # r['input_ids'] += r['labels']
109
+ # # r['labels'] += r['input_ids']
110
+ print(r)
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "langs": [
3
+ "en",
4
+ "zh"
5
+ ],
6
+ "model_max_length": 300,
7
+ "tokenizer_class": "FSMTTokenizer",
8
+ "auto_map": {
9
+ "AutoTokenizer": [
10
+ "tokenization_bat.FSMTTokenizer",
11
+ null
12
+ ]
13
+ }
14
+ }
vocab-src.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab-tgt.json ADDED
The diff for this file is too large to render. See raw diff