Yuanfei commited on
Commit
a03d44f
·
verified ·
1 Parent(s): 3639473

Upload LucaGPLM

Browse files
Files changed (3) hide show
  1. alphabet.py +44 -9
  2. file_operator.py +230 -230
  3. modeling_gplm.py +1210 -1210
alphabet.py CHANGED
@@ -1,10 +1,11 @@
1
  #!/usr/bin/env python
2
  # encoding: utf-8
3
 
4
- import sys
 
5
  import itertools
6
  from typing import Sequence, List
7
-
8
  from .batch_converter import BatchConverter
9
 
10
  gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
@@ -21,7 +22,7 @@ gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]']
21
  class Alphabet(object):
22
  def __init__(
23
  self,
24
- standard_toks: Sequence[str],
25
  prepend_toks: Sequence[str] = gene_prot_prepend_toks,
26
  append_toks: Sequence[str] = gene_prot_append_toks,
27
  prepend_bos: bool = True,
@@ -156,9 +157,43 @@ class Alphabet(object):
156
  def encode(self, text):
157
  return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
158
 
159
-
160
- if __name__ == "__main__":
161
- alphabet = Alphabet.from_predefined("gene_prot")
162
- from src.utils import gene_seq_replace
163
- print(alphabet.encode(gene_seq_replace("gttgtttggtagctaggagcctgactacatggcttcaaggctaaatggccacaggtgcccaggctatttggcttgctggaggcttcattcat")))
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python
2
  # encoding: utf-8
3
 
4
+ import os
5
+ import json
6
  import itertools
7
  from typing import Sequence, List
8
+ from transformers import PreTrainedTokenizer
9
  from .batch_converter import BatchConverter
10
 
11
  gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*']
 
22
  class Alphabet(object):
23
  def __init__(
24
  self,
25
+ standard_toks: Sequence[str] = gene_prot_standard_toks,
26
  prepend_toks: Sequence[str] = gene_prot_prepend_toks,
27
  append_toks: Sequence[str] = gene_prot_append_toks,
28
  prepend_bos: bool = True,
 
157
  def encode(self, text):
158
  return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
159
 
160
+ class AlphabetTokenizer(PreTrainedTokenizer):
161
+ def __init__(
162
+ self,
163
+ alphabet: Alphabet = Alphabet(),
164
+ **kwargs
165
+ ):
166
+ super().__init__(**kwargs)
167
+ self.alphabet = alphabet
168
+ self.pad_token = '[PAD]'
169
+ self.cls_token = '[CLS]'
170
+ self.sep_token = '[SEP]'
171
+ self.mask_token = '[MASK]'
172
+ self.unk_token = '[UNK]'
173
+
174
+ def _tokenize(self, text: str):
175
+ # Use your Alphabet class's tokenize method
176
+ return self.alphabet.tokenize(text)
177
+
178
+ def convert_tokens_to_ids(self, tokens):
179
+ # Use the Alphabet class's get_idx method
180
+ return [self.alphabet.get_idx(token) for token in tokens]
181
+
182
+ def convert_ids_to_tokens(self, ids):
183
+ # Use the Alphabet class's get_tok method
184
+ return [self.alphabet.get_tok(index) for index in ids]
185
+
186
+ def save_vocabulary(self, save_directory, filename_prefix=None):
187
+ # Save the tokenizer vocabulary, required by Hugging Face
188
+ vocab_file = os.path.join(save_directory, (filename_prefix or "") + "vocab.json")
189
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
190
+ json.dump(self.alphabet.to_dict(), vocab_writer, ensure_ascii=False)
191
+ return (vocab_file,)
192
+
193
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
194
+ # Add special tokens to input ids, if required
195
+ cls_token = [self.alphabet.cls_idx]
196
+ sep_token = [self.alphabet.eos_idx]
197
+ if token_ids_1:
198
+ return cls_token + token_ids_0 + sep_token + token_ids_1 + sep_token
199
+ return cls_token + token_ids_0 + sep_token
file_operator.py CHANGED
@@ -1,230 +1,230 @@
1
- #!/usr/bin/env python
2
- # encoding: utf-8
3
-
4
- import csv, sys
5
- import io, textwrap, itertools
6
- from Bio import SeqIO
7
- from Bio.Seq import Seq
8
- from Bio.SeqRecord import SeqRecord
9
- csv.field_size_limit(sys.maxsize)
10
-
11
-
12
- common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
13
-
14
- # not {'O', 'U', 'Z', 'J', 'B'}
15
- # Common amino acids
16
- common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
17
-
18
-
19
- def clean_seq(protein_id, seq):
20
- seq = seq.upper()
21
- new_seq = ""
22
- has_invalid_char = False
23
- invalid_char_set = set()
24
- for ch in seq:
25
- if 'A' <= ch <= 'Z' and ch not in ['J']:
26
- new_seq += ch
27
- else:
28
- invalid_char_set.add(ch)
29
- has_invalid_char = True
30
- if has_invalid_char:
31
- print("id: %s. Seq: %s" % (protein_id, seq))
32
- print("invalid char set:", invalid_char_set)
33
- return new_seq
34
-
35
-
36
- def file_reader(filename, header=True, header_filter=True):
37
- if filename.endswith(".fa") or filename.endswith(".fas") or filename.endswith(".fasta"):
38
- return fasta_reader(filename)
39
- elif filename.endswith(".csv"):
40
- return csv_reader(filename, header=True, header_filter=True)
41
- elif filename.endswith(".tsv"):
42
- return tsv_reader(filename, header=True, header_filter=True)
43
- else:
44
- return txt_reader(filename, header=header, header_filter=header_filter)
45
-
46
-
47
- def txt_reader(handle, header=True, header_filter=True):
48
- '''
49
- csv 读取器,适合大文件
50
- :param handle:
51
- :param header:
52
- :param header_filter: 返回结果是否去掉头
53
- :return:
54
- '''
55
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
56
- try:
57
- cnt = 0
58
- for line in handle:
59
- cnt += 1
60
- if header and header_filter and cnt == 1:
61
- continue
62
- yield line.strip()
63
- except Exception as e:
64
- raise StopIteration
65
- finally:
66
- if not handle.closed:
67
- handle.close()
68
-
69
-
70
- def tsv_reader(handle, header=True, header_filter=True):
71
- '''
72
- csv 读取器,适合大文件
73
- :param handle:
74
- :param header:
75
- :param header_filter: 返回结果是否去掉头
76
- :return:
77
- '''
78
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
79
- try:
80
- reader = csv.reader(handle, delimiter="\t")
81
- cnt = 0
82
- for row in reader:
83
- cnt += 1
84
- if header and header_filter and cnt == 1:
85
- continue
86
- yield row
87
- except Exception as e:
88
- raise StopIteration
89
- finally:
90
- if not handle.closed:
91
- handle.close()
92
-
93
-
94
- def csv_reader(handle, header=True, header_filter=True):
95
- '''
96
- csv 读取器,适合大文件
97
- :param handle:
98
- :param header:
99
- :param header_filter: 返回结果是否去掉头
100
- :return:
101
- '''
102
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
103
- try:
104
- # data = csv.reader((line.replace('\0','') for line in data_initial), delimiter=",")
105
- # reader = csv.reader(handle)
106
- reader = csv.reader((line.replace('\0', '') for line in handle))
107
- cnt = 0
108
- for row in reader:
109
- cnt += 1
110
- if header and header_filter and cnt == 1:
111
- continue
112
- yield row
113
- except Exception as e:
114
- raise StopIteration
115
- finally:
116
- if not handle.closed:
117
- handle.close()
118
-
119
-
120
- def txt_writer(dataset, handle, header=None):
121
- '''
122
- txt 写
123
- :param dataset: 数据
124
- :param handle: 文件
125
- :param header: 头
126
- :return:
127
- '''
128
- '''
129
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
130
- try:
131
- if header:
132
- if isinstance(header, list):
133
- handle.write(",".join(header) + "\n")
134
- else:
135
- handle.write(header + "\n")
136
- print("header: %s" %header)
137
- for row in dataset:
138
- handle.write(str(row) + "\n")
139
- except Exception as e:
140
- raise e
141
- finally:
142
- if not handle.closed:
143
- handle.close()
144
- '''
145
- with open(handle, "w") as wfp:
146
- if header:
147
- if isinstance(header, list):
148
- wfp.write(",".join(header) + "\n")
149
- else:
150
- wfp.write(header + "\n")
151
- for row in dataset:
152
- wfp.write(str(row) + "\n")
153
-
154
-
155
- def csv_writer(dataset, handle, header):
156
- '''
157
- csv 写,适合大文件
158
- :param dataset: 数据
159
- :param handle: 文件
160
- :param header: 头
161
- :return:
162
- '''
163
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
164
- try:
165
- writer = csv.writer(handle)
166
- if header:
167
- writer.writerow(header)
168
- for row in dataset:
169
- writer.writerow(row)
170
- except Exception as e:
171
- raise e
172
- finally:
173
- if not handle.closed:
174
- handle.close()
175
-
176
-
177
- def fasta_reader(handle, width=None):
178
- """
179
- Reads a FASTA file, yielding header, sequence pairs for each sequence recovered 适合大文件
180
- args:
181
- :handle (str, pathliob.Path, or file pointer) - fasta to read from
182
- :width (int or None) - formats the sequence to have max `width` character per line.
183
- If <= 0, processed as None. If None, there is no max width.
184
- yields:
185
- :(header, sequence) tuples
186
- returns:
187
- :None
188
- """
189
- FASTA_STOP_CODON = "*"
190
-
191
- handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
192
- width = width if isinstance(width, int) and width > 0 else None
193
- try:
194
- header = None
195
- for is_header, group in itertools.groupby(handle, lambda line: line.startswith(">")):
196
- if is_header:
197
- header = group.__next__().strip()
198
- else:
199
- seq = ''.join(line.strip() for line in group).strip().rstrip(FASTA_STOP_CODON)
200
- if width is not None:
201
- seq = textwrap.fill(seq, width)
202
- yield header, seq
203
- except Exception as e:
204
- raise StopIteration
205
- finally:
206
- if not handle.closed:
207
- handle.close()
208
-
209
-
210
- def write_fasta(filepath, sequences):
211
- '''
212
- write fasta file
213
- :param filepath: savepath
214
- :param sequences: fasta sequence(each item: [id, seq])
215
- :return:
216
- '''
217
-
218
- if sequences:
219
- with open(filepath, "w") as output_handle:
220
- if len(sequences[0]) > 1 and isinstance(sequences[0][0], str):
221
- for row in sequences:
222
- protein_id = row[0]
223
- seq = row[1]
224
- sequence = SeqRecord(Seq(seq, None), id=protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id, description="")
225
- SeqIO.write(sequence, output_handle, "fasta")
226
- else:
227
- for sequence in sequences:
228
- SeqIO.write(sequence, output_handle, "fasta")
229
-
230
-
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import csv, sys
5
+ import io, textwrap, itertools
6
+ from Bio import SeqIO
7
+ from Bio.Seq import Seq
8
+ from Bio.SeqRecord import SeqRecord
9
+ csv.field_size_limit(sys.maxsize)
10
+
11
+
12
+ common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}
13
+
14
+ # not {'O', 'U', 'Z', 'J', 'B'}
15
+ # Common amino acids
16
+ common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}
17
+
18
+
19
+ def clean_seq(protein_id, seq):
20
+ seq = seq.upper()
21
+ new_seq = ""
22
+ has_invalid_char = False
23
+ invalid_char_set = set()
24
+ for ch in seq:
25
+ if 'A' <= ch <= 'Z' and ch not in ['J']:
26
+ new_seq += ch
27
+ else:
28
+ invalid_char_set.add(ch)
29
+ has_invalid_char = True
30
+ if has_invalid_char:
31
+ print("id: %s. Seq: %s" % (protein_id, seq))
32
+ print("invalid char set:", invalid_char_set)
33
+ return new_seq
34
+
35
+
36
+ def file_reader(filename, header=True, header_filter=True):
37
+ if filename.endswith(".fa") or filename.endswith(".fas") or filename.endswith(".fasta"):
38
+ return fasta_reader(filename)
39
+ elif filename.endswith(".csv"):
40
+ return csv_reader(filename, header=True, header_filter=True)
41
+ elif filename.endswith(".tsv"):
42
+ return tsv_reader(filename, header=True, header_filter=True)
43
+ else:
44
+ return txt_reader(filename, header=header, header_filter=header_filter)
45
+
46
+
47
+ def txt_reader(handle, header=True, header_filter=True):
48
+ '''
49
+ csv 读取器,适合大文件
50
+ :param handle:
51
+ :param header:
52
+ :param header_filter: 返回结果是否去掉头
53
+ :return:
54
+ '''
55
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
56
+ try:
57
+ cnt = 0
58
+ for line in handle:
59
+ cnt += 1
60
+ if header and header_filter and cnt == 1:
61
+ continue
62
+ yield line.strip()
63
+ except Exception as e:
64
+ raise StopIteration
65
+ finally:
66
+ if not handle.closed:
67
+ handle.close()
68
+
69
+
70
+ def tsv_reader(handle, header=True, header_filter=True):
71
+ '''
72
+ csv 读取器,适合大文件
73
+ :param handle:
74
+ :param header:
75
+ :param header_filter: 返回结果是否去掉头
76
+ :return:
77
+ '''
78
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
79
+ try:
80
+ reader = csv.reader(handle, delimiter="\t")
81
+ cnt = 0
82
+ for row in reader:
83
+ cnt += 1
84
+ if header and header_filter and cnt == 1:
85
+ continue
86
+ yield row
87
+ except Exception as e:
88
+ raise StopIteration
89
+ finally:
90
+ if not handle.closed:
91
+ handle.close()
92
+
93
+
94
+ def csv_reader(handle, header=True, header_filter=True):
95
+ '''
96
+ csv 读取器,适合大文件
97
+ :param handle:
98
+ :param header:
99
+ :param header_filter: 返回结果是���去掉头
100
+ :return:
101
+ '''
102
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
103
+ try:
104
+ # data = csv.reader((line.replace('\0','') for line in data_initial), delimiter=",")
105
+ # reader = csv.reader(handle)
106
+ reader = csv.reader((line.replace('\0', '') for line in handle))
107
+ cnt = 0
108
+ for row in reader:
109
+ cnt += 1
110
+ if header and header_filter and cnt == 1:
111
+ continue
112
+ yield row
113
+ except Exception as e:
114
+ raise StopIteration
115
+ finally:
116
+ if not handle.closed:
117
+ handle.close()
118
+
119
+
120
+ def txt_writer(dataset, handle, header=None):
121
+ '''
122
+ txt 写
123
+ :param dataset: 数据
124
+ :param handle: 文件
125
+ :param header: 头
126
+ :return:
127
+ '''
128
+ '''
129
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
130
+ try:
131
+ if header:
132
+ if isinstance(header, list):
133
+ handle.write(",".join(header) + "\n")
134
+ else:
135
+ handle.write(header + "\n")
136
+ print("header: %s" %header)
137
+ for row in dataset:
138
+ handle.write(str(row) + "\n")
139
+ except Exception as e:
140
+ raise e
141
+ finally:
142
+ if not handle.closed:
143
+ handle.close()
144
+ '''
145
+ with open(handle, "w") as wfp:
146
+ if header:
147
+ if isinstance(header, list):
148
+ wfp.write(",".join(header) + "\n")
149
+ else:
150
+ wfp.write(header + "\n")
151
+ for row in dataset:
152
+ wfp.write(str(row) + "\n")
153
+
154
+
155
+ def csv_writer(dataset, handle, header):
156
+ '''
157
+ csv 写,适合大文件
158
+ :param dataset: 数据
159
+ :param handle: 文件
160
+ :param header: 头
161
+ :return:
162
+ '''
163
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'w')
164
+ try:
165
+ writer = csv.writer(handle)
166
+ if header:
167
+ writer.writerow(header)
168
+ for row in dataset:
169
+ writer.writerow(row)
170
+ except Exception as e:
171
+ raise e
172
+ finally:
173
+ if not handle.closed:
174
+ handle.close()
175
+
176
+
177
+ def fasta_reader(handle, width=None):
178
+ """
179
+ Reads a FASTA file, yielding header, sequence pairs for each sequence recovered 适合大文件
180
+ args:
181
+ :handle (str, pathliob.Path, or file pointer) - fasta to read from
182
+ :width (int or None) - formats the sequence to have max `width` character per line.
183
+ If <= 0, processed as None. If None, there is no max width.
184
+ yields:
185
+ :(header, sequence) tuples
186
+ returns:
187
+ :None
188
+ """
189
+ FASTA_STOP_CODON = "*"
190
+
191
+ handle = handle if isinstance(handle, io.TextIOWrapper) else open(handle, 'r')
192
+ width = width if isinstance(width, int) and width > 0 else None
193
+ try:
194
+ header = None
195
+ for is_header, group in itertools.groupby(handle, lambda line: line.startswith(">")):
196
+ if is_header:
197
+ header = group.__next__().strip()
198
+ else:
199
+ seq = ''.join(line.strip() for line in group).strip().rstrip(FASTA_STOP_CODON)
200
+ if width is not None:
201
+ seq = textwrap.fill(seq, width)
202
+ yield header, seq
203
+ except Exception as e:
204
+ raise StopIteration
205
+ finally:
206
+ if not handle.closed:
207
+ handle.close()
208
+
209
+
210
+ def write_fasta(filepath, sequences):
211
+ '''
212
+ write fasta file
213
+ :param filepath: savepath
214
+ :param sequences: fasta sequence(each item: [id, seq])
215
+ :return:
216
+ '''
217
+
218
+ if sequences:
219
+ with open(filepath, "w") as output_handle:
220
+ if len(sequences[0]) > 1 and isinstance(sequences[0][0], str):
221
+ for row in sequences:
222
+ protein_id = row[0]
223
+ seq = row[1]
224
+ sequence = SeqRecord(Seq(seq, None), id=protein_id[1:] if protein_id and protein_id[0] == ">" else protein_id, description="")
225
+ SeqIO.write(sequence, output_handle, "fasta")
226
+ else:
227
+ for sequence in sequences:
228
+ SeqIO.write(sequence, output_handle, "fasta")
229
+
230
+
modeling_gplm.py CHANGED
@@ -1,1210 +1,1210 @@
1
- #!/usr/bin/env python
2
- # encoding: utf-8
3
-
4
- import math
5
- from typing import Dict, Optional, Sequence, Tuple, List, Union
6
- import uuid
7
- import torch
8
- import torch.nn.functional as F
9
- from torch import Tensor, nn
10
- from torch.nn import Parameter
11
-
12
-
13
- def gelu(x):
14
- """Implementation of the gelu activation function.
15
- OpenAI GPT's gelu: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
16
- """
17
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
-
19
-
20
- def symmetrize(x):
21
- "Make layer symmetric in final two dimensions, used for contact prediction."
22
- return x + x.transpose(-1, -2)
23
-
24
-
25
- def apc(x):
26
- "Perform average product correct, used for contact prediction."
27
- a1 = x.sum(-1, keepdims=True)
28
- a2 = x.sum(-2, keepdims=True)
29
- a12 = x.sum((-1, -2), keepdims=True)
30
-
31
- avg = a1 * a2
32
- avg.div_(a12) # in-place to reduce memory
33
- normalized = x - avg
34
- return normalized
35
-
36
-
37
- class LucaGPLM1LayerNorm(nn.Module):
38
- def __init__(self, hidden_size, eps=1e-12, affine=True):
39
- """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
40
- super().__init__()
41
- self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
42
- self.eps = eps
43
- self.affine = bool(affine)
44
- if self.affine:
45
- self.weight = nn.Parameter(torch.ones(hidden_size))
46
- self.bias = nn.Parameter(torch.zeros(hidden_size))
47
- else:
48
- self.weight, self.bias = None, None
49
-
50
- def forward(self, x):
51
- dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
52
- means = x.mean(dims, keepdim=True)
53
- x_zeromean = x - means
54
- variances = x_zeromean.pow(2).mean(dims, keepdim=True)
55
- x = x_zeromean / torch.sqrt(variances + self.eps)
56
- if self.affine:
57
- x = (self.weight * x) + self.bias
58
- return x
59
-
60
-
61
- from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
62
-
63
- class LucaGPLMTransformerLayer(nn.Module):
64
- """LucaGPLM Transformer layer block."""
65
-
66
- def __init__(
67
- self,
68
- embed_dim,
69
- ffn_embed_dim,
70
- attention_heads,
71
- add_bias_kv=True,
72
- use_lucagplm1b_layer_norm=False,
73
- use_rotary_embeddings: bool = False,
74
- ):
75
- '''
76
- Tramsformer-Encoder 层
77
- :param embed_dim: token embedding dim
78
- :param ffn_embed_dim: fully connected layer dim
79
- :param attention_heads: heads num
80
- :param add_bias_kv: key-value layer add bias
81
- :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
82
- :param use_rotary_embeddings: whether to use rotary embedding
83
- '''
84
- super().__init__()
85
- self.embed_dim = embed_dim
86
- self.ffn_embed_dim = ffn_embed_dim
87
- self.attention_heads = attention_heads
88
- self.use_rotary_embeddings = use_rotary_embeddings
89
- self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
90
-
91
- def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
92
- LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
93
-
94
- # pre layer norm
95
- self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
96
-
97
- self.self_attn = LucaGPLMMultiheadAttention(
98
- self.embed_dim,
99
- self.attention_heads,
100
- add_bias_kv=add_bias_kv,
101
- add_zero_attn=False,
102
- use_rotary_embeddings=self.use_rotary_embeddings,
103
- )
104
-
105
- # post layer norm
106
- self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
107
-
108
- # dimension increase by the fully connected layer
109
- self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
110
-
111
- # dimension reduction by the fully connected layer
112
- self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
113
-
114
- def forward(
115
- self,
116
- x,
117
- self_attn_mask=None,
118
- self_attn_padding_mask=None,
119
- need_head_weights=False
120
- ):
121
- residual = x
122
- x = self.pre_layer_norm(x)
123
- x, attn = self.self_attn(
124
- query=x,
125
- key=x,
126
- value=x,
127
- key_padding_mask=self_attn_padding_mask,
128
- need_weights=True,
129
- need_head_weights=need_head_weights,
130
- attn_mask=self_attn_mask,
131
- )
132
- x = residual + x
133
-
134
- residual = x
135
- x = self.post_layer_norm(x)
136
- x = gelu(self.fc1(x))
137
- x = self.fc2(x)
138
- x = residual + x
139
-
140
- return x, attn
141
-
142
-
143
- class AxialTransformerLayer(nn.Module):
144
- """Implements an Axial MSA Transformer block."""
145
- def __init__(
146
- self,
147
- embedding_dim: int = 768,
148
- ffn_embedding_dim: int = 3072,
149
- num_attention_heads: int = 8,
150
- dropout: float = 0.1,
151
- attention_dropout: float = 0.1,
152
- activation_dropout: float = 0.1,
153
- max_tokens_per_msa: int = 2**14,
154
- ) -> None:
155
- super().__init__()
156
-
157
- # Initialize parameters
158
- self.embedding_dim = embedding_dim
159
- self.dropout_prob = dropout
160
-
161
- row_self_attention = RowSelfAttention(
162
- embedding_dim,
163
- num_attention_heads,
164
- dropout=dropout,
165
- max_tokens_per_msa=max_tokens_per_msa,
166
- )
167
-
168
- column_self_attention = ColumnSelfAttention(
169
- embedding_dim,
170
- num_attention_heads,
171
- dropout=dropout,
172
- max_tokens_per_msa=max_tokens_per_msa,
173
- )
174
-
175
- feed_forward_layer = FeedForwardNetwork(
176
- embedding_dim,
177
- ffn_embedding_dim,
178
- activation_dropout=activation_dropout,
179
- max_tokens_per_msa=max_tokens_per_msa,
180
- )
181
-
182
- self.row_self_attention = self.build_residual(row_self_attention)
183
- self.column_self_attention = self.build_residual(column_self_attention)
184
- self.feed_forward_layer = self.build_residual(feed_forward_layer)
185
-
186
- def build_residual(self, layer: nn.Module):
187
- return NormalizedResidualBlock(
188
- layer,
189
- self.embedding_dim,
190
- self.dropout_prob,
191
- )
192
-
193
- def forward(
194
- self,
195
- x: torch.Tensor,
196
- self_attn_mask: Optional[torch.Tensor] = None,
197
- self_attn_padding_mask: Optional[torch.Tensor] = None,
198
- need_head_weights: bool = False,
199
- ):
200
- """
201
- LayerNorm is applied either before or after the self-attention/ffn
202
- modules similar to the original Transformer implementation.
203
- """
204
- x, row_attn = self.row_self_attention(
205
- x,
206
- self_attn_mask=self_attn_mask,
207
- self_attn_padding_mask=self_attn_padding_mask,
208
- )
209
- x, column_attn = self.column_self_attention(
210
- x,
211
- self_attn_mask=self_attn_mask,
212
- self_attn_padding_mask=self_attn_padding_mask,
213
- )
214
- x = self.feed_forward_layer(x)
215
- if need_head_weights:
216
- return x, column_attn, row_attn
217
- else:
218
- return x
219
-
220
-
221
- class LearnedPositionalEmbedding(nn.Embedding):
222
- """
223
- This module learns positional embeddings up to a fixed maximum size.
224
- Padding ids are ignored by either offsetting based on padding_idx
225
- or by setting padding_idx to None and ensuring that the appropriate
226
- position ids are passed to the forward function.
227
- """
228
-
229
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
230
- if padding_idx is not None:
231
- num_embeddings_ = num_embeddings + padding_idx + 1
232
- else:
233
- num_embeddings_ = num_embeddings
234
- super().__init__(num_embeddings_, embedding_dim, padding_idx)
235
- self.max_positions = num_embeddings
236
-
237
- def forward(self, input: torch.Tensor):
238
- """Input is expected to be of size [bsz x seqlen]."""
239
- if input.size(1) > self.max_positions:
240
- raise ValueError(
241
- f"Sequence length {input.size(1)} above maximum "
242
- f" sequence length of {self.max_positions}"
243
- )
244
- mask = input.ne(self.padding_idx).int()
245
- positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
246
- return F.embedding(
247
- positions,
248
- self.weight,
249
- self.padding_idx,
250
- self.max_norm,
251
- self.norm_type,
252
- self.scale_grad_by_freq,
253
- self.sparse,
254
- )
255
-
256
-
257
- class SinusoidalPositionalEmbedding(nn.Module):
258
- def __init__(self, embed_dim, padding_idx, learned=False):
259
- super().__init__()
260
- self.embed_dim = embed_dim
261
- self.padding_idx = padding_idx
262
- self.register_buffer("_float_tensor", torch.FloatTensor(1))
263
- self.weights = None
264
-
265
- def forward(self, x):
266
- bsz, seq_len = x.shape
267
- max_pos = self.padding_idx + 1 + seq_len
268
- if self.weights is None or max_pos > self.weights.size(0):
269
- self.weights = self.get_embedding(max_pos)
270
- self.weights = self.weights.type_as(self._float_tensor)
271
-
272
- positions = self.make_positions(x)
273
- return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
274
-
275
- def make_positions(self, x):
276
- mask = x.ne(self.padding_idx)
277
- range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
278
- positions = range_buf.expand_as(x)
279
- return positions * mask.long() + self.padding_idx * (1 - mask.long())
280
-
281
- def get_embedding(self, num_embeddings):
282
- half_dim = self.embed_dim // 2
283
- emb = math.log(10000) / (half_dim - 1)
284
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
285
- emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
286
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
287
- if self.embed_dim % 2 == 1:
288
- # zero pad
289
- emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
290
- if self.padding_idx is not None:
291
- emb[self.padding_idx, :] = 0
292
- return emb
293
-
294
-
295
- class RobertaLMHead(nn.Module):
296
- """Head for masked language modeling."""
297
-
298
- def __init__(self, embed_dim, output_dim, weight):
299
- super().__init__()
300
- self.dense = nn.Linear(embed_dim, embed_dim)
301
- self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
302
- self.weight = weight
303
- self.bias = nn.Parameter(torch.zeros(output_dim))
304
-
305
- def forward(self, features):
306
- x = self.dense(features)
307
- x = gelu(x)
308
- x = self.layer_norm(x)
309
- # project back to size of vocabulary with bias
310
- x = F.linear(x, self.weight) + self.bias
311
- return x
312
-
313
-
314
- class ContactPredictionHead(nn.Module):
315
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
316
-
317
- def __init__(
318
- self,
319
- in_features: int,
320
- prepend_bos: bool,
321
- append_eos: bool,
322
- bias=True,
323
- eos_idx: Optional[int] = None,
324
- ):
325
- super().__init__()
326
- self.in_features = in_features
327
- self.prepend_bos = prepend_bos
328
- self.append_eos = append_eos
329
- if append_eos and eos_idx is None:
330
- raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
331
- self.eos_idx = eos_idx
332
- self.regression = nn.Linear(in_features, 1, bias)
333
- self.activation = nn.Sigmoid()
334
-
335
- def forward(self, tokens, attentions):
336
- # remove eos token attentions
337
- if self.append_eos:
338
- eos_mask = tokens.ne(self.eos_idx).to(attentions)
339
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
340
- attentions = attentions * eos_mask[:, None, None, :, :]
341
- attentions = attentions[..., :-1, :-1]
342
- # remove cls token attentions
343
- if self.prepend_bos:
344
- attentions = attentions[..., 1:, 1:]
345
- batch_size, layers, heads, seqlen, _ = attentions.size()
346
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
347
-
348
- # features: B x C x T x T
349
- attentions = attentions.to(
350
- self.regression.weight.device
351
- ) # attentions always float32, may need to convert to float16
352
- attentions = apc(symmetrize(attentions))
353
- attentions = attentions.permute(0, 2, 3, 1)
354
- return self.activation(self.regression(attentions).squeeze(3))
355
-
356
-
357
- class NormalizedResidualBlock(nn.Module):
358
- def __init__(
359
- self,
360
- layer: nn.Module,
361
- embedding_dim: int,
362
- dropout: float = 0.1,
363
- ):
364
- super().__init__()
365
- self.embedding_dim = embedding_dim
366
-
367
- self.layer = layer
368
- self.dropout_module = nn.Dropout(
369
- dropout,
370
- )
371
- self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
372
-
373
- def forward(self, x, *args, **kwargs):
374
- residual = x
375
- x = self.layer_norm(x)
376
- outputs = self.layer(x, *args, **kwargs)
377
- if isinstance(outputs, tuple):
378
- x, *out = outputs
379
- else:
380
- x = outputs
381
- out = None
382
-
383
- x = self.dropout_module(x)
384
- x = residual + x
385
-
386
- if out is not None:
387
- return (x,) + tuple(out)
388
- else:
389
- return x
390
-
391
-
392
- class FeedForwardNetwork(nn.Module):
393
- def __init__(
394
- self,
395
- embedding_dim: int,
396
- ffn_embedding_dim: int,
397
- activation_dropout: float = 0.1,
398
- max_tokens_per_msa: int = 2**14,
399
- ):
400
- super().__init__()
401
- self.embedding_dim = embedding_dim
402
- self.ffn_embedding_dim = ffn_embedding_dim
403
- self.max_tokens_per_msa = max_tokens_per_msa
404
- self.activation_fn = nn.GELU()
405
- self.activation_dropout_module = nn.Dropout(
406
- activation_dropout,
407
- )
408
- self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
409
- self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
410
-
411
- def forward(self, x):
412
- x = self.activation_fn(self.fc1(x))
413
- x = self.activation_dropout_module(x)
414
- x = self.fc2(x)
415
- return x
416
-
417
-
418
- class RowSelfAttention(nn.Module):
419
- """Compute self-attention over rows of a 2D input."""
420
-
421
- def __init__(
422
- self,
423
- embed_dim,
424
- num_heads,
425
- dropout=0.0,
426
- max_tokens_per_msa: int = 2 ** 16,
427
- ):
428
- super().__init__()
429
- self.num_heads = num_heads
430
- self.dropout = dropout
431
- self.head_dim = embed_dim // num_heads
432
- self.scaling = self.head_dim ** -0.5
433
- self.max_tokens_per_msa = max_tokens_per_msa
434
- self.attn_shape = "hnij"
435
-
436
- self.k_proj = nn.Linear(embed_dim, embed_dim)
437
- self.v_proj = nn.Linear(embed_dim, embed_dim)
438
- self.q_proj = nn.Linear(embed_dim, embed_dim)
439
-
440
- self.out_proj = nn.Linear(embed_dim, embed_dim)
441
- self.dropout_module = nn.Dropout(dropout)
442
-
443
- def align_scaling(self, q):
444
- num_rows = q.size(0)
445
- return self.scaling / math.sqrt(num_rows)
446
-
447
- def _batched_forward(
448
- self,
449
- x,
450
- self_attn_mask=None,
451
- self_attn_padding_mask=None,
452
- ):
453
- num_rows, num_cols, batch_size, embed_dim = x.size()
454
- max_rows = max(1, self.max_tokens_per_msa // num_cols)
455
- attns = 0
456
- scaling = self.align_scaling(x)
457
- for start in range(0, num_rows, max_rows):
458
- attn_weights = self.compute_attention_weights(
459
- x[start : start + max_rows],
460
- scaling,
461
- self_attn_mask=self_attn_mask,
462
- self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
463
- if self_attn_padding_mask is not None
464
- else None,
465
- )
466
- attns += attn_weights
467
- attn_probs = attns.softmax(-1)
468
- attn_probs = self.dropout_module(attn_probs)
469
-
470
- outputs = []
471
- for start in range(0, num_rows, max_rows):
472
- output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
473
- outputs.append(output)
474
-
475
- output = torch.cat(outputs, 0)
476
- return output, attn_probs
477
-
478
- def compute_attention_weights(
479
- self,
480
- x,
481
- scaling: float,
482
- self_attn_mask=None,
483
- self_attn_padding_mask=None,
484
- ):
485
- num_rows, num_cols, batch_size, embed_dim = x.size()
486
- q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
487
- k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
488
- q *= scaling
489
- if self_attn_padding_mask is not None:
490
- # Zero out any padded aligned positions - this is important since
491
- # we take a sum across the alignment axis.
492
- q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
493
-
494
- attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
495
-
496
- if self_attn_mask is not None:
497
- raise NotImplementedError
498
- # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
499
-
500
- if self_attn_padding_mask is not None:
501
- attn_weights = attn_weights.masked_fill(
502
- self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
503
- -10000,
504
- )
505
-
506
- return attn_weights
507
-
508
- def compute_attention_update(
509
- self,
510
- x,
511
- attn_probs,
512
- ):
513
- num_rows, num_cols, batch_size, embed_dim = x.size()
514
- v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
515
- context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
516
- context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
517
- output = self.out_proj(context)
518
- return output
519
-
520
- def forward(
521
- self,
522
- x,
523
- self_attn_mask=None,
524
- self_attn_padding_mask=None,
525
- ):
526
- num_rows, num_cols, batch_size, embed_dim = x.size()
527
- if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
528
- return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
529
- else:
530
- scaling = self.align_scaling(x)
531
- attn_weights = self.compute_attention_weights(
532
- x, scaling, self_attn_mask, self_attn_padding_mask
533
- )
534
- attn_probs = attn_weights.softmax(-1)
535
- attn_probs = self.dropout_module(attn_probs)
536
- output = self.compute_attention_update(x, attn_probs)
537
- return output, attn_probs
538
-
539
-
540
- class ColumnSelfAttention(nn.Module):
541
- """Compute self-attention over columns of a 2D input."""
542
-
543
- def __init__(
544
- self,
545
- embed_dim,
546
- num_heads,
547
- dropout=0.0,
548
- max_tokens_per_msa: int = 2 ** 16,
549
- ):
550
- super().__init__()
551
-
552
- self.num_heads = num_heads
553
- self.dropout = dropout
554
- self.head_dim = embed_dim // num_heads
555
- self.scaling = self.head_dim ** -0.5
556
- self.max_tokens_per_msa = max_tokens_per_msa
557
-
558
- self.k_proj = nn.Linear(embed_dim, embed_dim)
559
- self.v_proj = nn.Linear(embed_dim, embed_dim)
560
- self.q_proj = nn.Linear(embed_dim, embed_dim)
561
-
562
- self.out_proj = nn.Linear(embed_dim, embed_dim)
563
- self.dropout_module = nn.Dropout(dropout)
564
-
565
- def _batched_forward(
566
- self,
567
- x,
568
- self_attn_mask=None,
569
- self_attn_padding_mask=None,
570
- ):
571
- num_rows, num_cols, batch_size, embed_dim = x.size()
572
- max_cols = max(1, self.max_tokens_per_msa // num_rows)
573
- outputs = []
574
- attns = []
575
- for start in range(0, num_cols, max_cols):
576
- output, attn = self(
577
- x[:, start : start + max_cols],
578
- self_attn_mask=self_attn_mask,
579
- self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
580
- if self_attn_padding_mask is not None
581
- else None,
582
- )
583
- outputs.append(output)
584
- attns.append(attn)
585
- output = torch.cat(outputs, 1)
586
- attns = torch.cat(attns, 1)
587
- return output, attns
588
-
589
- def compute_attention_update(
590
- self,
591
- x,
592
- self_attn_mask=None,
593
- self_attn_padding_mask=None,
594
- ):
595
- num_rows, num_cols, batch_size, embed_dim = x.size()
596
- if num_rows == 1:
597
- # if there is only 1 position, this is equivalent and doesn't break with padding
598
- attn_probs = torch.ones(
599
- self.num_heads,
600
- num_cols,
601
- batch_size,
602
- num_rows,
603
- num_rows,
604
- device=x.device,
605
- dtype=x.dtype,
606
- )
607
- output = self.out_proj(self.v_proj(x))
608
- else:
609
- q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
610
- k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
611
- v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
612
- q *= self.scaling
613
-
614
- attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
615
-
616
- if self_attn_mask is not None:
617
- raise NotImplementedError
618
- if self_attn_padding_mask is not None:
619
- attn_weights = attn_weights.masked_fill(
620
- self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
621
- -10000,
622
- )
623
-
624
- attn_probs = attn_weights.softmax(-1)
625
- attn_probs = self.dropout_module(attn_probs)
626
- context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
627
- context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
628
- output = self.out_proj(context)
629
- return output, attn_probs
630
-
631
- def forward(
632
- self,
633
- x,
634
- self_attn_mask=None,
635
- self_attn_padding_mask=None,
636
- ):
637
- num_rows, num_cols, batch_size, embed_dim = x.size()
638
- # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
639
- if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
640
- return self._batched_forward(
641
- x,
642
- self_attn_mask,
643
- self_attn_padding_mask,
644
- )
645
- else:
646
- return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
647
-
648
-
649
- def utils_softmax(x, dim: int, onnx_trace: bool = False):
650
- if onnx_trace:
651
- return F.softmax(x.float(), dim=dim)
652
- else:
653
- return F.softmax(x, dim=dim, dtype=torch.float32)
654
-
655
-
656
- class FairseqIncrementalState(object):
657
- def __init__(self, *args, **kwargs):
658
- super().__init__(*args, **kwargs)
659
- self.init_incremental_state()
660
-
661
- def init_incremental_state(self):
662
- self._incremental_state_id = str(uuid.uuid4())
663
-
664
- def _get_full_incremental_state_key(self, key: str) -> str:
665
- return "{}.{}".format(self._incremental_state_id, key)
666
-
667
- def get_incremental_state(
668
- self,
669
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
670
- key: str,
671
- ) -> Optional[Dict[str, Optional[Tensor]]]:
672
- """Helper for getting incremental state for an nn.Module."""
673
- full_key = self._get_full_incremental_state_key(key)
674
- if incremental_state is None or full_key not in incremental_state:
675
- return None
676
- return incremental_state[full_key]
677
-
678
- def set_incremental_state(
679
- self,
680
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
681
- key: str,
682
- value: Dict[str, Optional[Tensor]],
683
- ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
684
- """Helper for setting incremental state for an nn.Module."""
685
- if incremental_state is not None:
686
- full_key = self._get_full_incremental_state_key(key)
687
- incremental_state[full_key] = value
688
- return incremental_state
689
-
690
-
691
- def with_incremental_state(cls):
692
- cls.__bases__ = (FairseqIncrementalState,) + tuple(
693
- b for b in cls.__bases__ if b != FairseqIncrementalState
694
- )
695
- return cls
696
-
697
-
698
- @with_incremental_state
699
- class LucaGPLMMultiheadAttention(nn.Module):
700
- """Multi-headed attention.
701
-
702
- See "Attention Is All You Need" for more details.
703
- """
704
-
705
- def __init__(
706
- self,
707
- embed_dim,
708
- num_heads,
709
- kdim=None,
710
- vdim=None,
711
- dropout=0.0,
712
- bias=True,
713
- add_bias_kv: bool = False,
714
- add_zero_attn: bool = False,
715
- self_attention: bool = False,
716
- encoder_decoder_attention: bool = False,
717
- use_rotary_embeddings: bool = False,
718
- ):
719
- super().__init__()
720
- self.embed_dim = embed_dim
721
- self.kdim = kdim if kdim is not None else embed_dim
722
- self.vdim = vdim if vdim is not None else embed_dim
723
- self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
724
-
725
- self.num_heads = num_heads
726
- self.dropout = dropout
727
- self.head_dim = embed_dim // num_heads
728
- assert (
729
- self.head_dim * num_heads == self.embed_dim
730
- ), "embed_dim must be divisible by num_heads"
731
- self.scaling = self.head_dim**-0.5
732
-
733
- self.self_attention = self_attention
734
- self.encoder_decoder_attention = encoder_decoder_attention
735
-
736
- assert not self.self_attention or self.qkv_same_dim, (
737
- "Self-attention requires query, key and " "value to be of the same size"
738
- )
739
-
740
- self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
741
- self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
742
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
743
-
744
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
745
-
746
- if add_bias_kv:
747
- self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
748
- self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
749
- else:
750
- self.bias_k = self.bias_v = None
751
-
752
- self.add_zero_attn = add_zero_attn
753
-
754
- self.reset_parameters()
755
-
756
- self.onnx_trace = False
757
- self.rot_emb = None
758
- if use_rotary_embeddings:
759
- self.rot_emb = RotaryEmbedding(dim=self.head_dim)
760
-
761
- self.enable_torch_version = False
762
- if hasattr(F, "multi_head_attention_forward"):
763
- self.enable_torch_version = True
764
- else:
765
- self.enable_torch_version = False
766
-
767
- def prepare_for_onnx_export_(self):
768
- self.onnx_trace = True
769
-
770
- def reset_parameters(self):
771
- '''
772
- if self.qkv_same_dim:
773
- # Empirically observed the convergence to be much better with
774
- # the scaled initialization
775
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
776
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
777
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
778
- else:
779
- nn.init.xavier_uniform_(self.k_proj.weight)
780
- nn.init.xavier_uniform_(self.v_proj.weight)
781
- nn.init.xavier_uniform_(self.q_proj.weight)
782
- '''
783
- nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
784
- nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
785
- nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
786
-
787
- nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
788
- # nn.init.xavier_uniform_(self.out_proj.weight)
789
- if self.out_proj.bias is not None:
790
- nn.init.constant_(self.out_proj.bias, 0.0)
791
- if self.bias_k is not None:
792
- nn.init.xavier_normal_(self.bias_k)
793
- if self.bias_v is not None:
794
- nn.init.xavier_normal_(self.bias_v)
795
-
796
- def forward(
797
- self,
798
- query,
799
- key: Optional[Tensor],
800
- value: Optional[Tensor],
801
- key_padding_mask: Optional[Tensor] = None,
802
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
803
- need_weights: bool = True,
804
- static_kv: bool = False,
805
- attn_mask: Optional[Tensor] = None,
806
- before_softmax: bool = False,
807
- need_head_weights: bool = False,
808
- ) -> Tuple[Tensor, Optional[Tensor]]:
809
- """Input shape: Time x Batch x Channel
810
-
811
- Args:
812
- key_padding_mask (ByteTensor, optional): mask to exclude
813
- keys that are pads, of shape `(batch, src_len)`, where
814
- padding elements are indicated by 1s.
815
- need_weights (bool, optional): return the attention weights,
816
- averaged over heads (default: False).
817
- attn_mask (ByteTensor, optional): typically used to
818
- implement causal attention, where the mask prevents the
819
- attention from looking forward in time (default: None).
820
- before_softmax (bool, optional): return the raw attention
821
- weights and values before the attention softmax.
822
- need_head_weights (bool, optional): return the attention
823
- weights for each head. Implies *need_weights*. Default:
824
- return the average attention weights over all heads.
825
- """
826
- if need_head_weights:
827
- need_weights = True
828
-
829
- tgt_len, bsz, embed_dim = query.size()
830
- assert embed_dim == self.embed_dim
831
- assert list(query.size()) == [tgt_len, bsz, embed_dim]
832
-
833
- if (
834
- not self.rot_emb
835
- and self.enable_torch_version
836
- and not self.onnx_trace
837
- and incremental_state is None
838
- and not static_kv
839
- # A workaround for quantization to work. Otherwise JIT compilation
840
- # treats bias in linear module as method.
841
- and not torch.jit.is_scripting()
842
- and not need_head_weights
843
- ):
844
- assert key is not None and value is not None
845
- return F.multi_head_attention_forward(
846
- query,
847
- key,
848
- value,
849
- self.embed_dim,
850
- self.num_heads,
851
- torch.empty([0]),
852
- torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
853
- self.bias_k,
854
- self.bias_v,
855
- self.add_zero_attn,
856
- self.dropout,
857
- self.out_proj.weight,
858
- self.out_proj.bias,
859
- self.training,
860
- key_padding_mask,
861
- need_weights,
862
- attn_mask,
863
- use_separate_proj_weight=True,
864
- q_proj_weight=self.q_proj.weight,
865
- k_proj_weight=self.k_proj.weight,
866
- v_proj_weight=self.v_proj.weight,
867
- )
868
- if incremental_state is not None:
869
- saved_state = self._get_input_buffer(incremental_state)
870
- if saved_state is not None and "prev_key" in saved_state:
871
- # previous time steps are cached - no need to recompute
872
- # key and value if they are static
873
- if static_kv:
874
- assert self.encoder_decoder_attention and not self.self_attention
875
- key = value = None
876
- else:
877
- saved_state = None
878
-
879
- if self.self_attention:
880
- q = self.q_proj(query)
881
- k = self.k_proj(query)
882
- v = self.v_proj(query)
883
- elif self.encoder_decoder_attention:
884
- # encoder-decoder attention
885
- q = self.q_proj(query)
886
- if key is None:
887
- assert value is None
888
- k = v = None
889
- else:
890
- k = self.k_proj(key)
891
- v = self.v_proj(key)
892
-
893
- else:
894
- assert key is not None and value is not None
895
- q = self.q_proj(query)
896
- k = self.k_proj(key)
897
- v = self.v_proj(value)
898
- q *= self.scaling
899
-
900
- if self.bias_k is not None:
901
- assert self.bias_v is not None
902
- k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
903
- v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
904
- if attn_mask is not None:
905
- attn_mask = torch.cat(
906
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
907
- )
908
- if key_padding_mask is not None:
909
- key_padding_mask = torch.cat(
910
- [
911
- key_padding_mask,
912
- key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
913
- ],
914
- dim=1,
915
- )
916
-
917
- q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
918
- if k is not None:
919
- k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
920
- if v is not None:
921
- v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
922
-
923
- if saved_state is not None:
924
- # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
925
- if "prev_key" in saved_state:
926
- _prev_key = saved_state["prev_key"]
927
- assert _prev_key is not None
928
- prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
929
- if static_kv:
930
- k = prev_key
931
- else:
932
- assert k is not None
933
- k = torch.cat([prev_key, k], dim=1)
934
- if "prev_value" in saved_state:
935
- _prev_value = saved_state["prev_value"]
936
- assert _prev_value is not None
937
- prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
938
- if static_kv:
939
- v = prev_value
940
- else:
941
- assert v is not None
942
- v = torch.cat([prev_value, v], dim=1)
943
- prev_key_padding_mask: Optional[Tensor] = None
944
- if "prev_key_padding_mask" in saved_state:
945
- prev_key_padding_mask = saved_state["prev_key_padding_mask"]
946
- assert k is not None and v is not None
947
- key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
948
- key_padding_mask=key_padding_mask,
949
- prev_key_padding_mask=prev_key_padding_mask,
950
- batch_size=bsz,
951
- src_len=k.size(1),
952
- static_kv=static_kv,
953
- )
954
-
955
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
956
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
957
- saved_state["prev_key_padding_mask"] = key_padding_mask
958
- # In this branch incremental_state is never None
959
- assert incremental_state is not None
960
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
961
- assert k is not None
962
- src_len = k.size(1)
963
-
964
- # This is part of a workaround to get around fork/join parallelism
965
- # not supporting Optional types.
966
- if key_padding_mask is not None and key_padding_mask.dim() == 0:
967
- key_padding_mask = None
968
-
969
- if key_padding_mask is not None:
970
- assert key_padding_mask.size(0) == bsz
971
- assert key_padding_mask.size(1) == src_len
972
-
973
- if self.add_zero_attn:
974
- assert v is not None
975
- src_len += 1
976
- k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
977
- v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
978
- if attn_mask is not None:
979
- attn_mask = torch.cat(
980
- [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
981
- )
982
- if key_padding_mask is not None:
983
- key_padding_mask = torch.cat(
984
- [
985
- key_padding_mask,
986
- torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
987
- ],
988
- dim=1,
989
- )
990
-
991
- if self.rot_emb:
992
- q, k = self.rot_emb(q, k)
993
-
994
- attn_weights = torch.bmm(q, k.transpose(1, 2))
995
- attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
996
-
997
- assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
998
-
999
- if attn_mask is not None:
1000
- attn_mask = attn_mask.unsqueeze(0)
1001
- if self.onnx_trace:
1002
- attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
1003
- attn_weights += attn_mask
1004
-
1005
- if key_padding_mask is not None:
1006
- # don't attend to padding symbols
1007
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
1008
- attn_weights = attn_weights.masked_fill(
1009
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
1010
- )
1011
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
1012
-
1013
- if before_softmax:
1014
- return attn_weights, v
1015
-
1016
- attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
1017
- attn_weights = attn_weights_float.type_as(attn_weights)
1018
- attn_probs = F.dropout(
1019
- attn_weights_float.type_as(attn_weights),
1020
- p=self.dropout,
1021
- training=self.training,
1022
- )
1023
- assert v is not None
1024
- attn = torch.bmm(attn_probs, v)
1025
- assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
1026
- if self.onnx_trace and attn.size(1) == 1:
1027
- # when ONNX tracing a single decoder step (sequence length == 1)
1028
- # the transpose is a no-op copy before view, thus unnecessary
1029
- attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
1030
- else:
1031
- attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
1032
- attn = self.out_proj(attn)
1033
- attn_weights: Optional[Tensor] = None
1034
- if need_weights:
1035
- attn_weights = attn_weights_float.view(
1036
- bsz, self.num_heads, tgt_len, src_len
1037
- ).type_as(attn).transpose(1, 0)
1038
- if not need_head_weights:
1039
- # average attention weights over heads
1040
- attn_weights = attn_weights.mean(dim=0)
1041
-
1042
- return attn, attn_weights
1043
-
1044
- @staticmethod
1045
- def _append_prev_key_padding_mask(
1046
- key_padding_mask: Optional[Tensor],
1047
- prev_key_padding_mask: Optional[Tensor],
1048
- batch_size: int,
1049
- src_len: int,
1050
- static_kv: bool,
1051
- ) -> Optional[Tensor]:
1052
- # saved key padding masks have shape (bsz, seq_len)
1053
- if prev_key_padding_mask is not None and static_kv:
1054
- new_key_padding_mask = prev_key_padding_mask
1055
- elif prev_key_padding_mask is not None and key_padding_mask is not None:
1056
- new_key_padding_mask = torch.cat(
1057
- [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1058
- )
1059
- # During incremental decoding, as the padding token enters and
1060
- # leaves the frame, there will be a time when prev or current
1061
- # is None
1062
- elif prev_key_padding_mask is not None:
1063
- filler = torch.zeros(
1064
- (batch_size, src_len - prev_key_padding_mask.size(1)),
1065
- device=prev_key_padding_mask.device,
1066
- )
1067
- new_key_padding_mask = torch.cat(
1068
- [prev_key_padding_mask.float(), filler.float()], dim=1
1069
- )
1070
- elif key_padding_mask is not None:
1071
- filler = torch.zeros(
1072
- (batch_size, src_len - key_padding_mask.size(1)),
1073
- device=key_padding_mask.device,
1074
- )
1075
- new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1076
- else:
1077
- new_key_padding_mask = prev_key_padding_mask
1078
- return new_key_padding_mask
1079
-
1080
- @torch.jit.export
1081
- def reorder_incremental_state(
1082
- self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1083
- ):
1084
- """Reorder buffered internal state (for incremental generation)."""
1085
- input_buffer = self._get_input_buffer(incremental_state)
1086
- if input_buffer is not None:
1087
- for k in input_buffer.keys():
1088
- input_buffer_k = input_buffer[k]
1089
- if input_buffer_k is not None:
1090
- if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1091
- 0
1092
- ):
1093
- break
1094
- input_buffer[k] = input_buffer_k.index_select(0, new_order)
1095
- incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1096
- return incremental_state
1097
-
1098
- def _get_input_buffer(
1099
- self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1100
- ) -> Dict[str, Optional[Tensor]]:
1101
- result = self.get_incremental_state(incremental_state, "attn_state")
1102
- if result is not None:
1103
- return result
1104
- else:
1105
- empty_result: Dict[str, Optional[Tensor]] = {}
1106
- return empty_result
1107
-
1108
- def _set_input_buffer(
1109
- self,
1110
- incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1111
- buffer: Dict[str, Optional[Tensor]],
1112
- ):
1113
- return self.set_incremental_state(incremental_state, "attn_state", buffer)
1114
-
1115
- def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1116
- return attn_weights
1117
-
1118
- def upgrade_state_dict_named(self, state_dict, name):
1119
- prefix = name + "." if name != "" else ""
1120
- items_to_add = {}
1121
- keys_to_remove = []
1122
- for k in state_dict.keys():
1123
- if k.endswith(prefix + "in_proj_weight"):
1124
- # in_proj_weight used to be q + k + v with same dimensions
1125
- dim = int(state_dict[k].shape[0] / 3)
1126
- items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1127
- items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1128
- items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1129
-
1130
- keys_to_remove.append(k)
1131
-
1132
- k_bias = prefix + "in_proj_bias"
1133
- if k_bias in state_dict.keys():
1134
- dim = int(state_dict[k].shape[0] / 3)
1135
- items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1136
- items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1137
- items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1138
-
1139
- keys_to_remove.append(prefix + "in_proj_bias")
1140
-
1141
- for k in keys_to_remove:
1142
- del state_dict[k]
1143
-
1144
- for key, value in items_to_add.items():
1145
- state_dict[key] = value
1146
-
1147
-
1148
- def rotate_half(x):
1149
- x1, x2 = x.chunk(2, dim=-1)
1150
- return torch.cat((-x2, x1), dim=-1)
1151
-
1152
-
1153
- def apply_rotary_pos_emb(x, cos, sin):
1154
- cos = cos[:, : x.shape[-2], :]
1155
- sin = sin[:, : x.shape[-2], :]
1156
-
1157
- return (x * cos) + (rotate_half(x) * sin)
1158
-
1159
-
1160
- class RotaryEmbedding(torch.nn.Module):
1161
- """
1162
- The rotary position embeddings from RoFormer_ (Su et. al).
1163
- A crucial insight from the method is that the query and keys are
1164
- transformed by rotation matrices which depend on the relative positions.
1165
- Other implementations are available in the Rotary Transformer repo_ and in
1166
- GPT-NeoX_, GPT-NeoX was an inspiration
1167
- .. _RoFormer: https://arxiv.org/abs/2104.09864
1168
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
1169
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
1170
- .. warning: Please note that this embedding is not registered on purpose, as it is transformative
1171
- (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
1172
- """
1173
-
1174
- def __init__(self, dim: int, *_, **__):
1175
- super().__init__()
1176
- # Generate and save the inverse frequency buffer (non trainable)
1177
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1178
- self.register_buffer("inv_freq", inv_freq)
1179
-
1180
- self._seq_len_cached = None
1181
- self._cos_cached = None
1182
- self._sin_cached = None
1183
-
1184
- def _update_cos_sin_tables(self, x, seq_dimension=1):
1185
- seq_len = x.shape[seq_dimension]
1186
-
1187
- # Reset the tables if the sequence length has changed,
1188
- # or if we're on a new device (possibly due to tracing for instance)
1189
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1190
- self._seq_len_cached = seq_len
1191
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1192
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1193
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1194
-
1195
- self._cos_cached = emb.cos()[None, :, :]
1196
- self._sin_cached = emb.sin()[None, :, :]
1197
-
1198
- return self._cos_cached, self._sin_cached
1199
-
1200
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1201
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1202
-
1203
- return (
1204
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1205
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1206
- )
1207
-
1208
-
1209
-
1210
-
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ import math
5
+ from typing import Dict, Optional, Sequence, Tuple, List, Union
6
+ import uuid
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+ from torch.nn import Parameter
11
+
12
+
13
+ def gelu(x):
14
+ """Implementation of the gelu activation function.
15
+ OpenAI GPT's gelu: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
16
+ """
17
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
+
19
+
20
+ def symmetrize(x):
21
+ "Make layer symmetric in final two dimensions, used for contact prediction."
22
+ return x + x.transpose(-1, -2)
23
+
24
+
25
+ def apc(x):
26
+ "Perform average product correct, used for contact prediction."
27
+ a1 = x.sum(-1, keepdims=True)
28
+ a2 = x.sum(-2, keepdims=True)
29
+ a12 = x.sum((-1, -2), keepdims=True)
30
+
31
+ avg = a1 * a2
32
+ avg.div_(a12) # in-place to reduce memory
33
+ normalized = x - avg
34
+ return normalized
35
+
36
+
37
+ class LucaGPLM1LayerNorm(nn.Module):
38
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
39
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
40
+ super().__init__()
41
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
42
+ self.eps = eps
43
+ self.affine = bool(affine)
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(hidden_size))
46
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
47
+ else:
48
+ self.weight, self.bias = None, None
49
+
50
+ def forward(self, x):
51
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
52
+ means = x.mean(dims, keepdim=True)
53
+ x_zeromean = x - means
54
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
55
+ x = x_zeromean / torch.sqrt(variances + self.eps)
56
+ if self.affine:
57
+ x = (self.weight * x) + self.bias
58
+ return x
59
+
60
+ from torch.nn import LayerNorm as LucaGPLM1bLayerNorm
61
+
62
+
63
+ class LucaGPLMTransformerLayer(nn.Module):
64
+ """LucaGPLM Transformer layer block."""
65
+
66
+ def __init__(
67
+ self,
68
+ embed_dim,
69
+ ffn_embed_dim,
70
+ attention_heads,
71
+ add_bias_kv=True,
72
+ use_lucagplm1b_layer_norm=False,
73
+ use_rotary_embeddings: bool = False,
74
+ ):
75
+ '''
76
+ Tramsformer-Encoder 层
77
+ :param embed_dim: token embedding dim
78
+ :param ffn_embed_dim: fully connected layer dim
79
+ :param attention_heads: heads num
80
+ :param add_bias_kv: key-value layer add bias
81
+ :param use_lucagplm1b_layer_norm: whether to use lucagplm 1b layer norm
82
+ :param use_rotary_embeddings: whether to use rotary embedding
83
+ '''
84
+ super().__init__()
85
+ self.embed_dim = embed_dim
86
+ self.ffn_embed_dim = ffn_embed_dim
87
+ self.attention_heads = attention_heads
88
+ self.use_rotary_embeddings = use_rotary_embeddings
89
+ self._init_submodules(add_bias_kv, use_lucagplm1b_layer_norm)
90
+
91
+ def _init_submodules(self, add_bias_kv, use_lucagplm1b_layer_norm):
92
+ LucaGPLMLayerNorm = LucaGPLM1bLayerNorm if use_lucagplm1b_layer_norm else LucaGPLM1LayerNorm
93
+
94
+ # pre layer norm
95
+ self.pre_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
96
+
97
+ self.self_attn = LucaGPLMMultiheadAttention(
98
+ self.embed_dim,
99
+ self.attention_heads,
100
+ add_bias_kv=add_bias_kv,
101
+ add_zero_attn=False,
102
+ use_rotary_embeddings=self.use_rotary_embeddings,
103
+ )
104
+
105
+ # post layer norm
106
+ self.post_layer_norm = LucaGPLMLayerNorm(self.embed_dim)
107
+
108
+ # dimension increase by the fully connected layer
109
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
110
+
111
+ # dimension reduction by the fully connected layer
112
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
113
+
114
+ def forward(
115
+ self,
116
+ x,
117
+ self_attn_mask=None,
118
+ self_attn_padding_mask=None,
119
+ need_head_weights=False
120
+ ):
121
+ residual = x
122
+ x = self.pre_layer_norm(x)
123
+ x, attn = self.self_attn(
124
+ query=x,
125
+ key=x,
126
+ value=x,
127
+ key_padding_mask=self_attn_padding_mask,
128
+ need_weights=True,
129
+ need_head_weights=need_head_weights,
130
+ attn_mask=self_attn_mask,
131
+ )
132
+ x = residual + x
133
+
134
+ residual = x
135
+ x = self.post_layer_norm(x)
136
+ x = gelu(self.fc1(x))
137
+ x = self.fc2(x)
138
+ x = residual + x
139
+
140
+ return x, attn
141
+
142
+
143
+ class AxialTransformerLayer(nn.Module):
144
+ """Implements an Axial MSA Transformer block."""
145
+ def __init__(
146
+ self,
147
+ embedding_dim: int = 768,
148
+ ffn_embedding_dim: int = 3072,
149
+ num_attention_heads: int = 8,
150
+ dropout: float = 0.1,
151
+ attention_dropout: float = 0.1,
152
+ activation_dropout: float = 0.1,
153
+ max_tokens_per_msa: int = 2**14,
154
+ ) -> None:
155
+ super().__init__()
156
+
157
+ # Initialize parameters
158
+ self.embedding_dim = embedding_dim
159
+ self.dropout_prob = dropout
160
+
161
+ row_self_attention = RowSelfAttention(
162
+ embedding_dim,
163
+ num_attention_heads,
164
+ dropout=dropout,
165
+ max_tokens_per_msa=max_tokens_per_msa,
166
+ )
167
+
168
+ column_self_attention = ColumnSelfAttention(
169
+ embedding_dim,
170
+ num_attention_heads,
171
+ dropout=dropout,
172
+ max_tokens_per_msa=max_tokens_per_msa,
173
+ )
174
+
175
+ feed_forward_layer = FeedForwardNetwork(
176
+ embedding_dim,
177
+ ffn_embedding_dim,
178
+ activation_dropout=activation_dropout,
179
+ max_tokens_per_msa=max_tokens_per_msa,
180
+ )
181
+
182
+ self.row_self_attention = self.build_residual(row_self_attention)
183
+ self.column_self_attention = self.build_residual(column_self_attention)
184
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
185
+
186
+ def build_residual(self, layer: nn.Module):
187
+ return NormalizedResidualBlock(
188
+ layer,
189
+ self.embedding_dim,
190
+ self.dropout_prob,
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ x: torch.Tensor,
196
+ self_attn_mask: Optional[torch.Tensor] = None,
197
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
198
+ need_head_weights: bool = False,
199
+ ):
200
+ """
201
+ LayerNorm is applied either before or after the self-attention/ffn
202
+ modules similar to the original Transformer implementation.
203
+ """
204
+ x, row_attn = self.row_self_attention(
205
+ x,
206
+ self_attn_mask=self_attn_mask,
207
+ self_attn_padding_mask=self_attn_padding_mask,
208
+ )
209
+ x, column_attn = self.column_self_attention(
210
+ x,
211
+ self_attn_mask=self_attn_mask,
212
+ self_attn_padding_mask=self_attn_padding_mask,
213
+ )
214
+ x = self.feed_forward_layer(x)
215
+ if need_head_weights:
216
+ return x, column_attn, row_attn
217
+ else:
218
+ return x
219
+
220
+
221
+ class LearnedPositionalEmbedding(nn.Embedding):
222
+ """
223
+ This module learns positional embeddings up to a fixed maximum size.
224
+ Padding ids are ignored by either offsetting based on padding_idx
225
+ or by setting padding_idx to None and ensuring that the appropriate
226
+ position ids are passed to the forward function.
227
+ """
228
+
229
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
230
+ if padding_idx is not None:
231
+ num_embeddings_ = num_embeddings + padding_idx + 1
232
+ else:
233
+ num_embeddings_ = num_embeddings
234
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
235
+ self.max_positions = num_embeddings
236
+
237
+ def forward(self, input: torch.Tensor):
238
+ """Input is expected to be of size [bsz x seqlen]."""
239
+ if input.size(1) > self.max_positions:
240
+ raise ValueError(
241
+ f"Sequence length {input.size(1)} above maximum "
242
+ f" sequence length of {self.max_positions}"
243
+ )
244
+ mask = input.ne(self.padding_idx).int()
245
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
246
+ return F.embedding(
247
+ positions,
248
+ self.weight,
249
+ self.padding_idx,
250
+ self.max_norm,
251
+ self.norm_type,
252
+ self.scale_grad_by_freq,
253
+ self.sparse,
254
+ )
255
+
256
+
257
+ class SinusoidalPositionalEmbedding(nn.Module):
258
+ def __init__(self, embed_dim, padding_idx, learned=False):
259
+ super().__init__()
260
+ self.embed_dim = embed_dim
261
+ self.padding_idx = padding_idx
262
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
263
+ self.weights = None
264
+
265
+ def forward(self, x):
266
+ bsz, seq_len = x.shape
267
+ max_pos = self.padding_idx + 1 + seq_len
268
+ if self.weights is None or max_pos > self.weights.size(0):
269
+ self.weights = self.get_embedding(max_pos)
270
+ self.weights = self.weights.type_as(self._float_tensor)
271
+
272
+ positions = self.make_positions(x)
273
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
274
+
275
+ def make_positions(self, x):
276
+ mask = x.ne(self.padding_idx)
277
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
278
+ positions = range_buf.expand_as(x)
279
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
280
+
281
+ def get_embedding(self, num_embeddings):
282
+ half_dim = self.embed_dim // 2
283
+ emb = math.log(10000) / (half_dim - 1)
284
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
285
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
286
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
287
+ if self.embed_dim % 2 == 1:
288
+ # zero pad
289
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
290
+ if self.padding_idx is not None:
291
+ emb[self.padding_idx, :] = 0
292
+ return emb
293
+
294
+
295
+ class RobertaLMHead(nn.Module):
296
+ """Head for masked language modeling."""
297
+
298
+ def __init__(self, embed_dim, output_dim, weight):
299
+ super().__init__()
300
+ self.dense = nn.Linear(embed_dim, embed_dim)
301
+ self.layer_norm = LucaGPLM1bLayerNorm(embed_dim)
302
+ self.weight = weight
303
+ self.bias = nn.Parameter(torch.zeros(output_dim))
304
+
305
+ def forward(self, features):
306
+ x = self.dense(features)
307
+ x = gelu(x)
308
+ x = self.layer_norm(x)
309
+ # project back to size of vocabulary with bias
310
+ x = F.linear(x, self.weight) + self.bias
311
+ return x
312
+
313
+
314
+ class ContactPredictionHead(nn.Module):
315
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
316
+
317
+ def __init__(
318
+ self,
319
+ in_features: int,
320
+ prepend_bos: bool,
321
+ append_eos: bool,
322
+ bias=True,
323
+ eos_idx: Optional[int] = None,
324
+ ):
325
+ super().__init__()
326
+ self.in_features = in_features
327
+ self.prepend_bos = prepend_bos
328
+ self.append_eos = append_eos
329
+ if append_eos and eos_idx is None:
330
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
331
+ self.eos_idx = eos_idx
332
+ self.regression = nn.Linear(in_features, 1, bias)
333
+ self.activation = nn.Sigmoid()
334
+
335
+ def forward(self, tokens, attentions):
336
+ # remove eos token attentions
337
+ if self.append_eos:
338
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
339
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
340
+ attentions = attentions * eos_mask[:, None, None, :, :]
341
+ attentions = attentions[..., :-1, :-1]
342
+ # remove cls token attentions
343
+ if self.prepend_bos:
344
+ attentions = attentions[..., 1:, 1:]
345
+ batch_size, layers, heads, seqlen, _ = attentions.size()
346
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
347
+
348
+ # features: B x C x T x T
349
+ attentions = attentions.to(
350
+ self.regression.weight.device
351
+ ) # attentions always float32, may need to convert to float16
352
+ attentions = apc(symmetrize(attentions))
353
+ attentions = attentions.permute(0, 2, 3, 1)
354
+ return self.activation(self.regression(attentions).squeeze(3))
355
+
356
+
357
+ class NormalizedResidualBlock(nn.Module):
358
+ def __init__(
359
+ self,
360
+ layer: nn.Module,
361
+ embedding_dim: int,
362
+ dropout: float = 0.1,
363
+ ):
364
+ super().__init__()
365
+ self.embedding_dim = embedding_dim
366
+
367
+ self.layer = layer
368
+ self.dropout_module = nn.Dropout(
369
+ dropout,
370
+ )
371
+ self.layer_norm = LucaGPLM1bLayerNorm(self.embedding_dim)
372
+
373
+ def forward(self, x, *args, **kwargs):
374
+ residual = x
375
+ x = self.layer_norm(x)
376
+ outputs = self.layer(x, *args, **kwargs)
377
+ if isinstance(outputs, tuple):
378
+ x, *out = outputs
379
+ else:
380
+ x = outputs
381
+ out = None
382
+
383
+ x = self.dropout_module(x)
384
+ x = residual + x
385
+
386
+ if out is not None:
387
+ return (x,) + tuple(out)
388
+ else:
389
+ return x
390
+
391
+
392
+ class FeedForwardNetwork(nn.Module):
393
+ def __init__(
394
+ self,
395
+ embedding_dim: int,
396
+ ffn_embedding_dim: int,
397
+ activation_dropout: float = 0.1,
398
+ max_tokens_per_msa: int = 2**14,
399
+ ):
400
+ super().__init__()
401
+ self.embedding_dim = embedding_dim
402
+ self.ffn_embedding_dim = ffn_embedding_dim
403
+ self.max_tokens_per_msa = max_tokens_per_msa
404
+ self.activation_fn = nn.GELU()
405
+ self.activation_dropout_module = nn.Dropout(
406
+ activation_dropout,
407
+ )
408
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
409
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
410
+
411
+ def forward(self, x):
412
+ x = self.activation_fn(self.fc1(x))
413
+ x = self.activation_dropout_module(x)
414
+ x = self.fc2(x)
415
+ return x
416
+
417
+
418
+ class RowSelfAttention(nn.Module):
419
+ """Compute self-attention over rows of a 2D input."""
420
+
421
+ def __init__(
422
+ self,
423
+ embed_dim,
424
+ num_heads,
425
+ dropout=0.0,
426
+ max_tokens_per_msa: int = 2 ** 16,
427
+ ):
428
+ super().__init__()
429
+ self.num_heads = num_heads
430
+ self.dropout = dropout
431
+ self.head_dim = embed_dim // num_heads
432
+ self.scaling = self.head_dim ** -0.5
433
+ self.max_tokens_per_msa = max_tokens_per_msa
434
+ self.attn_shape = "hnij"
435
+
436
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
437
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
438
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
439
+
440
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
441
+ self.dropout_module = nn.Dropout(dropout)
442
+
443
+ def align_scaling(self, q):
444
+ num_rows = q.size(0)
445
+ return self.scaling / math.sqrt(num_rows)
446
+
447
+ def _batched_forward(
448
+ self,
449
+ x,
450
+ self_attn_mask=None,
451
+ self_attn_padding_mask=None,
452
+ ):
453
+ num_rows, num_cols, batch_size, embed_dim = x.size()
454
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
455
+ attns = 0
456
+ scaling = self.align_scaling(x)
457
+ for start in range(0, num_rows, max_rows):
458
+ attn_weights = self.compute_attention_weights(
459
+ x[start : start + max_rows],
460
+ scaling,
461
+ self_attn_mask=self_attn_mask,
462
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
463
+ if self_attn_padding_mask is not None
464
+ else None,
465
+ )
466
+ attns += attn_weights
467
+ attn_probs = attns.softmax(-1)
468
+ attn_probs = self.dropout_module(attn_probs)
469
+
470
+ outputs = []
471
+ for start in range(0, num_rows, max_rows):
472
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
473
+ outputs.append(output)
474
+
475
+ output = torch.cat(outputs, 0)
476
+ return output, attn_probs
477
+
478
+ def compute_attention_weights(
479
+ self,
480
+ x,
481
+ scaling: float,
482
+ self_attn_mask=None,
483
+ self_attn_padding_mask=None,
484
+ ):
485
+ num_rows, num_cols, batch_size, embed_dim = x.size()
486
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
487
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
488
+ q *= scaling
489
+ if self_attn_padding_mask is not None:
490
+ # Zero out any padded aligned positions - this is important since
491
+ # we take a sum across the alignment axis.
492
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
493
+
494
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
495
+
496
+ if self_attn_mask is not None:
497
+ raise NotImplementedError
498
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
499
+
500
+ if self_attn_padding_mask is not None:
501
+ attn_weights = attn_weights.masked_fill(
502
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
503
+ -10000,
504
+ )
505
+
506
+ return attn_weights
507
+
508
+ def compute_attention_update(
509
+ self,
510
+ x,
511
+ attn_probs,
512
+ ):
513
+ num_rows, num_cols, batch_size, embed_dim = x.size()
514
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
515
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
516
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
517
+ output = self.out_proj(context)
518
+ return output
519
+
520
+ def forward(
521
+ self,
522
+ x,
523
+ self_attn_mask=None,
524
+ self_attn_padding_mask=None,
525
+ ):
526
+ num_rows, num_cols, batch_size, embed_dim = x.size()
527
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
528
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
529
+ else:
530
+ scaling = self.align_scaling(x)
531
+ attn_weights = self.compute_attention_weights(
532
+ x, scaling, self_attn_mask, self_attn_padding_mask
533
+ )
534
+ attn_probs = attn_weights.softmax(-1)
535
+ attn_probs = self.dropout_module(attn_probs)
536
+ output = self.compute_attention_update(x, attn_probs)
537
+ return output, attn_probs
538
+
539
+
540
+ class ColumnSelfAttention(nn.Module):
541
+ """Compute self-attention over columns of a 2D input."""
542
+
543
+ def __init__(
544
+ self,
545
+ embed_dim,
546
+ num_heads,
547
+ dropout=0.0,
548
+ max_tokens_per_msa: int = 2 ** 16,
549
+ ):
550
+ super().__init__()
551
+
552
+ self.num_heads = num_heads
553
+ self.dropout = dropout
554
+ self.head_dim = embed_dim // num_heads
555
+ self.scaling = self.head_dim ** -0.5
556
+ self.max_tokens_per_msa = max_tokens_per_msa
557
+
558
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
559
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
560
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
561
+
562
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
563
+ self.dropout_module = nn.Dropout(dropout)
564
+
565
+ def _batched_forward(
566
+ self,
567
+ x,
568
+ self_attn_mask=None,
569
+ self_attn_padding_mask=None,
570
+ ):
571
+ num_rows, num_cols, batch_size, embed_dim = x.size()
572
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
573
+ outputs = []
574
+ attns = []
575
+ for start in range(0, num_cols, max_cols):
576
+ output, attn = self(
577
+ x[:, start : start + max_cols],
578
+ self_attn_mask=self_attn_mask,
579
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
580
+ if self_attn_padding_mask is not None
581
+ else None,
582
+ )
583
+ outputs.append(output)
584
+ attns.append(attn)
585
+ output = torch.cat(outputs, 1)
586
+ attns = torch.cat(attns, 1)
587
+ return output, attns
588
+
589
+ def compute_attention_update(
590
+ self,
591
+ x,
592
+ self_attn_mask=None,
593
+ self_attn_padding_mask=None,
594
+ ):
595
+ num_rows, num_cols, batch_size, embed_dim = x.size()
596
+ if num_rows == 1:
597
+ # if there is only 1 position, this is equivalent and doesn't break with padding
598
+ attn_probs = torch.ones(
599
+ self.num_heads,
600
+ num_cols,
601
+ batch_size,
602
+ num_rows,
603
+ num_rows,
604
+ device=x.device,
605
+ dtype=x.dtype,
606
+ )
607
+ output = self.out_proj(self.v_proj(x))
608
+ else:
609
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
610
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
611
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
612
+ q *= self.scaling
613
+
614
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
615
+
616
+ if self_attn_mask is not None:
617
+ raise NotImplementedError
618
+ if self_attn_padding_mask is not None:
619
+ attn_weights = attn_weights.masked_fill(
620
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
621
+ -10000,
622
+ )
623
+
624
+ attn_probs = attn_weights.softmax(-1)
625
+ attn_probs = self.dropout_module(attn_probs)
626
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
627
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
628
+ output = self.out_proj(context)
629
+ return output, attn_probs
630
+
631
+ def forward(
632
+ self,
633
+ x,
634
+ self_attn_mask=None,
635
+ self_attn_padding_mask=None,
636
+ ):
637
+ num_rows, num_cols, batch_size, embed_dim = x.size()
638
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
639
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
640
+ return self._batched_forward(
641
+ x,
642
+ self_attn_mask,
643
+ self_attn_padding_mask,
644
+ )
645
+ else:
646
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
647
+
648
+
649
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
650
+ if onnx_trace:
651
+ return F.softmax(x.float(), dim=dim)
652
+ else:
653
+ return F.softmax(x, dim=dim, dtype=torch.float32)
654
+
655
+
656
+ class FairseqIncrementalState(object):
657
+ def __init__(self, *args, **kwargs):
658
+ super().__init__(*args, **kwargs)
659
+ self.init_incremental_state()
660
+
661
+ def init_incremental_state(self):
662
+ self._incremental_state_id = str(uuid.uuid4())
663
+
664
+ def _get_full_incremental_state_key(self, key: str) -> str:
665
+ return "{}.{}".format(self._incremental_state_id, key)
666
+
667
+ def get_incremental_state(
668
+ self,
669
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
670
+ key: str,
671
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
672
+ """Helper for getting incremental state for an nn.Module."""
673
+ full_key = self._get_full_incremental_state_key(key)
674
+ if incremental_state is None or full_key not in incremental_state:
675
+ return None
676
+ return incremental_state[full_key]
677
+
678
+ def set_incremental_state(
679
+ self,
680
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
681
+ key: str,
682
+ value: Dict[str, Optional[Tensor]],
683
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
684
+ """Helper for setting incremental state for an nn.Module."""
685
+ if incremental_state is not None:
686
+ full_key = self._get_full_incremental_state_key(key)
687
+ incremental_state[full_key] = value
688
+ return incremental_state
689
+
690
+
691
+ def with_incremental_state(cls):
692
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
693
+ b for b in cls.__bases__ if b != FairseqIncrementalState
694
+ )
695
+ return cls
696
+
697
+
698
+ @with_incremental_state
699
+ class LucaGPLMMultiheadAttention(nn.Module):
700
+ """Multi-headed attention.
701
+
702
+ See "Attention Is All You Need" for more details.
703
+ """
704
+
705
+ def __init__(
706
+ self,
707
+ embed_dim,
708
+ num_heads,
709
+ kdim=None,
710
+ vdim=None,
711
+ dropout=0.0,
712
+ bias=True,
713
+ add_bias_kv: bool = False,
714
+ add_zero_attn: bool = False,
715
+ self_attention: bool = False,
716
+ encoder_decoder_attention: bool = False,
717
+ use_rotary_embeddings: bool = False,
718
+ ):
719
+ super().__init__()
720
+ self.embed_dim = embed_dim
721
+ self.kdim = kdim if kdim is not None else embed_dim
722
+ self.vdim = vdim if vdim is not None else embed_dim
723
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
724
+
725
+ self.num_heads = num_heads
726
+ self.dropout = dropout
727
+ self.head_dim = embed_dim // num_heads
728
+ assert (
729
+ self.head_dim * num_heads == self.embed_dim
730
+ ), "embed_dim must be divisible by num_heads"
731
+ self.scaling = self.head_dim**-0.5
732
+
733
+ self.self_attention = self_attention
734
+ self.encoder_decoder_attention = encoder_decoder_attention
735
+
736
+ assert not self.self_attention or self.qkv_same_dim, (
737
+ "Self-attention requires query, key and " "value to be of the same size"
738
+ )
739
+
740
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
741
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
742
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
743
+
744
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
745
+
746
+ if add_bias_kv:
747
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
748
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
749
+ else:
750
+ self.bias_k = self.bias_v = None
751
+
752
+ self.add_zero_attn = add_zero_attn
753
+
754
+ self.reset_parameters()
755
+
756
+ self.onnx_trace = False
757
+ self.rot_emb = None
758
+ if use_rotary_embeddings:
759
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
760
+
761
+ self.enable_torch_version = False
762
+ if hasattr(F, "multi_head_attention_forward"):
763
+ self.enable_torch_version = True
764
+ else:
765
+ self.enable_torch_version = False
766
+
767
+ def prepare_for_onnx_export_(self):
768
+ self.onnx_trace = True
769
+
770
+ def reset_parameters(self):
771
+ '''
772
+ if self.qkv_same_dim:
773
+ # Empirically observed the convergence to be much better with
774
+ # the scaled initialization
775
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
776
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
777
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
778
+ else:
779
+ nn.init.xavier_uniform_(self.k_proj.weight)
780
+ nn.init.xavier_uniform_(self.v_proj.weight)
781
+ nn.init.xavier_uniform_(self.q_proj.weight)
782
+ '''
783
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=nn.init.calculate_gain("relu"))
784
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=nn.init.calculate_gain("relu"))
785
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=nn.init.calculate_gain("relu"))
786
+
787
+ nn.init.xavier_uniform_(self.out_proj.weight, gain=nn.init.calculate_gain("relu"))
788
+ # nn.init.xavier_uniform_(self.out_proj.weight)
789
+ if self.out_proj.bias is not None:
790
+ nn.init.constant_(self.out_proj.bias, 0.0)
791
+ if self.bias_k is not None:
792
+ nn.init.xavier_normal_(self.bias_k)
793
+ if self.bias_v is not None:
794
+ nn.init.xavier_normal_(self.bias_v)
795
+
796
+ def forward(
797
+ self,
798
+ query,
799
+ key: Optional[Tensor],
800
+ value: Optional[Tensor],
801
+ key_padding_mask: Optional[Tensor] = None,
802
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
803
+ need_weights: bool = True,
804
+ static_kv: bool = False,
805
+ attn_mask: Optional[Tensor] = None,
806
+ before_softmax: bool = False,
807
+ need_head_weights: bool = False,
808
+ ) -> Tuple[Tensor, Optional[Tensor]]:
809
+ """Input shape: Time x Batch x Channel
810
+
811
+ Args:
812
+ key_padding_mask (ByteTensor, optional): mask to exclude
813
+ keys that are pads, of shape `(batch, src_len)`, where
814
+ padding elements are indicated by 1s.
815
+ need_weights (bool, optional): return the attention weights,
816
+ averaged over heads (default: False).
817
+ attn_mask (ByteTensor, optional): typically used to
818
+ implement causal attention, where the mask prevents the
819
+ attention from looking forward in time (default: None).
820
+ before_softmax (bool, optional): return the raw attention
821
+ weights and values before the attention softmax.
822
+ need_head_weights (bool, optional): return the attention
823
+ weights for each head. Implies *need_weights*. Default:
824
+ return the average attention weights over all heads.
825
+ """
826
+ if need_head_weights:
827
+ need_weights = True
828
+
829
+ tgt_len, bsz, embed_dim = query.size()
830
+ assert embed_dim == self.embed_dim
831
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
832
+
833
+ if (
834
+ not self.rot_emb
835
+ and self.enable_torch_version
836
+ and not self.onnx_trace
837
+ and incremental_state is None
838
+ and not static_kv
839
+ # A workaround for quantization to work. Otherwise JIT compilation
840
+ # treats bias in linear module as method.
841
+ and not torch.jit.is_scripting()
842
+ and not need_head_weights
843
+ ):
844
+ assert key is not None and value is not None
845
+ return F.multi_head_attention_forward(
846
+ query,
847
+ key,
848
+ value,
849
+ self.embed_dim,
850
+ self.num_heads,
851
+ torch.empty([0]),
852
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
853
+ self.bias_k,
854
+ self.bias_v,
855
+ self.add_zero_attn,
856
+ self.dropout,
857
+ self.out_proj.weight,
858
+ self.out_proj.bias,
859
+ self.training,
860
+ key_padding_mask,
861
+ need_weights,
862
+ attn_mask,
863
+ use_separate_proj_weight=True,
864
+ q_proj_weight=self.q_proj.weight,
865
+ k_proj_weight=self.k_proj.weight,
866
+ v_proj_weight=self.v_proj.weight,
867
+ )
868
+ if incremental_state is not None:
869
+ saved_state = self._get_input_buffer(incremental_state)
870
+ if saved_state is not None and "prev_key" in saved_state:
871
+ # previous time steps are cached - no need to recompute
872
+ # key and value if they are static
873
+ if static_kv:
874
+ assert self.encoder_decoder_attention and not self.self_attention
875
+ key = value = None
876
+ else:
877
+ saved_state = None
878
+
879
+ if self.self_attention:
880
+ q = self.q_proj(query)
881
+ k = self.k_proj(query)
882
+ v = self.v_proj(query)
883
+ elif self.encoder_decoder_attention:
884
+ # encoder-decoder attention
885
+ q = self.q_proj(query)
886
+ if key is None:
887
+ assert value is None
888
+ k = v = None
889
+ else:
890
+ k = self.k_proj(key)
891
+ v = self.v_proj(key)
892
+
893
+ else:
894
+ assert key is not None and value is not None
895
+ q = self.q_proj(query)
896
+ k = self.k_proj(key)
897
+ v = self.v_proj(value)
898
+ q *= self.scaling
899
+
900
+ if self.bias_k is not None:
901
+ assert self.bias_v is not None
902
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
903
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
904
+ if attn_mask is not None:
905
+ attn_mask = torch.cat(
906
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
907
+ )
908
+ if key_padding_mask is not None:
909
+ key_padding_mask = torch.cat(
910
+ [
911
+ key_padding_mask,
912
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
913
+ ],
914
+ dim=1,
915
+ )
916
+
917
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
918
+ if k is not None:
919
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
920
+ if v is not None:
921
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
922
+
923
+ if saved_state is not None:
924
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
925
+ if "prev_key" in saved_state:
926
+ _prev_key = saved_state["prev_key"]
927
+ assert _prev_key is not None
928
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
929
+ if static_kv:
930
+ k = prev_key
931
+ else:
932
+ assert k is not None
933
+ k = torch.cat([prev_key, k], dim=1)
934
+ if "prev_value" in saved_state:
935
+ _prev_value = saved_state["prev_value"]
936
+ assert _prev_value is not None
937
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
938
+ if static_kv:
939
+ v = prev_value
940
+ else:
941
+ assert v is not None
942
+ v = torch.cat([prev_value, v], dim=1)
943
+ prev_key_padding_mask: Optional[Tensor] = None
944
+ if "prev_key_padding_mask" in saved_state:
945
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
946
+ assert k is not None and v is not None
947
+ key_padding_mask = LucaGPLMMultiheadAttention._append_prev_key_padding_mask(
948
+ key_padding_mask=key_padding_mask,
949
+ prev_key_padding_mask=prev_key_padding_mask,
950
+ batch_size=bsz,
951
+ src_len=k.size(1),
952
+ static_kv=static_kv,
953
+ )
954
+
955
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
956
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
957
+ saved_state["prev_key_padding_mask"] = key_padding_mask
958
+ # In this branch incremental_state is never None
959
+ assert incremental_state is not None
960
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
961
+ assert k is not None
962
+ src_len = k.size(1)
963
+
964
+ # This is part of a workaround to get around fork/join parallelism
965
+ # not supporting Optional types.
966
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
967
+ key_padding_mask = None
968
+
969
+ if key_padding_mask is not None:
970
+ assert key_padding_mask.size(0) == bsz
971
+ assert key_padding_mask.size(1) == src_len
972
+
973
+ if self.add_zero_attn:
974
+ assert v is not None
975
+ src_len += 1
976
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
977
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
978
+ if attn_mask is not None:
979
+ attn_mask = torch.cat(
980
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
981
+ )
982
+ if key_padding_mask is not None:
983
+ key_padding_mask = torch.cat(
984
+ [
985
+ key_padding_mask,
986
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
987
+ ],
988
+ dim=1,
989
+ )
990
+
991
+ if self.rot_emb:
992
+ q, k = self.rot_emb(q, k)
993
+
994
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
995
+ attn_weights = LucaGPLMMultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
996
+
997
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
998
+
999
+ if attn_mask is not None:
1000
+ attn_mask = attn_mask.unsqueeze(0)
1001
+ if self.onnx_trace:
1002
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
1003
+ attn_weights += attn_mask
1004
+
1005
+ if key_padding_mask is not None:
1006
+ # don't attend to padding symbols
1007
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
1008
+ attn_weights = attn_weights.masked_fill(
1009
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
1010
+ )
1011
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
1012
+
1013
+ if before_softmax:
1014
+ return attn_weights, v
1015
+
1016
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
1017
+ attn_weights = attn_weights_float.type_as(attn_weights)
1018
+ attn_probs = F.dropout(
1019
+ attn_weights_float.type_as(attn_weights),
1020
+ p=self.dropout,
1021
+ training=self.training,
1022
+ )
1023
+ assert v is not None
1024
+ attn = torch.bmm(attn_probs, v)
1025
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
1026
+ if self.onnx_trace and attn.size(1) == 1:
1027
+ # when ONNX tracing a single decoder step (sequence length == 1)
1028
+ # the transpose is a no-op copy before view, thus unnecessary
1029
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
1030
+ else:
1031
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
1032
+ attn = self.out_proj(attn)
1033
+ attn_weights: Optional[Tensor] = None
1034
+ if need_weights:
1035
+ attn_weights = attn_weights_float.view(
1036
+ bsz, self.num_heads, tgt_len, src_len
1037
+ ).type_as(attn).transpose(1, 0)
1038
+ if not need_head_weights:
1039
+ # average attention weights over heads
1040
+ attn_weights = attn_weights.mean(dim=0)
1041
+
1042
+ return attn, attn_weights
1043
+
1044
+ @staticmethod
1045
+ def _append_prev_key_padding_mask(
1046
+ key_padding_mask: Optional[Tensor],
1047
+ prev_key_padding_mask: Optional[Tensor],
1048
+ batch_size: int,
1049
+ src_len: int,
1050
+ static_kv: bool,
1051
+ ) -> Optional[Tensor]:
1052
+ # saved key padding masks have shape (bsz, seq_len)
1053
+ if prev_key_padding_mask is not None and static_kv:
1054
+ new_key_padding_mask = prev_key_padding_mask
1055
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
1056
+ new_key_padding_mask = torch.cat(
1057
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
1058
+ )
1059
+ # During incremental decoding, as the padding token enters and
1060
+ # leaves the frame, there will be a time when prev or current
1061
+ # is None
1062
+ elif prev_key_padding_mask is not None:
1063
+ filler = torch.zeros(
1064
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
1065
+ device=prev_key_padding_mask.device,
1066
+ )
1067
+ new_key_padding_mask = torch.cat(
1068
+ [prev_key_padding_mask.float(), filler.float()], dim=1
1069
+ )
1070
+ elif key_padding_mask is not None:
1071
+ filler = torch.zeros(
1072
+ (batch_size, src_len - key_padding_mask.size(1)),
1073
+ device=key_padding_mask.device,
1074
+ )
1075
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
1076
+ else:
1077
+ new_key_padding_mask = prev_key_padding_mask
1078
+ return new_key_padding_mask
1079
+
1080
+ @torch.jit.export
1081
+ def reorder_incremental_state(
1082
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
1083
+ ):
1084
+ """Reorder buffered internal state (for incremental generation)."""
1085
+ input_buffer = self._get_input_buffer(incremental_state)
1086
+ if input_buffer is not None:
1087
+ for k in input_buffer.keys():
1088
+ input_buffer_k = input_buffer[k]
1089
+ if input_buffer_k is not None:
1090
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
1091
+ 0
1092
+ ):
1093
+ break
1094
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
1095
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
1096
+ return incremental_state
1097
+
1098
+ def _get_input_buffer(
1099
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
1100
+ ) -> Dict[str, Optional[Tensor]]:
1101
+ result = self.get_incremental_state(incremental_state, "attn_state")
1102
+ if result is not None:
1103
+ return result
1104
+ else:
1105
+ empty_result: Dict[str, Optional[Tensor]] = {}
1106
+ return empty_result
1107
+
1108
+ def _set_input_buffer(
1109
+ self,
1110
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1111
+ buffer: Dict[str, Optional[Tensor]],
1112
+ ):
1113
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
1114
+
1115
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
1116
+ return attn_weights
1117
+
1118
+ def upgrade_state_dict_named(self, state_dict, name):
1119
+ prefix = name + "." if name != "" else ""
1120
+ items_to_add = {}
1121
+ keys_to_remove = []
1122
+ for k in state_dict.keys():
1123
+ if k.endswith(prefix + "in_proj_weight"):
1124
+ # in_proj_weight used to be q + k + v with same dimensions
1125
+ dim = int(state_dict[k].shape[0] / 3)
1126
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
1127
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
1128
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
1129
+
1130
+ keys_to_remove.append(k)
1131
+
1132
+ k_bias = prefix + "in_proj_bias"
1133
+ if k_bias in state_dict.keys():
1134
+ dim = int(state_dict[k].shape[0] / 3)
1135
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
1136
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
1137
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
1138
+
1139
+ keys_to_remove.append(prefix + "in_proj_bias")
1140
+
1141
+ for k in keys_to_remove:
1142
+ del state_dict[k]
1143
+
1144
+ for key, value in items_to_add.items():
1145
+ state_dict[key] = value
1146
+
1147
+
1148
+ def rotate_half(x):
1149
+ x1, x2 = x.chunk(2, dim=-1)
1150
+ return torch.cat((-x2, x1), dim=-1)
1151
+
1152
+
1153
+ def apply_rotary_pos_emb(x, cos, sin):
1154
+ cos = cos[:, : x.shape[-2], :]
1155
+ sin = sin[:, : x.shape[-2], :]
1156
+
1157
+ return (x * cos) + (rotate_half(x) * sin)
1158
+
1159
+
1160
+ class RotaryEmbedding(torch.nn.Module):
1161
+ """
1162
+ The rotary position embeddings from RoFormer_ (Su et. al).
1163
+ A crucial insight from the method is that the query and keys are
1164
+ transformed by rotation matrices which depend on the relative positions.
1165
+ Other implementations are available in the Rotary Transformer repo_ and in
1166
+ GPT-NeoX_, GPT-NeoX was an inspiration
1167
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
1168
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
1169
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
1170
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
1171
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
1172
+ """
1173
+
1174
+ def __init__(self, dim: int, *_, **__):
1175
+ super().__init__()
1176
+ # Generate and save the inverse frequency buffer (non trainable)
1177
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
1178
+ self.register_buffer("inv_freq", inv_freq)
1179
+
1180
+ self._seq_len_cached = None
1181
+ self._cos_cached = None
1182
+ self._sin_cached = None
1183
+
1184
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
1185
+ seq_len = x.shape[seq_dimension]
1186
+
1187
+ # Reset the tables if the sequence length has changed,
1188
+ # or if we're on a new device (possibly due to tracing for instance)
1189
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
1190
+ self._seq_len_cached = seq_len
1191
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
1192
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1193
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1194
+
1195
+ self._cos_cached = emb.cos()[None, :, :]
1196
+ self._sin_cached = emb.sin()[None, :, :]
1197
+
1198
+ return self._cos_cached, self._sin_cached
1199
+
1200
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1201
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
1202
+
1203
+ return (
1204
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1205
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1206
+ )
1207
+
1208
+
1209
+
1210
+