Yuanfei commited on
Commit
6b710a1
·
verified ·
1 Parent(s): ca6b592

Update modeling_gplm.py

Browse files
Files changed (1) hide show
  1. modeling_gplm.py +1145 -1161
modeling_gplm.py CHANGED
@@ -1,1161 +1,1145 @@
1
- #!/usr/bin/env python
2
- # encoding: utf-8
3
- '''
4
- @license: (C) Copyright 2021, Hey.
5
- @author: Hey
6
- @email: [email protected]
7
- @tel: 137****6540
8
- @datetime: 2023/7/24 10:01
9
- @project: LucaOne
10
- @file: modeling_gplm
11
- @desc: LucaOne Model Detail
12
- '''
13
- import math
14
- from typing import Dict, Optional, Sequence, Tuple, List, Union
15
- import uuid
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import Tensor, nn
19
- from torch.nn import Parameter
20
-
21
-
22
- def gelu(x):
23
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24
-
25
-
26
- def symmetrize(x):
27
- return x + x.transpose(-1, -2)
28
-
29
-
30
- def apc(x):
31
- a1 = x.sum(-1, keepdims=True)
32
- a2 = x.sum(-2, keepdims=True)
33
- a12 = x.sum((-1, -2), keepdims=True)
34
-
35
- avg = a1 * a2
36
- avg.div_(a12) # in-place to reduce memory
37
- normalized = x - avg
38
- return normalized
39
-
40
-
41
- class LucaGPLM1LayerNorm(nn.Module):
42
- def __init__(self, hidden_size, eps=1e-12, affine=True):
43
- """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
44
- super().__init__()
45
- self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
46
- self.eps = eps
47
- self.affine = bool(affine)
48
- if self.affine:
49
- self.weight = nn.Parameter(torch.ones(hidden_size))
50
- self.bias = nn.Parameter(torch.zeros(hidden_size))
51
- else:
52
- self.weight, self.bias = None, None
53
-
54
- def forward(self, x):
55
- dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
56
- means = x.mean(dims, keepdim=True)
57
- x_zeromean = x - means
58
- variances = x_zeromean.pow(2).mean(dims, keepdim=True)
59
- x = x_zeromean / torch.sqrt(variances + self.eps)
60
- if self.affine:
61
- x = (self.weight * x) + self.bias
62
- return x
63
-
64
-
65
- try:
66
- # Optimized LayerNorm
67
- from apex.normalization import FusedLayerNorm as _FusedLayerNorm
68
- class LucaGPLM1bLayerNorm(_FusedLayerNorm):
69
- @torch.jit.unused
70
- def forward(self, x):
71
- if not x.is_cuda:
72
- return super().forward(x)
73
- else:
74
- with torch.cuda.device(x.device):
75
- return super().forward(x)
76
-
77
- except ImportError as e:
78
- print("import apex err:", e)
79
- from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
80
-
81
-
82
- class LucaGPLMTransformerLayer(nn.Module):
83
- """LucaGPLM Transformer layer block."""
84
-
85
- def __init__(
86
- self,
87
- embed_dim,
88
- ffn_embed_dim,
89
- attention_heads,
90
- add_bias_kv=True,
91
- use_lucagplm1b_layer_norm=False,
92
- use_rotary_embeddings: bool = False,
93
- ):
94
- '''
95
- Tramsformer-Encoder
96
- :param embed_dim: token embedding dim
97
- :param ffn_embed_dim: fully connected layer dim
98
- :param attention_heads: heads num
99
- :param add_bias_kv: key-value layer add bias
100
- :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
101
- :param use_rotary_embeddings: whether to use rotary embedding
102
- '''
103
- super().__init__()
104
- self.embed_dim = embed_dim
105
- self.ffn_embed_dim = ffn_embed_dim
106
- self.attention_heads = attention_heads
107
- self.use_rotary_embeddings = use_rotary_embeddings
108
- self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
109
-
110
- def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
111
- LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
112
-
113
- # pre layer norm
114
- self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
115
-
116
- self.self_attn = LucaGPLMMultiheadAttention(
117
- self.embed_dim,
118
- self.attention_heads,
119
- add_bias_kv=add_bias_kv,
120
- add_zero_attn=False,
121
- use_rotary_embeddings=self.use_rotary_embeddings,
122
- )
123
-
124
- # post layer norm
125
- self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
126
-
127
- # dimension increase by the fully connected layer
128
- self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
129
-
130
- # dimension reduction by the fully connected layer
131
- self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
132
-
133
- def forward(
134
- self,
135
- x,
136
- self_attn_mask=None,
137
- self_attn_padding_mask=None,
138
- need_head_weights=False
139
- ):
140
- residual = x
141
- x = self.pre_layer_norm(x)
142
- x, attn = self.self_attn(
143
- query=x,
144
- key=x,
145
- value=x,
146
- key_padding_mask=self_attn_padding_mask,
147
- need_weights=True,
148
- need_head_weights=need_head_weights,
149
- attn_mask=self_attn_mask,
150
- )
151
- x = residual + x
152
-
153
- residual = x
154
- x = self.post_layer_norm(x)
155
- x = gelu(self.fc1(x))
156
- x = self.fc2(x)
157
- x = residual + x
158
-
159
- return x, attn
160
-
161
-
162
- class AxialTransformerLayer(nn.Module):
163
- def __init__(
164
- self,
165
- embedding_dim: int = 768,
166
- ffn_embedding_dim: int = 3072,
167
- num_attention_heads: int = 8,
168
- dropout: float = 0.1,
169
- attention_dropout: float = 0.1,
170
- activation_dropout: float = 0.1,
171
- max_tokens_per_msa: int = 2**14,
172
- ) -> None:
173
- super().__init__()
174
-
175
- # Initialize parameters
176
- self.embedding_dim = embedding_dim
177
- self.dropout_prob = dropout
178
-
179
- row_self_attention = RowSelfAttention(
180
- embedding_dim,
181
- num_attention_heads,
182
- dropout=dropout,
183
- max_tokens_per_msa=max_tokens_per_msa,
184
- )
185
-
186
- column_self_attention = ColumnSelfAttention(
187
- embedding_dim,
188
- num_attention_heads,
189
- dropout=dropout,
190
- max_tokens_per_msa=max_tokens_per_msa,
191
- )
192
-
193
- feed_forward_layer = FeedForwardNetwork(
194
- embedding_dim,
195
- ffn_embedding_dim,
196
- activation_dropout=activation_dropout,
197
- max_tokens_per_msa=max_tokens_per_msa,
198
- )
199
-
200
- self.row_self_attention = self.build_residual(row_self_attention)
201
- self.column_self_attention = self.build_residual(column_self_attention)
202
- self.feed_forward_layer = self.build_residual(feed_forward_layer)
203
-
204
- def build_residual(self, layer: nn.Module):
205
- return NormalizedResidualBlock(
206
- layer,
207
- self.embedding_dim,
208
- self.dropout_prob,
209
- )
210
-
211
- def forward(
212
- self,
213
- x: torch.Tensor,
214
- self_attn_mask: Optional[torch.Tensor] = None,
215
- self_attn_padding_mask: Optional[torch.Tensor] = None,
216
- need_head_weights: bool = False,
217
- ):
218
- x, row_attn = self.row_self_attention(
219
- x,
220
- self_attn_mask=self_attn_mask,
221
- self_attn_padding_mask=self_attn_padding_mask,
222
- )
223
- x, column_attn = self.column_self_attention(
224
- x,
225
- self_attn_mask=self_attn_mask,
226
- self_attn_padding_mask=self_attn_padding_mask,
227
- )
228
- x = self.feed_forward_layer(x)
229
- if need_head_weights:
230
- return x, column_attn, row_attn
231
- else:
232
- return x
233
-
234
-
235
- class LearnedPositionalEmbedding(nn.Embedding):
236
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
237
- if padding_idx is not None:
238
- num_embeddings_ = num_embeddings + padding_idx + 1
239
- else:
240
- num_embeddings_ = num_embeddings
241
- super().__init__(num_embeddings_, embedding_dim, padding_idx)
242
- self.max_positions = num_embeddings
243
-
244
- def forward(self, input: torch.Tensor):
245
- """Input is expected to be of size [bsz x seqlen]."""
246
- if input.size(1) > self.max_positions:
247
- raise ValueError(
248
- f"Sequence length {input.size(1)} above maximum "
249
- f" sequence length of {self.max_positions}"
250
- )
251
- mask = input.ne(self.padding_idx).int()
252
- positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
253
- return F.embedding(
254
- positions,
255
- self.weight,
256
- self.padding_idx,
257
- self.max_norm,
258
- self.norm_type,
259
- self.scale_grad_by_freq,
260
- self.sparse,
261
- )
262
-
263
-
264
- class SinusoidalPositionalEmbedding(nn.Module):
265
- def __init__(self, embed_dim, padding_idx, learned=False):
266
- super().__init__()
267
- self.embed_dim = embed_dim
268
- self.padding_idx = padding_idx
269
- self.register_buffer("_float_tensor", torch.FloatTensor(1))
270
- self.weights = None
271
-
272
- def forward(self, x):
273
- bsz, seq_len = x.shape
274
- max_pos = self.padding_idx + 1 + seq_len
275
- if self.weights is None or max_pos > self.weights.size(0):
276
- self.weights = self.get_embedding(max_pos)
277
- self.weights = self.weights.type_as(self._float_tensor)
278
-
279
- positions = self.make_positions(x)
280
- return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
281
-
282
- def make_positions(self, x):
283
- mask = x.ne(self.padding_idx)
284
- range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
285
- positions = range_buf.expand_as(x)
286
- return positions * mask.long() + self.padding_idx * (1 - mask.long())
287
-
288
- def get_embedding(self, num_embeddings):
289
- half_dim = self.embed_dim // 2
290
- emb = math.log(10000) / (half_dim - 1)
291
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
292
- emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
293
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
294
- if self.embed_dim % 2 == 1:
295
- # zero pad
296
- emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
297
- if self.padding_idx is not None:
298
- emb[self.padding_idx, :] = 0
299
- return emb
300
-
301
-
302
- class RobertaLMHead(nn.Module):
303
- def __init__(self, embed_dim, output_dim, weight):
304
- super().__init__()
305
- self.dense = nn.Linear(embed_dim, embed_dim)
306
- self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
307
- self.weight = weight
308
- self.bias = nn.Parameter(torch.zeros(output_dim))
309
-
310
- def forward(self, features):
311
- x = self.dense(features)
312
- x = gelu(x)
313
- x = self.layer_norm(x)
314
- # project back to size of vocabulary with bias
315
- x = F.linear(x, self.weight) + self.bias
316
- return x
317
-
318
-
319
- class ContactPredictionHead(nn.Module):
320
- def __init__(
321
- self,
322
- in_features: int,
323
- prepend_bos: bool,
324
- append_eos: bool,
325
- bias=True,
326
- eos_idx: Optional[int] = None,
327
- ):
328
- super().__init__()
329
- self.in_features = in_features
330
- self.prepend_bos = prepend_bos
331
- self.append_eos = append_eos
332
- if append_eos and eos_idx is None:
333
- raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
334
- self.eos_idx = eos_idx
335
- self.regression = nn.Linear(in_features, 1, bias)
336
- self.activation = nn.Sigmoid()
337
-
338
- def forward(self, tokens, attentions):
339
- # remove eos token attentions
340
- if self.append_eos:
341
- eos_mask = tokens.ne(self.eos_idx).to(attentions)
342
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
343
- attentions = attentions * eos_mask[:, None, None, :, :]
344
- attentions = attentions[..., :-1, :-1]
345
- # remove cls token attentions
346
- if self.prepend_bos:
347
- attentions = attentions[..., 1:, 1:]
348
- batch_size, layers, heads, seqlen, _ = attentions.size()
349
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
350
-
351
- # features: B x C x T x T
352
- attentions = attentions.to(
353
- self.regression.weight.device
354
- ) # attentions always float32, may need to convert to float16
355
- attentions = apc(symmetrize(attentions))
356
- attentions = attentions.permute(0, 2, 3, 1)
357
- return self.activation(self.regression(attentions).squeeze(3))
358
-
359
-
360
- class NormalizedResidualBlock(nn.Module):
361
- def __init__(
362
- self,
363
- layer: nn.Module,
364
- embedding_dim: int,
365
- dropout: float = 0.1,
366
- ):
367
- super().__init__()
368
- self.embedding_dim = embedding_dim
369
-
370
- self.layer = layer
371
- self.dropout_module = nn.Dropout(
372
- dropout,
373
- )
374
- self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
375
-
376
- def forward(self, x, *args, **kwargs):
377
- residual = x
378
- x = self.layer_norm(x)
379
- outputs = self.layer(x, *args, **kwargs)
380
- if isinstance(outputs, tuple):
381
- x, *out = outputs
382
- else:
383
- x = outputs
384
- out = None
385
-
386
- x = self.dropout_module(x)
387
- x = residual + x
388
-
389
- if out is not None:
390
- return (x,) + tuple(out)
391
- else:
392
- return x
393
-
394
-
395
- class FeedForwardNetwork(nn.Module):
396
- def __init__(
397
- self,
398
- embedding_dim: int,
399
- ffn_embedding_dim: int,
400
- activation_dropout: float = 0.1,
401
- max_tokens_per_msa: int = 2**14,
402
- ):
403
- super().__init__()
404
- self.embedding_dim = embedding_dim
405
- self.ffn_embedding_dim = ffn_embedding_dim
406
- self.max_tokens_per_msa = max_tokens_per_msa
407
- self.activation_fn = nn.GELU()
408
- self.activation_dropout_module = nn.Dropout(
409
- activation_dropout,
410
- )
411
- self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
412
- self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
413
-
414
- def forward(self, x):
415
- x = self.activation_fn(self.fc1(x))
416
- x = self.activation_dropout_module(x)
417
- x = self.fc2(x)
418
- return x
419
-
420
-
421
- class RowSelfAttention(nn.Module):
422
- """Compute self-attention over rows of a 2D input."""
423
-
424
- def __init__(
425
- self,
426
- embed_dim,
427
- num_heads,
428
- dropout=0.0,
429
- max_tokens_per_msa: int = 2 ** 16,
430
- ):
431
- super().__init__()
432
- self.num_heads = num_heads
433
- self.dropout = dropout
434
- self.head_dim = embed_dim // num_heads
435
- self.scaling = self.head_dim ** -0.5
436
- self.max_tokens_per_msa = max_tokens_per_msa
437
- self.attn_shape = "hnij"
438
-
439
- self.k_proj = nn.Linear(embed_dim, embed_dim)
440
- self.v_proj = nn.Linear(embed_dim, embed_dim)
441
- self.q_proj = nn.Linear(embed_dim, embed_dim)
442
-
443
- self.out_proj = nn.Linear(embed_dim, embed_dim)
444
- self.dropout_module = nn.Dropout(dropout)
445
-
446
- def align_scaling(self, q):
447
- num_rows = q.size(0)
448
- return self.scaling / math.sqrt(num_rows)
449
-
450
- def _batched_forward(
451
- self,
452
- x,
453
- self_attn_mask=None,
454
- self_attn_padding_mask=None,
455
- ):
456
- num_rows, num_cols, batch_size, embed_dim = x.size()
457
- max_rows = max(1, self.max_tokens_per_msa // num_cols)
458
- attns = 0
459
- scaling = self.align_scaling(x)
460
- for start in range(0, num_rows, max_rows):
461
- attn_weights = self.compute_attention_weights(
462
- x[start : start + max_rows],
463
- scaling,
464
- self_attn_mask=self_attn_mask,
465
- self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
466
- if self_attn_padding_mask is not None
467
- else None,
468
- )
469
- attns += attn_weights
470
- attn_probs = attns.softmax(-1)
471
- attn_probs = self.dropout_module(attn_probs)
472
-
473
- outputs = []
474
- for start in range(0, num_rows, max_rows):
475
- output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
476
- outputs.append(output)
477
-
478
- output = torch.cat(outputs, 0)
479
- return output, attn_probs
480
-
481
- def compute_attention_weights(
482
- self,
483
- x,
484
- scaling: float,
485
- self_attn_mask=None,
486
- self_attn_padding_mask=None,
487
- ):
488
- num_rows, num_cols, batch_size, embed_dim = x.size()
489
- q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
490
- k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
491
- q *= scaling
492
- if self_attn_padding_mask is not None:
493
- # Zero out any padded aligned positions - this is important since
494
- # we take a sum across the alignment axis.
495
- q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
496
-
497
- attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
498
-
499
- if self_attn_mask is not None:
500
- raise NotImplementedError
501
- # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
502
-
503
- if self_attn_padding_mask is not None:
504
- attn_weights = attn_weights.masked_fill(
505
- self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
506
- -10000,
507
- )
508
-
509
- return attn_weights
510
-
511
- def compute_attention_update(
512
- self,
513
- x,
514
- attn_probs,
515
- ):
516
- num_rows, num_cols, batch_size, embed_dim = x.size()
517
- v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
518
- context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
519
- context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
520
- output = self.out_proj(context)
521
- return output
522
-
523
- def forward(
524
- self,
525
- x,
526
- self_attn_mask=None,
527
- self_attn_padding_mask=None,
528
- ):
529
- num_rows, num_cols, batch_size, embed_dim = x.size()
530
- if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
531
- return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
532
- else:
533
- scaling = self.align_scaling(x)
534
- attn_weights = self.compute_attention_weights(
535
- x, scaling, self_attn_mask, self_attn_padding_mask
536
- )
537
- attn_probs = attn_weights.softmax(-1)
538
- attn_probs = self.dropout_module(attn_probs)
539
- output = self.compute_attention_update(x, attn_probs)
540
- return output, attn_probs
541
-
542
-
543
- class ColumnSelfAttention(nn.Module):
544
- """Compute self-attention over columns of a 2D input."""
545
-
546
- def __init__(
547
- self,
548
- embed_dim,
549
- num_heads,
550
- dropout=0.0,
551
- max_tokens_per_msa: int = 2 ** 16,
552
- ):
553
- super().__init__()
554
-
555
- self.num_heads = num_heads
556
- self.dropout = dropout
557
- self.head_dim = embed_dim // num_heads
558
- self.scaling = self.head_dim ** -0.5
559
- self.max_tokens_per_msa = max_tokens_per_msa
560
-
561
- self.k_proj = nn.Linear(embed_dim, embed_dim)
562
- self.v_proj = nn.Linear(embed_dim, embed_dim)
563
- self.q_proj = nn.Linear(embed_dim, embed_dim)
564
-
565
- self.out_proj = nn.Linear(embed_dim, embed_dim)
566
- self.dropout_module = nn.Dropout(dropout)
567
-
568
- def _batched_forward(
569
- self,
570
- x,
571
- self_attn_mask=None,
572
- self_attn_padding_mask=None,
573
- ):
574
- num_rows, num_cols, batch_size, embed_dim = x.size()
575
- max_cols = max(1, self.max_tokens_per_msa // num_rows)
576
- outputs = []
577
- attns = []
578
- for start in range(0, num_cols, max_cols):
579
- output, attn = self(
580
- x[:, start : start + max_cols],
581
- self_attn_mask=self_attn_mask,
582
- self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
583
- if self_attn_padding_mask is not None
584
- else None,
585
- )
586
- outputs.append(output)
587
- attns.append(attn)
588
- output = torch.cat(outputs, 1)
589
- attns = torch.cat(attns, 1)
590
- return output, attns
591
-
592
- def compute_attention_update(
593
- self,
594
- x,
595
- self_attn_mask=None,
596
- self_attn_padding_mask=None,
597
- ):
598
- num_rows, num_cols, batch_size, embed_dim = x.size()
599
- if num_rows == 1:
600
- # if there is only 1 position, this is equivalent and doesn't break with padding
601
- attn_probs = torch.ones(
602
- self.num_heads,
603
- num_cols,
604
- batch_size,
605
- num_rows,
606
- num_rows,
607
- device=x.device,
608
- dtype=x.dtype,
609
- )
610
- output = self.out_proj(self.v_proj(x))
611
- else:
612
- q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
613
- k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
614
- v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
615
- q *= self.scaling
616
-
617
- attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
618
-
619
- if self_attn_mask is not None:
620
- raise NotImplementedError
621
- if self_attn_padding_mask is not None:
622
- attn_weights = attn_weights.masked_fill(
623
- self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
624
- -10000,
625
- )
626
-
627
- attn_probs = attn_weights.softmax(-1)
628
- attn_probs = self.dropout_module(attn_probs)
629
- context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
630
- context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
631
- output = self.out_proj(context)
632
- return output, attn_probs
633
-
634
- def forward(
635
- self,
636
- x,
637
- self_attn_mask=None,
638
- self_attn_padding_mask=None,
639
- ):
640
- num_rows, num_cols, batch_size, embed_dim = x.size()
641
- # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
642
- if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
643
- return self._batched_forward(
644
- x,
645
- self_attn_mask,
646
- self_attn_padding_mask,
647
- )
648
- else:
649
- return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
650
-
651
-
652
- def utils_softmax(x, dim: int, onnx_trace: bool = False):
653
- if onnx_trace:
654
- return F.softmax(x.float(), dim=dim)
655
- else:
656
- return F.softmax(x, dim=dim, dtype=torch.float32)
657
-
658
-
659
- class FairseqIncrementalState(object):
660
- def __init__(self, *args, **kwargs):
661
- super().__init__(*args, **kwargs)
662
- self.init_incremental_state()
663
-
664
- def init_incremental_state(self):
665
- self._incremental_state_id = str(uuid.uuid4())
666
-
667
- def _get_full_incremental_state_key(self, key: str) -> str:
668
- return "{}.{}".format(self._incremental_state_id, key)
669
-
670
- def get_incremental_state(
671
- self,
672
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
673
- key: str,
674
- ) -> Optional[Dict[str, Optional[Tensor]]]:
675
- """Helper for getting incremental state for an nn.Module."""
676
- full_key = self._get_full_incremental_state_key(key)
677
- if incremental_state is None or full_key not in incremental_state:
678
- return None
679
- return incremental_state[full_key]
680
-
681
- def set_incremental_state(
682
- self,
683
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
684
- key: str,
685
- value: Dict[str, Optional[Tensor]],
686
- ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
687
- """Helper for setting incremental state for an nn.Module."""
688
- if incremental_state is not None:
689
- full_key = self._get_full_incremental_state_key(key)
690
- incremental_state[full_key] = value
691
- return incremental_state
692
-
693
-
694
- def with_incremental_state(cls):
695
- cls.__bases__ = (FairseqIncrementalState,) + tuple(
696
- b for b in cls.__bases__ if b != FairseqIncrementalState
697
- )
698
- return cls
699
-
700
-
701
- @with_incremental_state
702
- class LucaGPLMMultiheadAttention(nn.Module):
703
- def __init__(
704
- self,
705
- embed_dim,
706
- num_heads,
707
- kdim=None,
708
- vdim=None,
709
- dropout=0.0,
710
- bias=True,
711
- add_bias_kv: bool = False,
712
- add_zero_attn: bool = False,
713
- self_attention: bool = False,
714
- encoder_decoder_attention: bool = False,
715
- use_rotary_embeddings: bool = False,
716
- ):
717
- super().__init__()
718
- self.embed_dim = embed_dim
719
- self.kdim = kdim if kdim is not None else embed_dim
720
- self.vdim = vdim if vdim is not None else embed_dim
721
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
722
-
723
- self.num_heads = num_heads
724
- self.dropout = dropout
725
- self.head_dim = embed_dim // num_heads
726
- assert (
727
- self.head_dim * num_heads == self.embed_dim
728
- ), "embed_dim must be divisible by num_heads"
729
- self.scaling = self.head_dim**-0.5
730
-
731
- self.self_attention = self_attention
732
- self.encoder_decoder_attention = encoder_decoder_attention
733
-
734
- assert not self.self_attention or self.qkv_same_dim, (
735
- "Self-attention requires query, key and " "value to be of the same size"
736
- )
737
-
738
- self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
739
- self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
740
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
741
-
742
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
743
-
744
- if add_bias_kv:
745
- self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
746
- self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
747
- else:
748
- self.bias_k = self.bias_v = None
749
-
750
- self.add_zero_attn = add_zero_attn
751
-
752
- self.reset_parameters()
753
-
754
- self.onnx_trace = False
755
- self.rot_emb = None
756
- if use_rotary_embeddings:
757
- self.rot_emb = RotaryEmbedding(dim=self.head_dim)
758
-
759
- self.enable_torch_version = False
760
- if hasattr(F, "multi_head_attention_forward"):
761
- self.enable_torch_version = True
762
- else:
763
- self.enable_torch_version = False
764
-
765
- def prepare_for_onnx_export_(self):
766
- self.onnx_trace = True
767
-
768
- def reset_parameters(self):
769
- nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
770
- nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
771
- nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
772
-
773
- nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
774
- # nn.init.xavier_uniform_(self.out_proj.weight)
775
- if self.out_proj.bias is not None:
776
- nn.init.constant_(self.out_proj.bias, 0.0)
777
- if self.bias_k is not None:
778
- nn.init.xavier_normal_(self.bias_k)
779
- if self.bias_v is not None:
780
- nn.init.xavier_normal_(self.bias_v)
781
-
782
- def forward(
783
- self,
784
- query,
785
- key: Optional[Tensor],
786
- value: Optional[Tensor],
787
- key_padding_mask: Optional[Tensor] = None,
788
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
789
- need_weights: bool = True,
790
- static_kv: bool = False,
791
- attn_mask: Optional[Tensor] = None,
792
- before_softmax: bool = False,
793
- need_head_weights: bool = False,
794
- ) -> Tuple[Tensor, Optional[Tensor]]:
795
- if need_head_weights:
796
- need_weights = True
797
-
798
- tgt_len, bsz, embed_dim = query.size()
799
- assert embed_dim == self.embed_dim
800
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
801
-
802
- if (
803
- not self.rot_emb
804
- and self.enable_torch_version
805
- and not self.onnx_trace
806
- and incremental_state is None
807
- and not static_kv
808
- # A workaround for quantization to work. Otherwise JIT compilation
809
- # treats bias in linear module as method.
810
- and not torch.jit.is_scripting()
811
- and not need_head_weights
812
- ):
813
- assert key is not None and value is not None
814
- return F.multi_head_attention_forward(
815
- query,
816
- key,
817
- value,
818
- self.embed_dim,
819
- self.num_heads,
820
- torch.empty([0]),
821
- torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
822
- self.bias_k,
823
- self.bias_v,
824
- self.add_zero_attn,
825
- self.dropout,
826
- self.out_proj.weight,
827
- self.out_proj.bias,
828
- self.training,
829
- key_padding_mask,
830
- need_weights,
831
- attn_mask,
832
- use_separate_proj_weight=True,
833
- q_proj_weight=self.q_proj.weight,
834
- k_proj_weight=self.k_proj.weight,
835
- v_proj_weight=self.v_proj.weight,
836
- )
837
- if incremental_state is not None:
838
- saved_state = self._get_input_buffer(incremental_state)
839
- if saved_state is not None and "prev_key" in saved_state:
840
- # previous time steps are cached - no need to recompute
841
- # key and value if they are static
842
- if static_kv:
843
- assert self.encoder_decoder_attention and not self.self_attention
844
- key = value = None
845
- else:
846
- saved_state = None
847
-
848
- if self.self_attention:
849
- q = self.q_proj(query)
850
- k = self.k_proj(query)
851
- v = self.v_proj(query)
852
- elif self.encoder_decoder_attention:
853
- # encoder-decoder attention
854
- q = self.q_proj(query)
855
- if key is None:
856
- assert value is None
857
- k = v = None
858
- else:
859
- k = self.k_proj(key)
860
- v = self.v_proj(key)
861
-
862
- else:
863
- assert key is not None and value is not None
864
- q = self.q_proj(query)
865
- k = self.k_proj(key)
866
- v = self.v_proj(value)
867
- q *= self.scaling
868
-
869
- if self.bias_k is not None:
870
- assert self.bias_v is not None
871
- k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
872
- v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
873
- if attn_mask is not None:
874
- attn_mask = torch.cat(
875
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
876
- )
877
- if key_padding_mask is not None:
878
- key_padding_mask = torch.cat(
879
- [
880
- key_padding_mask,
881
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
882
- ],
883
- dim=1,
884
- )
885
-
886
- q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
887
- if k is not None:
888
- k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
889
- if v is not None:
890
- v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
891
-
892
- if saved_state is not None:
893
- # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
894
- if "prev_key" in saved_state:
895
- _prev_key = saved_state["prev_key"]
896
- assert _prev_key is not None
897
- prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
898
- if static_kv:
899
- k = prev_key
900
- else:
901
- assert k is not None
902
- k = torch.cat([prev_key, k], dim=1)
903
- if "prev_value" in saved_state:
904
- _prev_value = saved_state["prev_value"]
905
- assert _prev_value is not None
906
- prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
907
- if static_kv:
908
- v = prev_value
909
- else:
910
- assert v is not None
911
- v = torch.cat([prev_value, v], dim=1)
912
- prev_key_padding_mask: Optional[Tensor] = None
913
- if "prev_key_padding_mask" in saved_state:
914
- prev_key_padding_mask = saved_state["prev_key_padding_mask"]
915
- assert k is not None and v is not None
916
- key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
917
- key_padding_mask=key_padding_mask,
918
- prev_key_padding_mask=prev_key_padding_mask,
919
- batch_size=bsz,
920
- src_len=k.size(1),
921
- static_kv=static_kv,
922
- )
923
-
924
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
925
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
926
- saved_state["prev_key_padding_mask"] = key_padding_mask
927
- # In this branch incremental_state is never None
928
- assert incremental_state is not None
929
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
930
- assert k is not None
931
- src_len = k.size(1)
932
-
933
- # This is part of a workaround to get around fork/join parallelism
934
- # not supporting Optional types.
935
- if key_padding_mask is not None and key_padding_mask.dim() == 0:
936
- key_padding_mask = None
937
-
938
- if key_padding_mask is not None:
939
- assert key_padding_mask.size(0) == bsz
940
- assert key_padding_mask.size(1) == src_len
941
-
942
- if self.add_zero_attn:
943
- assert v is not None
944
- src_len += 1
945
- k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
946
- v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
947
- if attn_mask is not None:
948
- attn_mask = torch.cat(
949
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
950
- )
951
- if key_padding_mask is not None:
952
- key_padding_mask = torch.cat(
953
- [
954
- key_padding_mask,
955
- torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
956
- ],
957
- dim=1,
958
- )
959
-
960
- if self.rot_emb:
961
- q, k = self.rot_emb(q, k)
962
-
963
- attn_weights = torch.bmm(q, k.transpose(1, 2))
964
- attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
965
-
966
- assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
967
-
968
- if attn_mask is not None:
969
- attn_mask = attn_mask.unsqueeze(0)
970
- if self.onnx_trace:
971
- attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
972
- attn_weights += attn_mask
973
-
974
- if key_padding_mask is not None:
975
- # don't attend to padding symbols
976
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
977
- attn_weights = attn_weights.masked_fill(
978
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
979
- )
980
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
981
-
982
- if before_softmax:
983
- return attn_weights, v
984
-
985
- attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
986
- attn_weights = attn_weights_float.type_as(attn_weights)
987
- attn_probs = F.dropout(
988
- attn_weights_float.type_as(attn_weights),
989
- p=self.dropout,
990
- training=self.training,
991
- )
992
- assert v is not None
993
- attn = torch.bmm(attn_probs, v)
994
- assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
995
- if self.onnx_trace and attn.size(1) == 1:
996
- # when ONNX tracing a single decoder step (sequence length == 1)
997
- # the transpose is a no-op copy before view, thus unnecessary
998
- attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
999
- else:
1000
- attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
1001
- attn = self.out_proj(attn)
1002
- attn_weights: Optional[Tensor] = None
1003
- if need_weights:
1004
- attn_weights = attn_weights_float.view(
1005
- bsz, self.num_heads, tgt_len, src_len
1006
- ).type_as(attn).transpose(1, 0)
1007
- if not need_head_weights:
1008
- # average attention weights over heads
1009
- attn_weights = attn_weights.mean(dim=0)
1010
-
1011
- return attn, attn_weights
1012
-
1013
- @staticmethod
1014
- def _append_prev_key_padding_mask(
1015
- key_padding_mask: Optional[Tensor],
1016
- prev_key_padding_mask: Optional[Tensor],
1017
- batch_size: int,
1018
- src_len: int,
1019
- static_kv: bool,
1020
- ) -> Optional[Tensor]:
1021
- # saved key padding masks have shape (bsz, seq_len)
1022
- if prev_key_padding_mask is not None and static_kv:
1023
- new_key_padding_mask = prev_key_padding_mask
1024
- elif prev_key_padding_mask is not None and key_padding_mask is not None:
1025
- new_key_padding_mask = torch.cat(
1026
- [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1027
- )
1028
- # During incremental decoding, as the padding token enters and
1029
- # leaves the frame, there will be a time when prev or current
1030
- # is None
1031
- elif prev_key_padding_mask is not None:
1032
- filler = torch.zeros(
1033
- (batch_size, src_len - prev_key_padding_mask.size(1)),
1034
- device=prev_key_padding_mask.device,
1035
- )
1036
- new_key_padding_mask = torch.cat(
1037
- [prev_key_padding_mask.float(), filler.float()], dim=1
1038
- )
1039
- elif key_padding_mask is not None:
1040
- filler = torch.zeros(
1041
- (batch_size, src_len - key_padding_mask.size(1)),
1042
- device=key_padding_mask.device,
1043
- )
1044
- new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1045
- else:
1046
- new_key_padding_mask = prev_key_padding_mask
1047
- return new_key_padding_mask
1048
-
1049
- @torch.jit.export
1050
- def reorder_incremental_state(
1051
- self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1052
- ):
1053
- input_buffer = self._get_input_buffer(incremental_state)
1054
- if input_buffer is not None:
1055
- for k in input_buffer.keys():
1056
- input_buffer_k = input_buffer[k]
1057
- if input_buffer_k is not None:
1058
- if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1059
- 0
1060
- ):
1061
- break
1062
- input_buffer[k] = input_buffer_k.index_select(0, new_order)
1063
- incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1064
- return incremental_state
1065
-
1066
- def _get_input_buffer(
1067
- self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1068
- ) -> Dict[str, Optional[Tensor]]:
1069
- result = self.get_incremental_state(incremental_state, "attn_state")
1070
- if result is not None:
1071
- return result
1072
- else:
1073
- empty_result: Dict[str, Optional[Tensor]] = {}
1074
- return empty_result
1075
-
1076
- def _set_input_buffer(
1077
- self,
1078
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1079
- buffer: Dict[str, Optional[Tensor]],
1080
- ):
1081
- return self.set_incremental_state(incremental_state, "attn_state", buffer)
1082
-
1083
- def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1084
- return attn_weights
1085
-
1086
- def upgrade_state_dict_named(self, state_dict, name):
1087
- prefix = name + "." if name != "" else ""
1088
- items_to_add = {}
1089
- keys_to_remove = []
1090
- for k in state_dict.keys():
1091
- if k.endswith(prefix + "in_proj_weight"):
1092
- dim = int(state_dict[k].shape[0] / 3)
1093
- items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1094
- items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1095
- items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1096
-
1097
- keys_to_remove.append(k)
1098
-
1099
- k_bias = prefix + "in_proj_bias"
1100
- if k_bias in state_dict.keys():
1101
- dim = int(state_dict[k].shape[0] / 3)
1102
- items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1103
- items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1104
- items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1105
-
1106
- keys_to_remove.append(prefix + "in_proj_bias")
1107
-
1108
- for k in keys_to_remove:
1109
- del state_dict[k]
1110
-
1111
- for key, value in items_to_add.items():
1112
- state_dict[key] = value
1113
-
1114
-
1115
- def rotate_half(x):
1116
- x1, x2 = x.chunk(2, dim=-1)
1117
- return torch.cat((-x2, x1), dim=-1)
1118
-
1119
-
1120
- def apply_rotary_pos_emb(x, cos, sin):
1121
- cos = cos[:, : x.shape[-2], :]
1122
- sin = sin[:, : x.shape[-2], :]
1123
-
1124
- return (x * cos) + (rotate_half(x) * sin)
1125
-
1126
-
1127
- class RotaryEmbedding(torch.nn.Module):
1128
- def __init__(self, dim: int, *_, **__):
1129
- super().__init__()
1130
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1131
- self.register_buffer("inv_freq", inv_freq)
1132
-
1133
- self._seq_len_cached = None
1134
- self._cos_cached = None
1135
- self._sin_cached = None
1136
-
1137
- def _update_cos_sin_tables(self, x, seq_dimension=1):
1138
- seq_len = x.shape[seq_dimension]
1139
-
1140
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1141
- self._seq_len_cached = seq_len
1142
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1143
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1144
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1145
-
1146
- self._cos_cached = emb.cos()[None, :, :]
1147
- self._sin_cached = emb.sin()[None, :, :]
1148
-
1149
- return self._cos_cached, self._sin_cached
1150
-
1151
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1152
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1153
-
1154
- return (
1155
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1156
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1157
- )
1158
-
1159
-
1160
-
1161
-
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ '''
4
+ @license: (C) Copyright 2021, Hey.
5
+ @author: Hey
6
+ @email: [email protected]
7
+ @tel: 137****6540
8
+ @datetime: 2023/7/24 10:01
9
+ @project: LucaOne
10
+ @file: modeling_gplm
11
+ @desc: LucaOne Model Detail
12
+ '''
13
+ import math
14
+ from typing import Dict, Optional, Sequence, Tuple, List, Union
15
+ import uuid
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import Tensor, nn
19
+ from torch.nn import Parameter
20
+
21
+
22
+ def gelu(x):
23
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24
+
25
+
26
+ def symmetrize(x):
27
+ return x + x.transpose(-1, -2)
28
+
29
+
30
+ def apc(x):
31
+ a1 = x.sum(-1, keepdims=True)
32
+ a2 = x.sum(-2, keepdims=True)
33
+ a12 = x.sum((-1, -2), keepdims=True)
34
+
35
+ avg = a1 * a2
36
+ avg.div_(a12) # in-place to reduce memory
37
+ normalized = x - avg
38
+ return normalized
39
+
40
+
41
+ class LucaGPLM1LayerNorm(nn.Module):
42
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
43
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
44
+ super().__init__()
45
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
46
+ self.eps = eps
47
+ self.affine = bool(affine)
48
+ if self.affine:
49
+ self.weight = nn.Parameter(torch.ones(hidden_size))
50
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
51
+ else:
52
+ self.weight, self.bias = None, None
53
+
54
+ def forward(self, x):
55
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
56
+ means = x.mean(dims, keepdim=True)
57
+ x_zeromean = x - means
58
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
59
+ x = x_zeromean / torch.sqrt(variances + self.eps)
60
+ if self.affine:
61
+ x = (self.weight * x) + self.bias
62
+ return x
63
+
64
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
65
+
66
+ class LucaGPLMTransformerLayer(nn.Module):
67
+ """LucaGPLM Transformer layer block."""
68
+
69
+ def __init__(
70
+ self,
71
+ embed_dim,
72
+ ffn_embed_dim,
73
+ attention_heads,
74
+ add_bias_kv=True,
75
+ use_lucagplm1b_layer_norm=False,
76
+ use_rotary_embeddings: bool = False,
77
+ ):
78
+ '''
79
+ Tramsformer-Encoder
80
+ :param embed_dim: token embedding dim
81
+ :param ffn_embed_dim: fully connected layer dim
82
+ :param attention_heads: heads num
83
+ :param add_bias_kv: key-value layer add bias
84
+ :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
85
+ :param use_rotary_embeddings: whether to use rotary embedding
86
+ '''
87
+ super().__init__()
88
+ self.embed_dim = embed_dim
89
+ self.ffn_embed_dim = ffn_embed_dim
90
+ self.attention_heads = attention_heads
91
+ self.use_rotary_embeddings = use_rotary_embeddings
92
+ self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
93
+
94
+ def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
95
+ LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
96
+
97
+ # pre layer norm
98
+ self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
99
+
100
+ self.self_attn = LucaGPLMMultiheadAttention(
101
+ self.embed_dim,
102
+ self.attention_heads,
103
+ add_bias_kv=add_bias_kv,
104
+ add_zero_attn=False,
105
+ use_rotary_embeddings=self.use_rotary_embeddings,
106
+ )
107
+
108
+ # post layer norm
109
+ self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
110
+
111
+ # dimension increase by the fully connected layer
112
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
113
+
114
+ # dimension reduction by the fully connected layer
115
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
116
+
117
+ def forward(
118
+ self,
119
+ x,
120
+ self_attn_mask=None,
121
+ self_attn_padding_mask=None,
122
+ need_head_weights=False
123
+ ):
124
+ residual = x
125
+ x = self.pre_layer_norm(x)
126
+ x, attn = self.self_attn(
127
+ query=x,
128
+ key=x,
129
+ value=x,
130
+ key_padding_mask=self_attn_padding_mask,
131
+ need_weights=True,
132
+ need_head_weights=need_head_weights,
133
+ attn_mask=self_attn_mask,
134
+ )
135
+ x = residual + x
136
+
137
+ residual = x
138
+ x = self.post_layer_norm(x)
139
+ x = gelu(self.fc1(x))
140
+ x = self.fc2(x)
141
+ x = residual + x
142
+
143
+ return x, attn
144
+
145
+
146
+ class AxialTransformerLayer(nn.Module):
147
+ def __init__(
148
+ self,
149
+ embedding_dim: int = 768,
150
+ ffn_embedding_dim: int = 3072,
151
+ num_attention_heads: int = 8,
152
+ dropout: float = 0.1,
153
+ attention_dropout: float = 0.1,
154
+ activation_dropout: float = 0.1,
155
+ max_tokens_per_msa: int = 2**14,
156
+ ) -> None:
157
+ super().__init__()
158
+
159
+ # Initialize parameters
160
+ self.embedding_dim = embedding_dim
161
+ self.dropout_prob = dropout
162
+
163
+ row_self_attention = RowSelfAttention(
164
+ embedding_dim,
165
+ num_attention_heads,
166
+ dropout=dropout,
167
+ max_tokens_per_msa=max_tokens_per_msa,
168
+ )
169
+
170
+ column_self_attention = ColumnSelfAttention(
171
+ embedding_dim,
172
+ num_attention_heads,
173
+ dropout=dropout,
174
+ max_tokens_per_msa=max_tokens_per_msa,
175
+ )
176
+
177
+ feed_forward_layer = FeedForwardNetwork(
178
+ embedding_dim,
179
+ ffn_embedding_dim,
180
+ activation_dropout=activation_dropout,
181
+ max_tokens_per_msa=max_tokens_per_msa,
182
+ )
183
+
184
+ self.row_self_attention = self.build_residual(row_self_attention)
185
+ self.column_self_attention = self.build_residual(column_self_attention)
186
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
187
+
188
+ def build_residual(self, layer: nn.Module):
189
+ return NormalizedResidualBlock(
190
+ layer,
191
+ self.embedding_dim,
192
+ self.dropout_prob,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ self_attn_mask: Optional[torch.Tensor] = None,
199
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
200
+ need_head_weights: bool = False,
201
+ ):
202
+ x, row_attn = self.row_self_attention(
203
+ x,
204
+ self_attn_mask=self_attn_mask,
205
+ self_attn_padding_mask=self_attn_padding_mask,
206
+ )
207
+ x, column_attn = self.column_self_attention(
208
+ x,
209
+ self_attn_mask=self_attn_mask,
210
+ self_attn_padding_mask=self_attn_padding_mask,
211
+ )
212
+ x = self.feed_forward_layer(x)
213
+ if need_head_weights:
214
+ return x, column_attn, row_attn
215
+ else:
216
+ return x
217
+
218
+
219
+ class LearnedPositionalEmbedding(nn.Embedding):
220
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
221
+ if padding_idx is not None:
222
+ num_embeddings_ = num_embeddings + padding_idx + 1
223
+ else:
224
+ num_embeddings_ = num_embeddings
225
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
226
+ self.max_positions = num_embeddings
227
+
228
+ def forward(self, input: torch.Tensor):
229
+ """Input is expected to be of size [bsz x seqlen]."""
230
+ if input.size(1) > self.max_positions:
231
+ raise ValueError(
232
+ f"Sequence length {input.size(1)} above maximum "
233
+ f" sequence length of {self.max_positions}"
234
+ )
235
+ mask = input.ne(self.padding_idx).int()
236
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
237
+ return F.embedding(
238
+ positions,
239
+ self.weight,
240
+ self.padding_idx,
241
+ self.max_norm,
242
+ self.norm_type,
243
+ self.scale_grad_by_freq,
244
+ self.sparse,
245
+ )
246
+
247
+
248
+ class SinusoidalPositionalEmbedding(nn.Module):
249
+ def __init__(self, embed_dim, padding_idx, learned=False):
250
+ super().__init__()
251
+ self.embed_dim = embed_dim
252
+ self.padding_idx = padding_idx
253
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
254
+ self.weights = None
255
+
256
+ def forward(self, x):
257
+ bsz, seq_len = x.shape
258
+ max_pos = self.padding_idx + 1 + seq_len
259
+ if self.weights is None or max_pos > self.weights.size(0):
260
+ self.weights = self.get_embedding(max_pos)
261
+ self.weights = self.weights.type_as(self._float_tensor)
262
+
263
+ positions = self.make_positions(x)
264
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
265
+
266
+ def make_positions(self, x):
267
+ mask = x.ne(self.padding_idx)
268
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
269
+ positions = range_buf.expand_as(x)
270
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
271
+
272
+ def get_embedding(self, num_embeddings):
273
+ half_dim = self.embed_dim // 2
274
+ emb = math.log(10000) / (half_dim - 1)
275
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
276
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
277
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
278
+ if self.embed_dim % 2 == 1:
279
+ # zero pad
280
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
281
+ if self.padding_idx is not None:
282
+ emb[self.padding_idx, :] = 0
283
+ return emb
284
+
285
+
286
+ class RobertaLMHead(nn.Module):
287
+ def __init__(self, embed_dim, output_dim, weight):
288
+ super().__init__()
289
+ self.dense = nn.Linear(embed_dim, embed_dim)
290
+ self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
291
+ self.weight = weight
292
+ self.bias = nn.Parameter(torch.zeros(output_dim))
293
+
294
+ def forward(self, features):
295
+ x = self.dense(features)
296
+ x = gelu(x)
297
+ x = self.layer_norm(x)
298
+ # project back to size of vocabulary with bias
299
+ x = F.linear(x, self.weight) + self.bias
300
+ return x
301
+
302
+
303
+ class ContactPredictionHead(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_features: int,
307
+ prepend_bos: bool,
308
+ append_eos: bool,
309
+ bias=True,
310
+ eos_idx: Optional[int] = None,
311
+ ):
312
+ super().__init__()
313
+ self.in_features = in_features
314
+ self.prepend_bos = prepend_bos
315
+ self.append_eos = append_eos
316
+ if append_eos and eos_idx is None:
317
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
318
+ self.eos_idx = eos_idx
319
+ self.regression = nn.Linear(in_features, 1, bias)
320
+ self.activation = nn.Sigmoid()
321
+
322
+ def forward(self, tokens, attentions):
323
+ # remove eos token attentions
324
+ if self.append_eos:
325
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
326
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
327
+ attentions = attentions * eos_mask[:, None, None, :, :]
328
+ attentions = attentions[..., :-1, :-1]
329
+ # remove cls token attentions
330
+ if self.prepend_bos:
331
+ attentions = attentions[..., 1:, 1:]
332
+ batch_size, layers, heads, seqlen, _ = attentions.size()
333
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
334
+
335
+ # features: B x C x T x T
336
+ attentions = attentions.to(
337
+ self.regression.weight.device
338
+ ) # attentions always float32, may need to convert to float16
339
+ attentions = apc(symmetrize(attentions))
340
+ attentions = attentions.permute(0, 2, 3, 1)
341
+ return self.activation(self.regression(attentions).squeeze(3))
342
+
343
+
344
+ class NormalizedResidualBlock(nn.Module):
345
+ def __init__(
346
+ self,
347
+ layer: nn.Module,
348
+ embedding_dim: int,
349
+ dropout: float = 0.1,
350
+ ):
351
+ super().__init__()
352
+ self.embedding_dim = embedding_dim
353
+
354
+ self.layer = layer
355
+ self.dropout_module = nn.Dropout(
356
+ dropout,
357
+ )
358
+ self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
359
+
360
+ def forward(self, x, *args, **kwargs):
361
+ residual = x
362
+ x = self.layer_norm(x)
363
+ outputs = self.layer(x, *args, **kwargs)
364
+ if isinstance(outputs, tuple):
365
+ x, *out = outputs
366
+ else:
367
+ x = outputs
368
+ out = None
369
+
370
+ x = self.dropout_module(x)
371
+ x = residual + x
372
+
373
+ if out is not None:
374
+ return (x,) + tuple(out)
375
+ else:
376
+ return x
377
+
378
+
379
+ class FeedForwardNetwork(nn.Module):
380
+ def __init__(
381
+ self,
382
+ embedding_dim: int,
383
+ ffn_embedding_dim: int,
384
+ activation_dropout: float = 0.1,
385
+ max_tokens_per_msa: int = 2**14,
386
+ ):
387
+ super().__init__()
388
+ self.embedding_dim = embedding_dim
389
+ self.ffn_embedding_dim = ffn_embedding_dim
390
+ self.max_tokens_per_msa = max_tokens_per_msa
391
+ self.activation_fn = nn.GELU()
392
+ self.activation_dropout_module = nn.Dropout(
393
+ activation_dropout,
394
+ )
395
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
396
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
397
+
398
+ def forward(self, x):
399
+ x = self.activation_fn(self.fc1(x))
400
+ x = self.activation_dropout_module(x)
401
+ x = self.fc2(x)
402
+ return x
403
+
404
+
405
+ class RowSelfAttention(nn.Module):
406
+ """Compute self-attention over rows of a 2D input."""
407
+
408
+ def __init__(
409
+ self,
410
+ embed_dim,
411
+ num_heads,
412
+ dropout=0.0,
413
+ max_tokens_per_msa: int = 2 ** 16,
414
+ ):
415
+ super().__init__()
416
+ self.num_heads = num_heads
417
+ self.dropout = dropout
418
+ self.head_dim = embed_dim // num_heads
419
+ self.scaling = self.head_dim ** -0.5
420
+ self.max_tokens_per_msa = max_tokens_per_msa
421
+ self.attn_shape = "hnij"
422
+
423
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
424
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
425
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
426
+
427
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
428
+ self.dropout_module = nn.Dropout(dropout)
429
+
430
+ def align_scaling(self, q):
431
+ num_rows = q.size(0)
432
+ return self.scaling / math.sqrt(num_rows)
433
+
434
+ def _batched_forward(
435
+ self,
436
+ x,
437
+ self_attn_mask=None,
438
+ self_attn_padding_mask=None,
439
+ ):
440
+ num_rows, num_cols, batch_size, embed_dim = x.size()
441
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
442
+ attns = 0
443
+ scaling = self.align_scaling(x)
444
+ for start in range(0, num_rows, max_rows):
445
+ attn_weights = self.compute_attention_weights(
446
+ x[start : start + max_rows],
447
+ scaling,
448
+ self_attn_mask=self_attn_mask,
449
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
450
+ if self_attn_padding_mask is not None
451
+ else None,
452
+ )
453
+ attns += attn_weights
454
+ attn_probs = attns.softmax(-1)
455
+ attn_probs = self.dropout_module(attn_probs)
456
+
457
+ outputs = []
458
+ for start in range(0, num_rows, max_rows):
459
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
460
+ outputs.append(output)
461
+
462
+ output = torch.cat(outputs, 0)
463
+ return output, attn_probs
464
+
465
+ def compute_attention_weights(
466
+ self,
467
+ x,
468
+ scaling: float,
469
+ self_attn_mask=None,
470
+ self_attn_padding_mask=None,
471
+ ):
472
+ num_rows, num_cols, batch_size, embed_dim = x.size()
473
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
474
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
475
+ q *= scaling
476
+ if self_attn_padding_mask is not None:
477
+ # Zero out any padded aligned positions - this is important since
478
+ # we take a sum across the alignment axis.
479
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
480
+
481
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
482
+
483
+ if self_attn_mask is not None:
484
+ raise NotImplementedError
485
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
486
+
487
+ if self_attn_padding_mask is not None:
488
+ attn_weights = attn_weights.masked_fill(
489
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
490
+ -10000,
491
+ )
492
+
493
+ return attn_weights
494
+
495
+ def compute_attention_update(
496
+ self,
497
+ x,
498
+ attn_probs,
499
+ ):
500
+ num_rows, num_cols, batch_size, embed_dim = x.size()
501
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
502
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
503
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
504
+ output = self.out_proj(context)
505
+ return output
506
+
507
+ def forward(
508
+ self,
509
+ x,
510
+ self_attn_mask=None,
511
+ self_attn_padding_mask=None,
512
+ ):
513
+ num_rows, num_cols, batch_size, embed_dim = x.size()
514
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
515
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
516
+ else:
517
+ scaling = self.align_scaling(x)
518
+ attn_weights = self.compute_attention_weights(
519
+ x, scaling, self_attn_mask, self_attn_padding_mask
520
+ )
521
+ attn_probs = attn_weights.softmax(-1)
522
+ attn_probs = self.dropout_module(attn_probs)
523
+ output = self.compute_attention_update(x, attn_probs)
524
+ return output, attn_probs
525
+
526
+
527
+ class ColumnSelfAttention(nn.Module):
528
+ """Compute self-attention over columns of a 2D input."""
529
+
530
+ def __init__(
531
+ self,
532
+ embed_dim,
533
+ num_heads,
534
+ dropout=0.0,
535
+ max_tokens_per_msa: int = 2 ** 16,
536
+ ):
537
+ super().__init__()
538
+
539
+ self.num_heads = num_heads
540
+ self.dropout = dropout
541
+ self.head_dim = embed_dim // num_heads
542
+ self.scaling = self.head_dim ** -0.5
543
+ self.max_tokens_per_msa = max_tokens_per_msa
544
+
545
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
546
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
547
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
548
+
549
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
550
+ self.dropout_module = nn.Dropout(dropout)
551
+
552
+ def _batched_forward(
553
+ self,
554
+ x,
555
+ self_attn_mask=None,
556
+ self_attn_padding_mask=None,
557
+ ):
558
+ num_rows, num_cols, batch_size, embed_dim = x.size()
559
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
560
+ outputs = []
561
+ attns = []
562
+ for start in range(0, num_cols, max_cols):
563
+ output, attn = self(
564
+ x[:, start : start + max_cols],
565
+ self_attn_mask=self_attn_mask,
566
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
567
+ if self_attn_padding_mask is not None
568
+ else None,
569
+ )
570
+ outputs.append(output)
571
+ attns.append(attn)
572
+ output = torch.cat(outputs, 1)
573
+ attns = torch.cat(attns, 1)
574
+ return output, attns
575
+
576
+ def compute_attention_update(
577
+ self,
578
+ x,
579
+ self_attn_mask=None,
580
+ self_attn_padding_mask=None,
581
+ ):
582
+ num_rows, num_cols, batch_size, embed_dim = x.size()
583
+ if num_rows == 1:
584
+ # if there is only 1 position, this is equivalent and doesn't break with padding
585
+ attn_probs = torch.ones(
586
+ self.num_heads,
587
+ num_cols,
588
+ batch_size,
589
+ num_rows,
590
+ num_rows,
591
+ device=x.device,
592
+ dtype=x.dtype,
593
+ )
594
+ output = self.out_proj(self.v_proj(x))
595
+ else:
596
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
597
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
598
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
599
+ q *= self.scaling
600
+
601
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
602
+
603
+ if self_attn_mask is not None:
604
+ raise NotImplementedError
605
+ if self_attn_padding_mask is not None:
606
+ attn_weights = attn_weights.masked_fill(
607
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
608
+ -10000,
609
+ )
610
+
611
+ attn_probs = attn_weights.softmax(-1)
612
+ attn_probs = self.dropout_module(attn_probs)
613
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
614
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
615
+ output = self.out_proj(context)
616
+ return output, attn_probs
617
+
618
+ def forward(
619
+ self,
620
+ x,
621
+ self_attn_mask=None,
622
+ self_attn_padding_mask=None,
623
+ ):
624
+ num_rows, num_cols, batch_size, embed_dim = x.size()
625
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
626
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
627
+ return self._batched_forward(
628
+ x,
629
+ self_attn_mask,
630
+ self_attn_padding_mask,
631
+ )
632
+ else:
633
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
634
+
635
+
636
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
637
+ if onnx_trace:
638
+ return F.softmax(x.float(), dim=dim)
639
+ else:
640
+ return F.softmax(x, dim=dim, dtype=torch.float32)
641
+
642
+
643
+ class FairseqIncrementalState(object):
644
+ def __init__(self, *args, **kwargs):
645
+ super().__init__(*args, **kwargs)
646
+ self.init_incremental_state()
647
+
648
+ def init_incremental_state(self):
649
+ self._incremental_state_id = str(uuid.uuid4())
650
+
651
+ def _get_full_incremental_state_key(self, key: str) -> str:
652
+ return "{}.{}".format(self._incremental_state_id, key)
653
+
654
+ def get_incremental_state(
655
+ self,
656
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
657
+ key: str,
658
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
659
+ """Helper for getting incremental state for an nn.Module."""
660
+ full_key = self._get_full_incremental_state_key(key)
661
+ if incremental_state is None or full_key not in incremental_state:
662
+ return None
663
+ return incremental_state[full_key]
664
+
665
+ def set_incremental_state(
666
+ self,
667
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
668
+ key: str,
669
+ value: Dict[str, Optional[Tensor]],
670
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
671
+ """Helper for setting incremental state for an nn.Module."""
672
+ if incremental_state is not None:
673
+ full_key = self._get_full_incremental_state_key(key)
674
+ incremental_state[full_key] = value
675
+ return incremental_state
676
+
677
+
678
+ def with_incremental_state(cls):
679
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
680
+ b for b in cls.__bases__ if b != FairseqIncrementalState
681
+ )
682
+ return cls
683
+
684
+
685
+ @with_incremental_state
686
+ class LucaGPLMMultiheadAttention(nn.Module):
687
+ def __init__(
688
+ self,
689
+ embed_dim,
690
+ num_heads,
691
+ kdim=None,
692
+ vdim=None,
693
+ dropout=0.0,
694
+ bias=True,
695
+ add_bias_kv: bool = False,
696
+ add_zero_attn: bool = False,
697
+ self_attention: bool = False,
698
+ encoder_decoder_attention: bool = False,
699
+ use_rotary_embeddings: bool = False,
700
+ ):
701
+ super().__init__()
702
+ self.embed_dim = embed_dim
703
+ self.kdim = kdim if kdim is not None else embed_dim
704
+ self.vdim = vdim if vdim is not None else embed_dim
705
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
706
+
707
+ self.num_heads = num_heads
708
+ self.dropout = dropout
709
+ self.head_dim = embed_dim // num_heads
710
+ assert (
711
+ self.head_dim * num_heads == self.embed_dim
712
+ ), "embed_dim must be divisible by num_heads"
713
+ self.scaling = self.head_dim**-0.5
714
+
715
+ self.self_attention = self_attention
716
+ self.encoder_decoder_attention = encoder_decoder_attention
717
+
718
+ assert not self.self_attention or self.qkv_same_dim, (
719
+ "Self-attention requires query, key and " "value to be of the same size"
720
+ )
721
+
722
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
723
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
724
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
725
+
726
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
727
+
728
+ if add_bias_kv:
729
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
730
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
731
+ else:
732
+ self.bias_k = self.bias_v = None
733
+
734
+ self.add_zero_attn = add_zero_attn
735
+
736
+ self.reset_parameters()
737
+
738
+ self.onnx_trace = False
739
+ self.rot_emb = None
740
+ if use_rotary_embeddings:
741
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
742
+
743
+ self.enable_torch_version = False
744
+ if hasattr(F, "multi_head_attention_forward"):
745
+ self.enable_torch_version = True
746
+ else:
747
+ self.enable_torch_version = False
748
+
749
+ def prepare_for_onnx_export_(self):
750
+ self.onnx_trace = True
751
+
752
+ def reset_parameters(self):
753
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
754
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
755
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
756
+
757
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
758
+ # nn.init.xavier_uniform_(self.out_proj.weight)
759
+ if self.out_proj.bias is not None:
760
+ nn.init.constant_(self.out_proj.bias, 0.0)
761
+ if self.bias_k is not None:
762
+ nn.init.xavier_normal_(self.bias_k)
763
+ if self.bias_v is not None:
764
+ nn.init.xavier_normal_(self.bias_v)
765
+
766
+ def forward(
767
+ self,
768
+ query,
769
+ key: Optional[Tensor],
770
+ value: Optional[Tensor],
771
+ key_padding_mask: Optional[Tensor] = None,
772
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
773
+ need_weights: bool = True,
774
+ static_kv: bool = False,
775
+ attn_mask: Optional[Tensor] = None,
776
+ before_softmax: bool = False,
777
+ need_head_weights: bool = False,
778
+ ) -> Tuple[Tensor, Optional[Tensor]]:
779
+ if need_head_weights:
780
+ need_weights = True
781
+
782
+ tgt_len, bsz, embed_dim = query.size()
783
+ assert embed_dim == self.embed_dim
784
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
785
+
786
+ if (
787
+ not self.rot_emb
788
+ and self.enable_torch_version
789
+ and not self.onnx_trace
790
+ and incremental_state is None
791
+ and not static_kv
792
+ # A workaround for quantization to work. Otherwise JIT compilation
793
+ # treats bias in linear module as method.
794
+ and not torch.jit.is_scripting()
795
+ and not need_head_weights
796
+ ):
797
+ assert key is not None and value is not None
798
+ return F.multi_head_attention_forward(
799
+ query,
800
+ key,
801
+ value,
802
+ self.embed_dim,
803
+ self.num_heads,
804
+ torch.empty([0]),
805
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
806
+ self.bias_k,
807
+ self.bias_v,
808
+ self.add_zero_attn,
809
+ self.dropout,
810
+ self.out_proj.weight,
811
+ self.out_proj.bias,
812
+ self.training,
813
+ key_padding_mask,
814
+ need_weights,
815
+ attn_mask,
816
+ use_separate_proj_weight=True,
817
+ q_proj_weight=self.q_proj.weight,
818
+ k_proj_weight=self.k_proj.weight,
819
+ v_proj_weight=self.v_proj.weight,
820
+ )
821
+ if incremental_state is not None:
822
+ saved_state = self._get_input_buffer(incremental_state)
823
+ if saved_state is not None and "prev_key" in saved_state:
824
+ # previous time steps are cached - no need to recompute
825
+ # key and value if they are static
826
+ if static_kv:
827
+ assert self.encoder_decoder_attention and not self.self_attention
828
+ key = value = None
829
+ else:
830
+ saved_state = None
831
+
832
+ if self.self_attention:
833
+ q = self.q_proj(query)
834
+ k = self.k_proj(query)
835
+ v = self.v_proj(query)
836
+ elif self.encoder_decoder_attention:
837
+ # encoder-decoder attention
838
+ q = self.q_proj(query)
839
+ if key is None:
840
+ assert value is None
841
+ k = v = None
842
+ else:
843
+ k = self.k_proj(key)
844
+ v = self.v_proj(key)
845
+
846
+ else:
847
+ assert key is not None and value is not None
848
+ q = self.q_proj(query)
849
+ k = self.k_proj(key)
850
+ v = self.v_proj(value)
851
+ q *= self.scaling
852
+
853
+ if self.bias_k is not None:
854
+ assert self.bias_v is not None
855
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
856
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
857
+ if attn_mask is not None:
858
+ attn_mask = torch.cat(
859
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
860
+ )
861
+ if key_padding_mask is not None:
862
+ key_padding_mask = torch.cat(
863
+ [
864
+ key_padding_mask,
865
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
866
+ ],
867
+ dim=1,
868
+ )
869
+
870
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
871
+ if k is not None:
872
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
873
+ if v is not None:
874
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
875
+
876
+ if saved_state is not None:
877
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
878
+ if "prev_key" in saved_state:
879
+ _prev_key = saved_state["prev_key"]
880
+ assert _prev_key is not None
881
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
882
+ if static_kv:
883
+ k = prev_key
884
+ else:
885
+ assert k is not None
886
+ k = torch.cat([prev_key, k], dim=1)
887
+ if "prev_value" in saved_state:
888
+ _prev_value = saved_state["prev_value"]
889
+ assert _prev_value is not None
890
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
891
+ if static_kv:
892
+ v = prev_value
893
+ else:
894
+ assert v is not None
895
+ v = torch.cat([prev_value, v], dim=1)
896
+ prev_key_padding_mask: Optional[Tensor] = None
897
+ if "prev_key_padding_mask" in saved_state:
898
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
899
+ assert k is not None and v is not None
900
+ key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
901
+ key_padding_mask=key_padding_mask,
902
+ prev_key_padding_mask=prev_key_padding_mask,
903
+ batch_size=bsz,
904
+ src_len=k.size(1),
905
+ static_kv=static_kv,
906
+ )
907
+
908
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
909
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
910
+ saved_state["prev_key_padding_mask"] = key_padding_mask
911
+ # In this branch incremental_state is never None
912
+ assert incremental_state is not None
913
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
914
+ assert k is not None
915
+ src_len = k.size(1)
916
+
917
+ # This is part of a workaround to get around fork/join parallelism
918
+ # not supporting Optional types.
919
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
920
+ key_padding_mask = None
921
+
922
+ if key_padding_mask is not None:
923
+ assert key_padding_mask.size(0) == bsz
924
+ assert key_padding_mask.size(1) == src_len
925
+
926
+ if self.add_zero_attn:
927
+ assert v is not None
928
+ src_len += 1
929
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
930
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
931
+ if attn_mask is not None:
932
+ attn_mask = torch.cat(
933
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
934
+ )
935
+ if key_padding_mask is not None:
936
+ key_padding_mask = torch.cat(
937
+ [
938
+ key_padding_mask,
939
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
940
+ ],
941
+ dim=1,
942
+ )
943
+
944
+ if self.rot_emb:
945
+ q, k = self.rot_emb(q, k)
946
+
947
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
948
+ attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
949
+
950
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
951
+
952
+ if attn_mask is not None:
953
+ attn_mask = attn_mask.unsqueeze(0)
954
+ if self.onnx_trace:
955
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
956
+ attn_weights += attn_mask
957
+
958
+ if key_padding_mask is not None:
959
+ # don't attend to padding symbols
960
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
961
+ attn_weights = attn_weights.masked_fill(
962
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
963
+ )
964
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
965
+
966
+ if before_softmax:
967
+ return attn_weights, v
968
+
969
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
970
+ attn_weights = attn_weights_float.type_as(attn_weights)
971
+ attn_probs = F.dropout(
972
+ attn_weights_float.type_as(attn_weights),
973
+ p=self.dropout,
974
+ training=self.training,
975
+ )
976
+ assert v is not None
977
+ attn = torch.bmm(attn_probs, v)
978
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
979
+ if self.onnx_trace and attn.size(1) == 1:
980
+ # when ONNX tracing a single decoder step (sequence length == 1)
981
+ # the transpose is a no-op copy before view, thus unnecessary
982
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
983
+ else:
984
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
985
+ attn = self.out_proj(attn)
986
+ attn_weights: Optional[Tensor] = None
987
+ if need_weights:
988
+ attn_weights = attn_weights_float.view(
989
+ bsz, self.num_heads, tgt_len, src_len
990
+ ).type_as(attn).transpose(1, 0)
991
+ if not need_head_weights:
992
+ # average attention weights over heads
993
+ attn_weights = attn_weights.mean(dim=0)
994
+
995
+ return attn, attn_weights
996
+
997
+ @staticmethod
998
+ def _append_prev_key_padding_mask(
999
+ key_padding_mask: Optional[Tensor],
1000
+ prev_key_padding_mask: Optional[Tensor],
1001
+ batch_size: int,
1002
+ src_len: int,
1003
+ static_kv: bool,
1004
+ ) -> Optional[Tensor]:
1005
+ # saved key padding masks have shape (bsz, seq_len)
1006
+ if prev_key_padding_mask is not None and static_kv:
1007
+ new_key_padding_mask = prev_key_padding_mask
1008
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
1009
+ new_key_padding_mask = torch.cat(
1010
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1011
+ )
1012
+ # During incremental decoding, as the padding token enters and
1013
+ # leaves the frame, there will be a time when prev or current
1014
+ # is None
1015
+ elif prev_key_padding_mask is not None:
1016
+ filler = torch.zeros(
1017
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
1018
+ device=prev_key_padding_mask.device,
1019
+ )
1020
+ new_key_padding_mask = torch.cat(
1021
+ [prev_key_padding_mask.float(), filler.float()], dim=1
1022
+ )
1023
+ elif key_padding_mask is not None:
1024
+ filler = torch.zeros(
1025
+ (batch_size, src_len - key_padding_mask.size(1)),
1026
+ device=key_padding_mask.device,
1027
+ )
1028
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1029
+ else:
1030
+ new_key_padding_mask = prev_key_padding_mask
1031
+ return new_key_padding_mask
1032
+
1033
+ @torch.jit.export
1034
+ def reorder_incremental_state(
1035
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1036
+ ):
1037
+ input_buffer = self._get_input_buffer(incremental_state)
1038
+ if input_buffer is not None:
1039
+ for k in input_buffer.keys():
1040
+ input_buffer_k = input_buffer[k]
1041
+ if input_buffer_k is not None:
1042
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1043
+ 0
1044
+ ):
1045
+ break
1046
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
1047
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1048
+ return incremental_state
1049
+
1050
+ def _get_input_buffer(
1051
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1052
+ ) -> Dict[str, Optional[Tensor]]:
1053
+ result = self.get_incremental_state(incremental_state, "attn_state")
1054
+ if result is not None:
1055
+ return result
1056
+ else:
1057
+ empty_result: Dict[str, Optional[Tensor]] = {}
1058
+ return empty_result
1059
+
1060
+ def _set_input_buffer(
1061
+ self,
1062
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1063
+ buffer: Dict[str, Optional[Tensor]],
1064
+ ):
1065
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
1066
+
1067
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1068
+ return attn_weights
1069
+
1070
+ def upgrade_state_dict_named(self, state_dict, name):
1071
+ prefix = name + "." if name != "" else ""
1072
+ items_to_add = {}
1073
+ keys_to_remove = []
1074
+ for k in state_dict.keys():
1075
+ if k.endswith(prefix + "in_proj_weight"):
1076
+ dim = int(state_dict[k].shape[0] / 3)
1077
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1078
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1079
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1080
+
1081
+ keys_to_remove.append(k)
1082
+
1083
+ k_bias = prefix + "in_proj_bias"
1084
+ if k_bias in state_dict.keys():
1085
+ dim = int(state_dict[k].shape[0] / 3)
1086
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1087
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1088
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1089
+
1090
+ keys_to_remove.append(prefix + "in_proj_bias")
1091
+
1092
+ for k in keys_to_remove:
1093
+ del state_dict[k]
1094
+
1095
+ for key, value in items_to_add.items():
1096
+ state_dict[key] = value
1097
+
1098
+
1099
+ def rotate_half(x):
1100
+ x1, x2 = x.chunk(2, dim=-1)
1101
+ return torch.cat((-x2, x1), dim=-1)
1102
+
1103
+
1104
+ def apply_rotary_pos_emb(x, cos, sin):
1105
+ cos = cos[:, : x.shape[-2], :]
1106
+ sin = sin[:, : x.shape[-2], :]
1107
+
1108
+ return (x * cos) + (rotate_half(x) * sin)
1109
+
1110
+
1111
+ class RotaryEmbedding(torch.nn.Module):
1112
+ def __init__(self, dim: int, *_, **__):
1113
+ super().__init__()
1114
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1115
+ self.register_buffer("inv_freq", inv_freq)
1116
+
1117
+ self._seq_len_cached = None
1118
+ self._cos_cached = None
1119
+ self._sin_cached = None
1120
+
1121
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
1122
+ seq_len = x.shape[seq_dimension]
1123
+
1124
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1125
+ self._seq_len_cached = seq_len
1126
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1127
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1128
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1129
+
1130
+ self._cos_cached = emb.cos()[None, :, :]
1131
+ self._sin_cached = emb.sin()[None, :, :]
1132
+
1133
+ return self._cos_cached, self._sin_cached
1134
+
1135
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1136
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1137
+
1138
+ return (
1139
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1140
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1141
+ )
1142
+
1143
+
1144
+
1145
+