a96123155 commited on
Commit
58bf737
·
1 Parent(s): 31c9faa
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. esm/.DS_Store +0 -0
  3. esm/__init__.py +17 -0
  4. esm/__pycache__/__init__.cpython-36.pyc +0 -0
  5. esm/__pycache__/__init__.cpython-39.pyc +0 -0
  6. esm/__pycache__/axial_attention.cpython-36.pyc +0 -0
  7. esm/__pycache__/axial_attention.cpython-39.pyc +0 -0
  8. esm/__pycache__/constants.cpython-36.pyc +0 -0
  9. esm/__pycache__/constants.cpython-39.pyc +0 -0
  10. esm/__pycache__/data.cpython-36.pyc +0 -0
  11. esm/__pycache__/data.cpython-39.pyc +0 -0
  12. esm/__pycache__/data_protein.cpython-36.pyc +0 -0
  13. esm/__pycache__/model.cpython-36.pyc +0 -0
  14. esm/__pycache__/model.cpython-39.pyc +0 -0
  15. esm/__pycache__/modules.cpython-36.pyc +0 -0
  16. esm/__pycache__/modules.cpython-39.pyc +0 -0
  17. esm/__pycache__/multihead_attention.cpython-36.pyc +0 -0
  18. esm/__pycache__/multihead_attention.cpython-39.pyc +0 -0
  19. esm/__pycache__/pretrained.cpython-36.pyc +0 -0
  20. esm/__pycache__/pretrained.cpython-39.pyc +0 -0
  21. esm/__pycache__/rotary_embedding.cpython-36.pyc +0 -0
  22. esm/__pycache__/rotary_embedding.cpython-39.pyc +0 -0
  23. esm/__pycache__/version.cpython-36.pyc +0 -0
  24. esm/__pycache__/version.cpython-39.pyc +0 -0
  25. esm/axial_attention.py +239 -0
  26. esm/constants.py +14 -0
  27. esm/data.py +524 -0
  28. esm/data_supervised.py +524 -0
  29. esm/model/__pycache__/esm1.cpython-36.pyc +0 -0
  30. esm/model/__pycache__/esm1.cpython-39.pyc +0 -0
  31. esm/model/__pycache__/esm2.cpython-36.pyc +0 -0
  32. esm/model/__pycache__/esm2.cpython-39.pyc +0 -0
  33. esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc +0 -0
  34. esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc +0 -0
  35. esm/model/__pycache__/esm2_supervised.cpython-39.pyc +0 -0
  36. esm/model/__pycache__/msa_transformer.cpython-36.pyc +0 -0
  37. esm/model/__pycache__/msa_transformer.cpython-39.pyc +0 -0
  38. esm/model/esm1.py +203 -0
  39. esm/model/esm2.py +163 -0
  40. esm/model/esm2_only_secondarystructure.py +179 -0
  41. esm/model/esm2_secondarystructure.py +179 -0
  42. esm/model/esm2_supervised.py +174 -0
  43. esm/model/msa_transformer.py +238 -0
  44. esm/modules.py +419 -0
  45. esm/multihead_attention.py +506 -0
  46. esm/pretrained.py +378 -0
  47. esm/rotary_embedding.py +69 -0
  48. esm/version.py +6 -0
  49. model.pt +3 -0
  50. requirements.txt +4 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
esm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
esm/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .version import version as __version__ # noqa
6
+
7
+ from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
8
+ from .model.esm1 import ProteinBertModel # noqa
9
+ from .model.esm2 import ESM2 # noqa
10
+ from .model.msa_transformer import MSATransformer #noqa
11
+ from . import pretrained # noqa
12
+
13
+ # from .version import version as __version__ # noqa
14
+
15
+ # from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
16
+ # from .model import ProteinBertModel, MSATransformer, ESM2 # noqa
17
+ # from . import pretrained # noqa
esm/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (480 Bytes). View file
 
esm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (474 Bytes). View file
 
esm/__pycache__/axial_attention.cpython-36.pyc ADDED
Binary file (5.43 kB). View file
 
esm/__pycache__/axial_attention.cpython-39.pyc ADDED
Binary file (5.41 kB). View file
 
esm/__pycache__/constants.cpython-36.pyc ADDED
Binary file (355 Bytes). View file
 
esm/__pycache__/constants.cpython-39.pyc ADDED
Binary file (323 Bytes). View file
 
esm/__pycache__/data.cpython-36.pyc ADDED
Binary file (16.4 kB). View file
 
esm/__pycache__/data.cpython-39.pyc ADDED
Binary file (16.2 kB). View file
 
esm/__pycache__/data_protein.cpython-36.pyc ADDED
Binary file (16.4 kB). View file
 
esm/__pycache__/model.cpython-36.pyc ADDED
Binary file (12.4 kB). View file
 
esm/__pycache__/model.cpython-39.pyc ADDED
Binary file (9.63 kB). View file
 
esm/__pycache__/modules.cpython-36.pyc ADDED
Binary file (12.7 kB). View file
 
esm/__pycache__/modules.cpython-39.pyc ADDED
Binary file (12.9 kB). View file
 
esm/__pycache__/multihead_attention.cpython-36.pyc ADDED
Binary file (11.8 kB). View file
 
esm/__pycache__/multihead_attention.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
esm/__pycache__/pretrained.cpython-36.pyc ADDED
Binary file (14.3 kB). View file
 
esm/__pycache__/pretrained.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
esm/__pycache__/rotary_embedding.cpython-36.pyc ADDED
Binary file (2.73 kB). View file
 
esm/__pycache__/rotary_embedding.cpython-39.pyc ADDED
Binary file (2.71 kB). View file
 
esm/__pycache__/version.cpython-36.pyc ADDED
Binary file (176 Bytes). View file
 
esm/__pycache__/version.cpython-39.pyc ADDED
Binary file (170 Bytes). View file
 
