nhatminh commited on
Commit
f4775d2
·
verified ·
1 Parent(s): ea8042d

Upload 8 files

Browse files
Files changed (7) hide show
  1. block.py +470 -0
  2. embedding.py +62 -0
  3. mha.py +662 -0
  4. mlp.py +194 -0
  5. modeling_xlm_roberta.py +1119 -0
  6. stochastic_depth.py +97 -0
  7. xlm_padding.py +218 -0
block.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.fx
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ def stochastic_depth(
25
+ input: Tensor, p: float, mode: str, training: bool = True
26
+ ) -> Tensor:
27
+ """
28
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30
+ branches of residual architectures.
31
+ Args:
32
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33
+ being its batch i.e. a batch with ``N`` rows.
34
+ p (float): probability of the input to be zeroed.
35
+ mode (str): ``"batch"`` or ``"row"``.
36
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37
+ randomly selected rows from the batch.
38
+ training: apply stochastic depth if is ``True``. Default: ``True``
39
+ Returns:
40
+ Tensor[N, ...]: The randomly zeroed tensor.
41
+ """
42
+ if p < 0.0 or p > 1.0:
43
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44
+ if mode not in ["batch", "row"]:
45
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46
+ if not training or p == 0.0:
47
+ return input
48
+
49
+ survival_rate = 1.0 - p
50
+ if mode == "row":
51
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
52
+ else:
53
+ size = [1] * input.ndim
54
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
55
+ noise = noise.bernoulli_(survival_rate)
56
+ if survival_rate > 0.0:
57
+ noise.div_(survival_rate)
58
+ return input * noise
59
+
60
+
61
+ torch.fx.wrap("stochastic_depth")
62
+
63
+
64
+ class StochasticDepth(nn.Module):
65
+ """
66
+ See :func:`stochastic_depth`.
67
+ """
68
+
69
+ def __init__(self, p: float, mode: str) -> None:
70
+ super().__init__()
71
+ self.p = p
72
+ self.mode = mode
73
+
74
+ def forward(self, input: Tensor) -> Tensor:
75
+ return stochastic_depth(input, self.p, self.mode, self.training)
76
+
77
+ def __repr__(self) -> str:
78
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79
+ return s
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim,
86
+ mixer_cls=None,
87
+ mlp_cls=None,
88
+ norm_cls=nn.LayerNorm,
89
+ dropout_cls=nn.Dropout,
90
+ prenorm=True,
91
+ resid_dropout1=0.0,
92
+ resid_dropout2=0.0,
93
+ drop_path1=0.0,
94
+ drop_path2=0.0,
95
+ fused_dropout_add_ln=False,
96
+ return_residual=False,
97
+ residual_in_fp32=False,
98
+ sequence_parallel=False,
99
+ mark_shared_params=False,
100
+ ):
101
+ """
102
+ For prenorm=True, this Block has a slightly different structure compared to a regular
103
+ prenorm Transformer block.
104
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105
+ [Ref: https://arxiv.org/abs/2002.04745]
106
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107
+ the hidden_states (output of the MLP) and the residual.
108
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109
+ The residual needs to be provided (except for the very first block).
110
+
111
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
+
114
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115
+ This is for performance reason: for post-norm architecture, returning the input allows us
116
+ to fuse the backward of nn.Linear with the residual connection.
117
+ """
118
+ super().__init__()
119
+ self.prenorm = prenorm
120
+ self.fused_dropout_add_ln = fused_dropout_add_ln
121
+ self.return_residual = return_residual
122
+ self.residual_in_fp32 = residual_in_fp32
123
+ if self.residual_in_fp32:
124
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125
+ if mixer_cls is None:
126
+ mixer_cls = partial(MHA, num_heads=dim // 64)
127
+ if mlp_cls is None:
128
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
129
+ self.mixer = mixer_cls(dim)
130
+ self.dropout1 = dropout_cls(resid_dropout1)
131
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132
+ self.norm1 = norm_cls(dim)
133
+ self.mlp = mlp_cls(dim)
134
+ if not isinstance(self.mlp, nn.Identity):
135
+ self.dropout2 = dropout_cls(resid_dropout2)
136
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137
+ self.norm2 = norm_cls(dim)
138
+
139
+ if self.fused_dropout_add_ln:
140
+ assert layer_norm_fn is not None, "Triton is not installed"
141
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142
+ self.dropout1, nn.Dropout
143
+ )
144
+
145
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146
+ # then the input to each worker in the tensor parallel group will be different.
147
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148
+ # For now this is not an issue because we always use sequence_parallel=True during training
149
+ # and only use sequence_parallel=False during inference.
150
+
151
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152
+ if sequence_parallel:
153
+ for p in self.norm1.parameters():
154
+ p._sequence_parallel = True
155
+ if hasattr(self, "norm2"):
156
+ for p in self.norm2.parameters():
157
+ p._sequence_parallel = True
158
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
159
+ if mark_shared_params:
160
+ for p in self.norm1.parameters():
161
+ p._shared_params = True
162
+ if hasattr(self, "norm2"):
163
+ for p in self.norm2.parameters():
164
+ p._shared_params = True
165
+
166
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167
+ return self.mixer.allocate_inference_cache(
168
+ batch_size, max_seqlen, dtype=dtype, **kwargs
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: Tensor,
174
+ residual: Optional[Tensor] = None,
175
+ mixer_subset=None,
176
+ mixer_kwargs=None,
177
+ ):
178
+ r"""Pass the input through the encoder layer.
179
+
180
+ Args:
181
+ hidden_states: the sequence to the encoder layer (required).
182
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
184
+ before applying the query projection. Useful for e.g., ViT where we only care
185
+ about the CLS token in the last layer.
186
+ """
187
+ if self.prenorm:
188
+ if not self.fused_dropout_add_ln:
189
+ dropped = self.drop_path1(self.dropout1(hidden_states))
190
+ residual = (dropped + residual) if residual is not None else dropped
191
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192
+ if self.residual_in_fp32:
193
+ residual = residual.to(torch.float32)
194
+ else:
195
+ if self.drop_path1.p == 0 or not self.training:
196
+ rowscale1 = None
197
+ else:
198
+ rowscale1 = self.drop_path1(
199
+ torch.ones(
200
+ hidden_states.shape[:-1],
201
+ device=hidden_states.device,
202
+ dtype=hidden_states.dtype,
203
+ )
204
+ )
205
+ hidden_states, residual = layer_norm_fn(
206
+ hidden_states,
207
+ self.norm1.weight,
208
+ self.norm1.bias,
209
+ residual=residual,
210
+ eps=self.norm1.eps,
211
+ dropout_p=self.dropout1.p if self.training else 0.0,
212
+ rowscale=rowscale1,
213
+ prenorm=True,
214
+ residual_in_fp32=self.residual_in_fp32,
215
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
216
+ )
217
+ if mixer_kwargs is None:
218
+ mixer_kwargs = {}
219
+ if mixer_subset is not None:
220
+ mixer_kwargs["mixer_subset"] = mixer_subset
221
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222
+ if mixer_subset is not None:
223
+ residual = residual[:, mixer_subset]
224
+ if not isinstance(self.mlp, nn.Identity):
225
+ if not self.fused_dropout_add_ln:
226
+ dropped = self.drop_path2(self.dropout2(hidden_states))
227
+ residual = (dropped + residual) if residual is not None else dropped
228
+ hidden_states = self.norm2(
229
+ residual.to(dtype=self.norm2.weight.dtype)
230
+ )
231
+ if self.residual_in_fp32:
232
+ residual = residual.to(torch.float32)
233
+ else:
234
+ if self.drop_path2.p == 0 or not self.training:
235
+ rowscale2 = None
236
+ else:
237
+ rowscale2 = self.drop_path2(
238
+ torch.ones(
239
+ hidden_states.shape[:-1],
240
+ device=hidden_states.device,
241
+ dtype=hidden_states.dtype,
242
+ )
243
+ )
244
+ hidden_states, residual = layer_norm_fn(
245
+ hidden_states,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=residual,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=True,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
255
+ )
256
+ hidden_states = self.mlp(hidden_states)
257
+ return hidden_states, residual
258
+ else:
259
+ assert residual is None
260
+ mixer_out = self.mixer(
261
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262
+ )
263
+ if self.return_residual: # mixer out is actually a pair here
264
+ mixer_out, hidden_states = mixer_out
265
+ if not self.fused_dropout_add_ln:
266
+ hidden_states = self.norm1(
267
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268
+ dtype=self.norm1.weight.dtype
269
+ )
270
+ )
271
+ else:
272
+ if self.drop_path1.p == 0 or not self.training:
273
+ rowscale1 = None
274
+ else:
275
+ rowscale1 = self.drop_path1(
276
+ torch.ones(
277
+ mixer_out.shape[:-1],
278
+ device=mixer_out.device,
279
+ dtype=mixer_out.dtype,
280
+ )
281
+ )
282
+ hidden_states = layer_norm_fn(
283
+ mixer_out,
284
+ self.norm1.weight,
285
+ self.norm1.bias,
286
+ residual=hidden_states,
287
+ eps=self.norm1.eps,
288
+ dropout_p=self.dropout1.p if self.training else 0.0,
289
+ rowscale=rowscale1,
290
+ prenorm=False,
291
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
292
+ )
293
+ if not isinstance(self.mlp, nn.Identity):
294
+ mlp_out = self.mlp(hidden_states)
295
+ if self.return_residual: # mlp out is actually a pair here
296
+ mlp_out, hidden_states = mlp_out
297
+ if not self.fused_dropout_add_ln:
298
+ hidden_states = self.norm2(
299
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300
+ dtype=self.norm2.weight.dtype
301
+ )
302
+ )
303
+ else:
304
+ if self.drop_path2.p == 0 or not self.training:
305
+ rowscale2 = None
306
+ else:
307
+ rowscale2 = self.drop_path2(
308
+ torch.ones(
309
+ mlp_out.shape[:-1],
310
+ device=mlp_out.device,
311
+ dtype=mlp_out.dtype,
312
+ )
313
+ )
314
+ hidden_states = layer_norm_fn(
315
+ mlp_out,
316
+ self.norm2.weight,
317
+ self.norm2.bias,
318
+ residual=hidden_states,
319
+ eps=self.norm2.eps,
320
+ dropout_p=self.dropout2.p if self.training else 0.0,
321
+ rowscale=rowscale2,
322
+ prenorm=False,
323
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
324
+ )
325
+ return hidden_states
326
+
327
+
328
+ class ParallelBlock(nn.Module):
329
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330
+ and PaLM.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ mixer_cls=None,
337
+ mlp_cls=None,
338
+ norm_cls=nn.LayerNorm,
339
+ dropout_cls=nn.Dropout,
340
+ resid_dropout1=0.0,
341
+ resid_dropout2=0.0,
342
+ tied_norm=False,
343
+ fused_dropout_add_ln=False,
344
+ residual_in_fp32=False,
345
+ sequence_parallel=False,
346
+ mark_shared_params=False,
347
+ ):
348
+ """
349
+ This Block has a slightly different structure compared to a regular
350
+ prenorm Transformer block.
351
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352
+ [Ref: https://arxiv.org/abs/2002.04745]
353
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354
+ the hidden_states (output1 of the MHA / MLP) and the residual.
355
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356
+ The residual needs to be provided (except for the very first block).
357
+ """
358
+ super().__init__()
359
+ self.tied_norm = tied_norm
360
+ self.fused_dropout_add_ln = fused_dropout_add_ln
361
+ self.residual_in_fp32 = residual_in_fp32
362
+ if mixer_cls is None:
363
+ mixer_cls = partial(MHA, num_heads=dim // 64)
364
+ if mlp_cls is None:
365
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
366
+ self.mixer = mixer_cls(dim)
367
+ self.dropout1 = dropout_cls(resid_dropout1)
368
+ self.norm1 = norm_cls(dim)
369
+ self.mlp = mlp_cls(dim)
370
+ self.dropout2 = dropout_cls(resid_dropout2)
371
+ if not self.tied_norm:
372
+ self.norm2 = norm_cls(dim)
373
+
374
+ if self.fused_dropout_add_ln:
375
+ assert layer_norm_fn is not None, "Triton is not installed"
376
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377
+ self.dropout1, nn.Dropout
378
+ )
379
+
380
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381
+ # then the input to each worker in the tensor parallel group will be different.
382
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383
+ # For now this is not an issue because we always use sequence_parallel=True during training
384
+ # and only use sequence_parallel=False during inference.
385
+
386
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387
+ if sequence_parallel:
388
+ for p in self.norm1.parameters():
389
+ p._sequence_parallel = True
390
+ if hasattr(self, "norm2"):
391
+ for p in self.norm2.parameters():
392
+ p._sequence_parallel = True
393
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
394
+ if mark_shared_params:
395
+ for p in self.norm1.parameters():
396
+ p._shared_params = True
397
+ if hasattr(self, "norm2"):
398
+ for p in self.norm2.parameters():
399
+ p._shared_params = True
400
+
401
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402
+ return self.mixer.allocate_inference_cache(
403
+ batch_size, max_seqlen, dtype=dtype, **kwargs
404
+ )
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states1: Tensor,
409
+ hidden_states2: Optional[Tensor] = None,
410
+ residual: Optional[Tensor] = None,
411
+ mixer_kwargs=None,
412
+ ):
413
+ r"""Pass the input through the encoder layer.
414
+
415
+ Args:
416
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
417
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418
+ residual.
419
+ """
420
+ # TODO: Ideally we should only do the allgather / allreduce once for
421
+ # the Linear to MLP & Attention
422
+ if not self.fused_dropout_add_ln:
423
+ dropped1 = self.dropout1(hidden_states1)
424
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
425
+ if hidden_states2 is not None:
426
+ dropped2 = self.dropout2(hidden_states2)
427
+ residual = (
428
+ (residual + dropped1 + dropped2)
429
+ if residual is not None
430
+ else dropped1 + dropped2
431
+ )
432
+ else:
433
+ residual = (residual + dropped1) if residual is not None else dropped1
434
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435
+ hidden_states2 = (
436
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437
+ if not self.tied_norm
438
+ else hidden_states1
439
+ )
440
+ if self.residual_in_fp32:
441
+ residual = residual.to(torch.float32)
442
+ else:
443
+ weight2, bias2 = (
444
+ (self.norm2.weight, self.norm2.bias)
445
+ if not self.tied_norm
446
+ else (None, None)
447
+ )
448
+ hidden_states1, *rest, residual = layer_norm_fn(
449
+ hidden_states1,
450
+ self.norm1.weight,
451
+ self.norm1.bias,
452
+ residual=residual,
453
+ x1=hidden_states2,
454
+ weight1=weight2,
455
+ bias1=bias2,
456
+ eps=self.norm1.eps,
457
+ dropout_p=self.dropout1.p if self.training else 0.0,
458
+ prenorm=True,
459
+ residual_in_fp32=self.residual_in_fp32,
460
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
461
+ )
462
+ if self.tied_norm:
463
+ hidden_states2 = hidden_states1
464
+ else:
465
+ (hidden_states2,) = rest
466
+ if mixer_kwargs is None:
467
+ mixer_kwargs = {}
468
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469
+ hidden_states2 = self.mlp(hidden_states2)
470
+ return hidden_states1, hidden_states2, residual
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class XLMRobertaEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
+
4
+ import math
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+
30
+ class FlashSelfAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ causal=False,
44
+ softmax_scale=None,
45
+ attention_dropout=0.0,
46
+ window_size=(-1, -1),
47
+ deterministic=False,
48
+ ):
49
+ super().__init__()
50
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52
+ self.causal = causal
53
+ self.softmax_scale = softmax_scale
54
+ self.drop = nn.Dropout(attention_dropout)
55
+ self.window_size = window_size
56
+ self.deterministic = deterministic
57
+
58
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value.
63
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66
+ causal: if passed, will override self.causal
67
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68
+ of the sequences in the batch, used to index into qkv.
69
+ max_seqlen: int. Maximum sequence length in the batch.
70
+ Returns:
71
+ --------
72
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73
+ else (B, S, H, D).
74
+ """
75
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
76
+ assert qkv.is_cuda
77
+ causal = self.causal if causal is None else causal
78
+ unpadded = cu_seqlens is not None
79
+
80
+ if unpadded:
81
+ assert cu_seqlens.dtype == torch.int32
82
+ assert max_seqlen is not None
83
+ assert isinstance(max_seqlen, int)
84
+ return flash_attn_varlen_qkvpacked_func(
85
+ qkv,
86
+ cu_seqlens,
87
+ max_seqlen,
88
+ self.drop.p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale,
90
+ causal=causal,
91
+ alibi_slopes=None,
92
+ window_size=self.window_size,
93
+ deterministic=self.deterministic,
94
+ )
95
+ else:
96
+ return flash_attn_qkvpacked_func(
97
+ qkv,
98
+ self.drop.p if self.training else 0.0,
99
+ softmax_scale=self.softmax_scale,
100
+ causal=causal,
101
+ alibi_slopes=None,
102
+ window_size=self.window_size,
103
+ deterministic=self.deterministic,
104
+ )
105
+
106
+
107
+ class FlashCrossAttention(nn.Module):
108
+ """Implement the scaled dot product attention with softmax.
109
+ Arguments
110
+ ---------
111
+ softmax_scale: The temperature to use for the softmax attention.
112
+ (default: 1/sqrt(d_keys) where d_keys is computed at
113
+ runtime)
114
+ attention_dropout: The dropout rate to apply to the attention
115
+ (default: 0.0)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ causal=False,
121
+ softmax_scale=None,
122
+ attention_dropout=0.0,
123
+ window_size=(-1, -1),
124
+ deterministic=False,
125
+ ):
126
+ super().__init__()
127
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.drop = nn.Dropout(attention_dropout)
132
+ self.window_size = window_size
133
+ self.deterministic = deterministic
134
+
135
+ def forward(
136
+ self,
137
+ q,
138
+ kv,
139
+ causal=None,
140
+ cu_seqlens=None,
141
+ max_seqlen=None,
142
+ cu_seqlens_k=None,
143
+ max_seqlen_k=None,
144
+ ):
145
+ """Implements the multihead softmax attention.
146
+ Arguments
147
+ ---------
148
+ q: The tensor containing the query. (B, Sq, H, D)
149
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150
+ causal: if passed, will override self.causal
151
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152
+ of the sequences in the batch, used to index into q.
153
+ max_seqlen: int. Maximum sequence length in the batch of q.
154
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155
+ of the sequences in the batch, used to index into kv.
156
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157
+ """
158
+ assert q.dtype in [torch.float16, torch.bfloat16]
159
+ assert q.is_cuda and kv.is_cuda
160
+ causal = self.causal if causal is None else causal
161
+ unpadded = cu_seqlens is not None
162
+
163
+ if unpadded:
164
+ assert cu_seqlens.dtype == torch.int32
165
+ assert max_seqlen is not None
166
+ assert isinstance(max_seqlen, int)
167
+ assert cu_seqlens_k is not None
168
+ assert cu_seqlens_k.dtype == torch.int32
169
+ assert max_seqlen_k is not None
170
+ assert isinstance(max_seqlen, int)
171
+ return flash_attn_varlen_kvpacked_func(
172
+ q,
173
+ kv,
174
+ cu_seqlens,
175
+ cu_seqlens_k,
176
+ max_seqlen,
177
+ max_seqlen_k,
178
+ self.drop.p if self.training else 0.0,
179
+ softmax_scale=self.softmax_scale,
180
+ causal=causal,
181
+ alibi_slopes=None,
182
+ window_size=self.window_size,
183
+ deterministic=self.deterministic,
184
+ )
185
+ else:
186
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
187
+ seqlen_k = kv.shape[1]
188
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189
+ return flash_attn_kvpacked_func(
190
+ q,
191
+ kv,
192
+ self.drop.p if self.training else 0.0,
193
+ causal=causal,
194
+ softmax_scale=self.softmax_scale,
195
+ alibi_slopes=None,
196
+ window_size=self.window_size,
197
+ deterministic=self.deterministic,
198
+ )
199
+
200
+
201
+ class SelfAttention(nn.Module):
202
+ """Implement the scaled dot product attention with softmax.
203
+ Arguments
204
+ ---------
205
+ softmax_scale: The temperature to use for the softmax attention.
206
+ (default: 1/sqrt(d_keys) where d_keys is computed at
207
+ runtime)
208
+ attention_dropout: The dropout rate to apply to the attention
209
+ (default: 0.0)
210
+ """
211
+
212
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213
+ super().__init__()
214
+ self.causal = causal
215
+ self.softmax_scale = softmax_scale
216
+ self.drop = nn.Dropout(attention_dropout)
217
+
218
+ def forward(self, qkv, causal=None, key_padding_mask=None):
219
+ """Implements the multihead softmax attention.
220
+ Arguments
221
+ ---------
222
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223
+ causal: if passed, will override self.causal
224
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225
+ False means to mask out. (B, S)
226
+ """
227
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228
+ causal = self.causal if causal is None else causal
229
+ q, k, v = qkv.unbind(dim=2)
230
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232
+ if key_padding_mask is not None:
233
+ padding_mask = torch.full(
234
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235
+ )
236
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
237
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239
+ if causal:
240
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
241
+ # So we have to construct the mask in float
242
+ causal_mask = torch.triu(
243
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244
+ )
245
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246
+ scores = scores + causal_mask.to(dtype=scores.dtype)
247
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248
+ attention_drop = self.drop(attention)
249
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250
+ return output
251
+
252
+
253
+ class CrossAttention(nn.Module):
254
+ """Implement the scaled dot product attention with softmax.
255
+ Arguments
256
+ ---------
257
+ softmax_scale: The temperature to use for the softmax attention.
258
+ (default: 1/sqrt(d_keys) where d_keys is computed at
259
+ runtime)
260
+ attention_dropout: The dropout rate to apply to the attention
261
+ (default: 0.0)
262
+ """
263
+
264
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265
+ super().__init__()
266
+ self.causal = causal
267
+ self.softmax_scale = softmax_scale
268
+ self.drop = nn.Dropout(attention_dropout)
269
+
270
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
271
+ """Implements the multihead softmax attention.
272
+ Arguments
273
+ ---------
274
+ q: The tensor containing the query. (B, Sq, H, D)
275
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276
+ causal: if passed, will override self.causal
277
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278
+ False means to mask out. (B, Sk)
279
+ """
280
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
281
+ causal = self.causal if causal is None else causal
282
+ seqlen_k = kv.shape[1]
283
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
285
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286
+ k, v = kv.unbind(dim=2)
287
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289
+ if key_padding_mask is not None:
290
+ padding_mask = torch.full(
291
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292
+ )
293
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
294
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296
+ if causal:
297
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298
+ row_idx = rearrange(
299
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300
+ )
301
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302
+ sk = (
303
+ seqlen_k
304
+ if key_padding_mask is None
305
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306
+ )
307
+ causal_mask = col_idx > row_idx + sk - seqlen_q
308
+ scores = scores.masked_fill(causal_mask, -10000.0)
309
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310
+ attention_drop = self.drop(attention)
311
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312
+ return output
313
+
314
+
315
+ class LinearResidual(nn.Linear):
316
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
+
318
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
319
+ return super().forward(input), input
320
+
321
+
322
+ def _update_kv_cache(kv, inference_params, layer_idx):
323
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324
+ # Pre-allocate memory for key-values for inference.
325
+ num_heads, head_dim = kv.shape[-2:]
326
+ if layer_idx not in inference_params.key_value_memory_dict:
327
+ kv_cache = torch.empty(
328
+ inference_params.max_batch_size,
329
+ inference_params.max_seqlen,
330
+ 2,
331
+ num_heads,
332
+ head_dim,
333
+ dtype=kv.dtype,
334
+ device=kv.device,
335
+ )
336
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
337
+ else:
338
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
339
+ # Adjust key and value for inference
340
+ batch_start = inference_params.batch_size_offset
341
+ batch_end = batch_start + kv.shape[0]
342
+ sequence_start = inference_params.seqlen_offset
343
+ sequence_end = sequence_start + kv.shape[1]
344
+ assert batch_end <= kv_cache.shape[0]
345
+ assert sequence_end <= kv_cache.shape[1]
346
+ assert kv_cache is not None
347
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
+
350
+
351
+ class MHA(nn.Module):
352
+ """Multi-head self-attention and cross-attention"""
353
+
354
+ def __init__(
355
+ self,
356
+ embed_dim,
357
+ num_heads,
358
+ num_heads_kv=None,
359
+ cross_attn=False,
360
+ qkv_proj_bias=True,
361
+ out_proj_bias=True,
362
+ dropout=0.0,
363
+ softmax_scale=None,
364
+ causal=False,
365
+ layer_idx=None,
366
+ dwconv=False,
367
+ window_size=(-1, -1),
368
+ fused_bias_fc=False,
369
+ use_flash_attn=False,
370
+ return_residual=False,
371
+ checkpointing=False,
372
+ device=None,
373
+ dtype=None,
374
+ ) -> None:
375
+ """
376
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377
+ return_residual: whether to return the input x along with the output. This is for
378
+ performance reason: for post-norm architecture, returning the input allows us
379
+ to fuse the backward of nn.Linear with the residual connection.
380
+ """
381
+ factory_kwargs = {"device": device, "dtype": dtype}
382
+ super().__init__()
383
+ self.embed_dim = embed_dim
384
+ self.cross_attn = cross_attn
385
+ self.causal = causal
386
+ self.layer_idx = layer_idx
387
+ self.dwconv = dwconv
388
+ self.use_flash_attn = use_flash_attn
389
+ self.return_residual = return_residual
390
+ self.checkpointing = checkpointing
391
+
392
+ if window_size != (-1, -1):
393
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
+
395
+ self.num_heads = num_heads
396
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397
+ assert (
398
+ self.num_heads % self.num_heads_kv == 0
399
+ ), "num_heads must be divisible by num_heads_kv"
400
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401
+ self.head_dim = self.embed_dim // num_heads
402
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
404
+
405
+ if fused_bias_fc and FusedDense is None:
406
+ raise ImportError("fused_dense is not installed")
407
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408
+ linear_resid_cls = (
409
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410
+ )
411
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412
+ inner_attn_cls = (
413
+ partial(FlashSelfAttention, window_size=window_size)
414
+ if use_flash_attn
415
+ else SelfAttention
416
+ )
417
+ inner_cross_attn_cls = (
418
+ partial(FlashCrossAttention, window_size=window_size)
419
+ if use_flash_attn
420
+ else CrossAttention
421
+ )
422
+ if not self.cross_attn:
423
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424
+ else:
425
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427
+ if self.dwconv:
428
+ if self.num_heads_kv == self.num_heads:
429
+ self.dwconv_qkv = nn.Conv1d(
430
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431
+ )
432
+ else:
433
+ self.dwconv_q = nn.Conv1d(
434
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435
+ )
436
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437
+ self.inner_attn = inner_attn_cls(
438
+ causal=causal,
439
+ softmax_scale=softmax_scale,
440
+ attention_dropout=dropout,
441
+ )
442
+ self.inner_cross_attn = inner_cross_attn_cls(
443
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444
+ )
445
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
449
+ device = self.out_proj.weight.device
450
+ return torch.empty(
451
+ batch_size,
452
+ max_seqlen,
453
+ 2,
454
+ self.num_heads_kv,
455
+ self.head_dim,
456
+ dtype=dtype,
457
+ device=device,
458
+ )
459
+
460
+ def _update_kv_cache(self, kv, inference_params):
461
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462
+ assert not self.dwconv, "Generation does not support dwconv yet"
463
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
465
+
466
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467
+ """
468
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469
+ q: (batch_size, seqlen_q, nheads, head_dim)
470
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471
+ """
472
+ assert inference_params is not None and inference_params.seqlen_offset > 0
473
+ assert self.use_flash_attn
474
+ batch = q.shape[0]
475
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476
+ cache_seqlens = (
477
+ inference_params.lengths_per_sample[:batch]
478
+ if inference_params.lengths_per_sample is not None
479
+ else inference_params.seqlen_offset
480
+ )
481
+ context = flash_attn_with_kvcache(
482
+ q,
483
+ kv_cache[:, :, 0],
484
+ kv_cache[:, :, 1],
485
+ kv[:, :, 0],
486
+ kv[:, :, 1],
487
+ cache_seqlens=cache_seqlens,
488
+ softmax_scale=self.inner_cross_attn.softmax_scale,
489
+ causal=self.inner_cross_attn.causal,
490
+ rotary_interleaved=False,
491
+ alibi_slopes=None,
492
+ )
493
+ return context
494
+
495
+ def _update_kvcache_attention(self, q, kv, inference_params):
496
+ """Write kv to inference_params, then do attention"""
497
+ if (
498
+ inference_params.seqlen_offset == 0
499
+ or flash_attn_with_kvcache is None
500
+ or not self.use_flash_attn
501
+ ):
502
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
503
+ kv = self._update_kv_cache(kv, inference_params)
504
+ return self.inner_cross_attn(q, kv)
505
+ else:
506
+ batch = q.shape[0]
507
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508
+ cache_seqlens = (
509
+ inference_params.lengths_per_sample[:batch]
510
+ if inference_params.lengths_per_sample is not None
511
+ else inference_params.seqlen_offset
512
+ )
513
+ return flash_attn_with_kvcache(
514
+ q,
515
+ kv_cache[:, :, 0],
516
+ kv_cache[:, :, 1],
517
+ kv[:, :, 0],
518
+ kv[:, :, 1],
519
+ cache_seqlens=cache_seqlens,
520
+ softmax_scale=self.inner_cross_attn.softmax_scale,
521
+ causal=self.inner_cross_attn.causal,
522
+ alibi_slopes=None,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x,
528
+ x_kv=None,
529
+ key_padding_mask=None,
530
+ cu_seqlens=None,
531
+ max_seqlen=None,
532
+ mixer_subset=None,
533
+ inference_params=None,
534
+ **kwargs,
535
+ ):
536
+ """
537
+ Arguments:
538
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540
+ is the is the sum of the sequence lengths in the batch.
541
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543
+ of the sequences in the batch, used to index into x. Only applicable when using
544
+ FlashAttention.
545
+ max_seqlen: int. Maximum sequence length in the batch.
546
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
547
+ (batch, seqlen). Only applicable when not using FlashAttention.
548
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
549
+ before applying the query projection. Useful for e.g., ViT where we only care
550
+ about the CLS token in the last layer.
551
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
552
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553
+ """
554
+ if cu_seqlens is not None:
555
+ assert max_seqlen is not None
556
+ assert key_padding_mask is None
557
+ assert self.use_flash_attn
558
+ assert not self.dwconv
559
+ if key_padding_mask is not None:
560
+ assert cu_seqlens is None
561
+ assert max_seqlen is None
562
+ assert not self.use_flash_attn
563
+ if inference_params is not None:
564
+ assert key_padding_mask is None
565
+ assert cu_seqlens is None and max_seqlen is None
566
+ assert not self.dwconv
567
+
568
+ kwargs = (
569
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570
+ if self.use_flash_attn
571
+ else {"key_padding_mask": key_padding_mask, **kwargs}
572
+ )
573
+ seqlen_offset = (
574
+ 0
575
+ if inference_params is None
576
+ else (
577
+ inference_params.lengths_per_sample
578
+ if inference_params.lengths_per_sample is not None
579
+ else inference_params.seqlen_offset
580
+ )
581
+ )
582
+ rotary_max_seqlen = (
583
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
584
+ )
585
+ batch, seqlen = x.shape[:2]
586
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
587
+ assert x_kv is None and mixer_subset is None
588
+ if not self.return_residual:
589
+ qkv = self.Wqkv(x)
590
+ else:
591
+ qkv, x = self.Wqkv(x)
592
+ if self.dwconv:
593
+ qkv = rearrange(
594
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595
+ ).contiguous()
596
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597
+ if (
598
+ inference_params is None
599
+ or inference_params.seqlen_offset == 0
600
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601
+ or not self.use_flash_attn
602
+ ):
603
+ if inference_params is None:
604
+ if not self.checkpointing:
605
+ context = self.inner_attn(qkv, **kwargs)
606
+ else:
607
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608
+ else:
609
+ context = self._update_kvcache_attention(
610
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
611
+ )
612
+ else:
613
+ context = self._apply_rotary_update_kvcache_attention(
614
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
615
+ )
616
+ else:
617
+ if self.cross_attn:
618
+ if not self.return_residual:
619
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620
+ kv = self.Wkv(x_kv if x_kv is not None else x)
621
+ else:
622
+ if x_kv is not None:
623
+ kv, x_kv = self.Wkv(x_kv)
624
+ else:
625
+ kv, x = self.Wkv(x)
626
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627
+ else:
628
+ assert self.num_heads_kv != self.num_heads
629
+ if not self.return_residual:
630
+ qkv = self.Wqkv(x)
631
+ else:
632
+ qkv, x = self.Wqkv(x)
633
+ q = qkv[..., : self.num_heads * self.head_dim]
634
+ kv = qkv[..., self.num_heads * self.head_dim :]
635
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637
+ if self.dwconv:
638
+ q = rearrange(
639
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640
+ ).contiguous()
641
+ kv = rearrange(
642
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643
+ ).contiguous()
644
+ if (
645
+ inference_params is None
646
+ or inference_params.seqlen_offset == 0
647
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648
+ or not self.use_flash_attn
649
+ ):
650
+ if inference_params is None:
651
+ if not self.checkpointing:
652
+ context = self.inner_cross_attn(q, kv, **kwargs)
653
+ else:
654
+ context = torch.utils.checkpoint.checkpoint(
655
+ self.inner_cross_attn, q, kv, **kwargs
656
+ )
657
+ else:
658
+ context = self._update_kvcache_attention(q, kv, inference_params)
659
+ else:
660
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
+ from transformers.models.bert.modeling_bert import (
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ BertForPreTrainingOutput,
32
+ )
33
+
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
+ from .embedding import XLMRobertaEmbeddings
45
+ from .mha import MHA
46
+ from .mlp import FusedMLP, Mlp
47
+
48
+ try:
49
+ from flash_attn.ops.fused_dense import FusedDense
50
+ except ImportError:
51
+ FusedDense = None
52
+
53
+ try:
54
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
55
+ except ImportError:
56
+ layer_norm_fn = None
57
+
58
+
59
+ try:
60
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
61
+ except ImportError:
62
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
+
64
+ try:
65
+ from tqdm.autonotebook import trange
66
+ except ImportError:
67
+ trange = None
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
74
+ if not getattr(config, "use_flash_attn", False):
75
+ return False
76
+ if not torch.cuda.is_available():
77
+ return False
78
+ if importlib.util.find_spec("flash_attn") is None:
79
+ logger.warning(
80
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
81
+ )
82
+ return False
83
+ return True
84
+
85
+
86
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
87
+ use_flash_attn = get_use_flash_attn(config)
88
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
+
90
+ mixer_cls = partial(
91
+ MHA,
92
+ num_heads=config.num_attention_heads,
93
+ cross_attn=cross_attn,
94
+ dropout=config.attention_probs_dropout_prob,
95
+ causal=False,
96
+ fused_bias_fc=fused_bias_fc,
97
+ use_flash_attn=use_flash_attn,
98
+ return_residual=return_residual,
99
+ )
100
+ return mixer_cls
101
+
102
+
103
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
104
+ inner_dim = config.intermediate_size
105
+ fused_mlp = getattr(config, "fused_mlp", False)
106
+ if fused_mlp:
107
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108
+ "fused_mlp only " "supports approximate gelu"
109
+ )
110
+ if not fused_mlp:
111
+ approximate = (
112
+ "tanh"
113
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114
+ else "none"
115
+ )
116
+ mlp_cls = partial(
117
+ Mlp,
118
+ hidden_features=inner_dim,
119
+ activation=partial(F.gelu, approximate=approximate),
120
+ return_residual=return_residual,
121
+ )
122
+ else:
123
+ if FusedMLP is None:
124
+ raise ImportError("fused_dense is not installed")
125
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127
+ if isinstance(mlp_checkpoint_lvl, Sequence):
128
+ assert layer_idx is not None
129
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130
+ mlp_cls = partial(
131
+ FusedMLP,
132
+ hidden_features=inner_dim,
133
+ checkpoint_lvl=mlp_checkpoint_lvl,
134
+ return_residual=return_residual,
135
+ )
136
+ return mlp_cls
137
+
138
+
139
+ def create_block(config, layer_idx=None):
140
+ last_layer_subset = getattr(config, "last_layer_subset", False)
141
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144
+ # one layer) so we just choose not to return residual in this case.
145
+ return_residual = not cross_attn
146
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149
+ block = Block(
150
+ config.hidden_size,
151
+ mixer_cls,
152
+ mlp_cls,
153
+ norm_cls=norm_cls,
154
+ prenorm=False,
155
+ resid_dropout1=config.hidden_dropout_prob,
156
+ resid_dropout2=config.hidden_dropout_prob,
157
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158
+ return_residual=return_residual,
159
+ )
160
+ return block
161
+
162
+
163
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164
+ def _init_weights(module, initializer_range=0.02):
165
+ if isinstance(module, nn.Linear):
166
+ nn.init.normal_(module.weight, std=initializer_range)
167
+ if module.bias is not None:
168
+ nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.padding_idx is not None:
172
+ nn.init.zeros_(module.weight[module.padding_idx])
173
+
174
+
175
+ class XLMRobertaEncoder(nn.Module):
176
+ def __init__(self, config: XLMRobertaFlashConfig):
177
+ super().__init__()
178
+ self.use_flash_attn = get_use_flash_attn(config)
179
+ self.layers = nn.ModuleList(
180
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
+ )
182
+ self._grad_checkpointing = False
183
+
184
+ @property
185
+ def gradient_checkpointing(self):
186
+ return self._grad_checkpointing
187
+
188
+ @gradient_checkpointing.setter
189
+ def gradient_checkpointing(self, value):
190
+ self._grad_checkpointing = value
191
+
192
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193
+ """If subset_mask is not None, we only want output for the subset of the sequence.
194
+ This means that we only compute the last layer output for these tokens.
195
+ subset_mask: (batch, seqlen), dtype=torch.bool
196
+ """
197
+ if key_padding_mask is None or not self.use_flash_attn:
198
+ mixer_kwargs = (
199
+ {"key_padding_mask": key_padding_mask.bool()}
200
+ if key_padding_mask is not None
201
+ else None
202
+ )
203
+ for layer in self.layers:
204
+ if self._grad_checkpointing:
205
+ hidden_states = torch.utils.checkpoint.checkpoint(
206
+ layer,
207
+ hidden_states,
208
+ use_reentrant=False,
209
+ mixer_kwargs=mixer_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213
+ if subset_mask is not None:
214
+ hidden_states = hidden_states[subset_mask]
215
+ else:
216
+ batch, seqlen = hidden_states.shape[:2]
217
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218
+ hidden_states, key_padding_mask
219
+ )
220
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221
+ if subset_mask is None:
222
+ for layer in self.layers:
223
+ if self._grad_checkpointing:
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ layer,
226
+ hidden_states,
227
+ use_reentrant=False,
228
+ mixer_kwargs=mixer_kwargs,
229
+ )
230
+ else:
231
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233
+ else:
234
+ for layer in self.layers[:-1]:
235
+ if self._grad_checkpointing:
236
+ hidden_states = torch.utils.checkpoint.checkpoint(
237
+ layer,
238
+ hidden_states,
239
+ use_reentrant=False,
240
+ mixer_kwargs=mixer_kwargs,
241
+ )
242
+ else:
243
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244
+ if key_padding_mask is not None:
245
+ subset_idx = torch.nonzero(
246
+ subset_mask[key_padding_mask], as_tuple=False
247
+ ).flatten()
248
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
249
+ dim=-1, dtype=torch.int32
250
+ )
251
+ subset_cu_seqlens = F.pad(
252
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253
+ (1, 0),
254
+ )
255
+ else:
256
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258
+ subset_cu_seqlens = F.pad(
259
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260
+ (1, 0),
261
+ )
262
+ hidden_states_subset, hidden_states = index_first_axis_residual(
263
+ hidden_states, subset_idx
264
+ )
265
+ # It's ok to set max_seqlen_q to be much larger
266
+ mixer_kwargs = {
267
+ "x_kv": hidden_states,
268
+ "cu_seqlens": subset_cu_seqlens,
269
+ "max_seqlen": max_seqlen_in_batch,
270
+ "cu_seqlens_k": cu_seqlens,
271
+ "max_seqlen_k": max_seqlen_in_batch,
272
+ }
273
+ if self._grad_checkpointing:
274
+ torch.utils.checkpoint.checkpoint(
275
+ self.layers[-1],
276
+ hidden_states_subset,
277
+ use_reentrant=False,
278
+ mixer_kwargs=mixer_kwargs,
279
+ )
280
+ else:
281
+ hidden_states = self.layers[-1](
282
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
283
+ )
284
+ return hidden_states
285
+
286
+
287
+ class XLMRobertaPooler(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
291
+ if fused_bias_fc and FusedDense is None:
292
+ raise ImportError("fused_dense is not installed")
293
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
295
+ self.activation = nn.Tanh()
296
+
297
+ def forward(self, hidden_states, pool=True):
298
+ # We "pool" the model by simply taking the hidden state corresponding
299
+ # to the first token.
300
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301
+ pooled_output = self.dense(first_token_tensor)
302
+ pooled_output = self.activation(pooled_output)
303
+ return pooled_output
304
+
305
+
306
+ class XLMRobertaPredictionHeadTransform(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
310
+ if fused_bias_fc and FusedDense is None:
311
+ raise ImportError("fused_dense is not installed")
312
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
314
+ raise ImportError("Triton is not installed")
315
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
317
+ approximate = (
318
+ "tanh"
319
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320
+ else "none"
321
+ )
322
+ self.transform_act_fn = nn.GELU(approximate=approximate)
323
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+
325
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.transform_act_fn(hidden_states)
328
+ if not self.fused_dropout_add_ln:
329
+ hidden_states = self.layer_norm(hidden_states)
330
+ else:
331
+ hidden_states = layer_norm_fn(
332
+ hidden_states,
333
+ self.layer_norm.weight,
334
+ self.layer_norm.bias,
335
+ eps=self.layer_norm.eps,
336
+ )
337
+ return hidden_states
338
+
339
+
340
+ class XLMRobertaLMPredictionHead(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
344
+ if fused_bias_fc and FusedDense is None:
345
+ raise ImportError("fused_dense is not installed")
346
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
+
348
+ self.transform = XLMRobertaPredictionHeadTransform(config)
349
+
350
+ # The output weights are the same as the input embeddings, but there is
351
+ # an output-only bias for each token.
352
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
+
354
+ def forward(self, hidden_states):
355
+ hidden_states = self.transform(hidden_states)
356
+ hidden_states = self.decoder(hidden_states)
357
+ return hidden_states
358
+
359
+
360
+ class XLMRobertaPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.predictions = XLMRobertaLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
373
+ """An abstract class to handle weights initialization and
374
+ a simple interface for dowloading and loading pretrained models.
375
+ """
376
+
377
+ config_class = XLMRobertaFlashConfig
378
+ base_model_prefix = "roberta"
379
+ supports_gradient_checkpointing = True
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, XLMRobertaEncoder):
383
+ module.gradient_checkpointing = value
384
+
385
+ @classmethod
386
+ def from_pretrained(
387
+ cls,
388
+ *args,
389
+ **kwargs,
390
+ ):
391
+ if not 'torch_dtype' in kwargs:
392
+ kwargs['torch_dtype'] = 'auto'
393
+ return super().from_pretrained(*args, **kwargs)
394
+
395
+
396
+
397
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
398
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399
+ super().__init__(config)
400
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
402
+ config.vocab_size += self.pad_vocab_size_multiple - (
403
+ config.vocab_size % self.pad_vocab_size_multiple
404
+ )
405
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
407
+ raise ImportError("Triton is not installed")
408
+ assert config.hidden_act in [
409
+ "gelu",
410
+ "gelu_new",
411
+ "gelu_fast",
412
+ "gelu_pytorch_tanh",
413
+ ]
414
+
415
+ self.embeddings = XLMRobertaEmbeddings(
416
+ config.hidden_size,
417
+ config.vocab_size,
418
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419
+ config.type_vocab_size,
420
+ padding_idx=config.pad_token_id,
421
+ )
422
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.encoder = XLMRobertaEncoder(config)
425
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
+
427
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
+
429
+
430
+ @torch.inference_mode()
431
+ def encode(
432
+ self: 'XLMRobertaModel',
433
+ sentences: Union[str, List[str]],
434
+ batch_size: int = 32,
435
+ show_progress_bar: Optional[bool] = None,
436
+ output_value: str = 'sentence_embedding',
437
+ convert_to_numpy: bool = True,
438
+ convert_to_tensor: bool = False,
439
+ device: Optional[torch.device] = None,
440
+ normalize_embeddings: bool = False,
441
+ truncate_dim: Optional[int] = None,
442
+ **tokenizer_kwargs,
443
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444
+ """
445
+ Computes sentence embeddings
446
+ Args:
447
+ sentences(`str` or `List[str]`):
448
+ Sentence or sentences to be encoded
449
+ batch_size(`int`, *optional*, defaults to 32):
450
+ Batch size for the computation
451
+ show_progress_bar(`bool`, *optional*, defaults to None):
452
+ Show a progress bar when encoding sentences.
453
+ If set to None, progress bar is only shown when
454
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456
+ Default sentence_embedding, to get sentence embeddings.
457
+ Can be set to token_embeddings to get wordpiece token embeddings.
458
+ Set to None, to get all output values
459
+ convert_to_numpy(`bool`, *optional*, defaults to True):
460
+ If true, the output is a list of numpy vectors.
461
+ Else, it is a list of pytorch tensors.
462
+ convert_to_tensor(`bool`, *optional*, defaults to False):
463
+ If true, you get one large tensor as return.
464
+ Overwrites any setting from convert_to_numpy
465
+ device(`torch.device`, *optional*, defaults to None):
466
+ Which torch.device to use for the computation
467
+ normalize_embeddings(`bool`, *optional*, defaults to False):
468
+ If set to true, returned vectors will have length 1. In that case, the
469
+ faster dot-product (util.dot_score) instead of cosine similarity can
470
+ be used.
471
+ truncate_dim(`int`, *optional*, defaults to None):
472
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
473
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474
+ Keyword arguments for the tokenizer
475
+ Returns:
476
+ By default, a list of tensors is returned.
477
+ If convert_to_tensor, a stacked tensor is returned.
478
+ If convert_to_numpy, a numpy matrix is returned.
479
+ """
480
+ from transformers import AutoTokenizer
481
+
482
+ self.tokenizer = AutoTokenizer.from_pretrained(
483
+ self.name_or_path, trust_remote_code=True
484
+ )
485
+
486
+ is_training = self.training
487
+ self.eval()
488
+
489
+ if show_progress_bar is None:
490
+ show_progress_bar = (
491
+ logger.getEffectiveLevel() == logging.INFO
492
+ or logger.getEffectiveLevel() == logging.DEBUG
493
+ )
494
+
495
+ if convert_to_tensor:
496
+ convert_to_numpy = False
497
+
498
+ if output_value != 'sentence_embedding':
499
+ convert_to_tensor = False
500
+ convert_to_numpy = False
501
+
502
+ input_was_string = False
503
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504
+ sentences = [sentences]
505
+ input_was_string = True
506
+
507
+ if device is not None:
508
+ self.to(device)
509
+
510
+ permutation = np.argsort([-len(i) for i in sentences])
511
+ inverse_permutation = np.argsort(permutation)
512
+ sentences = [sentences[idx] for idx in permutation]
513
+
514
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517
+ )
518
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
+
520
+ all_embeddings = []
521
+
522
+ if trange is not None:
523
+ range_iter = trange(
524
+ 0,
525
+ len(sentences),
526
+ batch_size,
527
+ desc="Encoding",
528
+ disable=not show_progress_bar,
529
+ )
530
+ else:
531
+ range_iter = range(0, len(sentences), batch_size)
532
+
533
+ for i in range_iter:
534
+ encoded_input = self.tokenizer(
535
+ sentences[i : i + batch_size],
536
+ return_tensors='pt',
537
+ **tokenizer_kwargs,
538
+ ).to(self.device)
539
+ token_embs = self.forward(**encoded_input)[0]
540
+
541
+ # Accumulate in fp32 to avoid overflow
542
+ token_embs = token_embs.float()
543
+
544
+ if output_value == 'token_embeddings':
545
+ raise NotImplementedError
546
+ elif output_value is None:
547
+ raise NotImplementedError
548
+ else:
549
+ if self.config.emb_pooler == 'cls':
550
+ embeddings = self.cls_pooling(
551
+ token_embs, encoded_input['attention_mask']
552
+ )
553
+ else:
554
+ embeddings = self.mean_pooling(
555
+ token_embs, encoded_input['attention_mask']
556
+ )
557
+
558
+ if normalize_embeddings:
559
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
+
561
+ if convert_to_numpy:
562
+ embeddings = embeddings.cpu()
563
+ all_embeddings.extend(embeddings)
564
+
565
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
+
567
+ truncate_dim = truncate_dim or self.config.truncate_dim
568
+ if truncate_dim:
569
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
+
571
+ if convert_to_tensor:
572
+ all_embeddings = torch.stack(all_embeddings)
573
+ elif convert_to_numpy:
574
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
+
576
+ if input_was_string:
577
+ all_embeddings = all_embeddings[0]
578
+
579
+ self.train(is_training)
580
+ return all_embeddings
581
+
582
+
583
+ def truncate_embeddings(self, embeddings, truncate_dim):
584
+ if not self.config.matryoshka_dimensions:
585
+ logger.warning(
586
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587
+ )
588
+ return embeddings
589
+ elif truncate_dim in self.config.matryoshka_dimensions:
590
+ return [tensor[:truncate_dim] for tensor in embeddings]
591
+ else:
592
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
+
595
+ def mean_pooling(
596
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597
+ ):
598
+ input_mask_expanded = (
599
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600
+ )
601
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602
+ input_mask_expanded.sum(1), min=1e-9
603
+ )
604
+
605
+
606
+ def cls_pooling(
607
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608
+ ):
609
+ return token_embeddings[:,0]
610
+
611
+
612
+ def forward(
613
+ self,
614
+ input_ids,
615
+ position_ids=None,
616
+ token_type_ids=None,
617
+ attention_mask=None,
618
+ masked_tokens_mask=None,
619
+ return_dict=None,
620
+ **kwargs,
621
+ ):
622
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623
+ we only want the output for the masked tokens. This means that we only compute the last
624
+ layer output for these tokens.
625
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626
+ """
627
+
628
+ if kwargs:
629
+ for key, value in kwargs.items():
630
+ if value is not None:
631
+ logger.warning(
632
+ 'Flash attention implementation does not support kwargs: %s',
633
+ key,
634
+ )
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ hidden_states = self.embeddings(
641
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642
+ )
643
+ # TD [2022-12:18]: Don't need to force residual in fp32
644
+ # BERT puts embedding LayerNorm before embedding dropout.
645
+ if not self.fused_dropout_add_ln:
646
+ hidden_states = self.emb_ln(hidden_states)
647
+ else:
648
+ hidden_states = layer_norm_fn(
649
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650
+ )
651
+ hidden_states = self.emb_drop(hidden_states)
652
+
653
+ if masked_tokens_mask is not None:
654
+ batch_size, seqlen = input_ids.shape[:2]
655
+ # We also need the first column for the CLS token
656
+ first_col_mask = torch.zeros(
657
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658
+ )
659
+ first_col_mask[:, 0] = True
660
+ subset_mask = masked_tokens_mask | first_col_mask
661
+ else:
662
+ subset_mask = None
663
+
664
+ sequence_output = self.encoder(
665
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666
+ )
667
+
668
+ if masked_tokens_mask is None:
669
+ pooled_output = (
670
+ self.pooler(sequence_output) if self.pooler is not None else None
671
+ )
672
+ else:
673
+ # TD [2022-03-01]: the indexing here is very tricky.
674
+ if attention_mask is not None:
675
+ subset_idx = subset_mask[attention_mask]
676
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677
+ sequence_output = sequence_output[
678
+ masked_tokens_mask[attention_mask][subset_idx]
679
+ ]
680
+ else:
681
+ pool_input = sequence_output[first_col_mask[subset_mask]]
682
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683
+ pooled_output = (
684
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
685
+ )
686
+
687
+ if not return_dict:
688
+ return sequence_output, pooled_output
689
+
690
+ return BaseModelOutputWithPoolingAndCrossAttentions(
691
+ last_hidden_state=sequence_output,
692
+ pooler_output=pooled_output,
693
+ )
694
+
695
+
696
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ if config.is_decoder:
703
+ logger.warning(
704
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705
+ "bi-directional self-attention."
706
+ )
707
+
708
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709
+ self.lm_head = XLMRobertaLMHead(config)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ def get_input_embeddings(self):
715
+ return self.roberta.embeddings.word_embeddings
716
+
717
+ def get_output_embeddings(self):
718
+ return self.lm_head.decoder
719
+
720
+ def set_output_embeddings(self, new_embeddings):
721
+ self.lm_head.decoder = new_embeddings
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.FloatTensor] = None,
727
+ token_type_ids: Optional[torch.LongTensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ head_mask: Optional[torch.FloatTensor] = None,
730
+ inputs_embeds: Optional[torch.FloatTensor] = None,
731
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
732
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ r"""
739
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744
+ Used to hide legacy arguments that have been deprecated.
745
+ """
746
+ return_dict = (
747
+ return_dict if return_dict is not None else self.config.use_return_dict
748
+ )
749
+
750
+ outputs = self.roberta(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ token_type_ids=token_type_ids,
754
+ position_ids=position_ids,
755
+ head_mask=head_mask,
756
+ inputs_embeds=inputs_embeds,
757
+ encoder_hidden_states=encoder_hidden_states,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ output_attentions=output_attentions,
760
+ output_hidden_states=output_hidden_states,
761
+ return_dict=return_dict,
762
+ )
763
+ sequence_output = outputs[0]
764
+ prediction_scores = self.lm_head(sequence_output)
765
+
766
+ masked_lm_loss = None
767
+ if labels is not None:
768
+ # move labels to correct device to enable model parallelism
769
+ labels = labels.to(prediction_scores.device)
770
+ loss_fct = CrossEntropyLoss()
771
+ masked_lm_loss = loss_fct(
772
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773
+ )
774
+
775
+ if not return_dict:
776
+ output = (prediction_scores,) + outputs[2:]
777
+ return (
778
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779
+ )
780
+
781
+ return MaskedLMOutput(
782
+ loss=masked_lm_loss,
783
+ logits=prediction_scores,
784
+ hidden_states=outputs.hidden_states,
785
+ attentions=outputs.attentions,
786
+ )
787
+
788
+
789
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790
+ class XLMRobertaClassificationHead(nn.Module):
791
+ """Head for sentence-level classification tasks."""
792
+
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
796
+ if fused_bias_fc and FusedDense is None:
797
+ raise ImportError("fused_dense is not installed")
798
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
800
+ classifier_dropout = (
801
+ config.classifier_dropout
802
+ if config.classifier_dropout is not None
803
+ else config.hidden_dropout_prob
804
+ )
805
+ self.dropout = nn.Dropout(classifier_dropout)
806
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
+
808
+ def forward(self, features, **kwargs):
809
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810
+ x = self.dropout(x)
811
+ x = self.dense(x)
812
+ x = torch.tanh(x)
813
+ x = self.dropout(x)
814
+ x = self.out_proj(x)
815
+ return x
816
+
817
+
818
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ self.num_labels = config.num_labels
823
+ self.config = config
824
+
825
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826
+ self.classifier = XLMRobertaClassificationHead(config)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def forward(
832
+ self,
833
+ input_ids: Optional[torch.LongTensor] = None,
834
+ attention_mask: Optional[torch.FloatTensor] = None,
835
+ token_type_ids: Optional[torch.LongTensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ head_mask: Optional[torch.FloatTensor] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ labels: Optional[torch.LongTensor] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844
+ r"""
845
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
+ """
850
+ return_dict = (
851
+ return_dict if return_dict is not None else self.config.use_return_dict
852
+ )
853
+
854
+ outputs = self.roberta(
855
+ input_ids,
856
+ attention_mask=attention_mask,
857
+ token_type_ids=token_type_ids,
858
+ position_ids=position_ids,
859
+ head_mask=head_mask,
860
+ inputs_embeds=inputs_embeds,
861
+ output_attentions=output_attentions,
862
+ output_hidden_states=output_hidden_states,
863
+ return_dict=return_dict,
864
+ )
865
+ sequence_output = outputs[0]
866
+ logits = self.classifier(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # move labels to correct device to enable model parallelism
871
+ labels = labels.to(logits.device)
872
+ if self.config.problem_type is None:
873
+ if self.num_labels == 1:
874
+ self.config.problem_type = "regression"
875
+ elif self.num_labels > 1 and (
876
+ labels.dtype == torch.long or labels.dtype == torch.int
877
+ ):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ loss_fct = MSELoss()
884
+ if self.num_labels == 1:
885
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
886
+ else:
887
+ loss = loss_fct(logits, labels)
888
+ elif self.config.problem_type == "single_label_classification":
889
+ loss_fct = CrossEntropyLoss()
890
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891
+ elif self.config.problem_type == "multi_label_classification":
892
+ loss_fct = BCEWithLogitsLoss()
893
+ loss = loss_fct(logits, labels)
894
+
895
+ if not return_dict:
896
+ output = (logits,) + outputs[2:]
897
+ return ((loss,) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutput(
900
+ loss=loss,
901
+ logits=logits,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
905
+
906
+
907
+ @torch.inference_mode()
908
+ def compute_score(
909
+ self,
910
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911
+ batch_size: int = 32,
912
+ max_length: Optional[int] = None,
913
+ ) -> List[float]:
914
+
915
+ if not hasattr(self, "_tokenizer"):
916
+ from transformers import AutoTokenizer
917
+
918
+ self._tokenizer = AutoTokenizer.from_pretrained(
919
+ self.name_or_path, trust_remote_code=True
920
+ )
921
+
922
+ assert isinstance(sentence_pairs, list)
923
+ if isinstance(sentence_pairs[0], str):
924
+ sentence_pairs = [sentence_pairs]
925
+
926
+ all_scores = []
927
+ for start_index in range(
928
+ 0, len(sentence_pairs), batch_size
929
+ ):
930
+ sentences_batch = sentence_pairs[
931
+ start_index : start_index + batch_size
932
+ ]
933
+ inputs = self._tokenizer(
934
+ sentences_batch,
935
+ padding=True,
936
+ truncation=True,
937
+ return_tensors='pt',
938
+ max_length=max_length,
939
+ ).to(self.device)
940
+ scores = (
941
+ self.forward(**inputs, return_dict=True)
942
+ .logits.view(
943
+ -1,
944
+ )
945
+ .float()
946
+ )
947
+ scores = torch.sigmoid(scores)
948
+ all_scores.extend(scores.cpu().numpy().tolist())
949
+
950
+ if len(all_scores) == 1:
951
+ return all_scores[0]
952
+ return all_scores
953
+
954
+ def predict(
955
+ self,
956
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
957
+ batch_size: int = 32,
958
+ max_length: Optional[int] = None,
959
+ ) -> List[float]:
960
+ # used for beir evaluation
961
+ return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
962
+
963
+ def rerank(
964
+ self,
965
+ query: str,
966
+ documents: List[str],
967
+ batch_size: int = 32,
968
+ max_length: int = 1024,
969
+ max_query_length: int = 512,
970
+ overlap_tokens: int = 80,
971
+ top_n: Optional[int] = None,
972
+ **kwargs,
973
+ ):
974
+ assert max_length >= max_query_length * 2, (
975
+ f'max_length ({max_length}) must be greater than or equal to '
976
+ f'max_query_length ({max_query_length}) * 2'
977
+ )
978
+
979
+ if not hasattr(self, "_tokenizer"):
980
+ from transformers import AutoTokenizer
981
+
982
+ self._tokenizer = AutoTokenizer.from_pretrained(
983
+ self.name_or_path, trust_remote_code=True
984
+ )
985
+
986
+ # preproc of tokenization
987
+ sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
988
+ query,
989
+ documents,
990
+ tokenizer=self._tokenizer,
991
+ max_length=max_length,
992
+ max_query_length=max_query_length,
993
+ overlap_tokens=overlap_tokens,
994
+ )
995
+
996
+ tot_scores = []
997
+ with torch.no_grad():
998
+ for k in range(0, len(sentence_pairs), batch_size):
999
+ batch = self._tokenizer.pad(
1000
+ sentence_pairs[k : k + batch_size],
1001
+ padding=True,
1002
+ max_length=max_length,
1003
+ pad_to_multiple_of=None,
1004
+ return_tensors="pt",
1005
+ )
1006
+ batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1007
+ scores = (
1008
+ self.forward(**batch_on_device, return_dict=True)
1009
+ .logits.view(
1010
+ -1,
1011
+ )
1012
+ .float()
1013
+ )
1014
+ scores = torch.sigmoid(scores)
1015
+ tot_scores.extend(scores.cpu().numpy().tolist())
1016
+
1017
+ # ranking
1018
+ merge_scores = [0 for _ in range(len(documents))]
1019
+ for pid, score in zip(sentence_pairs_pids, tot_scores):
1020
+ merge_scores[pid] = max(merge_scores[pid], score)
1021
+
1022
+ merge_scores_argsort = np.argsort(merge_scores)[::-1]
1023
+ sorted_documents = []
1024
+ sorted_scores = []
1025
+ for mid in merge_scores_argsort:
1026
+ sorted_scores.append(merge_scores[mid])
1027
+ sorted_documents.append(documents[mid])
1028
+
1029
+ top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1030
+
1031
+ return [
1032
+ {
1033
+ 'document': sorted_documents[i],
1034
+ 'relevance_score': sorted_scores[i],
1035
+ 'index': merge_scores_argsort[i],
1036
+ }
1037
+ for i in range(top_n)
1038
+ ]
1039
+
1040
+
1041
+ def reranker_tokenize_preproc(
1042
+ query: str,
1043
+ passages: List[str],
1044
+ tokenizer=None,
1045
+ max_length: int = 1024,
1046
+ max_query_length: int = 512,
1047
+ overlap_tokens: int = 80,
1048
+ ):
1049
+ from copy import deepcopy
1050
+
1051
+ assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1052
+ sep_id = tokenizer.sep_token_id
1053
+
1054
+ def _merge_inputs(chunk1_raw, chunk2):
1055
+ chunk1 = deepcopy(chunk1_raw)
1056
+ chunk1['input_ids'].append(sep_id)
1057
+ chunk1['input_ids'].extend(chunk2['input_ids'])
1058
+ chunk1['input_ids'].append(sep_id)
1059
+ chunk1['attention_mask'].append(chunk2['attention_mask'][0])
1060
+ chunk1['attention_mask'].extend(chunk2['attention_mask'])
1061
+ chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
1062
+ if 'token_type_ids' in chunk1:
1063
+ token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1064
+ chunk1['token_type_ids'].extend(token_type_ids)
1065
+ return chunk1
1066
+
1067
+ # Note: the long query will be truncated to 256 tokens by default
1068
+ query_inputs = tokenizer.encode_plus(
1069
+ query, truncation=True, padding=False, max_length=max_query_length
1070
+ )
1071
+
1072
+ max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1073
+ # assert (
1074
+ # max_passage_inputs_length > 100
1075
+ # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1076
+
1077
+ overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1078
+
1079
+ res_merge_inputs = []
1080
+ res_merge_inputs_pids = []
1081
+ for pid, passage in enumerate(passages):
1082
+ passage_inputs = tokenizer.encode_plus(
1083
+ passage,
1084
+ truncation=False,
1085
+ padding=False,
1086
+ add_special_tokens=False,
1087
+ max_length=0,
1088
+ )
1089
+ passage_inputs_length = len(passage_inputs['input_ids'])
1090
+
1091
+ if passage_inputs_length <= max_passage_inputs_length:
1092
+ qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1093
+ res_merge_inputs.append(qp_merge_inputs)
1094
+ res_merge_inputs_pids.append(pid)
1095
+ else:
1096
+ start_id = 0
1097
+ while start_id < passage_inputs_length:
1098
+ end_id = start_id + max_passage_inputs_length
1099
+ # make sure the length of the last chunk is `max_passage_inputs_length`
1100
+ if end_id >= passage_inputs_length:
1101
+ sub_passage_inputs = {
1102
+ k: v[-max_passage_inputs_length:]
1103
+ for k, v in passage_inputs.items()
1104
+ }
1105
+ else:
1106
+ sub_passage_inputs = {
1107
+ k: v[start_id:end_id] for k, v in passage_inputs.items()
1108
+ }
1109
+ start_id = (
1110
+ end_id - overlap_tokens_implt
1111
+ if end_id < passage_inputs_length
1112
+ else end_id
1113
+ )
1114
+
1115
+ qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1116
+ res_merge_inputs.append(qp_merge_inputs)
1117
+ res_merge_inputs_pids.append(pid)
1118
+
1119
+ return res_merge_inputs, res_merge_inputs_pids
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import nn, Tensor
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s
xlm_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)