esm/axial_attention.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class RowSelfAttention(nn.Module):
12
+ """Compute self-attention over rows of a 2D input."""
13
+
14
+ def __init__(
15
+ self,
16
+ embed_dim,
17
+ num_heads,
18
+ dropout=0.0,
19
+ max_tokens_per_msa: int = 2 ** 16,
20
+ ):
21
+ super().__init__()
22
+ self.num_heads = num_heads
23
+ self.dropout = dropout
24
+ self.head_dim = embed_dim // num_heads
25
+ self.scaling = self.head_dim ** -0.5
26
+ self.max_tokens_per_msa = max_tokens_per_msa
27
+ self.attn_shape = "hnij"
28
+
29
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
30
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
31
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
32
+
33
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
34
+ self.dropout_module = nn.Dropout(dropout)
35
+
36
+ def align_scaling(self, q):
37
+ num_rows = q.size(0)
38
+ return self.scaling / math.sqrt(num_rows)
39
+
40
+ def _batched_forward(
41
+ self,
42
+ x,
43
+ self_attn_mask=None,
44
+ self_attn_padding_mask=None,
45
+ ):
46
+ num_rows, num_cols, batch_size, embed_dim = x.size()
47
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
48
+ attns = 0
49
+ scaling = self.align_scaling(x)
50
+ for start in range(0, num_rows, max_rows):
51
+ attn_weights = self.compute_attention_weights(
52
+ x[start : start + max_rows],
53
+ scaling,
54
+ self_attn_mask=self_attn_mask,
55
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
56
+ if self_attn_padding_mask is not None
57
+ else None,
58
+ )
59
+ attns += attn_weights
60
+ attn_probs = attns.softmax(-1)
61
+ attn_probs = self.dropout_module(attn_probs)
62
+
63
+ outputs = []
64
+ for start in range(0, num_rows, max_rows):
65
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
66
+ outputs.append(output)
67
+
68
+ output = torch.cat(outputs, 0)
69
+ return output, attn_probs
70
+
71
+ def compute_attention_weights(
72
+ self,
73
+ x,
74
+ scaling: float,
75
+ self_attn_mask=None,
76
+ self_attn_padding_mask=None,
77
+ ):
78
+ num_rows, num_cols, batch_size, embed_dim = x.size()
79
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
80
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
81
+ q *= scaling
82
+ if self_attn_padding_mask is not None:
83
+ # Zero out any padded aligned positions - this is important since
84
+ # we take a sum across the alignment axis.
85
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
86
+
87
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
88
+
89
+ if self_attn_mask is not None:
90
+ raise NotImplementedError
91
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
92
+
93
+ if self_attn_padding_mask is not None:
94
+ attn_weights = attn_weights.masked_fill(
95
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
96
+ -10000,
97
+ )
98
+
99
+ return attn_weights
100
+
101
+ def compute_attention_update(
102
+ self,
103
+ x,
104
+ attn_probs,
105
+ ):
106
+ num_rows, num_cols, batch_size, embed_dim = x.size()
107
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
108
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
109
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
110
+ output = self.out_proj(context)
111
+ return output
112
+
113
+ def forward(
114
+ self,
115
+ x,
116
+ self_attn_mask=None,
117
+ self_attn_padding_mask=None,
118
+ ):
119
+ num_rows, num_cols, batch_size, embed_dim = x.size()
120
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
121
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
122
+ else:
123
+ scaling = self.align_scaling(x)
124
+ attn_weights = self.compute_attention_weights(
125
+ x, scaling, self_attn_mask, self_attn_padding_mask
126
+ )
127
+ attn_probs = attn_weights.softmax(-1)
128
+ attn_probs = self.dropout_module(attn_probs)
129
+ output = self.compute_attention_update(x, attn_probs)
130
+ return output, attn_probs
131
+
132
+
133
+ class ColumnSelfAttention(nn.Module):
134
+ """Compute self-attention over columns of a 2D input."""
135
+
136
+ def __init__(
137
+ self,
138
+ embed_dim,
139
+ num_heads,
140
+ dropout=0.0,
141
+ max_tokens_per_msa: int = 2 ** 16,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.num_heads = num_heads
146
+ self.dropout = dropout
147
+ self.head_dim = embed_dim // num_heads
148
+ self.scaling = self.head_dim ** -0.5
149
+ self.max_tokens_per_msa = max_tokens_per_msa
150
+
151
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
152
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
153
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
154
+
155
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
156
+ self.dropout_module = nn.Dropout(dropout)
157
+
158
+ def _batched_forward(
159
+ self,
160
+ x,
161
+ self_attn_mask=None,
162
+ self_attn_padding_mask=None,
163
+ ):
164
+ num_rows, num_cols, batch_size, embed_dim = x.size()
165
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
166
+ outputs = []
167
+ attns = []
168
+ for start in range(0, num_cols, max_cols):
169
+ output, attn = self(
170
+ x[:, start : start + max_cols],
171
+ self_attn_mask=self_attn_mask,
172
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
173
+ if self_attn_padding_mask is not None
174
+ else None,
175
+ )
176
+ outputs.append(output)
177
+ attns.append(attn)
178
+ output = torch.cat(outputs, 1)
179
+ attns = torch.cat(attns, 1)
180
+ return output, attns
181
+
182
+ def compute_attention_update(
183
+ self,
184
+ x,
185
+ self_attn_mask=None,
186
+ self_attn_padding_mask=None,
187
+ ):
188
+ num_rows, num_cols, batch_size, embed_dim = x.size()
189
+ if num_rows == 1:
190
+ # if there is only 1 position, this is equivalent and doesn't break with padding
191
+ attn_probs = torch.ones(
192
+ self.num_heads,
193
+ num_cols,
194
+ batch_size,
195
+ num_rows,
196
+ num_rows,
197
+ device=x.device,
198
+ dtype=x.dtype,
199
+ )
200
+ output = self.out_proj(self.v_proj(x))
201
+ else:
202
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
203
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
204
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
205
+ q *= self.scaling
206
+
207
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
208
+
209
+ if self_attn_mask is not None:
210
+ raise NotImplementedError
211
+ if self_attn_padding_mask is not None:
212
+ attn_weights = attn_weights.masked_fill(
213
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
214
+ -10000,
215
+ )
216
+
217
+ attn_probs = attn_weights.softmax(-1)
218
+ attn_probs = self.dropout_module(attn_probs)
219
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
220
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
221
+ output = self.out_proj(context)
222
+ return output, attn_probs
223
+
224
+ def forward(
225
+ self,
226
+ x,
227
+ self_attn_mask=None,
228
+ self_attn_padding_mask=None,
229
+ ):
230
+ num_rows, num_cols, batch_size, embed_dim = x.size()
231
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
232
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
233
+ return self._batched_forward(
234
+ x,
235
+ self_attn_mask,
236
+ self_attn_padding_mask,
237
+ )
238
+ else:
239
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
esm/constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # fmt: off
7
+ proteinseq_toks = {
8
+ 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
9
+ }
10
+
11
+ rnaseq_toks = {
12
+ 'toks': ['A', 'G', 'T', 'C']
13
+ }
14
+ # fmt: on
esm/data.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import os
8
+ from typing import Sequence, Tuple, List, Union
9
+ import pickle
10
+ import re
11
+ import shutil
12
+ import torch
13
+ from pathlib import Path
14
+ from .constants import proteinseq_toks, rnaseq_toks
15
+ import math
16
+ import random
17
+ from copy import deepcopy
18
+
19
+ RawMSA = Sequence[Tuple[str, str]]
20
+
21
+
22
+ class Alphabet(object):
23
+ def __init__(
24
+ self,
25
+ standard_toks: Sequence[str],
26
+ prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"), # "<null_0>",
27
+ append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), #
28
+ prepend_bos: bool = True,
29
+ append_eos: bool = True,
30
+ use_msa: bool = False,
31
+ mask_prob: float = 0.15, ###---
32
+ ):
33
+ self.mask_prob = mask_prob ###---
34
+ self.standard_toks = list(standard_toks)
35
+ self.prepend_toks = list(prepend_toks)
36
+ self.append_toks = list(append_toks)
37
+ self.prepend_bos = prepend_bos
38
+ self.append_eos = append_eos
39
+ self.use_msa = use_msa
40
+
41
+ self.all_toks = list(self.prepend_toks)
42
+ self.all_toks.extend(self.standard_toks)
43
+ # for i in range((8 - (len(self.all_toks) % 8)) % 8):
44
+ # self.all_toks.append(f"<null_{i + 1}>")
45
+ self.all_toks.extend(self.append_toks)
46
+
47
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
48
+ # print(self.tok_to_idx)
49
+ self.unk_idx = self.tok_to_idx["<unk>"]
50
+ self.padding_idx = self.get_idx("<pad>")
51
+ self.cls_idx = self.get_idx("<cls>")
52
+ self.mask_idx = self.get_idx("<mask>")
53
+ self.eos_idx = self.get_idx("<eos>")
54
+ self.all_special_tokens = ['<eos>', '<pad>', '<mask>'] # , '<unk>', '<cls>'
55
+ self.unique_no_split_tokens = self.all_toks
56
+
57
+ def __len__(self):
58
+ return len(self.all_toks)
59
+
60
+ def get_idx(self, tok):
61
+ return self.tok_to_idx.get(tok, self.unk_idx)
62
+
63
+ def get_tok(self, ind):
64
+ return self.all_toks[ind]
65
+
66
+ def to_dict(self):
67
+ return self.tok_to_idx.copy()
68
+
69
+ def get_batch_converter(self):
70
+ if self.use_msa:
71
+ return MSABatchConverter(self)
72
+ else:
73
+ return BatchConverter(self)
74
+
75
+ @classmethod
76
+ def from_architecture(cls, name: str) -> "Alphabet":
77
+ if name in ("ESM-1", "protein_bert_base"):
78
+ standard_toks = proteinseq_toks["toks"]
79
+ prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
80
+ append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
81
+ prepend_bos = True
82
+ append_eos = False
83
+ use_msa = False
84
+ elif name in ("ESM-1b", "roberta_large"):
85
+ standard_toks = proteinseq_toks["toks"] ###---rnaseq
86
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
87
+ append_toks = ("<mask>",)
88
+ prepend_bos = True
89
+ append_eos = True
90
+ use_msa = False
91
+ elif name in ("MSA Transformer", "msa_transformer"):
92
+ standard_toks = proteinseq_toks["toks"]
93
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
94
+ append_toks = ("<mask>",)
95
+ prepend_bos = True
96
+ append_eos = False
97
+ use_msa = True
98
+ else:
99
+ raise ValueError("Unknown architecture selected")
100
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
101
+
102
+ def _tokenize(self, text) -> str:
103
+ return text.split()
104
+
105
+ def tokenize(self, text, **kwargs) -> List[str]:
106
+ """
107
+ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
108
+ Converts a string in a sequence of tokens, using the tokenizer.
109
+
110
+ Args:
111
+ text (:obj:`str`):
112
+ The sequence to be encoded.
113
+
114
+ Returns:
115
+ :obj:`List[str]`: The list of tokens.
116
+ """
117
+
118
+ def split_on_token(tok, text):
119
+ result = []
120
+ split_text = text.split(tok)
121
+ for i, sub_text in enumerate(split_text):
122
+ # AddedToken can control whitespace stripping around them.
123
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
124
+ # Cf. https://github.com/huggingface/transformers/pull/2778
125
+ # and https://github.com/huggingface/transformers/issues/3788
126
+ # We strip left and right by default
127
+ if i < len(split_text) - 1:
128
+ sub_text = sub_text.rstrip()
129
+ if i > 0:
130
+ sub_text = sub_text.lstrip()
131
+
132
+ if i == 0 and not sub_text:
133
+ result.append(tok)
134
+ elif i == len(split_text) - 1:
135
+ if sub_text:
136
+ result.append(sub_text)
137
+ else:
138
+ pass
139
+ else:
140
+ if sub_text:
141
+ result.append(sub_text)
142
+ result.append(tok)
143
+ return result
144
+
145
+ def split_on_tokens(tok_list, text):
146
+ if not text.strip():
147
+ return []
148
+
149
+ tokenized_text = []
150
+ text_list = [text]
151
+ for tok in tok_list:
152
+ tokenized_text = []
153
+ for sub_text in text_list:
154
+ if sub_text not in self.unique_no_split_tokens:
155
+ tokenized_text.extend(split_on_token(tok, sub_text))
156
+ else:
157
+ tokenized_text.append(sub_text)
158
+ text_list = tokenized_text
159
+
160
+ return list(
161
+ itertools.chain.from_iterable(
162
+ (
163
+ self._tokenize(token)
164
+ if token not in self.unique_no_split_tokens
165
+ else [token]
166
+ for token in tokenized_text
167
+ )
168
+ )
169
+ )
170
+
171
+ no_split_token = self.unique_no_split_tokens
172
+ tokenized_text = split_on_tokens(no_split_token, text)
173
+ return tokenized_text
174
+
175
+ def encode(self, text):
176
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
177
+
178
+ class FastaBatchedDataset(object):
179
+ def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15):
180
+ self.sequence_labels = list(sequence_labels)
181
+ self.sequence_strs = list(sequence_strs)
182
+ self.mask_prob = mask_prob
183
+
184
+ @classmethod
185
+ def from_file(cls, fasta_file, mask_prob = 0.15):
186
+ sequence_labels, sequence_strs = [], []
187
+ cur_seq_label = None
188
+ buf = []
189
+
190
+ def _flush_current_seq():
191
+ nonlocal cur_seq_label, buf
192
+ if cur_seq_label is None:
193
+ return
194
+ sequence_labels.append(cur_seq_label)
195
+ sequence_strs.append("".join(buf))
196
+ cur_seq_label = None
197
+ buf = []
198
+
199
+ with open(fasta_file, "r") as infile:
200
+ for line_idx, line in enumerate(infile):
201
+ if line.startswith(">"): # label line
202
+ _flush_current_seq()
203
+ line = line[1:].strip()
204
+ if len(line) > 0:
205
+ cur_seq_label = line
206
+ else:
207
+ cur_seq_label = f"seqnum{line_idx:09d}"
208
+ else: # sequence line
209
+ buf.append(line.strip())
210
+
211
+ _flush_current_seq()
212
+
213
+ assert len(set(sequence_strs)) == len(
214
+ sequence_strs
215
+ ), "Found duplicate sequence labels"
216
+
217
+ return cls(sequence_labels, sequence_strs, mask_prob)
218
+
219
+ def __len__(self):
220
+ return len(self.sequence_labels)
221
+
222
+ def mask_sequence(self, seq): ###---
223
+ length = len(seq)
224
+ # print(self.mask_prob)
225
+ max_length = math.ceil(length * self.mask_prob)
226
+ rand = random.sample(range(0, length), max_length)
227
+ res = ''.join(['<mask>' if idx in rand else ele for idx, ele in enumerate(seq)])
228
+ #print(seq, rand, res)
229
+ return rand, res
230
+
231
+ def __getitem__(self, idx):
232
+ sequence_str = self.sequence_strs[idx]
233
+ sequence_label = self.sequence_labels[idx]
234
+ masked_indices, masked_sequence_str = self.mask_sequence(sequence_str)
235
+ return sequence_label, sequence_str, masked_sequence_str, masked_indices
236
+
237
+ def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
238
+ sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
239
+ sizes.sort()
240
+ batches = []
241
+ buf = []
242
+ max_len = 0
243
+
244
+ def _flush_current_buf():
245
+ nonlocal max_len, buf
246
+ if len(buf) == 0:
247
+ return
248
+ batches.append(buf)
249
+ buf = []
250
+ max_len = 0
251
+
252
+ for sz, i in sizes:
253
+ sz += extra_toks_per_seq
254
+ if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
255
+ _flush_current_buf()
256
+ max_len = max(max_len, sz)
257
+ buf.append(i)
258
+
259
+ _flush_current_buf()
260
+ return batches
261
+
262
+ class BatchConverter(object):
263
+ """Callable to convert an unprocessed (labels + strings) batch to a
264
+ processed (labels + tensor) batch.
265
+ """
266
+
267
+ def __init__(self, alphabet):
268
+ self.alphabet = alphabet
269
+
270
+ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
271
+ # RoBERTa uses an eos token, while ESM-1 does not.
272
+ batch_size = len(raw_batch)
273
+ batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch)
274
+
275
+ masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###---
276
+ seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###---
277
+ # print('====', seq_str_list)
278
+ # print('----', masked_seq_str_list)
279
+ # print('++++', masked_seq_encoded_list)
280
+ # print('****', seq_encoded_list)
281
+
282
+ max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list)
283
+ tokens = torch.empty(
284
+ (
285
+ batch_size,
286
+ max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
287
+ ),
288
+ dtype=torch.int64,
289
+ )
290
+ tokens.fill_(self.alphabet.padding_idx)
291
+ masked_tokens = deepcopy(tokens)
292
+
293
+ labels = []
294
+ strs, masked_strs = [], []
295
+ masked_indices = []
296
+ # print('=================')
297
+ for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate(
298
+ zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###---
299
+ ):
300
+ labels.append(label)
301
+ strs.append(seq_str)
302
+ masked_strs.append(masked_seq_str)
303
+ masked_indices.append(indices_mask)
304
+
305
+ if self.alphabet.prepend_bos:
306
+ tokens[i, 0] = self.alphabet.cls_idx
307
+ masked_tokens[i, 0] = self.alphabet.cls_idx
308
+
309
+ seq = torch.tensor(seq_encoded, dtype=torch.int64)
310
+ masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64)
311
+ # print(tokens, masked_tokens)
312
+ tokens[
313
+ i,
314
+ int(self.alphabet.prepend_bos) : len(seq_encoded)
315
+ + int(self.alphabet.prepend_bos),
316
+ ] = seq
317
+
318
+ masked_tokens[
319
+ i,
320
+ int(self.alphabet.prepend_bos) : len(masked_seq_encoded)
321
+ + int(self.alphabet.prepend_bos),
322
+ ] = masked_seq
323
+ # print(tokens, masked_tokens)
324
+ if self.alphabet.append_eos:
325
+ tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
326
+ masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
327
+ # print(tokens, masked_tokens)
328
+ return labels, strs, masked_strs, tokens, masked_tokens, masked_indices
329
+
330
+
331
+ class MSABatchConverter(BatchConverter):
332
+ def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
333
+ if isinstance(inputs[0][0], str):
334
+ # Input is a single MSA
335
+ raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
336
+ else:
337
+ raw_batch = inputs # type: ignore
338
+
339
+ batch_size = len(raw_batch)
340
+ max_alignments = max(len(msa) for msa in raw_batch)
341
+ max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
342
+
343
+ tokens = torch.empty(
344
+ (
345
+ batch_size,
346
+ max_alignments,
347
+ max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
348
+ ),
349
+ dtype=torch.int64,
350
+ )
351
+ tokens.fill_(self.alphabet.padding_idx)
352
+ labels = []
353
+ strs = []
354
+
355
+ for i, msa in enumerate(raw_batch):
356
+ msa_seqlens = set(len(seq) for _, seq in msa)
357
+ if not len(msa_seqlens) == 1:
358
+ raise RuntimeError(
359
+ "Received unaligned sequences for input to MSA, all sequence "
360
+ "lengths must be equal."
361
+ )
362
+ msa_labels, msa_strs, msa_tokens = super().__call__(msa)
363
+ labels.append(msa_labels)
364
+ strs.append(msa_strs)
365
+ tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
366
+
367
+ return labels, strs, tokens
368
+
369
+
370
+ def read_fasta(
371
+ path,
372
+ keep_gaps=True,
373
+ keep_insertions=True,
374
+ to_upper=False,
375
+ ):
376
+ with open(path, "r") as f:
377
+ for result in read_alignment_lines(
378
+ f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
379
+ ):
380
+ yield result
381
+
382
+
383
+ def read_alignment_lines(
384
+ lines,
385
+ keep_gaps=True,
386
+ keep_insertions=True,
387
+ to_upper=False,
388
+ ):
389
+ seq = desc = None
390
+
391
+ def parse(s):
392
+ if not keep_gaps:
393
+ s = re.sub("-", "", s)
394
+ if not keep_insertions:
395
+ s = re.sub("[a-z]", "", s)
396
+ return s.upper() if to_upper else s
397
+
398
+ for line in lines:
399
+ # Line may be empty if seq % file_line_width == 0
400
+ if len(line) > 0 and line[0] == ">":
401
+ if seq is not None:
402
+ yield desc, parse(seq)
403
+ desc = line.strip()
404
+ seq = ""
405
+ else:
406
+ assert isinstance(seq, str)
407
+ seq += line.strip()
408
+ assert isinstance(seq, str) and isinstance(desc, str)
409
+ yield desc, parse(seq)
410
+
411
+
412
+ class ESMStructuralSplitDataset(torch.utils.data.Dataset):
413
+ """
414
+ Structural Split Dataset as described in section A.10 of the supplement of our paper.
415
+ https://doi.org/10.1101/622803
416
+
417
+ We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
418
+ generated on January 23, 2020.
419
+
420
+ For each SCOPe domain:
421
+ - We extract the sequence from the corresponding PDB file
422
+ - We extract the 3D coordinates of the Carbon beta atoms, aligning them
423
+ to the sequence. We put NaN where Cb atoms are missing.
424
+ - From the 3D coordinates, we calculate a pairwise distance map, based
425
+ on L2 distance
426
+ - We use DSSP to generate secondary structure labels for the corresponding
427
+ PDB file. This is also aligned to the sequence. We put - where SSP
428
+ labels are missing.
429
+
430
+ For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
431
+ we have split the data into 5 partitions for cross validation. These are provided
432
+ in a downloaded splits folder, in the format:
433
+ splits/{split_level}/{cv_partition}/{train|valid}.txt
434
+ where train is the partition and valid is the concatentation of the remaining 4.
435
+
436
+ For each SCOPe domain, we provide a pkl dump that contains:
437
+ - seq : The domain sequence, stored as an L-length string
438
+ - ssp : The secondary structure labels, stored as an L-length string
439
+ - dist : The distance map, stored as an LxL numpy array
440
+ - coords : The 3D coordinates, stored as an Lx3 numpy array
441
+
442
+ """
443
+
444
+ base_folder = "structural-data"
445
+ file_list = [
446
+ # url tar filename filename MD5 Hash
447
+ (
448
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
449
+ "splits.tar.gz",
450
+ "splits",
451
+ "456fe1c7f22c9d3d8dfe9735da52411d",
452
+ ),
453
+ (
454
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
455
+ "pkl.tar.gz",
456
+ "pkl",
457
+ "644ea91e56066c750cd50101d390f5db",
458
+ ),
459
+ ]
460
+
461
+ def __init__(
462
+ self,
463
+ split_level,
464
+ cv_partition,
465
+ split,
466
+ root_path=os.path.expanduser("~/.cache/torch/data/esm"),
467
+ download=False,
468
+ ):
469
+ super().__init__()
470
+ assert split in [
471
+ "train",
472
+ "valid",
473
+ ], "train_valid must be 'train' or 'valid'"
474
+ self.root_path = root_path
475
+ self.base_path = os.path.join(self.root_path, self.base_folder)
476
+
477
+ # check if root path has what you need or else download it
478
+ if download:
479
+ self.download()
480
+
481
+ self.split_file = os.path.join(
482
+ self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
483
+ )
484
+ self.pkl_dir = os.path.join(self.base_path, "pkl")
485
+ self.names = []
486
+ with open(self.split_file) as f:
487
+ self.names = f.read().splitlines()
488
+
489
+ def __len__(self):
490
+ return len(self.names)
491
+
492
+ def _check_exists(self) -> bool:
493
+ for (_, _, filename, _) in self.file_list:
494
+ fpath = os.path.join(self.base_path, filename)
495
+ if not os.path.exists(fpath) or not os.path.isdir(fpath):
496
+ return False
497
+ return True
498
+
499
+ def download(self):
500
+
501
+ if self._check_exists():
502
+ print("Files already downloaded and verified")
503
+ return
504
+
505
+ from torchvision.datasets.utils import download_url
506
+
507
+ for url, tar_filename, filename, md5_hash in self.file_list:
508
+ download_path = os.path.join(self.base_path, tar_filename)
509
+ download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
510
+ shutil.unpack_archive(download_path, self.base_path)
511
+
512
+ def __getitem__(self, idx):
513
+ """
514
+ Returns a dict with the following entires
515
+ - seq : Str (domain sequence)
516
+ - ssp : Str (SSP labels)
517
+ - dist : np.array (distance map)
518
+ - coords : np.array (3D coordinates)
519
+ """
520
+ name = self.names[idx]
521
+ pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
522
+ with open(pkl_fname, "rb") as f:
523
+ obj = pickle.load(f)
524
+ return obj
esm/data_supervised.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import os
8
+ from typing import Sequence, Tuple, List, Union
9
+ import pickle
10
+ import re
11
+ import shutil
12
+ import torch
13
+ from pathlib import Path
14
+ from .constants import proteinseq_toks, rnaseq_toks
15
+ import math
16
+ import random
17
+ from copy import deepcopy
18
+
19
+ RawMSA = Sequence[Tuple[str, str]]
20
+
21
+
22
+ class Alphabet(object):
23
+ def __init__(
24
+ self,
25
+ standard_toks: Sequence[str],
26
+ prepend_toks: Sequence[str] = ("<pad>", "<eos>", "<unk>"), # "<null_0>",
27
+ append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), #
28
+ prepend_bos: bool = True,
29
+ append_eos: bool = True,
30
+ use_msa: bool = False,
31
+ mask_prob: float = 0.15, ###---
32
+ ):
33
+ self.mask_prob = mask_prob ###---
34
+ self.standard_toks = list(standard_toks)
35
+ self.prepend_toks = list(prepend_toks)
36
+ self.append_toks = list(append_toks)
37
+ self.prepend_bos = prepend_bos
38
+ self.append_eos = append_eos
39
+ self.use_msa = use_msa
40
+
41
+ self.all_toks = list(self.prepend_toks)
42
+ self.all_toks.extend(self.standard_toks)
43
+ # for i in range((8 - (len(self.all_toks) % 8)) % 8):
44
+ # self.all_toks.append(f"<null_{i + 1}>")
45
+ self.all_toks.extend(self.append_toks)
46
+
47
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
48
+ # print(self.tok_to_idx)
49
+ self.unk_idx = self.tok_to_idx["<unk>"]
50
+ self.padding_idx = self.get_idx("<pad>")
51
+ self.cls_idx = self.get_idx("<cls>")
52
+ self.mask_idx = self.get_idx("<mask>")
53
+ self.eos_idx = self.get_idx("<eos>")
54
+ self.all_special_tokens = ['<eos>', '<pad>', '<mask>'] # , '<unk>', '<cls>'
55
+ self.unique_no_split_tokens = self.all_toks
56
+
57
+ def __len__(self):
58
+ return len(self.all_toks)
59
+
60
+ def get_idx(self, tok):
61
+ return self.tok_to_idx.get(tok, self.unk_idx)
62
+
63
+ def get_tok(self, ind):
64
+ return self.all_toks[ind]
65
+
66
+ def to_dict(self):
67
+ return self.tok_to_idx.copy()
68
+
69
+ def get_batch_converter(self):
70
+ if self.use_msa:
71
+ return MSABatchConverter(self)
72
+ else:
73
+ return BatchConverter(self)
74
+
75
+ @classmethod
76
+ def from_architecture(cls, name: str) -> "Alphabet":
77
+ if name in ("ESM-1", "protein_bert_base"):
78
+ standard_toks = proteinseq_toks["toks"]
79
+ prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
80
+ append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
81
+ prepend_bos = True
82
+ append_eos = False
83
+ use_msa = False
84
+ elif name in ("ESM-1b", "roberta_large"):
85
+ standard_toks = proteinseq_toks["toks"] ###---rnaseq
86
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
87
+ append_toks = ("<mask>",)
88
+ prepend_bos = True
89
+ append_eos = True
90
+ use_msa = False
91
+ elif name in ("MSA Transformer", "msa_transformer"):
92
+ standard_toks = proteinseq_toks["toks"]
93
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
94
+ append_toks = ("<mask>",)
95
+ prepend_bos = True
96
+ append_eos = False
97
+ use_msa = True
98
+ else:
99
+ raise ValueError("Unknown architecture selected")
100
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
101
+
102
+ def _tokenize(self, text) -> str:
103
+ return text.split()
104
+
105
+ def tokenize(self, text, **kwargs) -> List[str]:
106
+ """
107
+ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
108
+ Converts a string in a sequence of tokens, using the tokenizer.
109
+
110
+ Args:
111
+ text (:obj:`str`):
112
+ The sequence to be encoded.
113
+
114
+ Returns:
115
+ :obj:`List[str]`: The list of tokens.
116
+ """
117
+
118
+ def split_on_token(tok, text):
119
+ result = []
120
+ split_text = text.split(tok)
121
+ for i, sub_text in enumerate(split_text):
122
+ # AddedToken can control whitespace stripping around them.
123
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
124
+ # Cf. https://github.com/huggingface/transformers/pull/2778
125
+ # and https://github.com/huggingface/transformers/issues/3788
126
+ # We strip left and right by default
127
+ if i < len(split_text) - 1:
128
+ sub_text = sub_text.rstrip()
129
+ if i > 0:
130
+ sub_text = sub_text.lstrip()
131
+
132
+ if i == 0 and not sub_text:
133
+ result.append(tok)
134
+ elif i == len(split_text) - 1:
135
+ if sub_text:
136
+ result.append(sub_text)
137
+ else:
138
+ pass
139
+ else:
140
+ if sub_text:
141
+ result.append(sub_text)
142
+ result.append(tok)
143
+ return result
144
+
145
+ def split_on_tokens(tok_list, text):
146
+ if not text.strip():
147
+ return []
148
+
149
+ tokenized_text = []
150
+ text_list = [text]
151
+ for tok in tok_list:
152
+ tokenized_text = []
153
+ for sub_text in text_list:
154
+ if sub_text not in self.unique_no_split_tokens:
155
+ tokenized_text.extend(split_on_token(tok, sub_text))
156
+ else:
157
+ tokenized_text.append(sub_text)
158
+ text_list = tokenized_text
159
+
160
+ return list(
161
+ itertools.chain.from_iterable(
162
+ (
163
+ self._tokenize(token)
164
+ if token not in self.unique_no_split_tokens
165
+ else [token]
166
+ for token in tokenized_text
167
+ )
168
+ )
169
+ )
170
+
171
+ no_split_token = self.unique_no_split_tokens
172
+ tokenized_text = split_on_tokens(no_split_token, text)
173
+ return tokenized_text
174
+
175
+ def encode(self, text):
176
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
177
+
178
+ class FastaBatchedDataset(object):
179
+ def __init__(self, sequence_labels, sequence_strs, mask_prob = 0.15):
180
+ self.sequence_labels = list(sequence_labels)
181
+ self.sequence_strs = list(sequence_strs)
182
+ self.mask_prob = mask_prob
183
+
184
+ @classmethod
185
+ def from_file(cls, fasta_file, mask_prob = 0.15):
186
+ sequence_labels, sequence_strs = [], []
187
+ cur_seq_label = None
188
+ buf = []
189
+
190
+ def _flush_current_seq():
191
+ nonlocal cur_seq_label, buf
192
+ if cur_seq_label is None:
193
+ return
194
+ sequence_labels.append(cur_seq_label)
195
+ sequence_strs.append("".join(buf))
196
+ cur_seq_label = None
197
+ buf = []
198
+
199
+ with open(fasta_file, "r") as infile:
200
+ for line_idx, line in enumerate(infile):
201
+ if line.startswith(">"): # label line
202
+ _flush_current_seq()
203
+ line = line[1:].strip()
204
+ if len(line) > 0:
205
+ cur_seq_label = line
206
+ else:
207
+ cur_seq_label = f"seqnum{line_idx:09d}"
208
+ else: # sequence line
209
+ buf.append(line.strip())
210
+
211
+ _flush_current_seq()
212
+
213
+ assert len(set(sequence_labels)) == len(
214
+ sequence_labels
215
+ ), "Found duplicate sequence labels"
216
+
217
+ return cls(sequence_labels, sequence_strs, mask_prob)
218
+
219
+ def __len__(self):
220
+ return len(self.sequence_labels)
221
+
222
+ def mask_sequence(self, seq): ###---
223
+ length = len(seq)
224
+ # print(self.mask_prob)
225
+ max_length = math.ceil(length * self.mask_prob)
226
+ rand = random.sample(range(0, length), max_length)
227
+ res = ''.join(['<mask>' if idx in rand else ele for idx, ele in enumerate(seq)])
228
+ #print(seq, rand, res)
229
+ return rand, res
230
+
231
+ def __getitem__(self, idx):
232
+ sequence_str = self.sequence_strs[idx]
233
+ sequence_label = self.sequence_labels[idx]
234
+ masked_indices, masked_sequence_str = self.mask_sequence(sequence_str)
235
+ return sequence_label, sequence_str, masked_sequence_str, masked_indices
236
+
237
+ def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
238
+ sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
239
+ sizes.sort()
240
+ batches = []
241
+ buf = []
242
+ max_len = 0
243
+
244
+ def _flush_current_buf():
245
+ nonlocal max_len, buf
246
+ if len(buf) == 0:
247
+ return
248
+ batches.append(buf)
249
+ buf = []
250
+ max_len = 0
251
+
252
+ for sz, i in sizes:
253
+ sz += extra_toks_per_seq
254
+ if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
255
+ _flush_current_buf()
256
+ max_len = max(max_len, sz)
257
+ buf.append(i)
258
+
259
+ _flush_current_buf()
260
+ return batches
261
+
262
+ class BatchConverter(object):
263
+ """Callable to convert an unprocessed (labels + strings) batch to a
264
+ processed (labels + tensor) batch.
265
+ """
266
+
267
+ def __init__(self, alphabet):
268
+ self.alphabet = alphabet
269
+
270
+ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
271
+ # RoBERTa uses an eos token, while ESM-1 does not.
272
+ batch_size = len(raw_batch)
273
+ batch_labels, seq_str_list, masked_seq_str_list, masked_indices_list = zip(*raw_batch)
274
+
275
+ masked_seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in masked_seq_str_list] ###---
276
+ seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] ###---
277
+ # print('====', seq_str_list)
278
+ # print('----', masked_seq_str_list)
279
+ # print('++++', masked_seq_encoded_list)
280
+ # print('****', seq_encoded_list)
281
+
282
+ max_len = max(len(seq_encoded) for seq_encoded in masked_seq_encoded_list)
283
+ tokens = torch.empty(
284
+ (
285
+ batch_size,
286
+ max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
287
+ ),
288
+ dtype=torch.int64,
289
+ )
290
+ tokens.fill_(self.alphabet.padding_idx)
291
+ masked_tokens = deepcopy(tokens)
292
+
293
+ labels = []
294
+ strs, masked_strs = [], []
295
+ masked_indices = []
296
+ # print('=================')
297
+ for i, (label, seq_str, masked_seq_str, seq_encoded, masked_seq_encoded, indices_mask) in enumerate(
298
+ zip(batch_labels, seq_str_list, masked_seq_str_list, seq_encoded_list, masked_seq_encoded_list, masked_indices_list) ###---
299
+ ):
300
+ labels.append(label)
301
+ strs.append(seq_str)
302
+ masked_strs.append(masked_seq_str)
303
+ masked_indices.append(indices_mask)
304
+
305
+ if self.alphabet.prepend_bos:
306
+ tokens[i, 0] = self.alphabet.cls_idx
307
+ masked_tokens[i, 0] = self.alphabet.cls_idx
308
+
309
+ seq = torch.tensor(seq_encoded, dtype=torch.int64)
310
+ masked_seq = torch.tensor(masked_seq_encoded, dtype=torch.int64)
311
+ # print(tokens, masked_tokens)
312
+ tokens[
313
+ i,
314
+ int(self.alphabet.prepend_bos) : len(seq_encoded)
315
+ + int(self.alphabet.prepend_bos),
316
+ ] = seq
317
+
318
+ masked_tokens[
319
+ i,
320
+ int(self.alphabet.prepend_bos) : len(masked_seq_encoded)
321
+ + int(self.alphabet.prepend_bos),
322
+ ] = masked_seq
323
+ # print(tokens, masked_tokens)
324
+ if self.alphabet.append_eos:
325
+ tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
326
+ masked_tokens[i, len(masked_seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
327
+ # print(tokens, masked_tokens)
328
+ return labels, strs, masked_strs, tokens, masked_tokens, masked_indices
329
+
330
+
331
+ class MSABatchConverter(BatchConverter):
332
+ def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
333
+ if isinstance(inputs[0][0], str):
334
+ # Input is a single MSA
335
+ raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
336
+ else:
337
+ raw_batch = inputs # type: ignore
338
+
339
+ batch_size = len(raw_batch)
340
+ max_alignments = max(len(msa) for msa in raw_batch)
341
+ max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
342
+
343
+ tokens = torch.empty(
344
+ (
345
+ batch_size,
346
+ max_alignments,
347
+ max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
348
+ ),
349
+ dtype=torch.int64,
350
+ )
351
+ tokens.fill_(self.alphabet.padding_idx)
352
+ labels = []
353
+ strs = []
354
+
355
+ for i, msa in enumerate(raw_batch):
356
+ msa_seqlens = set(len(seq) for _, seq in msa)
357
+ if not len(msa_seqlens) == 1:
358
+ raise RuntimeError(
359
+ "Received unaligned sequences for input to MSA, all sequence "
360
+ "lengths must be equal."
361
+ )
362
+ msa_labels, msa_strs, msa_tokens = super().__call__(msa)
363
+ labels.append(msa_labels)
364
+ strs.append(msa_strs)
365
+ tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
366
+
367
+ return labels, strs, tokens
368
+
369
+
370
+ def read_fasta(
371
+ path,
372
+ keep_gaps=True,
373
+ keep_insertions=True,
374
+ to_upper=False,
375
+ ):
376
+ with open(path, "r") as f:
377
+ for result in read_alignment_lines(
378
+ f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
379
+ ):
380
+ yield result
381
+
382
+
383
+ def read_alignment_lines(
384
+ lines,
385
+ keep_gaps=True,
386
+ keep_insertions=True,
387
+ to_upper=False,
388
+ ):
389
+ seq = desc = None
390
+
391
+ def parse(s):
392
+ if not keep_gaps:
393
+ s = re.sub("-", "", s)
394
+ if not keep_insertions:
395
+ s = re.sub("[a-z]", "", s)
396
+ return s.upper() if to_upper else s
397
+
398
+ for line in lines:
399
+ # Line may be empty if seq % file_line_width == 0
400
+ if len(line) > 0 and line[0] == ">":
401
+ if seq is not None:
402
+ yield desc, parse(seq)
403
+ desc = line.strip()
404
+ seq = ""
405
+ else:
406
+ assert isinstance(seq, str)
407
+ seq += line.strip()
408
+ assert isinstance(seq, str) and isinstance(desc, str)
409
+ yield desc, parse(seq)
410
+
411
+
412
+ class ESMStructuralSplitDataset(torch.utils.data.Dataset):
413
+ """
414
+ Structural Split Dataset as described in section A.10 of the supplement of our paper.
415
+ https://doi.org/10.1101/622803
416
+
417
+ We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
418
+ generated on January 23, 2020.
419
+
420
+ For each SCOPe domain:
421
+ - We extract the sequence from the corresponding PDB file
422
+ - We extract the 3D coordinates of the Carbon beta atoms, aligning them
423
+ to the sequence. We put NaN where Cb atoms are missing.
424
+ - From the 3D coordinates, we calculate a pairwise distance map, based
425
+ on L2 distance
426
+ - We use DSSP to generate secondary structure labels for the corresponding
427
+ PDB file. This is also aligned to the sequence. We put - where SSP
428
+ labels are missing.
429
+
430
+ For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
431
+ we have split the data into 5 partitions for cross validation. These are provided
432
+ in a downloaded splits folder, in the format:
433
+ splits/{split_level}/{cv_partition}/{train|valid}.txt
434
+ where train is the partition and valid is the concatentation of the remaining 4.
435
+
436
+ For each SCOPe domain, we provide a pkl dump that contains:
437
+ - seq : The domain sequence, stored as an L-length string
438
+ - ssp : The secondary structure labels, stored as an L-length string
439
+ - dist : The distance map, stored as an LxL numpy array
440
+ - coords : The 3D coordinates, stored as an Lx3 numpy array
441
+
442
+ """
443
+
444
+ base_folder = "structural-data"
445
+ file_list = [
446
+ # url tar filename filename MD5 Hash
447
+ (
448
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
449
+ "splits.tar.gz",
450
+ "splits",
451
+ "456fe1c7f22c9d3d8dfe9735da52411d",
452
+ ),
453
+ (
454
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
455
+ "pkl.tar.gz",
456
+ "pkl",
457
+ "644ea91e56066c750cd50101d390f5db",
458
+ ),
459
+ ]
460
+
461
+ def __init__(
462
+ self,
463
+ split_level,
464
+ cv_partition,
465
+ split,
466
+ root_path=os.path.expanduser("~/.cache/torch/data/esm"),
467
+ download=False,
468
+ ):
469
+ super().__init__()
470
+ assert split in [
471
+ "train",
472
+ "valid",
473
+ ], "train_valid must be 'train' or 'valid'"
474
+ self.root_path = root_path
475
+ self.base_path = os.path.join(self.root_path, self.base_folder)
476
+
477
+ # check if root path has what you need or else download it
478
+ if download:
479
+ self.download()
480
+
481
+ self.split_file = os.path.join(
482
+ self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
483
+ )
484
+ self.pkl_dir = os.path.join(self.base_path, "pkl")
485
+ self.names = []
486
+ with open(self.split_file) as f:
487
+ self.names = f.read().splitlines()
488
+
489
+ def __len__(self):
490
+ return len(self.names)
491
+
492
+ def _check_exists(self) -> bool:
493
+ for (_, _, filename, _) in self.file_list:
494
+ fpath = os.path.join(self.base_path, filename)
495
+ if not os.path.exists(fpath) or not os.path.isdir(fpath):
496
+ return False
497
+ return True
498
+
499
+ def download(self):
500
+
501
+ if self._check_exists():
502
+ print("Files already downloaded and verified")
503
+ return
504
+
505
+ from torchvision.datasets.utils import download_url
506
+
507
+ for url, tar_filename, filename, md5_hash in self.file_list:
508
+ download_path = os.path.join(self.base_path, tar_filename)
509
+ download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
510
+ shutil.unpack_archive(download_path, self.base_path)
511
+
512
+ def __getitem__(self, idx):
513
+ """
514
+ Returns a dict with the following entires
515
+ - seq : Str (domain sequence)
516
+ - ssp : Str (SSP labels)
517
+ - dist : np.array (distance map)
518
+ - coords : np.array (3D coordinates)
519
+ """
520
+ name = self.names[idx]
521
+ pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
522
+ with open(pkl_fname, "rb") as f:
523
+ obj = pickle.load(f)
524
+ return obj
esm/model/__pycache__/esm1.cpython-36.pyc ADDED
Binary file (5.18 kB). View file
 
esm/model/__pycache__/esm1.cpython-39.pyc ADDED
Binary file (5.19 kB). View file
 
esm/model/__pycache__/esm2.cpython-36.pyc ADDED
Binary file (3.51 kB). View file
 
esm/model/__pycache__/esm2.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
esm/model/__pycache__/esm2_only_secondarystructure.cpython-39.pyc ADDED
Binary file (4.54 kB). View file
 
esm/model/__pycache__/esm2_secondarystructure.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
esm/model/__pycache__/esm2_supervised.cpython-39.pyc ADDED
Binary file (4.55 kB). View file
 
esm/model/__pycache__/msa_transformer.cpython-36.pyc ADDED
Binary file (5.46 kB). View file
 
esm/model/__pycache__/msa_transformer.cpython-39.pyc ADDED
Binary file (5.51 kB). View file
 
esm/model/esm1.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from ..modules import (
13
+ TransformerLayer,
14
+ LearnedPositionalEmbedding,
15
+ SinusoidalPositionalEmbedding,
16
+ RobertaLMHead,
17
+ ESM1bLayerNorm,
18
+ ContactPredictionHead,
19
+ )
20
+
21
+
22
+ class ProteinBertModel(nn.Module):
23
+ @classmethod
24
+ def add_args(cls, parser):
25
+ parser.add_argument(
26
+ "--num_layers", default=36, type=int, metavar="N", help="number of layers"
27
+ )
28
+ parser.add_argument(
29
+ "--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
30
+ )
31
+ parser.add_argument(
32
+ "--logit_bias", action="store_true", help="whether to apply bias to logits"
33
+ )
34
+ parser.add_argument(
35
+ "--ffn_embed_dim",
36
+ default=5120,
37
+ type=int,
38
+ metavar="N",
39
+ help="embedding dimension for FFN",
40
+ )
41
+ parser.add_argument(
42
+ "--attention_heads",
43
+ default=20,
44
+ type=int,
45
+ metavar="N",
46
+ help="number of attention heads",
47
+ )
48
+
49
+ def __init__(self, args, alphabet):
50
+ super().__init__()
51
+ self.args = args
52
+ self.alphabet_size = len(alphabet)
53
+ self.padding_idx = alphabet.padding_idx
54
+ self.mask_idx = alphabet.mask_idx
55
+ self.cls_idx = alphabet.cls_idx
56
+ self.eos_idx = alphabet.eos_idx
57
+ self.prepend_bos = alphabet.prepend_bos
58
+ self.append_eos = alphabet.append_eos
59
+ self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
60
+ if self.args.arch == "roberta_large":
61
+ self.model_version = "ESM-1b"
62
+ self._init_submodules_esm1b()
63
+ else:
64
+ self.model_version = "ESM-1"
65
+ self._init_submodules_esm1()
66
+
67
+ def _init_submodules_common(self):
68
+ self.embed_tokens = nn.Embedding(
69
+ self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
70
+ )
71
+ self.layers = nn.ModuleList(
72
+ [
73
+ TransformerLayer(
74
+ self.args.embed_dim,
75
+ self.args.ffn_embed_dim,
76
+ self.args.attention_heads,
77
+ add_bias_kv=(self.model_version != "ESM-1b"),
78
+ use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
79
+ )
80
+ for _ in range(self.args.layers)
81
+ ]
82
+ )
83
+
84
+ self.contact_head = ContactPredictionHead(
85
+ self.args.layers * self.args.attention_heads,
86
+ self.prepend_bos,
87
+ self.append_eos,
88
+ eos_idx=self.eos_idx,
89
+ )
90
+
91
+ def _init_submodules_esm1b(self):
92
+ self._init_submodules_common()
93
+ self.embed_scale = 1
94
+ self.embed_positions = LearnedPositionalEmbedding(
95
+ self.args.max_positions, self.args.embed_dim, self.padding_idx
96
+ )
97
+ self.emb_layer_norm_before = (
98
+ ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
99
+ )
100
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
101
+ self.lm_head = RobertaLMHead(
102
+ embed_dim=self.args.embed_dim,
103
+ output_dim=self.alphabet_size,
104
+ weight=self.embed_tokens.weight,
105
+ )
106
+
107
+ def _init_submodules_esm1(self):
108
+ self._init_submodules_common()
109
+ self.embed_scale = math.sqrt(self.args.embed_dim)
110
+ self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
111
+ self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
112
+ self.embed_out_bias = None
113
+ if self.args.final_bias:
114
+ self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
115
+
116
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False):
117
+ if return_contacts:
118
+ need_head_weights = True
119
+
120
+ assert tokens.ndim == 2
121
+ padding_mask = tokens.eq(self.padding_idx) # B, T
122
+
123
+ x = self.embed_scale * self.embed_tokens(tokens)
124
+
125
+ if getattr(self.args, "token_dropout", False):
126
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
127
+ # x: B x T x C
128
+ mask_ratio_train = 0.15 * 0.8
129
+ src_lengths = (~padding_mask).sum(-1)
130
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
131
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
132
+
133
+ x = x + self.embed_positions(tokens)
134
+
135
+ if self.model_version == "ESM-1b":
136
+ if self.emb_layer_norm_before:
137
+ x = self.emb_layer_norm_before(x)
138
+ if padding_mask is not None:
139
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
140
+
141
+ repr_layers = set(repr_layers)
142
+ hidden_representations = {}
143
+ if 0 in repr_layers:
144
+ hidden_representations[0] = x
145
+
146
+ if need_head_weights:
147
+ attn_weights = []
148
+
149
+ # (B, T, E) => (T, B, E)
150
+ x = x.transpose(0, 1)
151
+
152
+ if not padding_mask.any():
153
+ padding_mask = None
154
+
155
+ for layer_idx, layer in enumerate(self.layers):
156
+ x, attn = layer(
157
+ x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
158
+ )
159
+ if (layer_idx + 1) in repr_layers:
160
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
161
+ if need_head_weights:
162
+ # (H, B, T, T) => (B, H, T, T)
163
+ attn_weights.append(attn.transpose(1, 0))
164
+
165
+ if self.model_version == "ESM-1b":
166
+ x = self.emb_layer_norm_after(x)
167
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
168
+
169
+ # last hidden representation should have layer norm applied
170
+ if (layer_idx + 1) in repr_layers:
171
+ hidden_representations[layer_idx + 1] = x
172
+ x = self.lm_head(x)
173
+ else:
174
+ x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
175
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
176
+
177
+ if return_representation:
178
+ result = {"logits": x, "representations": hidden_representations}
179
+ else:
180
+ result = {"logits": x}
181
+ if need_head_weights:
182
+ # attentions: B x L x H x T x T
183
+ attentions = torch.stack(attn_weights, 1)
184
+ if self.model_version == "ESM-1":
185
+ # ESM-1 models have an additional null-token for attention, which we remove
186
+ attentions = attentions[..., :-1]
187
+ if padding_mask is not None:
188
+ attention_mask = 1 - padding_mask.type_as(attentions)
189
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
190
+ attentions = attentions * attention_mask[:, None, None, :, :]
191
+ result["attentions"] = attentions
192
+ if return_contacts:
193
+ contacts = self.contact_head(tokens, attentions)
194
+ result["contacts"] = contacts
195
+
196
+ return result
197
+
198
+ def predict_contacts(self, tokens):
199
+ return self(tokens, return_contacts=True)["contacts"]
200
+
201
+ @property
202
+ def num_layers(self):
203
+ return self.args.layers
esm/model/esm2.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import esm
11
+ from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
12
+
13
+
14
+ class ESM2(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_layers: int = 33,
18
+ embed_dim: int = 1280,
19
+ attention_heads: int = 20,
20
+ alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
21
+ token_dropout: bool = True,
22
+ ):
23
+ super().__init__()
24
+ self.num_layers = num_layers
25
+ self.embed_dim = embed_dim
26
+ self.attention_heads = attention_heads
27
+ if not isinstance(alphabet, esm.data.Alphabet):
28
+ alphabet = esm.data.Alphabet.from_architecture(alphabet)
29
+ self.alphabet = alphabet
30
+ self.alphabet_size = len(alphabet)
31
+ self.padding_idx = alphabet.padding_idx
32
+ self.mask_idx = alphabet.mask_idx
33
+ self.cls_idx = alphabet.cls_idx
34
+ self.eos_idx = alphabet.eos_idx
35
+ self.prepend_bos = alphabet.prepend_bos
36
+ self.append_eos = alphabet.append_eos
37
+ self.token_dropout = token_dropout
38
+
39
+ self._init_submodules()
40
+
41
+ def _init_submodules(self):
42
+ self.embed_scale = 1
43
+ self.embed_tokens = nn.Embedding(
44
+ self.alphabet_size,
45
+ self.embed_dim,
46
+ padding_idx=self.padding_idx,
47
+ )
48
+
49
+ self.layers = nn.ModuleList(
50
+ [
51
+ TransformerLayer(
52
+ self.embed_dim,
53
+ 4 * self.embed_dim,
54
+ self.attention_heads,
55
+ add_bias_kv=False,
56
+ use_esm1b_layer_norm=True,
57
+ use_rotary_embeddings=True,
58
+ )
59
+ for _ in range(self.num_layers)
60
+ ]
61
+ )
62
+
63
+ self.contact_head = ContactPredictionHead(
64
+ self.num_layers * self.attention_heads,
65
+ self.prepend_bos,
66
+ self.append_eos,
67
+ eos_idx=self.eos_idx,
68
+ )
69
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
70
+
71
+ self.lm_head = RobertaLMHead(
72
+ embed_dim=self.embed_dim,
73
+ output_dim=self.alphabet_size,
74
+ weight=self.embed_tokens.weight,
75
+ )
76
+
77
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, return_representation=False):
78
+ if return_contacts:
79
+ need_head_weights = True
80
+
81
+ assert tokens.ndim == 2
82
+ padding_mask = tokens.eq(self.padding_idx) # B, T
83
+
84
+ x = self.embed_scale * self.embed_tokens(tokens)
85
+
86
+ if self.token_dropout:
87
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
88
+ # x: B x T x C
89
+ mask_ratio_train = 0.15 * 0.8
90
+ src_lengths = (~padding_mask).sum(-1)
91
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
92
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
93
+
94
+ if padding_mask is not None:
95
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
96
+
97
+ repr_layers = set(repr_layers)
98
+ hidden_representations = {}
99
+ if 0 in repr_layers:
100
+ hidden_representations[0] = x
101
+
102
+ if need_head_weights:
103
+ attn_weights = []
104
+
105
+ # (B, T, E) => (T, B, E)
106
+ x = x.transpose(0, 1)
107
+
108
+ if not padding_mask.any():
109
+ padding_mask = None
110
+
111
+ for layer_idx, layer in enumerate(self.layers):
112
+ x, attn = layer(
113
+ x,
114
+ self_attn_padding_mask=padding_mask,
115
+ need_head_weights=need_head_weights,
116
+ )
117
+ if (layer_idx + 1) in repr_layers:
118
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
119
+ if need_head_weights:
120
+ # (H, B, T, T) => (B, H, T, T)
121
+ attn_weights.append(attn.transpose(1, 0))
122
+ # print(x.shape) # 73, 2, 1280
123
+ x = self.emb_layer_norm_after(x)
124
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
125
+
126
+ # last hidden representation should have layer norm applied
127
+ if (layer_idx + 1) in repr_layers:
128
+ hidden_representations[layer_idx + 1] = x
129
+ x = self.lm_head(x)
130
+
131
+ if return_representation:
132
+ result = {"logits": x, "representations": hidden_representations}
133
+ else:
134
+ result = {"logits": x}
135
+ if need_head_weights:
136
+ # attentions: B x L x H x T x T
137
+ attentions = torch.stack(attn_weights, 1)
138
+ if padding_mask is not None:
139
+ attention_mask = 1 - padding_mask.type_as(attentions)
140
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
141
+ attentions = attentions * attention_mask[:, None, None, :, :]
142
+ result["attentions"] = attentions
143
+ if return_contacts:
144
+ attentions_symm, contacts = self.contact_head(tokens, attentions)
145
+ result["contacts"] = contacts
146
+ result["attentions_symm"] = attentions_symm
147
+
148
+ return result
149
+
150
+ def predict_contacts(self, tokens):
151
+ return self(tokens, return_contacts=True)["contacts"]
152
+
153
+ def predict_symmetric_attentions(self, tokens):
154
+ return self(tokens, return_contacts=True)["attentions_symm"]
155
+
156
+ def predict_attentions(self, tokens):
157
+ return self(tokens, need_head_weights=True)["attentions"]
158
+
159
+ def predict_representations(self, tokens):
160
+ return self(tokens, return_representation=True)['representations']
161
+
162
+ def predict_logits(self, tokens):
163
+ return self(tokens)['logits']
esm/model/esm2_only_secondarystructure.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import esm
11
+ from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
12
+
13
+
14
+ class ESM2(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_layers: int = 33,
18
+ embed_dim: int = 1280,
19
+ attention_heads: int = 20,
20
+ alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
21
+ token_dropout: bool = True,
22
+ ):
23
+ super().__init__()
24
+ self.num_layers = num_layers
25
+ self.embed_dim = embed_dim
26
+ self.attention_heads = attention_heads
27
+ if not isinstance(alphabet, esm.data.Alphabet):
28
+ alphabet = esm.data.Alphabet.from_architecture(alphabet)
29
+ self.alphabet = alphabet
30
+ self.alphabet_size = len(alphabet)
31
+ self.padding_idx = alphabet.padding_idx
32
+ self.mask_idx = alphabet.mask_idx
33
+ self.cls_idx = alphabet.cls_idx
34
+ self.eos_idx = alphabet.eos_idx
35
+ self.prepend_bos = alphabet.prepend_bos
36
+ self.append_eos = alphabet.append_eos
37
+ self.token_dropout = token_dropout
38
+
39
+ self._init_submodules()
40
+
41
+ def _init_submodules(self):
42
+ self.embed_scale = 1
43
+ self.embed_tokens = nn.Embedding(
44
+ self.alphabet_size,
45
+ self.embed_dim,
46
+ padding_idx=self.padding_idx,
47
+ )
48
+
49
+ self.layers = nn.ModuleList(
50
+ [
51
+ TransformerLayer(
52
+ self.embed_dim,
53
+ 4 * self.embed_dim,
54
+ self.attention_heads,
55
+ add_bias_kv=False,
56
+ use_esm1b_layer_norm=True,
57
+ use_rotary_embeddings=True,
58
+ )
59
+ for _ in range(self.num_layers)
60
+ ]
61
+ )
62
+
63
+ self.contact_head = ContactPredictionHead(
64
+ self.num_layers * self.attention_heads,
65
+ self.prepend_bos,
66
+ self.append_eos,
67
+ eos_idx=self.eos_idx,
68
+ )
69
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
70
+
71
+ self.lm_head = RobertaLMHead(
72
+ embed_dim=self.embed_dim,
73
+ output_dim=self.alphabet_size,
74
+ weight=self.embed_tokens.weight,
75
+ )
76
+ # self.supervised_linear = nn.Linear(self.embed_dim, 1)
77
+ self.structure_linear = nn.Linear(self.embed_dim, 3)
78
+ def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
79
+ if return_contacts:
80
+ need_head_weights = True
81
+
82
+ assert tokens.ndim == 2
83
+ padding_mask = tokens.eq(self.padding_idx) # B, T
84
+
85
+ x = self.embed_scale * self.embed_tokens(tokens)
86
+
87
+ if self.token_dropout:
88
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
89
+ #print(f'tokens = {tokens}')
90
+ #print(f'self.mask_idx = {self.mask_idx}')
91
+ #print('x.shape = ', x.shape)
92
+ # x: B x T x C
93
+ mask_ratio_train = 0.15 * 0.8
94
+ src_lengths = (~padding_mask).sum(-1)
95
+ #print(f'mask_ratio_train = {mask_ratio_train}')
96
+ #print(f'padding_mask = {padding_mask}')
97
+ #print(f'src_lengths = {src_lengths}')
98
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
99
+ #print('mask_ratio_observed = ',mask_ratio_observed)
100
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
101
+ #print(f'x.shape = {x.shape}:\n', x)
102
+ if padding_mask is not None:
103
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
104
+ #print(f'x.shape = {x.shape}:\n', x)
105
+ repr_layers = set(repr_layers)
106
+ hidden_representations = {}
107
+ if 0 in repr_layers:
108
+ hidden_representations[0] = x
109
+
110
+ if need_head_weights:
111
+ attn_weights = []
112
+
113
+ # (B, T, E) => (T, B, E)
114
+ x = x.transpose(0, 1)
115
+
116
+ if not padding_mask.any():
117
+ padding_mask = None
118
+
119
+ for layer_idx, layer in enumerate(self.layers):
120
+ x, attn = layer(
121
+ x,
122
+ self_attn_padding_mask=padding_mask,
123
+ need_head_weights=need_head_weights,
124
+ )
125
+ if (layer_idx + 1) in repr_layers:
126
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
127
+ if need_head_weights:
128
+ # (H, B, T, T) => (B, H, T, T)
129
+ attn_weights.append(attn.transpose(1, 0))
130
+ # print(x.shape) # 73, 2, 1280
131
+ x = self.emb_layer_norm_after(x)
132
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
133
+
134
+ # last hidden representation should have layer norm applied
135
+ if (layer_idx + 1) in repr_layers:
136
+ hidden_representations[layer_idx + 1] = x
137
+ # x_supervised = self.supervised_linear(x[:,0,:])
138
+ x_structure = self.structure_linear(x)
139
+ x = self.lm_head(x)
140
+
141
+ if return_representation:
142
+ result = {"logits": x, "logits_structure": x_structure, "representations": hidden_representations}
143
+ else:
144
+ result = {"logits": x, "logits_structure": x_structure}
145
+ if need_head_weights:
146
+ # attentions: B x L x H x T x T
147
+ attentions = torch.stack(attn_weights, 1)
148
+ if padding_mask is not None:
149
+ attention_mask = 1 - padding_mask.type_as(attentions)
150
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
151
+ attentions = attentions * attention_mask[:, None, None, :, :]
152
+ if return_attentions: result["attentions"] = attentions
153
+ if return_contacts:
154
+ attentions_symm, contacts = self.contact_head(tokens, attentions)
155
+ result["contacts"] = contacts
156
+ if return_attentions_symm: result["attentions_symm"] = attentions_symm
157
+
158
+ return result
159
+
160
+ def predict_contacts(self, tokens):
161
+ return self(tokens, return_contacts=True)["contacts"]
162
+
163
+ def predict_symmetric_attentions(self, tokens):
164
+ return self(tokens, return_contacts=True)["attentions_symm"]
165
+
166
+ def predict_attentions(self, tokens):
167
+ return self(tokens, need_head_weights=True)["attentions"]
168
+
169
+ def predict_representations(self, tokens):
170
+ return self(tokens, return_representation=True)['representations']
171
+
172
+ def predict_logits(self, tokens):
173
+ return self(tokens)['logits']
174
+
175
+ # def predict_logits_supervised(self, tokens):
176
+ # return self(tokens)['logits_supervised']
177
+
178
+ def predict_logits_structure(self, tokens):
179
+ return self(tokens)['logits_structure']
esm/model/esm2_secondarystructure.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import esm
11
+ from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
12
+ # ```该代码定义了一个名为 ESM2 的 PyTorch 模型,继承自 nn.Module。在 __init__ 方法中,定义了一些超参数,例如 num_layers、embed_dim、attention_heads 等等。同时,它还初始化了一些子模块,例如 Embedding 层 embed_tokens、一系列 Transformer 层 layers、预测接触的 ContactPredictionHead 层 contact_head,以及一些线性层 lm_head、supervised_linear、structure_linear 等。该模型的前向传播在 forward 方法中定义,接收一个表示序列的 token 序列 tokens,返回预测的标签和其他附加信息。```
13
+
14
+ class ESM2(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_layers: int = 33,
18
+ embed_dim: int = 1280,
19
+ attention_heads: int = 20,
20
+ alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
21
+ token_dropout: bool = True,
22
+ ):
23
+ super().__init__()
24
+ self.num_layers = num_layers
25
+ self.embed_dim = embed_dim
26
+ self.attention_heads = attention_heads
27
+ if not isinstance(alphabet, esm.data.Alphabet):
28
+ alphabet = esm.data.Alphabet.from_architecture(alphabet)
29
+ self.alphabet = alphabet
30
+ self.alphabet_size = len(alphabet)
31
+ self.padding_idx = alphabet.padding_idx
32
+ self.mask_idx = alphabet.mask_idx
33
+ self.cls_idx = alphabet.cls_idx
34
+ self.eos_idx = alphabet.eos_idx
35
+ self.prepend_bos = alphabet.prepend_bos
36
+ self.append_eos = alphabet.append_eos
37
+ self.token_dropout = token_dropout
38
+
39
+ self._init_submodules()
40
+
41
+ def _init_submodules(self):
42
+ self.embed_scale = 1
43
+ self.embed_tokens = nn.Embedding(
44
+ self.alphabet_size,
45
+ self.embed_dim,
46
+ padding_idx=self.padding_idx,
47
+ )
48
+
49
+ self.layers = nn.ModuleList(
50
+ [
51
+ TransformerLayer(
52
+ self.embed_dim,
53
+ 4 * self.embed_dim,
54
+ self.attention_heads,
55
+ add_bias_kv=False,
56
+ use_esm1b_layer_norm=True,
57
+ use_rotary_embeddings=True,
58
+ )
59
+ for _ in range(self.num_layers)
60
+ ]
61
+ )
62
+
63
+ self.contact_head = ContactPredictionHead(
64
+ self.num_layers * self.attention_heads,
65
+ self.prepend_bos,
66
+ self.append_eos,
67
+ eos_idx=self.eos_idx,
68
+ )
69
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
70
+
71
+ self.lm_head = RobertaLMHead(
72
+ embed_dim=self.embed_dim,
73
+ output_dim=self.alphabet_size,
74
+ weight=self.embed_tokens.weight,
75
+ )
76
+ self.supervised_linear = nn.Linear(self.embed_dim, 1)
77
+ self.structure_linear = nn.Linear(self.embed_dim, 3)
78
+ def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
79
+ if return_contacts:
80
+ need_head_weights = True
81
+
82
+ assert tokens.ndim == 2
83
+ padding_mask = tokens.eq(self.padding_idx) # B, T
84
+
85
+ x = self.embed_scale * self.embed_tokens(tokens)
86
+
87
+ if self.token_dropout:
88
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
89
+ #print(f'tokens = {tokens}')
90
+ #print(f'self.mask_idx = {self.mask_idx}')
91
+ #print('x.shape = ', x.shape)
92
+ # x: B x T x C
93
+ mask_ratio_train = 0.15 * 0.8
94
+ src_lengths = (~padding_mask).sum(-1)
95
+ #print(f'mask_ratio_train = {mask_ratio_train}')
96
+ #print(f'padding_mask = {padding_mask}')
97
+ #print(f'src_lengths = {src_lengths}')
98
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
99
+ #print('mask_ratio_observed = ',mask_ratio_observed)
100
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
101
+ #print(f'x.shape = {x.shape}:\n', x)
102
+ if padding_mask is not None:
103
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
104
+ #print(f'x.shape = {x.shape}:\n', x)
105
+ repr_layers = set(repr_layers)
106
+ hidden_representations = {}
107
+ if 0 in repr_layers:
108
+ hidden_representations[0] = x
109
+
110
+ if need_head_weights:
111
+ attn_weights = []
112
+
113
+ # (B, T, E) => (T, B, E)
114
+ x = x.transpose(0, 1)
115
+
116
+ if not padding_mask.any():
117
+ padding_mask = None
118
+
119
+ for layer_idx, layer in enumerate(self.layers):
120
+ x, attn = layer(
121
+ x,
122
+ self_attn_padding_mask=padding_mask,
123
+ need_head_weights=need_head_weights,
124
+ )
125
+ if (layer_idx + 1) in repr_layers:
126
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
127
+ if need_head_weights:
128
+ # (H, B, T, T) => (B, H, T, T)
129
+ attn_weights.append(attn.transpose(1, 0))
130
+ # print(x.shape) # 73, 2, 1280
131
+ x = self.emb_layer_norm_after(x)
132
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
133
+
134
+ # last hidden representation should have layer norm applied
135
+ if (layer_idx + 1) in repr_layers:
136
+ hidden_representations[layer_idx + 1] = x
137
+ x_supervised = self.supervised_linear(x[:,0,:])
138
+ x_structure = self.structure_linear(x)
139
+ x = self.lm_head(x)
140
+
141
+ if return_representation:
142
+ result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure, "representations": hidden_representations}
143
+ else:
144
+ result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure}
145
+ if need_head_weights:
146
+ # attentions: B x L x H x T x T
147
+ attentions = torch.stack(attn_weights, 1)
148
+ if padding_mask is not None:
149
+ attention_mask = 1 - padding_mask.type_as(attentions)
150
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
151
+ attentions = attentions * attention_mask[:, None, None, :, :]
152
+ if return_attentions: result["attentions"] = attentions
153
+ if return_contacts:
154
+ attentions_symm, contacts = self.contact_head(tokens, attentions)
155
+ result["contacts"] = contacts
156
+ if return_attentions_symm: result["attentions_symm"] = attentions_symm
157
+
158
+ return result
159
+
160
+ def predict_contacts(self, tokens):
161
+ return self(tokens, return_contacts=True)["contacts"]
162
+
163
+ def predict_symmetric_attentions(self, tokens):
164
+ return self(tokens, return_contacts=True)["attentions_symm"]
165
+
166
+ def predict_attentions(self, tokens):
167
+ return self(tokens, need_head_weights=True)["attentions"]
168
+
169
+ def predict_representations(self, tokens):
170
+ return self(tokens, return_representation=True)['representations']
171
+
172
+ def predict_logits(self, tokens):
173
+ return self(tokens)['logits']
174
+
175
+ def predict_logits_supervised(self, tokens):
176
+ return self(tokens)['logits_supervised']
177
+
178
+ def predict_logits_structure(self, tokens):
179
+ return self(tokens)['logits_structure']
esm/model/esm2_supervised.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import esm
11
+ from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
12
+
13
+
14
+ class ESM2(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_layers: int = 33,
18
+ embed_dim: int = 1280,
19
+ attention_heads: int = 20,
20
+ alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
21
+ token_dropout: bool = True,
22
+ ):
23
+ super().__init__()
24
+ self.num_layers = num_layers
25
+ self.embed_dim = embed_dim
26
+ self.attention_heads = attention_heads
27
+ if not isinstance(alphabet, esm.data.Alphabet):
28
+ alphabet = esm.data.Alphabet.from_architecture(alphabet)
29
+ self.alphabet = alphabet
30
+ self.alphabet_size = len(alphabet)
31
+ self.padding_idx = alphabet.padding_idx
32
+ self.mask_idx = alphabet.mask_idx
33
+ self.cls_idx = alphabet.cls_idx
34
+ self.eos_idx = alphabet.eos_idx
35
+ self.prepend_bos = alphabet.prepend_bos
36
+ self.append_eos = alphabet.append_eos
37
+ self.token_dropout = token_dropout
38
+
39
+ self._init_submodules()
40
+
41
+ def _init_submodules(self):
42
+ self.embed_scale = 1
43
+ self.embed_tokens = nn.Embedding(
44
+ self.alphabet_size,
45
+ self.embed_dim,
46
+ padding_idx=self.padding_idx,
47
+ )
48
+
49
+ self.layers = nn.ModuleList(
50
+ [
51
+ TransformerLayer(
52
+ self.embed_dim,
53
+ 4 * self.embed_dim,
54
+ self.attention_heads,
55
+ add_bias_kv=False,
56
+ use_esm1b_layer_norm=True,
57
+ use_rotary_embeddings=True,
58
+ )
59
+ for _ in range(self.num_layers)
60
+ ]
61
+ )
62
+
63
+ self.contact_head = ContactPredictionHead(
64
+ self.num_layers * self.attention_heads,
65
+ self.prepend_bos,
66
+ self.append_eos,
67
+ eos_idx=self.eos_idx,
68
+ )
69
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
70
+
71
+ self.lm_head = RobertaLMHead(
72
+ embed_dim=self.embed_dim,
73
+ output_dim=self.alphabet_size,
74
+ weight=self.embed_tokens.weight,
75
+ )
76
+ self.supervised_linear = nn.Linear(self.embed_dim, 1)
77
+ def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
78
+ if return_contacts:
79
+ need_head_weights = True
80
+
81
+ assert tokens.ndim == 2
82
+ padding_mask = tokens.eq(self.padding_idx) # B, T
83
+
84
+ x = self.embed_scale * self.embed_tokens(tokens)
85
+
86
+ if self.token_dropout:
87
+ x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
88
+ #print(f'tokens = {tokens}')
89
+ #print(f'self.mask_idx = {self.mask_idx}')
90
+ #print('x.shape = ', x.shape)
91
+ # x: B x T x C
92
+ mask_ratio_train = 0.15 * 0.8
93
+ src_lengths = (~padding_mask).sum(-1)
94
+ #print(f'mask_ratio_train = {mask_ratio_train}')
95
+ #print(f'padding_mask = {padding_mask}')
96
+ #print(f'src_lengths = {src_lengths}')
97
+ mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
98
+ #print('mask_ratio_observed = ',mask_ratio_observed)
99
+ x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
100
+ #print(f'x.shape = {x.shape}:\n', x)
101
+ if padding_mask is not None:
102
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
103
+ #print(f'x.shape = {x.shape}:\n', x)
104
+ repr_layers = set(repr_layers)
105
+ hidden_representations = {}
106
+ if 0 in repr_layers:
107
+ hidden_representations[0] = x
108
+
109
+ if need_head_weights:
110
+ attn_weights = []
111
+
112
+ # (B, T, E) => (T, B, E)
113
+ x = x.transpose(0, 1)
114
+
115
+ if not padding_mask.any():
116
+ padding_mask = None
117
+
118
+ for layer_idx, layer in enumerate(self.layers):
119
+ x, attn = layer(
120
+ x,
121
+ self_attn_padding_mask=padding_mask,
122
+ need_head_weights=need_head_weights,
123
+ )
124
+ if (layer_idx + 1) in repr_layers:
125
+ hidden_representations[layer_idx + 1] = x.transpose(0, 1)
126
+ if need_head_weights:
127
+ # (H, B, T, T) => (B, H, T, T)
128
+ attn_weights.append(attn.transpose(1, 0))
129
+ # print(x.shape) # 73, 2, 1280
130
+ x = self.emb_layer_norm_after(x)
131
+ x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
132
+
133
+ # last hidden representation should have layer norm applied
134
+ if (layer_idx + 1) in repr_layers:
135
+ hidden_representations[layer_idx + 1] = x
136
+ x_supervised = self.supervised_linear(x[:,0,:])
137
+ x = self.lm_head(x)
138
+
139
+ if return_representation:
140
+ result = {"logits": x, "logits_supervised": x_supervised, "representations": hidden_representations}
141
+ else:
142
+ result = {"logits": x, "logits_supervised": x_supervised}
143
+ if need_head_weights:
144
+ # attentions: B x L x H x T x T
145
+ attentions = torch.stack(attn_weights, 1)
146
+ if padding_mask is not None:
147
+ attention_mask = 1 - padding_mask.type_as(attentions)
148
+ attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
149
+ attentions = attentions * attention_mask[:, None, None, :, :]
150
+ if return_attentions: result["attentions"] = attentions
151
+ if return_contacts:
152
+ attentions_symm, contacts = self.contact_head(tokens, attentions)
153
+ result["contacts"] = contacts
154
+ if return_attentions_symm: result["attentions_symm"] = attentions_symm
155
+
156
+ return result
157
+
158
+ def predict_contacts(self, tokens):
159
+ return self(tokens, return_contacts=True)["contacts"]
160
+
161
+ def predict_symmetric_attentions(self, tokens):
162
+ return self(tokens, return_contacts=True)["attentions_symm"]
163
+
164
+ def predict_attentions(self, tokens):
165
+ return self(tokens, need_head_weights=True)["attentions"]
166
+
167
+ def predict_representations(self, tokens):
168
+ return self(tokens, return_representation=True)['representations']
169
+
170
+ def predict_logits(self, tokens):
171
+ return self(tokens)['logits']
172
+
173
+ def predict_logits_supervised(self, tokens):
174
+ return self(tokens)['logits_supervised']
esm/model/msa_transformer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..modules import (
10
+ AxialTransformerLayer,
11
+ LearnedPositionalEmbedding,
12
+ RobertaLMHead,
13
+ ESM1bLayerNorm,
14
+ ContactPredictionHead,
15
+ )
16
+
17
+ from ..axial_attention import RowSelfAttention, ColumnSelfAttention
18
+
19
+
20
+
21
+ class MSATransformer(nn.Module):
22
+ @classmethod
23
+ def add_args(cls, parser):
24
+ # fmt: off
25
+ parser.add_argument(
26
+ "--num_layers",
27
+ default=12,
28
+ type=int,
29
+ metavar="N",
30
+ help="number of layers"
31
+ )
32
+ parser.add_argument(
33
+ "--embed_dim",
34
+ default=768,
35
+ type=int,
36
+ metavar="N",
37
+ help="embedding dimension"
38
+ )
39
+ parser.add_argument(
40
+ "--logit_bias",
41
+ action="store_true",
42
+ help="whether to apply bias to logits"
43
+ )
44
+ parser.add_argument(
45
+ "--ffn_embed_dim",
46
+ default=3072,
47
+ type=int,
48
+ metavar="N",
49
+ help="embedding dimension for FFN",
50
+ )
51
+ parser.add_argument(
52
+ "--attention_heads",
53
+ default=12,
54
+ type=int,
55
+ metavar="N",
56
+ help="number of attention heads",
57
+ )
58
+ parser.add_argument(
59
+ "--dropout",
60
+ default=0.1,
61
+ type=float,
62
+ help="Dropout to apply."
63
+ )
64
+ parser.add_argument(
65
+ "--attention_dropout",
66
+ default=0.1,
67
+ type=float,
68
+ help="Dropout to apply."
69
+ )
70
+ parser.add_argument(
71
+ "--activation_dropout",
72
+ default=0.1,
73
+ type=float,
74
+ help="Dropout to apply."
75
+ )
76
+ parser.add_argument(
77
+ "--max_tokens_per_msa",
78
+ default=2 ** 14,
79
+ type=int,
80
+ help=(
81
+ "Used during inference to batch attention computations in a single "
82
+ "forward pass. This allows increased input sizes with less memory."
83
+ ),
84
+ )
85
+ # fmt: on
86
+
87
+ def __init__(self, args, alphabet):
88
+ super().__init__()
89
+ self.args = args
90
+ self.alphabet_size = len(alphabet)
91
+ self.padding_idx = alphabet.padding_idx
92
+ self.mask_idx = alphabet.mask_idx
93
+ self.cls_idx = alphabet.cls_idx
94
+ self.eos_idx = alphabet.eos_idx
95
+ self.prepend_bos = alphabet.prepend_bos
96
+ self.append_eos = alphabet.append_eos
97
+
98
+ self.embed_tokens = nn.Embedding(
99
+ self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
100
+ )
101
+
102
+ if getattr(self.args, "embed_positions_msa", False):
103
+ emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim)
104
+ self.msa_position_embedding = nn.Parameter(
105
+ 0.01 * torch.randn(1, 1024, 1, emb_dim),
106
+ requires_grad=True,
107
+ )
108
+ else:
109
+ self.register_parameter("msa_position_embedding", None)
110
+
111
+ self.dropout_module = nn.Dropout(self.args.dropout)
112
+ self.layers = nn.ModuleList(
113
+ [
114
+ AxialTransformerLayer(
115
+ self.args.embed_dim,
116
+ self.args.ffn_embed_dim,
117
+ self.args.attention_heads,
118
+ self.args.dropout,
119
+ self.args.attention_dropout,
120
+ self.args.activation_dropout,
121
+ getattr(self.args, "max_tokens_per_msa", self.args.max_tokens),
122
+ )
123
+ for _ in range(self.args.layers)
124
+ ]
125
+ )
126
+
127
+ self.contact_head = ContactPredictionHead(
128
+ self.args.layers * self.args.attention_heads,
129
+ self.prepend_bos,
130
+ self.append_eos,
131
+ eos_idx=self.eos_idx,
132
+ )
133
+ self.embed_positions = LearnedPositionalEmbedding(
134
+ self.args.max_positions,
135
+ self.args.embed_dim,
136
+ self.padding_idx,
137
+ )
138
+ self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim)
139
+ self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
140
+ self.lm_head = RobertaLMHead(
141
+ embed_dim=self.args.embed_dim,
142
+ output_dim=self.alphabet_size,
143
+ weight=self.embed_tokens.weight,
144
+ )
145
+
146
+ def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
147
+ if return_contacts:
148
+ need_head_weights = True
149
+
150
+ assert tokens.ndim == 3
151
+ batch_size, num_alignments, seqlen = tokens.size()
152
+ padding_mask = tokens.eq(self.padding_idx) # B, R, C
153
+ if not padding_mask.any():
154
+ padding_mask = None
155
+
156
+ x = self.embed_tokens(tokens)
157
+ x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
158
+ if self.msa_position_embedding is not None:
159
+ if x.size(1) > 1024:
160
+ raise RuntimeError(
161
+ "Using model with MSA position embedding trained on maximum MSA "
162
+ f"depth of 1024, but received {x.size(1)} alignments."
163
+ )
164
+ x += self.msa_position_embedding[:, :num_alignments]
165
+
166
+ x = self.emb_layer_norm_before(x)
167
+
168
+ x = self.dropout_module(x)
169
+
170
+ if padding_mask is not None:
171
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
172
+
173
+ repr_layers = set(repr_layers)
174
+ hidden_representations = {}
175
+ if 0 in repr_layers:
176
+ hidden_representations[0] = x
177
+
178
+ if need_head_weights:
179
+ row_attn_weights = []
180
+ col_attn_weights = []
181
+
182
+ # B x R x C x D -> R x C x B x D
183
+ x = x.permute(1, 2, 0, 3)
184
+
185
+ for layer_idx, layer in enumerate(self.layers):
186
+ x = layer(
187
+ x,
188
+ self_attn_padding_mask=padding_mask,
189
+ need_head_weights=need_head_weights,
190
+ )
191
+ if need_head_weights:
192
+ x, col_attn, row_attn = x
193
+ # H x C x B x R x R -> B x H x C x R x R
194
+ col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
195
+ # H x B x C x C -> B x H x C x C
196
+ row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
197
+ if (layer_idx + 1) in repr_layers:
198
+ hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
199
+
200
+ x = self.emb_layer_norm_after(x)
201
+ x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
202
+
203
+ # last hidden representation should have layer norm applied
204
+ if (layer_idx + 1) in repr_layers:
205
+ hidden_representations[layer_idx + 1] = x
206
+ x = self.lm_head(x)
207
+
208
+ result = {"logits": x, "representations": hidden_representations}
209
+ if need_head_weights:
210
+ # col_attentions: B x L x H x C x R x R
211
+ col_attentions = torch.stack(col_attn_weights, 1)
212
+ # row_attentions: B x L x H x C x C
213
+ row_attentions = torch.stack(row_attn_weights, 1)
214
+ result["col_attentions"] = col_attentions
215
+ result["row_attentions"] = row_attentions
216
+ if return_contacts:
217
+ contacts = self.contact_head(tokens, row_attentions)
218
+ result["contacts"] = contacts
219
+
220
+ return result
221
+
222
+ def predict_contacts(self, tokens):
223
+ return self(tokens, return_contacts=True)["contacts"]
224
+
225
+ @property
226
+ def num_layers(self):
227
+ return self.args.layers
228
+
229
+ def max_tokens_per_msa_(self, value: int) -> None:
230
+ """The MSA Transformer automatically batches attention computations when
231
+ gradients are disabled to allow you to pass in larger MSAs at test time than
232
+ you can fit in GPU memory. By default this occurs when more than 2^14 tokens
233
+ are passed in the input MSA. You can set this value to infinity to disable
234
+ this behavior.
235
+ """
236
+ for module in self.modules():
237
+ if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
238
+ module.max_tokens_per_msa = value
esm/modules.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .multihead_attention import MultiheadAttention # noqa
14
+ from .axial_attention import ColumnSelfAttention, RowSelfAttention
15
+
16
+
17
+ def gelu(x):
18
+ """Implementation of the gelu activation function.
19
+ For information: OpenAI GPT's gelu is slightly different
20
+ (and gives slightly different results):
21
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
22
+ """
23
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24
+
25
+
26
+ def symmetrize(x):
27
+ "Make layer symmetric in final two dimensions, used for contact prediction."
28
+ return x + x.transpose(-1, -2)
29
+
30
+
31
+ def apc(x):
32
+ "Perform average product correct, used for contact prediction."
33
+ a1 = x.sum(-1, keepdims=True)
34
+ a2 = x.sum(-2, keepdims=True)
35
+ a12 = x.sum((-1, -2), keepdims=True)
36
+
37
+ avg = a1 * a2
38
+ avg.div_(a12) # in-place to reduce memory
39
+ normalized = x - avg
40
+ return normalized
41
+
42
+
43
+ class ESM1LayerNorm(nn.Module):
44
+ def __init__(self, hidden_size, eps=1e-12, affine=True):
45
+ """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
46
+ super().__init__()
47
+ self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
48
+ self.eps = eps
49
+ self.affine = bool(affine)
50
+ if self.affine:
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
53
+ else:
54
+ self.weight, self.bias = None, None
55
+
56
+ def forward(self, x):
57
+ dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
58
+ means = x.mean(dims, keepdim=True)
59
+ x_zeromean = x - means
60
+ variances = x_zeromean.pow(2).mean(dims, keepdim=True)
61
+ x = x_zeromean / torch.sqrt(variances + self.eps)
62
+ if self.affine:
63
+ x = (self.weight * x) + self.bias
64
+ return x
65
+
66
+
67
+ try:
68
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
69
+
70
+ class ESM1bLayerNorm(_FusedLayerNorm):
71
+ @torch.jit.unused
72
+ def forward(self, x):
73
+ if not x.is_cuda:
74
+ return super().forward(x)
75
+ else:
76
+ with torch.cuda.device(x.device):
77
+ return super().forward(x)
78
+
79
+ except ImportError:
80
+ from torch.nn import LayerNorm as ESM1bLayerNorm
81
+
82
+
83
+ class TransformerLayer(nn.Module):
84
+ """Transformer layer block."""
85
+
86
+ def __init__(
87
+ self,
88
+ embed_dim,
89
+ ffn_embed_dim,
90
+ attention_heads,
91
+ add_bias_kv=True,
92
+ use_esm1b_layer_norm=False,
93
+ use_rotary_embeddings: bool = False,
94
+ ):
95
+ super().__init__()
96
+ self.embed_dim = embed_dim
97
+ self.ffn_embed_dim = ffn_embed_dim
98
+ self.attention_heads = attention_heads
99
+ self.use_rotary_embeddings = use_rotary_embeddings
100
+ self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
101
+
102
+ def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
103
+ BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
104
+
105
+ self.self_attn = MultiheadAttention(
106
+ self.embed_dim,
107
+ self.attention_heads,
108
+ add_bias_kv=add_bias_kv,
109
+ add_zero_attn=False,
110
+ use_rotary_embeddings=self.use_rotary_embeddings,
111
+ )
112
+ self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
113
+
114
+ self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
115
+ self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
116
+
117
+ self.final_layer_norm = BertLayerNorm(self.embed_dim)
118
+
119
+ def forward(
120
+ self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
121
+ ):
122
+ residual = x
123
+ x = self.self_attn_layer_norm(x)
124
+ x, attn = self.self_attn(
125
+ query=x,
126
+ key=x,
127
+ value=x,
128
+ key_padding_mask=self_attn_padding_mask,
129
+ need_weights=True,
130
+ need_head_weights=need_head_weights,
131
+ attn_mask=self_attn_mask,
132
+ )
133
+ x = residual + x
134
+
135
+ residual = x
136
+ x = self.final_layer_norm(x)
137
+ x = gelu(self.fc1(x))
138
+ x = self.fc2(x)
139
+ x = residual + x
140
+ #print(f'------{attn.half().dtype}-----')
141
+
142
+ return x, attn#.half() ###
143
+
144
+
145
+ class AxialTransformerLayer(nn.Module):
146
+ """Implements an Axial MSA Transformer block."""
147
+
148
+ def __init__(
149
+ self,
150
+ embedding_dim: int = 768,
151
+ ffn_embedding_dim: int = 3072,
152
+ num_attention_heads: int = 8,
153
+ dropout: float = 0.1,
154
+ attention_dropout: float = 0.1,
155
+ activation_dropout: float = 0.1,
156
+ max_tokens_per_msa: int = 2**14,
157
+ ) -> None:
158
+ super().__init__()
159
+
160
+ # Initialize parameters
161
+ self.embedding_dim = embedding_dim
162
+ self.dropout_prob = dropout
163
+
164
+ row_self_attention = RowSelfAttention(
165
+ embedding_dim,
166
+ num_attention_heads,
167
+ dropout=dropout,
168
+ max_tokens_per_msa=max_tokens_per_msa,
169
+ )
170
+
171
+ column_self_attention = ColumnSelfAttention(
172
+ embedding_dim,
173
+ num_attention_heads,
174
+ dropout=dropout,
175
+ max_tokens_per_msa=max_tokens_per_msa,
176
+ )
177
+
178
+ feed_forward_layer = FeedForwardNetwork(
179
+ embedding_dim,
180
+ ffn_embedding_dim,
181
+ activation_dropout=activation_dropout,
182
+ max_tokens_per_msa=max_tokens_per_msa,
183
+ )
184
+
185
+ self.row_self_attention = self.build_residual(row_self_attention)
186
+ self.column_self_attention = self.build_residual(column_self_attention)
187
+ self.feed_forward_layer = self.build_residual(feed_forward_layer)
188
+
189
+ def build_residual(self, layer: nn.Module):
190
+ return NormalizedResidualBlock(
191
+ layer,
192
+ self.embedding_dim,
193
+ self.dropout_prob,
194
+ )
195
+
196
+ def forward(
197
+ self,
198
+ x: torch.Tensor,
199
+ self_attn_mask: Optional[torch.Tensor] = None,
200
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
201
+ need_head_weights: bool = False,
202
+ ):
203
+ """
204
+ LayerNorm is applied either before or after the self-attention/ffn
205
+ modules similar to the original Transformer implementation.
206
+ """
207
+ x, row_attn = self.row_self_attention(
208
+ x,
209
+ self_attn_mask=self_attn_mask,
210
+ self_attn_padding_mask=self_attn_padding_mask,
211
+ )
212
+ x, column_attn = self.column_self_attention(
213
+ x,
214
+ self_attn_mask=self_attn_mask,
215
+ self_attn_padding_mask=self_attn_padding_mask,
216
+ )
217
+ x = self.feed_forward_layer(x)
218
+ if need_head_weights:
219
+ return x, column_attn, row_attn
220
+ else:
221
+ return x
222
+
223
+
224
+ class LearnedPositionalEmbedding(nn.Embedding):
225
+ """
226
+ This module learns positional embeddings up to a fixed maximum size.
227
+ Padding ids are ignored by either offsetting based on padding_idx
228
+ or by setting padding_idx to None and ensuring that the appropriate
229
+ position ids are passed to the forward function.
230
+ """
231
+
232
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
233
+ if padding_idx is not None:
234
+ num_embeddings_ = num_embeddings + padding_idx + 1
235
+ else:
236
+ num_embeddings_ = num_embeddings
237
+ super().__init__(num_embeddings_, embedding_dim, padding_idx)
238
+ self.max_positions = num_embeddings
239
+
240
+ def forward(self, input: torch.Tensor):
241
+ """Input is expected to be of size [bsz x seqlen]."""
242
+ if input.size(1) > self.max_positions:
243
+ raise ValueError(
244
+ f"Sequence length {input.size(1)} above maximum "
245
+ f" sequence length of {self.max_positions}"
246
+ )
247
+ mask = input.ne(self.padding_idx).int()
248
+ positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
249
+ return F.embedding(
250
+ positions,
251
+ self.weight,
252
+ self.padding_idx,
253
+ self.max_norm,
254
+ self.norm_type,
255
+ self.scale_grad_by_freq,
256
+ self.sparse,
257
+ )
258
+
259
+
260
+ class SinusoidalPositionalEmbedding(nn.Module):
261
+ def __init__(self, embed_dim, padding_idx, learned=False):
262
+ super().__init__()
263
+ self.embed_dim = embed_dim
264
+ self.padding_idx = padding_idx
265
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
266
+ self.weights = None
267
+
268
+ def forward(self, x):
269
+ bsz, seq_len = x.shape
270
+ max_pos = self.padding_idx + 1 + seq_len
271
+ if self.weights is None or max_pos > self.weights.size(0):
272
+ self.weights = self.get_embedding(max_pos)
273
+ self.weights = self.weights.type_as(self._float_tensor)
274
+
275
+ positions = self.make_positions(x)
276
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
277
+
278
+ def make_positions(self, x):
279
+ mask = x.ne(self.padding_idx)
280
+ range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
281
+ positions = range_buf.expand_as(x)
282
+ return positions * mask.long() + self.padding_idx * (1 - mask.long())
283
+
284
+ def get_embedding(self, num_embeddings):
285
+ half_dim = self.embed_dim // 2
286
+ emb = math.log(10000) / (half_dim - 1)
287
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
288
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
289
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
290
+ if self.embed_dim % 2 == 1:
291
+ # zero pad
292
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
293
+ if self.padding_idx is not None:
294
+ emb[self.padding_idx, :] = 0
295
+ return emb
296
+
297
+
298
+ class RobertaLMHead(nn.Module):
299
+ """Head for masked language modeling."""
300
+
301
+ def __init__(self, embed_dim, output_dim, weight):
302
+ super().__init__()
303
+ self.dense = nn.Linear(embed_dim, embed_dim)
304
+ self.layer_norm = ESM1bLayerNorm(embed_dim)
305
+ self.weight = weight
306
+ self.bias = nn.Parameter(torch.zeros(output_dim))
307
+
308
+ def forward(self, features):
309
+ x = self.dense(features)
310
+ x = gelu(x)
311
+ x = self.layer_norm(x)
312
+ # project back to size of vocabulary with bias
313
+ x = F.linear(x, self.weight) + self.bias
314
+ return x
315
+
316
+
317
+ class ContactPredictionHead(nn.Module):
318
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
319
+
320
+ def __init__(
321
+ self,
322
+ in_features: int,
323
+ prepend_bos: bool,
324
+ append_eos: bool,
325
+ bias=True,
326
+ eos_idx: Optional[int] = None,
327
+ ):
328
+ super().__init__()
329
+ self.in_features = in_features
330
+ self.prepend_bos = prepend_bos
331
+ self.append_eos = append_eos
332
+ if append_eos and eos_idx is None:
333
+ raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
334
+ self.eos_idx = eos_idx
335
+ self.regression = nn.Linear(in_features, 1, bias)
336
+ self.activation = nn.Sigmoid()
337
+
338
+ def forward(self, tokens, attentions):
339
+ # remove eos token attentions
340
+ if self.append_eos:
341
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
342
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
343
+ attentions = attentions * eos_mask[:, None, None, :, :]
344
+ attentions = attentions[..., :-1, :-1]
345
+ # remove cls token attentions
346
+ if self.prepend_bos:
347
+ attentions = attentions[..., 1:, 1:]
348
+ batch_size, layers, heads, seqlen, _ = attentions.size()
349
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
350
+
351
+ # features: B x C x T x T
352
+ attentions = attentions.to(
353
+ self.regression.weight.device
354
+ ) # attentions always float32, may need to convert to float16
355
+ attentions = apc(symmetrize(attentions))
356
+ attentions = attentions.permute(0, 2, 3, 1)
357
+ #print(f'----------{attentions.dtype, attentions.float().dtype}----')
358
+ return attentions.sum(dim=-1), self.activation(self.regression(attentions).squeeze(3))#float().to(self.regression.weight.device)).squeeze(3))
359
+
360
+
361
+ class NormalizedResidualBlock(nn.Module):
362
+ def __init__(
363
+ self,
364
+ layer: nn.Module,
365
+ embedding_dim: int,
366
+ dropout: float = 0.1,
367
+ ):
368
+ super().__init__()
369
+ self.embedding_dim = embedding_dim
370
+
371
+ self.layer = layer
372
+ self.dropout_module = nn.Dropout(
373
+ dropout,
374
+ )
375
+ self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
376
+
377
+ def forward(self, x, *args, **kwargs):
378
+ residual = x
379
+ x = self.layer_norm(x)
380
+ outputs = self.layer(x, *args, **kwargs)
381
+ if isinstance(outputs, tuple):
382
+ x, *out = outputs
383
+ else:
384
+ x = outputs
385
+ out = None
386
+
387
+ x = self.dropout_module(x)
388
+ x = residual + x
389
+
390
+ if out is not None:
391
+ return (x,) + tuple(out)
392
+ else:
393
+ return x
394
+
395
+
396
+ class FeedForwardNetwork(nn.Module):
397
+ def __init__(
398
+ self,
399
+ embedding_dim: int,
400
+ ffn_embedding_dim: int,
401
+ activation_dropout: float = 0.1,
402
+ max_tokens_per_msa: int = 2**14,
403
+ ):
404
+ super().__init__()
405
+ self.embedding_dim = embedding_dim
406
+ self.ffn_embedding_dim = ffn_embedding_dim
407
+ self.max_tokens_per_msa = max_tokens_per_msa
408
+ self.activation_fn = nn.GELU()
409
+ self.activation_dropout_module = nn.Dropout(
410
+ activation_dropout,
411
+ )
412
+ self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
413
+ self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
414
+
415
+ def forward(self, x):
416
+ x = self.activation_fn(self.fc1(x))
417
+ x = self.activation_dropout_module(x)
418
+ x = self.fc2(x)
419
+ return x
esm/multihead_attention.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from torch.nn import Parameter
13
+ from esm.rotary_embedding import RotaryEmbedding
14
+
15
+ import uuid
16
+
17
+
18
+ def utils_softmax(x, dim: int, onnx_trace: bool = False):
19
+ if onnx_trace:
20
+ return F.softmax(x.float(), dim=dim)
21
+ else:
22
+ return F.softmax(x, dim=dim, dtype=torch.float32)
23
+
24
+
25
+ class FairseqIncrementalState(object):
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self.init_incremental_state()
29
+
30
+ def init_incremental_state(self):
31
+ self._incremental_state_id = str(uuid.uuid4())
32
+
33
+ def _get_full_incremental_state_key(self, key: str) -> str:
34
+ return "{}.{}".format(self._incremental_state_id, key)
35
+
36
+ def get_incremental_state(
37
+ self,
38
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
39
+ key: str,
40
+ ) -> Optional[Dict[str, Optional[Tensor]]]:
41
+ """Helper for getting incremental state for an nn.Module."""
42
+ full_key = self._get_full_incremental_state_key(key)
43
+ if incremental_state is None or full_key not in incremental_state:
44
+ return None
45
+ return incremental_state[full_key]
46
+
47
+ def set_incremental_state(
48
+ self,
49
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
50
+ key: str,
51
+ value: Dict[str, Optional[Tensor]],
52
+ ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
53
+ """Helper for setting incremental state for an nn.Module."""
54
+ if incremental_state is not None:
55
+ full_key = self._get_full_incremental_state_key(key)
56
+ incremental_state[full_key] = value
57
+ return incremental_state
58
+
59
+
60
+ def with_incremental_state(cls):
61
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(
62
+ b for b in cls.__bases__ if b != FairseqIncrementalState
63
+ )
64
+ return cls
65
+
66
+
67
+ @with_incremental_state
68
+ class MultiheadAttention(nn.Module):
69
+ """Multi-headed attention.
70
+ See "Attention Is All You Need" for more details.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ embed_dim,
76
+ num_heads,
77
+ kdim=None,
78
+ vdim=None,
79
+ dropout=0.0,
80
+ bias=True,
81
+ add_bias_kv: bool = False,
82
+ add_zero_attn: bool = False,
83
+ self_attention: bool = False,
84
+ encoder_decoder_attention: bool = False,
85
+ use_rotary_embeddings: bool = False,
86
+ ):
87
+ super().__init__()
88
+ self.embed_dim = embed_dim
89
+ self.kdim = kdim if kdim is not None else embed_dim
90
+ self.vdim = vdim if vdim is not None else embed_dim
91
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
92
+
93
+ self.num_heads = num_heads
94
+ self.dropout = dropout
95
+ self.head_dim = embed_dim // num_heads
96
+ assert (
97
+ self.head_dim * num_heads == self.embed_dim
98
+ ), "embed_dim must be divisible by num_heads"
99
+ self.scaling = self.head_dim**-0.5
100
+
101
+ self.self_attention = self_attention
102
+ self.encoder_decoder_attention = encoder_decoder_attention
103
+
104
+ assert not self.self_attention or self.qkv_same_dim, (
105
+ "Self-attention requires query, key and " "value to be of the same size"
106
+ )
107
+
108
+ self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
109
+ self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
110
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
111
+
112
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
113
+
114
+ if add_bias_kv:
115
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
116
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
117
+ else:
118
+ self.bias_k = self.bias_v = None
119
+
120
+ self.add_zero_attn = add_zero_attn
121
+
122
+ self.reset_parameters()
123
+
124
+ self.onnx_trace = False
125
+ self.rot_emb = None
126
+ if use_rotary_embeddings:
127
+ self.rot_emb = RotaryEmbedding(dim=self.head_dim)
128
+
129
+ self.enable_torch_version = False
130
+ if hasattr(F, "multi_head_attention_forward"):
131
+ self.enable_torch_version = True
132
+ else:
133
+ self.enable_torch_version = False
134
+
135
+ def prepare_for_onnx_export_(self):
136
+ self.onnx_trace = True
137
+
138
+ def reset_parameters(self):
139
+ if self.qkv_same_dim:
140
+ # Empirically observed the convergence to be much better with
141
+ # the scaled initialization
142
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
143
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
144
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
145
+ else:
146
+ nn.init.xavier_uniform_(self.k_proj.weight)
147
+ nn.init.xavier_uniform_(self.v_proj.weight)
148
+ nn.init.xavier_uniform_(self.q_proj.weight)
149
+
150
+ nn.init.xavier_uniform_(self.out_proj.weight)
151
+ if self.out_proj.bias is not None:
152
+ nn.init.constant_(self.out_proj.bias, 0.0)
153
+ if self.bias_k is not None:
154
+ nn.init.xavier_normal_(self.bias_k)
155
+ if self.bias_v is not None:
156
+ nn.init.xavier_normal_(self.bias_v)
157
+
158
+ def forward(
159
+ self,
160
+ query,
161
+ key: Optional[Tensor],
162
+ value: Optional[Tensor],
163
+ key_padding_mask: Optional[Tensor] = None,
164
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
165
+ need_weights: bool = True,
166
+ static_kv: bool = False,
167
+ attn_mask: Optional[Tensor] = None,
168
+ before_softmax: bool = False,
169
+ need_head_weights: bool = False,
170
+ ) -> Tuple[Tensor, Optional[Tensor]]:
171
+ """Input shape: Time x Batch x Channel
172
+ Args:
173
+ key_padding_mask (ByteTensor, optional): mask to exclude
174
+ keys that are pads, of shape `(batch, src_len)`, where
175
+ padding elements are indicated by 1s.
176
+ need_weights (bool, optional): return the attention weights,
177
+ averaged over heads (default: False).
178
+ attn_mask (ByteTensor, optional): typically used to
179
+ implement causal attention, where the mask prevents the
180
+ attention from looking forward in time (default: None).
181
+ before_softmax (bool, optional): return the raw attention
182
+ weights and values before the attention softmax.
183
+ need_head_weights (bool, optional): return the attention
184
+ weights for each head. Implies *need_weights*. Default:
185
+ return the average attention weights over all heads.
186
+ """
187
+ if need_head_weights:
188
+ need_weights = True
189
+
190
+ tgt_len, bsz, embed_dim = query.size()
191
+ assert embed_dim == self.embed_dim
192
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
193
+
194
+ if (
195
+ not self.rot_emb
196
+ and self.enable_torch_version
197
+ and not self.onnx_trace
198
+ and incremental_state is None
199
+ and not static_kv
200
+ # A workaround for quantization to work. Otherwise JIT compilation
201
+ # treats bias in linear module as method.
202
+ and not torch.jit.is_scripting()
203
+ and not need_head_weights
204
+ ):
205
+ assert key is not None and value is not None
206
+ return F.multi_head_attention_forward(
207
+ query,
208
+ key,
209
+ value,
210
+ self.embed_dim,
211
+ self.num_heads,
212
+ torch.empty([0]),
213
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
214
+ self.bias_k,
215
+ self.bias_v,
216
+ self.add_zero_attn,
217
+ self.dropout,
218
+ self.out_proj.weight,
219
+ self.out_proj.bias,
220
+ self.training,
221
+ key_padding_mask,
222
+ need_weights,
223
+ attn_mask,
224
+ use_separate_proj_weight=True,
225
+ q_proj_weight=self.q_proj.weight,
226
+ k_proj_weight=self.k_proj.weight,
227
+ v_proj_weight=self.v_proj.weight,
228
+ )
229
+ if incremental_state is not None:
230
+ saved_state = self._get_input_buffer(incremental_state)
231
+ if saved_state is not None and "prev_key" in saved_state:
232
+ # previous time steps are cached - no need to recompute
233
+ # key and value if they are static
234
+ if static_kv:
235
+ assert self.encoder_decoder_attention and not self.self_attention
236
+ key = value = None
237
+ else:
238
+ saved_state = None
239
+
240
+ if self.self_attention:
241
+ q = self.q_proj(query)
242
+ k = self.k_proj(query)
243
+ v = self.v_proj(query)
244
+ elif self.encoder_decoder_attention:
245
+ # encoder-decoder attention
246
+ q = self.q_proj(query)
247
+ if key is None:
248
+ assert value is None
249
+ k = v = None
250
+ else:
251
+ k = self.k_proj(key)
252
+ v = self.v_proj(key)
253
+
254
+ else:
255
+ assert key is not None and value is not None
256
+ q = self.q_proj(query)
257
+ k = self.k_proj(key)
258
+ v = self.v_proj(value)
259
+ q *= self.scaling
260
+
261
+ if self.bias_k is not None:
262
+ assert self.bias_v is not None
263
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
264
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
265
+ if attn_mask is not None:
266
+ attn_mask = torch.cat(
267
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
268
+ )
269
+ if key_padding_mask is not None:
270
+ key_padding_mask = torch.cat(
271
+ [
272
+ key_padding_mask,
273
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
274
+ ],
275
+ dim=1,
276
+ )
277
+
278
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
279
+ if k is not None:
280
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
281
+ if v is not None:
282
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
283
+
284
+ if saved_state is not None:
285
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
286
+ if "prev_key" in saved_state:
287
+ _prev_key = saved_state["prev_key"]
288
+ assert _prev_key is not None
289
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
290
+ if static_kv:
291
+ k = prev_key
292
+ else:
293
+ assert k is not None
294
+ k = torch.cat([prev_key, k], dim=1)
295
+ if "prev_value" in saved_state:
296
+ _prev_value = saved_state["prev_value"]
297
+ assert _prev_value is not None
298
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
299
+ if static_kv:
300
+ v = prev_value
301
+ else:
302
+ assert v is not None
303
+ v = torch.cat([prev_value, v], dim=1)
304
+ prev_key_padding_mask: Optional[Tensor] = None
305
+ if "prev_key_padding_mask" in saved_state:
306
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
307
+ assert k is not None and v is not None
308
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
309
+ key_padding_mask=key_padding_mask,
310
+ prev_key_padding_mask=prev_key_padding_mask,
311
+ batch_size=bsz,
312
+ src_len=k.size(1),
313
+ static_kv=static_kv,
314
+ )
315
+
316
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
317
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
318
+ saved_state["prev_key_padding_mask"] = key_padding_mask
319
+ # In this branch incremental_state is never None
320
+ assert incremental_state is not None
321
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
322
+ assert k is not None
323
+ src_len = k.size(1)
324
+
325
+ # This is part of a workaround to get around fork/join parallelism
326
+ # not supporting Optional types.
327
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
328
+ key_padding_mask = None
329
+
330
+ if key_padding_mask is not None:
331
+ assert key_padding_mask.size(0) == bsz
332
+ assert key_padding_mask.size(1) == src_len
333
+
334
+ if self.add_zero_attn:
335
+ assert v is not None
336
+ src_len += 1
337
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
338
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
339
+ if attn_mask is not None:
340
+ attn_mask = torch.cat(
341
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
342
+ )
343
+ if key_padding_mask is not None:
344
+ key_padding_mask = torch.cat(
345
+ [
346
+ key_padding_mask,
347
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
348
+ ],
349
+ dim=1,
350
+ )
351
+
352
+ if self.rot_emb:
353
+ q, k = self.rot_emb(q, k)
354
+
355
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
356
+ attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
357
+
358
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
359
+
360
+ if attn_mask is not None:
361
+ attn_mask = attn_mask.unsqueeze(0)
362
+ if self.onnx_trace:
363
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
364
+ attn_weights += attn_mask
365
+
366
+ if key_padding_mask is not None:
367
+ # don't attend to padding symbols
368
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
369
+ attn_weights = attn_weights.masked_fill(
370
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
371
+ )
372
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
373
+
374
+ if before_softmax:
375
+ return attn_weights, v
376
+
377
+ attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
378
+ attn_weights = attn_weights_float.type_as(attn_weights)
379
+ attn_probs = F.dropout(
380
+ attn_weights_float.type_as(attn_weights),
381
+ p=self.dropout,
382
+ training=self.training,
383
+ )
384
+ assert v is not None
385
+ attn = torch.bmm(attn_probs, v)
386
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
387
+ if self.onnx_trace and attn.size(1) == 1:
388
+ # when ONNX tracing a single decoder step (sequence length == 1)
389
+ # the transpose is a no-op copy before view, thus unnecessary
390
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
391
+ else:
392
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
393
+ attn = self.out_proj(attn)
394
+ attn_weights: Optional[Tensor] = None
395
+ if need_weights:
396
+ attn_weights = attn_weights_float.view(
397
+ bsz, self.num_heads, tgt_len, src_len
398
+ ).type_as(attn).transpose(1, 0)
399
+ if not need_head_weights:
400
+ # average attention weights over heads
401
+ attn_weights = attn_weights.mean(dim=0)
402
+
403
+ return attn, attn_weights
404
+
405
+ @staticmethod
406
+ def _append_prev_key_padding_mask(
407
+ key_padding_mask: Optional[Tensor],
408
+ prev_key_padding_mask: Optional[Tensor],
409
+ batch_size: int,
410
+ src_len: int,
411
+ static_kv: bool,
412
+ ) -> Optional[Tensor]:
413
+ # saved key padding masks have shape (bsz, seq_len)
414
+ if prev_key_padding_mask is not None and static_kv:
415
+ new_key_padding_mask = prev_key_padding_mask
416
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
417
+ new_key_padding_mask = torch.cat(
418
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
419
+ )
420
+ # During incremental decoding, as the padding token enters and
421
+ # leaves the frame, there will be a time when prev or current
422
+ # is None
423
+ elif prev_key_padding_mask is not None:
424
+ filler = torch.zeros(
425
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
426
+ device=prev_key_padding_mask.device,
427
+ )
428
+ new_key_padding_mask = torch.cat(
429
+ [prev_key_padding_mask.float(), filler.float()], dim=1
430
+ )
431
+ elif key_padding_mask is not None:
432
+ filler = torch.zeros(
433
+ (batch_size, src_len - key_padding_mask.size(1)),
434
+ device=key_padding_mask.device,
435
+ )
436
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
437
+ else:
438
+ new_key_padding_mask = prev_key_padding_mask
439
+ return new_key_padding_mask
440
+
441
+ @torch.jit.export
442
+ def reorder_incremental_state(
443
+ self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
444
+ ):
445
+ """Reorder buffered internal state (for incremental generation)."""
446
+ input_buffer = self._get_input_buffer(incremental_state)
447
+ if input_buffer is not None:
448
+ for k in input_buffer.keys():
449
+ input_buffer_k = input_buffer[k]
450
+ if input_buffer_k is not None:
451
+ if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
452
+ 0
453
+ ):
454
+ break
455
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
456
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
457
+ return incremental_state
458
+
459
+ def _get_input_buffer(
460
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
461
+ ) -> Dict[str, Optional[Tensor]]:
462
+ result = self.get_incremental_state(incremental_state, "attn_state")
463
+ if result is not None:
464
+ return result
465
+ else:
466
+ empty_result: Dict[str, Optional[Tensor]] = {}
467
+ return empty_result
468
+
469
+ def _set_input_buffer(
470
+ self,
471
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
472
+ buffer: Dict[str, Optional[Tensor]],
473
+ ):
474
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
475
+
476
+ def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
477
+ return attn_weights
478
+
479
+ def upgrade_state_dict_named(self, state_dict, name):
480
+ prefix = name + "." if name != "" else ""
481
+ items_to_add = {}
482
+ keys_to_remove = []
483
+ for k in state_dict.keys():
484
+ if k.endswith(prefix + "in_proj_weight"):
485
+ # in_proj_weight used to be q + k + v with same dimensions
486
+ dim = int(state_dict[k].shape[0] / 3)
487
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
488
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
489
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
490
+
491
+ keys_to_remove.append(k)
492
+
493
+ k_bias = prefix + "in_proj_bias"
494
+ if k_bias in state_dict.keys():
495
+ dim = int(state_dict[k].shape[0] / 3)
496
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
497
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
498
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
499
+
500
+ keys_to_remove.append(prefix + "in_proj_bias")
501
+
502
+ for k in keys_to_remove:
503
+ del state_dict[k]
504
+
505
+ for key, value in items_to_add.items():
506
+ state_dict[key] = value
esm/pretrained.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import urllib
8
+ import warnings
9
+ from argparse import Namespace
10
+ from pathlib import Path
11
+
12
+ import torch
13
+
14
+ import esm
15
+ from esm.model.esm2 import ESM2
16
+
17
+
18
+ def _has_regression_weights(model_name):
19
+ """Return whether we expect / require regression weights;
20
+ Right now that is all models except ESM-1v and ESM-IF"""
21
+ return not ("esm1v" in model_name or "esm_if" in model_name)
22
+
23
+
24
+ def load_model_and_alphabet(model_name):
25
+ if model_name.endswith(".pt"): # treat as filepath
26
+ return load_model_and_alphabet_local(model_name)
27
+ else:
28
+ return load_model_and_alphabet_hub(model_name)
29
+
30
+
31
+ def load_hub_workaround(url):
32
+ try:
33
+ data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
34
+ except RuntimeError:
35
+ # Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
36
+ fn = Path(url).name
37
+ data = torch.load(
38
+ f"{torch.hub.get_dir()}/checkpoints/{fn}",
39
+ map_location="cpu",
40
+ )
41
+ except urllib.error.HTTPError as e:
42
+ raise Exception(f"Could not load {url}, check if you specified a correct model name?")
43
+ return data
44
+
45
+
46
+ def load_regression_hub(model_name):
47
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"
48
+ regression_data = load_hub_workaround(url)
49
+ return regression_data
50
+
51
+
52
+ def _download_model_and_regression_data(model_name):
53
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
54
+ model_data = load_hub_workaround(url)
55
+ if _has_regression_weights(model_name):
56
+ regression_data = load_regression_hub(model_name)
57
+ else:
58
+ regression_data = None
59
+ return model_data, regression_data
60
+
61
+
62
+ def load_model_and_alphabet_hub(model_name):
63
+ model_data, regression_data = _download_model_and_regression_data(model_name)
64
+ return load_model_and_alphabet_core(model_name, model_data, regression_data)
65
+
66
+
67
+ def load_model_and_alphabet_local(model_location):
68
+ """Load from local path. The regression weights need to be co-located"""
69
+ model_location = Path(model_location)
70
+ model_data = torch.load(str(model_location), map_location="cpu")
71
+ model_name = model_location.stem
72
+ if _has_regression_weights(model_name):
73
+ regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt"
74
+ regression_data = torch.load(regression_location, map_location="cpu")
75
+ else:
76
+ regression_data = None
77
+ return load_model_and_alphabet_core(model_name, model_data, regression_data)
78
+
79
+
80
+ def has_emb_layer_norm_before(model_state):
81
+ """Determine whether layer norm needs to be applied before the encoder"""
82
+ return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items())
83
+
84
+
85
+ def _load_model_and_alphabet_core_v1(model_data):
86
+ import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here
87
+
88
+ alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
89
+
90
+ if model_data["args"].arch == "roberta_large":
91
+ # upgrade state dict
92
+ pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
93
+ prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
94
+ prs2 = lambda s: "".join(
95
+ s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
96
+ )
97
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
98
+ model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()}
99
+ model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop
100
+ model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state)
101
+ model_type = esm.ProteinBertModel
102
+
103
+ elif model_data["args"].arch == "protein_bert_base":
104
+
105
+ # upgrade state dict
106
+ pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
107
+ prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
108
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
109
+ model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}
110
+ model_type = esm.ProteinBertModel
111
+ elif model_data["args"].arch == "msa_transformer":
112
+
113
+ # upgrade state dict
114
+ pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
115
+ prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
116
+ prs2 = lambda s: "".join(
117
+ s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
118
+ )
119
+ prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
120
+ model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
121
+ model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
122
+ if model_args.get("embed_positions_msa", False):
123
+ emb_dim = model_state["msa_position_embedding"].size(-1)
124
+ model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
125
+
126
+ model_type = esm.MSATransformer
127
+
128
+ elif "invariant_gvp" in model_data["args"].arch:
129
+ import esm.inverse_folding
130
+
131
+ model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel
132
+ model_args = vars(model_data["args"]) # convert Namespace -> dict
133
+
134
+ def update_name(s):
135
+ # Map the module names in checkpoints trained with internal code to
136
+ # the updated module names in open source code
137
+ s = s.replace("W_v", "embed_graph.embed_node")
138
+ s = s.replace("W_e", "embed_graph.embed_edge")
139
+ s = s.replace("embed_scores.0", "embed_confidence")
140
+ s = s.replace("embed_score.", "embed_graph.embed_confidence.")
141
+ s = s.replace("seq_logits_projection.", "")
142
+ s = s.replace("embed_ingraham_features", "embed_dihedrals")
143
+ s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output")
144
+ s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features")
145
+ return s
146
+
147
+ model_state = {
148
+ update_name(sname): svalue
149
+ for sname, svalue in model_data["model"].items()
150
+ if "version" not in sname
151
+ }
152
+
153
+ else:
154
+ raise ValueError("Unknown architecture selected")
155
+
156
+ model = model_type(
157
+ Namespace(**model_args),
158
+ alphabet,
159
+ )
160
+
161
+ return model, alphabet, model_state
162
+
163
+
164
+ def _load_model_and_alphabet_core_v2(model_data):
165
+ def upgrade_state_dict(state_dict):
166
+ """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
167
+ prefixes = ["encoder.sentence_encoder.", "encoder."]
168
+ pattern = re.compile("^" + "|".join(prefixes))
169
+ state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
170
+ return state_dict
171
+
172
+ cfg = model_data["cfg"]["model"]
173
+ state_dict = model_data["model"]
174
+ state_dict = upgrade_state_dict(state_dict)
175
+ alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
176
+ model = ESM2(
177
+ num_layers=cfg.encoder_layers,
178
+ embed_dim=cfg.encoder_embed_dim,
179
+ attention_heads=cfg.encoder_attention_heads,
180
+ alphabet=alphabet,
181
+ token_dropout=cfg.token_dropout,
182
+ )
183
+ return model, alphabet, state_dict
184
+
185
+
186
+ def load_model_and_alphabet_core(model_name, model_data, regression_data=None):
187
+ if regression_data is not None:
188
+ model_data["model"].update(regression_data["model"])
189
+
190
+ if model_name.startswith("esm2"):
191
+ model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data)
192
+ else:
193
+ model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data)
194
+
195
+ expected_keys = set(model.state_dict().keys())
196
+ found_keys = set(model_state.keys())
197
+
198
+ if regression_data is None:
199
+ expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"}
200
+ error_msgs = []
201
+ missing = (expected_keys - found_keys) - expected_missing
202
+ if missing:
203
+ error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
204
+ unexpected = found_keys - expected_keys
205
+ if unexpected:
206
+ error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.")
207
+
208
+ if error_msgs:
209
+ raise RuntimeError(
210
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
211
+ model.__class__.__name__, "\n\t".join(error_msgs)
212
+ )
213
+ )
214
+ if expected_missing - found_keys:
215
+ warnings.warn(
216
+ "Regression weights not found, predicting contacts will not produce correct results."
217
+ )
218
+
219
+ model.load_state_dict(model_state, strict=regression_data is not None)
220
+
221
+ return model, alphabet
222
+
223
+
224
+ def esm1_t34_670M_UR50S():
225
+ """34 layer transformer model with 670M params, trained on Uniref50 Sparse.
226
+ Returns a tuple of (Model, Alphabet).
227
+ """
228
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR50S")
229
+
230
+
231
+ def esm1_t34_670M_UR50D():
232
+ """34 layer transformer model with 670M params, trained on Uniref50 Dense.
233
+ Returns a tuple of (Model, Alphabet).
234
+ """
235
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR50D")
236
+
237
+
238
+ def esm1_t34_670M_UR100():
239
+ """34 layer transformer model with 670M params, trained on Uniref100.
240
+ Returns a tuple of (Model, Alphabet).
241
+ """
242
+ return load_model_and_alphabet_hub("esm1_t34_670M_UR100")
243
+
244
+
245
+ def esm1_t12_85M_UR50S():
246
+ """12 layer transformer model with 85M params, trained on Uniref50 Sparse.
247
+ Returns a tuple of (Model, Alphabet).
248
+ """
249
+ return load_model_and_alphabet_hub("esm1_t12_85M_UR50S")
250
+
251
+
252
+ def esm1_t6_43M_UR50S():
253
+ """6 layer transformer model with 43M params, trained on Uniref50 Sparse.
254
+ Returns a tuple of (Model, Alphabet).
255
+ """
256
+ return load_model_and_alphabet_hub("esm1_t6_43M_UR50S")
257
+
258
+
259
+ def esm1b_t33_650M_UR50S():
260
+ """33 layer transformer model with 650M params, trained on Uniref50 Sparse.
261
+ This is our best performing model, which will be described in a future publication.
262
+ Returns a tuple of (Model, Alphabet).
263
+ """
264
+ return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S")
265
+
266
+
267
+ def esm_msa1_t12_100M_UR50S():
268
+ warnings.warn(
269
+ "This model had a minor bug in the positional embeddings, "
270
+ "please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()",
271
+ )
272
+ return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S")
273
+
274
+
275
+ def esm_msa1b_t12_100M_UR50S():
276
+ return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S")
277
+
278
+
279
+ def esm1v_t33_650M_UR90S():
280
+ """33 layer transformer model with 650M params, trained on Uniref90.
281
+ This is model 1 of a 5 model ensemble.
282
+ Returns a tuple of (Model, Alphabet).
283
+ """
284
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
285
+
286
+
287
+ def esm1v_t33_650M_UR90S_1():
288
+ """33 layer transformer model with 650M params, trained on Uniref90.
289
+ This is model 1 of a 5 model ensemble.
290
+ Returns a tuple of (Model, Alphabet).
291
+ """
292
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1")
293
+
294
+
295
+ def esm1v_t33_650M_UR90S_2():
296
+ """33 layer transformer model with 650M params, trained on Uniref90.
297
+ This is model 2 of a 5 model ensemble.
298
+ Returns a tuple of (Model, Alphabet).
299
+ """
300
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2")
301
+
302
+
303
+ def esm1v_t33_650M_UR90S_3():
304
+ """33 layer transformer model with 650M params, trained on Uniref90.
305
+ This is model 3 of a 5 model ensemble.
306
+ Returns a tuple of (Model, Alphabet).
307
+ """
308
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3")
309
+
310
+
311
+ def esm1v_t33_650M_UR90S_4():
312
+ """33 layer transformer model with 650M params, trained on Uniref90.
313
+ This is model 4 of a 5 model ensemble.
314
+ Returns a tuple of (Model, Alphabet).
315
+ """
316
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4")
317
+
318
+
319
+ def esm1v_t33_650M_UR90S_5():
320
+ """33 layer transformer model with 650M params, trained on Uniref90.
321
+ This is model 5 of a 5 model ensemble.
322
+ Returns a tuple of (Model, Alphabet).
323
+ """
324
+ return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5")
325
+
326
+
327
+ def esm_if1_gvp4_t16_142M_UR50():
328
+ """Inverse folding model with 142M params, with 4 GVP-GNN layers, 8
329
+ Transformer encoder layers, and 8 Transformer decoder layers, trained on
330
+ CATH structures and 12 million alphafold2 predicted structures from UniRef50
331
+ sequences.
332
+ Returns a tuple of (Model, Alphabet).
333
+ """
334
+ return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50")
335
+
336
+
337
+ def esm2_t6_8M_UR50D():
338
+ """6 layer ESM-2 model with 8M params, trained on UniRef50.
339
+ Returns a tuple of (Model, Alphabet).
340
+ """
341
+ return load_model_and_alphabet_hub("esm2_t6_8M_UR50D")
342
+
343
+
344
+ def esm2_t12_35M_UR50D():
345
+ """12 layer ESM-2 model with 35M params, trained on UniRef50.
346
+ Returns a tuple of (Model, Alphabet).
347
+ """
348
+ return load_model_and_alphabet_hub("esm2_t12_35M_UR50D")
349
+
350
+
351
+ def esm2_t30_150M_UR50D():
352
+ """30 layer ESM-2 model with 150M params, trained on UniRef50.
353
+ Returns a tuple of (Model, Alphabet).
354
+ """
355
+ return load_model_and_alphabet_hub("esm2_t30_150M_UR50D")
356
+
357
+
358
+ def esm2_t33_650M_UR50D():
359
+ """33 layer ESM-2 model with 650M params, trained on UniRef50.
360
+ Returns a tuple of (Model, Alphabet).
361
+ """
362
+ return load_model_and_alphabet_hub("esm2_t33_650M_UR50D")
363
+
364
+
365
+ def esm2_t36_3B_UR50D():
366
+ """36 layer ESM-2 model with 3B params, trained on UniRef50.
367
+ Returns a tuple of (Model, Alphabet).
368
+ """
369
+ return load_model_and_alphabet_hub("esm2_t36_3B_UR50D")
370
+
371
+
372
+ def esm2_t48_15B_UR50D():
373
+ """48 layer ESM-2 model with 15B params, trained on UniRef50.
374
+ If you have OOM while loading this model, please refer to README
375
+ on how to employ FSDP and ZeRO CPU offloading
376
+ Returns a tuple of (Model, Alphabet).
377
+ """
378
+ return load_model_and_alphabet_hub("esm2_t48_15B_UR50D")
esm/rotary_embedding.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+
10
+
11
+ def rotate_half(x):
12
+ x1, x2 = x.chunk(2, dim=-1)
13
+ return torch.cat((-x2, x1), dim=-1)
14
+
15
+
16
+ def apply_rotary_pos_emb(x, cos, sin):
17
+ cos = cos[:, : x.shape[-2], :]
18
+ sin = sin[:, : x.shape[-2], :]
19
+
20
+ return (x * cos) + (rotate_half(x) * sin)
21
+
22
+
23
+ class RotaryEmbedding(torch.nn.Module):
24
+ """
25
+ The rotary position embeddings from RoFormer_ (Su et. al).
26
+ A crucial insight from the method is that the query and keys are
27
+ transformed by rotation matrices which depend on the relative positions.
28
+ Other implementations are available in the Rotary Transformer repo_ and in
29
+ GPT-NeoX_, GPT-NeoX was an inspiration
30
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
31
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
32
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
33
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
34
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
35
+ """
36
+
37
+ def __init__(self, dim: int, *_, **__):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+
43
+ self._seq_len_cached = None
44
+ self._cos_cached = None
45
+ self._sin_cached = None
46
+
47
+ def _update_cos_sin_tables(self, x, seq_dimension=1):
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
53
+ self._seq_len_cached = seq_len
54
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
55
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
56
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
57
+
58
+ self._cos_cached = emb.cos()[None, :, :]
59
+ self._sin_cached = emb.sin()[None, :, :]
60
+
61
+ return self._cos_cached, self._sin_cached
62
+
63
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
64
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
65
+
66
+ return (
67
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
68
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
69
+ )
esm/version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ version = "1.0.2"
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37cd3362c097a7b9283088d705b1f820dd3628597d9f1cf17d21921e956b117a
3
+ size 4911833
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ biopython==1.81
2
+ torch==2.0.1
3
+ numpy
4
+ pandas