KeXing
commited on
Commit
·
212111c
1
Parent(s):
5f1e767
Upload 26 files
Browse files- tape/__init__.py +21 -0
- tape/datasets.py +926 -0
- tape/errors.py +3 -0
- tape/main.py +269 -0
- tape/metrics.py +48 -0
- tape/models/__init__.py +28 -0
- tape/models/file_utils.py +353 -0
- tape/models/modeling_autoencoder.py +316 -0
- tape/models/modeling_bert.py +606 -0
- tape/models/modeling_bottleneck.py +150 -0
- tape/models/modeling_lstm.py +335 -0
- tape/models/modeling_onehot.py +155 -0
- tape/models/modeling_resnet.py +389 -0
- tape/models/modeling_trrosetta.py +336 -0
- tape/models/modeling_unirep.py +270 -0
- tape/models/modeling_utils.py +887 -0
- tape/optimization.py +209 -0
- tape/registry.py +254 -0
- tape/tokenizers.py +174 -0
- tape/training.py +659 -0
- tape/utils/__init__.py +23 -0
- tape/utils/_sampler.py +98 -0
- tape/utils/distributed_utils.py +170 -0
- tape/utils/setup_utils.py +135 -0
- tape/utils/utils.py +327 -0
- tape/visualization.py +124 -0
tape/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import datasets # noqa: F401
|
2 |
+
from . import metrics # noqa: F401
|
3 |
+
from .tokenizers import TAPETokenizer # noqa: F401
|
4 |
+
from .models.modeling_utils import ProteinModel
|
5 |
+
from .models.modeling_utils import ProteinConfig
|
6 |
+
|
7 |
+
import sys
|
8 |
+
from pathlib import Path
|
9 |
+
import importlib
|
10 |
+
import pkgutil
|
11 |
+
|
12 |
+
__version__ = '0.4'
|
13 |
+
|
14 |
+
|
15 |
+
# Import all the models and configs
|
16 |
+
for _, name, _ in pkgutil.iter_modules([str(Path(__file__).parent / 'models')]):
|
17 |
+
imported_module = importlib.import_module('.models.' + name, package=__name__)
|
18 |
+
for name, cls in imported_module.__dict__.items():
|
19 |
+
if isinstance(cls, type) and \
|
20 |
+
(issubclass(cls, ProteinModel) or issubclass(cls, ProteinConfig)):
|
21 |
+
setattr(sys.modules[__name__], name, cls)
|
tape/datasets.py
ADDED
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
|
2 |
+
from copy import copy
|
3 |
+
from pathlib import Path
|
4 |
+
import pickle as pkl
|
5 |
+
import logging
|
6 |
+
import random
|
7 |
+
|
8 |
+
import lmdb
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from scipy.spatial.distance import pdist, squareform
|
14 |
+
|
15 |
+
from .tokenizers import TAPETokenizer
|
16 |
+
from .registry import registry
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset:
|
22 |
+
data_file = Path(data_file)
|
23 |
+
if not data_file.exists():
|
24 |
+
raise FileNotFoundError(data_file)
|
25 |
+
if data_file.suffix == '.lmdb':
|
26 |
+
return LMDBDataset(data_file, *args, **kwargs)
|
27 |
+
elif data_file.suffix in {'.fasta', '.fna', '.ffn', '.faa', '.frn'}:
|
28 |
+
return FastaDataset(data_file, *args, **kwargs)
|
29 |
+
elif data_file.suffix == '.json':
|
30 |
+
return JSONDataset(data_file, *args, **kwargs)
|
31 |
+
elif data_file.is_dir():
|
32 |
+
return NPZDataset(data_file, *args, **kwargs)
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Unrecognized datafile type {data_file.suffix}")
|
35 |
+
|
36 |
+
|
37 |
+
def pad_sequences(sequences: Sequence, constant_value=0, dtype=None) -> np.ndarray:
|
38 |
+
batch_size = len(sequences)
|
39 |
+
shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
|
40 |
+
|
41 |
+
if dtype is None:
|
42 |
+
dtype = sequences[0].dtype
|
43 |
+
|
44 |
+
if isinstance(sequences[0], np.ndarray):
|
45 |
+
array = np.full(shape, constant_value, dtype=dtype)
|
46 |
+
elif isinstance(sequences[0], torch.Tensor):
|
47 |
+
array = torch.full(shape, constant_value, dtype=dtype)
|
48 |
+
|
49 |
+
for arr, seq in zip(array, sequences):
|
50 |
+
arrslice = tuple(slice(dim) for dim in seq.shape)
|
51 |
+
arr[arrslice] = seq
|
52 |
+
|
53 |
+
return array
|
54 |
+
|
55 |
+
|
56 |
+
class FastaDataset(Dataset):
|
57 |
+
"""Creates a dataset from a fasta file.
|
58 |
+
Args:
|
59 |
+
data_file (Union[str, Path]): Path to fasta file.
|
60 |
+
in_memory (bool, optional): Whether to load the full dataset into memory.
|
61 |
+
Default: False.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
data_file: Union[str, Path],
|
66 |
+
in_memory: bool = False):
|
67 |
+
|
68 |
+
from Bio import SeqIO
|
69 |
+
data_file = Path(data_file)
|
70 |
+
if not data_file.exists():
|
71 |
+
raise FileNotFoundError(data_file)
|
72 |
+
|
73 |
+
# if in_memory:
|
74 |
+
cache = list(SeqIO.parse(str(data_file), 'fasta'))
|
75 |
+
num_examples = len(cache)
|
76 |
+
self._cache = cache
|
77 |
+
# else:
|
78 |
+
# records = SeqIO.index(str(data_file), 'fasta')
|
79 |
+
# num_examples = len(records)
|
80 |
+
#
|
81 |
+
# if num_examples < 10000:
|
82 |
+
# logger.info("Reading full fasta file into memory because number of examples "
|
83 |
+
# "is very low. This loads data approximately 20x faster.")
|
84 |
+
# in_memory = True
|
85 |
+
# cache = list(records.values())
|
86 |
+
# self._cache = cache
|
87 |
+
# else:
|
88 |
+
# self._records = records
|
89 |
+
# self._keys = list(records.keys())
|
90 |
+
|
91 |
+
self._in_memory = in_memory
|
92 |
+
self._num_examples = num_examples
|
93 |
+
|
94 |
+
def __len__(self) -> int:
|
95 |
+
return self._num_examples
|
96 |
+
|
97 |
+
def __getitem__(self, index: int):
|
98 |
+
if not 0 <= index < self._num_examples:
|
99 |
+
raise IndexError(index)
|
100 |
+
|
101 |
+
# if self._in_memory and self._cache[index] is not None:
|
102 |
+
record = self._cache[index]
|
103 |
+
# else:
|
104 |
+
# key = self._keys[index]
|
105 |
+
# record = self._records[key]
|
106 |
+
# if self._in_memory:
|
107 |
+
# self._cache[index] = record
|
108 |
+
|
109 |
+
item = {'id': record.id,
|
110 |
+
'primary': str(record.seq),
|
111 |
+
'protein_length': len(record.seq)}
|
112 |
+
return item
|
113 |
+
|
114 |
+
|
115 |
+
class LMDBDataset(Dataset):
|
116 |
+
"""Creates a dataset from an lmdb file.
|
117 |
+
Args:
|
118 |
+
data_file (Union[str, Path]): Path to lmdb file.
|
119 |
+
in_memory (bool, optional): Whether to load the full dataset into memory.
|
120 |
+
Default: False.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self,
|
124 |
+
data_file: Union[str, Path],
|
125 |
+
in_memory: bool = False):
|
126 |
+
|
127 |
+
data_file = Path(data_file)
|
128 |
+
if not data_file.exists():
|
129 |
+
raise FileNotFoundError(data_file)
|
130 |
+
|
131 |
+
env = lmdb.open(str(data_file), max_readers=1, readonly=True,
|
132 |
+
lock=False, readahead=False, meminit=False)
|
133 |
+
|
134 |
+
with env.begin(write=False) as txn:
|
135 |
+
num_examples = pkl.loads(txn.get(b'num_examples'))
|
136 |
+
|
137 |
+
if in_memory:
|
138 |
+
cache = [None] * num_examples
|
139 |
+
self._cache = cache
|
140 |
+
|
141 |
+
self._env = env
|
142 |
+
self._in_memory = in_memory
|
143 |
+
self._num_examples = num_examples
|
144 |
+
|
145 |
+
def __len__(self) -> int:
|
146 |
+
return self._num_examples
|
147 |
+
|
148 |
+
def __getitem__(self, index: int):
|
149 |
+
if not 0 <= index < self._num_examples:
|
150 |
+
raise IndexError(index)
|
151 |
+
|
152 |
+
if self._in_memory and self._cache[index] is not None:
|
153 |
+
item = self._cache[index]
|
154 |
+
else:
|
155 |
+
with self._env.begin(write=False) as txn:
|
156 |
+
item = pkl.loads(txn.get(str(index).encode()))
|
157 |
+
if 'id' not in item:
|
158 |
+
item['id'] = str(index)
|
159 |
+
if self._in_memory:
|
160 |
+
self._cache[index] = item
|
161 |
+
return item
|
162 |
+
|
163 |
+
|
164 |
+
class JSONDataset(Dataset):
|
165 |
+
"""Creates a dataset from a json file. Assumes that data is
|
166 |
+
a JSON serialized list of record, where each record is
|
167 |
+
a dictionary.
|
168 |
+
Args:
|
169 |
+
data_file (Union[str, Path]): Path to json file.
|
170 |
+
in_memory (bool): Dummy variable to match API of other datasets
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, data_file: Union[str, Path], in_memory: bool = True):
|
174 |
+
import json
|
175 |
+
data_file = Path(data_file)
|
176 |
+
if not data_file.exists():
|
177 |
+
raise FileNotFoundError(data_file)
|
178 |
+
records = json.loads(data_file.read_text())
|
179 |
+
|
180 |
+
if not isinstance(records, list):
|
181 |
+
raise TypeError(f"TAPE JSONDataset requires a json serialized list, "
|
182 |
+
f"received {type(records)}")
|
183 |
+
self._records = records
|
184 |
+
self._num_examples = len(records)
|
185 |
+
|
186 |
+
def __len__(self) -> int:
|
187 |
+
return self._num_examples
|
188 |
+
|
189 |
+
def __getitem__(self, index: int):
|
190 |
+
if not 0 <= index < self._num_examples:
|
191 |
+
raise IndexError(index)
|
192 |
+
|
193 |
+
item = self._records[index]
|
194 |
+
if not isinstance(item, dict):
|
195 |
+
raise TypeError(f"Expected dataset to contain a list of dictionary "
|
196 |
+
f"records, received record of type {type(item)}")
|
197 |
+
if 'id' not in item:
|
198 |
+
item['id'] = str(index)
|
199 |
+
return item
|
200 |
+
|
201 |
+
|
202 |
+
class NPZDataset(Dataset):
|
203 |
+
"""Creates a dataset from a directory of npz files.
|
204 |
+
Args:
|
205 |
+
data_file (Union[str, Path]): Path to directory of npz files
|
206 |
+
in_memory (bool): Dummy variable to match API of other datasets
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
data_file: Union[str, Path],
|
211 |
+
in_memory: bool = True,
|
212 |
+
split_files: Optional[Collection[str]] = None):
|
213 |
+
data_file = Path(data_file)
|
214 |
+
if not data_file.exists():
|
215 |
+
raise FileNotFoundError(data_file)
|
216 |
+
if not data_file.is_dir():
|
217 |
+
raise NotADirectoryError(data_file)
|
218 |
+
file_glob = data_file.glob('*.npz')
|
219 |
+
if split_files is None:
|
220 |
+
file_list = list(file_glob)
|
221 |
+
else:
|
222 |
+
split_files = set(split_files)
|
223 |
+
if len(split_files) == 0:
|
224 |
+
raise ValueError("Passed an empty split file set")
|
225 |
+
|
226 |
+
file_list = [f for f in file_glob if f.name in split_files]
|
227 |
+
if len(file_list) != len(split_files):
|
228 |
+
num_missing = len(split_files) - len(file_list)
|
229 |
+
raise FileNotFoundError(
|
230 |
+
f"{num_missing} specified split files not found in directory")
|
231 |
+
|
232 |
+
if len(file_list) == 0:
|
233 |
+
raise FileNotFoundError(f"No .npz files found in {data_file}")
|
234 |
+
|
235 |
+
self._file_list = file_list
|
236 |
+
|
237 |
+
def __len__(self) -> int:
|
238 |
+
return len(self._file_list)
|
239 |
+
|
240 |
+
def __getitem__(self, index: int):
|
241 |
+
if not 0 <= index < len(self):
|
242 |
+
raise IndexError(index)
|
243 |
+
|
244 |
+
item = dict(np.load(self._file_list[index]))
|
245 |
+
if not isinstance(item, dict):
|
246 |
+
raise TypeError(f"Expected dataset to contain a list of dictionary "
|
247 |
+
f"records, received record of type {type(item)}")
|
248 |
+
if 'id' not in item:
|
249 |
+
item['id'] = self._file_list[index].stem
|
250 |
+
return item
|
251 |
+
|
252 |
+
|
253 |
+
@registry.register_task('embed')
|
254 |
+
class EmbedDataset(Dataset):
|
255 |
+
|
256 |
+
def __init__(self,
|
257 |
+
data_file: Union[str, Path],
|
258 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
259 |
+
in_memory: bool = False,
|
260 |
+
convert_tokens_to_ids: bool = True):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
if isinstance(tokenizer, str):
|
264 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
265 |
+
self.tokenizer = tokenizer
|
266 |
+
self.data = dataset_factory(data_file)
|
267 |
+
|
268 |
+
def __len__(self) -> int:
|
269 |
+
return len(self.data)
|
270 |
+
|
271 |
+
def __getitem__(self, index: int):
|
272 |
+
item = self.data[index]
|
273 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
274 |
+
input_mask = np.ones_like(token_ids)
|
275 |
+
return item['id'], token_ids, input_mask
|
276 |
+
|
277 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
278 |
+
ids, tokens, input_mask = zip(*batch)
|
279 |
+
ids = list(ids)
|
280 |
+
tokens = torch.from_numpy(pad_sequences(tokens))
|
281 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask))
|
282 |
+
return {'ids': ids, 'input_ids': tokens, 'input_mask': input_mask} # type: ignore
|
283 |
+
|
284 |
+
|
285 |
+
@registry.register_task('masked_language_modeling')
|
286 |
+
class MaskedLanguageModelingDataset(Dataset):
|
287 |
+
"""Creates the Masked Language Modeling Pfam Dataset
|
288 |
+
Args:
|
289 |
+
data_path (Union[str, Path]): Path to tape data root.
|
290 |
+
split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load.
|
291 |
+
in_memory (bool, optional): Whether to load the full dataset into memory.
|
292 |
+
Default: False.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self,
|
296 |
+
data_path: Union[str, Path],
|
297 |
+
split: str,
|
298 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
299 |
+
in_memory: bool = False):
|
300 |
+
super().__init__()
|
301 |
+
if split not in ('train', 'valid', 'holdout'):
|
302 |
+
raise ValueError(
|
303 |
+
f"Unrecognized split: {split}. "
|
304 |
+
f"Must be one of ['train', 'valid', 'holdout']")
|
305 |
+
if isinstance(tokenizer, str):
|
306 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
307 |
+
self.tokenizer = tokenizer
|
308 |
+
|
309 |
+
data_path = Path(data_path)
|
310 |
+
data_file = f'pfam/pfam_{split}.lmdb'
|
311 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
312 |
+
|
313 |
+
def __len__(self) -> int:
|
314 |
+
return len(self.data)
|
315 |
+
|
316 |
+
def __getitem__(self, index):
|
317 |
+
item = self.data[index]
|
318 |
+
tokens = self.tokenizer.tokenize(item['primary'])
|
319 |
+
tokens = self.tokenizer.add_special_tokens(tokens)
|
320 |
+
masked_tokens, labels = self._apply_bert_mask(tokens)
|
321 |
+
masked_token_ids = np.array(
|
322 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
323 |
+
input_mask = np.ones_like(masked_token_ids)
|
324 |
+
|
325 |
+
masked_token_ids = np.array(
|
326 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
327 |
+
|
328 |
+
return masked_token_ids, input_mask, labels, item['clan'], item['family']
|
329 |
+
|
330 |
+
def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
|
331 |
+
input_ids, input_mask, lm_label_ids, clan, family = tuple(zip(*batch))
|
332 |
+
|
333 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
334 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
335 |
+
# ignore_index is -1
|
336 |
+
lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
|
337 |
+
clan = torch.LongTensor(clan) # type: ignore
|
338 |
+
family = torch.LongTensor(family) # type: ignore
|
339 |
+
|
340 |
+
return {'input_ids': input_ids,
|
341 |
+
'input_mask': input_mask,
|
342 |
+
'targets': lm_label_ids}
|
343 |
+
|
344 |
+
def _apply_bert_mask(self, tokens: List[str]) -> Tuple[List[str], List[int]]:
|
345 |
+
masked_tokens = copy(tokens)
|
346 |
+
labels = np.zeros([len(tokens)], np.int64) - 1
|
347 |
+
|
348 |
+
for i, token in enumerate(tokens):
|
349 |
+
# Tokens begin and end with start_token and stop_token, ignore these
|
350 |
+
if token in (self.tokenizer.start_token, self.tokenizer.stop_token):
|
351 |
+
pass
|
352 |
+
|
353 |
+
prob = random.random()
|
354 |
+
if prob < 0.15:
|
355 |
+
prob /= 0.15
|
356 |
+
labels[i] = self.tokenizer.convert_token_to_id(token)
|
357 |
+
|
358 |
+
if prob < 0.8:
|
359 |
+
# 80% random change to mask token
|
360 |
+
token = self.tokenizer.mask_token
|
361 |
+
elif prob < 0.9:
|
362 |
+
# 10% chance to change to random token
|
363 |
+
token = self.tokenizer.convert_id_to_token(
|
364 |
+
random.randint(0, self.tokenizer.vocab_size - 1))
|
365 |
+
else:
|
366 |
+
# 10% chance to keep current token
|
367 |
+
pass
|
368 |
+
|
369 |
+
masked_tokens[i] = token
|
370 |
+
|
371 |
+
return masked_tokens, labels
|
372 |
+
|
373 |
+
|
374 |
+
@registry.register_task('beta_lactamase')
|
375 |
+
class BetaModelingDataset(MaskedLanguageModelingDataset):
|
376 |
+
|
377 |
+
def __init__(self,
|
378 |
+
data_path: Union[str, Path],
|
379 |
+
split: str,
|
380 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
381 |
+
in_memory: bool = False):
|
382 |
+
super().__init__(data_path, split, tokenizer, in_memory)
|
383 |
+
data_path = Path(data_path)
|
384 |
+
data_file = f'unilanguage/{split}_combined.fasta'
|
385 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
386 |
+
|
387 |
+
def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
|
388 |
+
input_ids, input_mask, lm_label_ids = tuple(zip(*batch))
|
389 |
+
|
390 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
391 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
392 |
+
# ignore_index is -1
|
393 |
+
lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
|
394 |
+
|
395 |
+
return {'input_ids': input_ids,
|
396 |
+
'input_mask': input_mask,
|
397 |
+
'targets': lm_label_ids}
|
398 |
+
|
399 |
+
def __getitem__(self, index):
|
400 |
+
item = self.data[index]
|
401 |
+
tokens = self.tokenizer.tokenize(item['primary'])
|
402 |
+
tokens = self.tokenizer.add_special_tokens(tokens)
|
403 |
+
masked_tokens, labels = self._apply_bert_mask(tokens)
|
404 |
+
masked_token_ids = np.array(
|
405 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
406 |
+
input_mask = np.ones_like(masked_token_ids)
|
407 |
+
|
408 |
+
masked_token_ids = np.array(
|
409 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
410 |
+
|
411 |
+
return masked_token_ids, input_mask, labels
|
412 |
+
|
413 |
+
|
414 |
+
@registry.register_task('unilanguage')
|
415 |
+
class UniModelingDataset(MaskedLanguageModelingDataset):
|
416 |
+
|
417 |
+
def __init__(self,
|
418 |
+
data_path: Union[str, Path],
|
419 |
+
split: str,
|
420 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
421 |
+
in_memory: bool = False):
|
422 |
+
super().__init__(data_path, split, tokenizer, in_memory)
|
423 |
+
data_path = Path(data_path)
|
424 |
+
data_file = f'unilanguage/PF00144_full_length_sequences_labeled.fasta'
|
425 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
426 |
+
|
427 |
+
def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
|
428 |
+
input_ids, input_mask, lm_label_ids = tuple(zip(*batch))
|
429 |
+
|
430 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
431 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
432 |
+
# ignore_index is -1
|
433 |
+
lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
|
434 |
+
|
435 |
+
return {'input_ids': input_ids,
|
436 |
+
'input_mask': input_mask,
|
437 |
+
'targets': lm_label_ids}
|
438 |
+
|
439 |
+
def __getitem__(self, index):
|
440 |
+
item = self.data[index]
|
441 |
+
tokens = self.tokenizer.tokenize(item['primary'])
|
442 |
+
tokens = self.tokenizer.add_special_tokens(tokens)
|
443 |
+
masked_tokens, labels = self._apply_bert_mask(tokens)
|
444 |
+
masked_token_ids = np.array(
|
445 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
446 |
+
input_mask = np.ones_like(masked_token_ids)
|
447 |
+
|
448 |
+
masked_token_ids = np.array(
|
449 |
+
self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
|
450 |
+
|
451 |
+
return masked_token_ids, input_mask, labels
|
452 |
+
|
453 |
+
|
454 |
+
@registry.register_task('language_modeling')
|
455 |
+
class LanguageModelingDataset(Dataset):
|
456 |
+
"""Creates the Language Modeling Pfam Dataset
|
457 |
+
Args:
|
458 |
+
data_path (Union[str, Path]): Path to tape data root.
|
459 |
+
split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load.
|
460 |
+
in_memory (bool, optional): Whether to load the full dataset into memory.
|
461 |
+
Default: False.
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(self,
|
465 |
+
data_path: Union[str, Path],
|
466 |
+
split: str,
|
467 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
468 |
+
in_memory: bool = False):
|
469 |
+
super().__init__()
|
470 |
+
if split not in ('train', 'valid', 'holdout'):
|
471 |
+
raise ValueError(
|
472 |
+
f"Unrecognized split: {split}. "
|
473 |
+
f"Must be one of ['train', 'valid', 'holdout']")
|
474 |
+
if isinstance(tokenizer, str):
|
475 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
476 |
+
self.tokenizer = tokenizer
|
477 |
+
|
478 |
+
data_path = Path(data_path)
|
479 |
+
data_file = f'pfam/pfam_{split}.lmdb'
|
480 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
481 |
+
|
482 |
+
def __len__(self) -> int:
|
483 |
+
return len(self.data)
|
484 |
+
|
485 |
+
def __getitem__(self, index):
|
486 |
+
item = self.data[index]
|
487 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
488 |
+
input_mask = np.ones_like(token_ids)
|
489 |
+
|
490 |
+
return token_ids, input_mask, item['clan'], item['family']
|
491 |
+
|
492 |
+
def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
|
493 |
+
input_ids, input_mask, clan, family = tuple(zip(*batch))
|
494 |
+
|
495 |
+
torch_inputs = torch.from_numpy(pad_sequences(input_ids, 0))
|
496 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
497 |
+
# ignore_index is -1
|
498 |
+
torch_labels = torch.from_numpy(pad_sequences(input_ids, -1))
|
499 |
+
clan = torch.LongTensor(clan) # type: ignore
|
500 |
+
family = torch.LongTensor(family) # type: ignore
|
501 |
+
|
502 |
+
return {'input_ids': torch_inputs,
|
503 |
+
'input_mask': input_mask,
|
504 |
+
'targets': torch_labels}
|
505 |
+
|
506 |
+
|
507 |
+
@registry.register_task('fluorescence')
|
508 |
+
class FluorescenceDataset(Dataset):
|
509 |
+
|
510 |
+
def __init__(self,
|
511 |
+
data_path: Union[str, Path],
|
512 |
+
split: str,
|
513 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
514 |
+
in_memory: bool = False):
|
515 |
+
|
516 |
+
if split not in ('train', 'valid', 'test'):
|
517 |
+
raise ValueError(f"Unrecognized split: {split}. "
|
518 |
+
f"Must be one of ['train', 'valid', 'test']")
|
519 |
+
if isinstance(tokenizer, str):
|
520 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
521 |
+
self.tokenizer = tokenizer
|
522 |
+
|
523 |
+
data_path = Path(data_path)
|
524 |
+
data_file = f'fluorescence/fluorescence_{split}.lmdb'
|
525 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
526 |
+
|
527 |
+
def __len__(self) -> int:
|
528 |
+
return len(self.data)
|
529 |
+
|
530 |
+
def __getitem__(self, index: int):
|
531 |
+
item = self.data[index]
|
532 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
533 |
+
input_mask = np.ones_like(token_ids)
|
534 |
+
return token_ids, input_mask, float(item['log_fluorescence'][0])
|
535 |
+
|
536 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
537 |
+
input_ids, input_mask, fluorescence_true_value = tuple(zip(*batch))
|
538 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
539 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
540 |
+
fluorescence_true_value = torch.FloatTensor(fluorescence_true_value) # type: ignore
|
541 |
+
fluorescence_true_value = fluorescence_true_value.unsqueeze(1)
|
542 |
+
|
543 |
+
return {'input_ids': input_ids,
|
544 |
+
'input_mask': input_mask,
|
545 |
+
'targets': fluorescence_true_value}
|
546 |
+
|
547 |
+
|
548 |
+
@registry.register_task('stability')
|
549 |
+
class StabilityDataset(Dataset):
|
550 |
+
|
551 |
+
def __init__(self,
|
552 |
+
data_path: Union[str, Path],
|
553 |
+
split: str,
|
554 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
555 |
+
in_memory: bool = False):
|
556 |
+
|
557 |
+
if split not in ('train', 'valid', 'test'):
|
558 |
+
raise ValueError(f"Unrecognized split: {split}. "
|
559 |
+
f"Must be one of ['train', 'valid', 'test']")
|
560 |
+
if isinstance(tokenizer, str):
|
561 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
562 |
+
self.tokenizer = tokenizer
|
563 |
+
|
564 |
+
data_path = Path(data_path)
|
565 |
+
data_file = f'stability/stability_{split}.lmdb'
|
566 |
+
|
567 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
568 |
+
|
569 |
+
def __len__(self) -> int:
|
570 |
+
return len(self.data)
|
571 |
+
|
572 |
+
def __getitem__(self, index: int):
|
573 |
+
item = self.data[index]
|
574 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
575 |
+
input_mask = np.ones_like(token_ids)
|
576 |
+
return token_ids, input_mask, float(item['stability_score'][0])
|
577 |
+
|
578 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
579 |
+
input_ids, input_mask, stability_true_value = tuple(zip(*batch))
|
580 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
581 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
582 |
+
stability_true_value = torch.FloatTensor(stability_true_value) # type: ignore
|
583 |
+
stability_true_value = stability_true_value.unsqueeze(1)
|
584 |
+
|
585 |
+
return {'input_ids': input_ids,
|
586 |
+
'input_mask': input_mask,
|
587 |
+
'targets': stability_true_value}
|
588 |
+
|
589 |
+
|
590 |
+
@registry.register_task('remote_homology', num_labels=1195)
|
591 |
+
class RemoteHomologyDataset(Dataset):
|
592 |
+
|
593 |
+
def __init__(self,
|
594 |
+
data_path: Union[str, Path],
|
595 |
+
split: str,
|
596 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
597 |
+
in_memory: bool = False):
|
598 |
+
|
599 |
+
if split not in ('train', 'valid', 'test_fold_holdout',
|
600 |
+
'test_family_holdout', 'test_superfamily_holdout'):
|
601 |
+
raise ValueError(f"Unrecognized split: {split}. Must be one of "
|
602 |
+
f"['train', 'valid', 'test_fold_holdout', "
|
603 |
+
f"'test_family_holdout', 'test_superfamily_holdout']")
|
604 |
+
if isinstance(tokenizer, str):
|
605 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
606 |
+
self.tokenizer = tokenizer
|
607 |
+
|
608 |
+
data_path = Path(data_path)
|
609 |
+
data_file = f'remote_homology/remote_homology_{split}.lmdb'
|
610 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
611 |
+
|
612 |
+
def __len__(self) -> int:
|
613 |
+
return len(self.data)
|
614 |
+
|
615 |
+
def __getitem__(self, index: int):
|
616 |
+
item = self.data[index]
|
617 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
618 |
+
input_mask = np.ones_like(token_ids)
|
619 |
+
return token_ids, input_mask, item['fold_label']
|
620 |
+
|
621 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
622 |
+
input_ids, input_mask, fold_label = tuple(zip(*batch))
|
623 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
624 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
625 |
+
fold_label = torch.LongTensor(fold_label) # type: ignore
|
626 |
+
|
627 |
+
return {'input_ids': input_ids,
|
628 |
+
'input_mask': input_mask,
|
629 |
+
'targets': fold_label}
|
630 |
+
|
631 |
+
|
632 |
+
@registry.register_task('contact_prediction')
|
633 |
+
class ProteinnetDataset(Dataset):
|
634 |
+
|
635 |
+
def __init__(self,
|
636 |
+
data_path: Union[str, Path],
|
637 |
+
split: str,
|
638 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
639 |
+
in_memory: bool = False):
|
640 |
+
|
641 |
+
if split not in ('train', 'train_unfiltered', 'valid', 'test'):
|
642 |
+
raise ValueError(f"Unrecognized split: {split}. Must be one of "
|
643 |
+
f"['train', 'train_unfiltered', 'valid', 'test']")
|
644 |
+
|
645 |
+
if isinstance(tokenizer, str):
|
646 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
647 |
+
self.tokenizer = tokenizer
|
648 |
+
|
649 |
+
data_path = Path(data_path)
|
650 |
+
data_file = f'proteinnet/proteinnet_{split}.lmdb'
|
651 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
652 |
+
|
653 |
+
def __len__(self) -> int:
|
654 |
+
return len(self.data)
|
655 |
+
|
656 |
+
def __getitem__(self, index: int):
|
657 |
+
item = self.data[index]
|
658 |
+
protein_length = len(item['primary'])
|
659 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
660 |
+
input_mask = np.ones_like(token_ids)
|
661 |
+
|
662 |
+
valid_mask = item['valid_mask']
|
663 |
+
contact_map = np.less(squareform(pdist(item['tertiary'])), 8.0).astype(np.int64)
|
664 |
+
|
665 |
+
yind, xind = np.indices(contact_map.shape)
|
666 |
+
invalid_mask = ~(valid_mask[:, None] & valid_mask[None, :])
|
667 |
+
invalid_mask |= np.abs(yind - xind) < 6
|
668 |
+
contact_map[invalid_mask] = -1
|
669 |
+
|
670 |
+
return token_ids, input_mask, contact_map, protein_length
|
671 |
+
|
672 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
673 |
+
input_ids, input_mask, contact_labels, protein_length = tuple(zip(*batch))
|
674 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
675 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
676 |
+
contact_labels = torch.from_numpy(pad_sequences(contact_labels, -1))
|
677 |
+
protein_length = torch.LongTensor(protein_length) # type: ignore
|
678 |
+
|
679 |
+
return {'input_ids': input_ids,
|
680 |
+
'input_mask': input_mask,
|
681 |
+
'targets': contact_labels,
|
682 |
+
'protein_length': protein_length}
|
683 |
+
|
684 |
+
|
685 |
+
@registry.register_task('secondary_structure', num_labels=3)
|
686 |
+
class SecondaryStructureDataset(Dataset):
|
687 |
+
|
688 |
+
def __init__(self,
|
689 |
+
data_path: Union[str, Path],
|
690 |
+
split: str,
|
691 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
692 |
+
in_memory: bool = False):
|
693 |
+
|
694 |
+
if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513'):
|
695 |
+
raise ValueError(f"Unrecognized split: {split}. Must be one of "
|
696 |
+
f"['train', 'valid', 'casp12', "
|
697 |
+
f"'ts115', 'cb513']")
|
698 |
+
if isinstance(tokenizer, str):
|
699 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
700 |
+
self.tokenizer = tokenizer
|
701 |
+
|
702 |
+
data_path = Path(data_path)
|
703 |
+
data_file = f'secondary_structure/secondary_structure_{split}.lmdb'
|
704 |
+
self.data = dataset_factory(data_path / data_file, in_memory)
|
705 |
+
|
706 |
+
def __len__(self) -> int:
|
707 |
+
return len(self.data)
|
708 |
+
|
709 |
+
def __getitem__(self, index: int):
|
710 |
+
item = self.data[index]
|
711 |
+
token_ids = self.tokenizer.encode(item['primary'])
|
712 |
+
input_mask = np.ones_like(token_ids)
|
713 |
+
|
714 |
+
# pad with -1s because of cls/sep tokens
|
715 |
+
labels = np.asarray(item['ss3'], np.int64)
|
716 |
+
labels = np.pad(labels, (1, 1), 'constant', constant_values=-1)
|
717 |
+
|
718 |
+
return token_ids, input_mask, labels
|
719 |
+
|
720 |
+
def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
|
721 |
+
input_ids, input_mask, ss_label = tuple(zip(*batch))
|
722 |
+
input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
|
723 |
+
input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
|
724 |
+
ss_label = torch.from_numpy(pad_sequences(ss_label, -1))
|
725 |
+
|
726 |
+
output = {'input_ids': input_ids,
|
727 |
+
'input_mask': input_mask,
|
728 |
+
'targets': ss_label}
|
729 |
+
|
730 |
+
return output
|
731 |
+
|
732 |
+
|
733 |
+
@registry.register_task('trrosetta')
|
734 |
+
class TRRosettaDataset(Dataset):
|
735 |
+
|
736 |
+
def __init__(self,
|
737 |
+
data_path: Union[str, Path],
|
738 |
+
split: str,
|
739 |
+
tokenizer: Union[str, TAPETokenizer] = 'iupac',
|
740 |
+
in_memory: bool = False,
|
741 |
+
max_seqlen: int = 300):
|
742 |
+
if split not in ('train', 'valid'):
|
743 |
+
raise ValueError(
|
744 |
+
f"Unrecognized split: {split}. "
|
745 |
+
f"Must be one of ['train', 'valid']")
|
746 |
+
if isinstance(tokenizer, str):
|
747 |
+
tokenizer = TAPETokenizer(vocab=tokenizer)
|
748 |
+
self.tokenizer = tokenizer
|
749 |
+
|
750 |
+
data_path = Path(data_path)
|
751 |
+
data_path = data_path / 'trrosetta'
|
752 |
+
split_files = (data_path / f'{split}_files.txt').read_text().split()
|
753 |
+
self.data = NPZDataset(data_path / 'npz', in_memory, split_files=split_files)
|
754 |
+
|
755 |
+
self._dist_bins = np.arange(2, 20.1, 0.5)
|
756 |
+
self._dihedral_bins = (15 + np.arange(-180, 180, 15)) / 180 * np.pi
|
757 |
+
self._planar_bins = (15 + np.arange(0, 180, 15)) / 180 * np.pi
|
758 |
+
self._split = split
|
759 |
+
self.max_seqlen = max_seqlen
|
760 |
+
self.msa_cutoff = 0.8
|
761 |
+
self.penalty_coeff = 4.5
|
762 |
+
|
763 |
+
def __len__(self) -> int:
|
764 |
+
return len(self.data)
|
765 |
+
|
766 |
+
def __getitem__(self, index):
|
767 |
+
item = self.data[index]
|
768 |
+
|
769 |
+
msa = item['msa']
|
770 |
+
dist = item['dist6d']
|
771 |
+
omega = item['omega6d']
|
772 |
+
theta = item['theta6d']
|
773 |
+
phi = item['phi6d']
|
774 |
+
|
775 |
+
if self._split == 'train':
|
776 |
+
msa = self._subsample_msa(msa)
|
777 |
+
elif self._split == 'valid':
|
778 |
+
msa = msa[:20000] # runs out of memory if msa is way too big
|
779 |
+
msa, dist, omega, theta, phi = self._slice_long_sequences(
|
780 |
+
msa, dist, omega, theta, phi)
|
781 |
+
|
782 |
+
mask = dist == 0
|
783 |
+
|
784 |
+
dist_bins = np.digitize(dist, self._dist_bins)
|
785 |
+
omega_bins = np.digitize(omega, self._dihedral_bins) + 1
|
786 |
+
theta_bins = np.digitize(theta, self._dihedral_bins) + 1
|
787 |
+
phi_bins = np.digitize(phi, self._planar_bins) + 1
|
788 |
+
|
789 |
+
dist_bins[mask] = 0
|
790 |
+
omega_bins[mask] = 0
|
791 |
+
theta_bins[mask] = 0
|
792 |
+
phi_bins[mask] = 0
|
793 |
+
|
794 |
+
dist_bins[np.diag_indices_from(dist_bins)] = -1
|
795 |
+
|
796 |
+
# input_mask = np.ones_like(msa[0])
|
797 |
+
|
798 |
+
return msa, dist_bins, omega_bins, theta_bins, phi_bins
|
799 |
+
|
800 |
+
def _slice_long_sequences(self, msa, dist, omega, theta, phi):
|
801 |
+
seqlen = msa.shape[1]
|
802 |
+
if self.max_seqlen > 0 and seqlen > self.max_seqlen:
|
803 |
+
start = np.random.randint(seqlen - self.max_seqlen + 1)
|
804 |
+
end = start + self.max_seqlen
|
805 |
+
|
806 |
+
msa = msa[:, start:end]
|
807 |
+
dist = dist[start:end, start:end]
|
808 |
+
omega = omega[start:end, start:end]
|
809 |
+
theta = theta[start:end, start:end]
|
810 |
+
phi = phi[start:end, start:end]
|
811 |
+
|
812 |
+
return msa, dist, omega, theta, phi
|
813 |
+
|
814 |
+
def _subsample_msa(self, msa):
|
815 |
+
num_alignments, seqlen = msa.shape
|
816 |
+
|
817 |
+
if num_alignments < 10:
|
818 |
+
return msa
|
819 |
+
|
820 |
+
num_sample = int(10 ** np.random.uniform(np.log10(num_alignments)) - 10)
|
821 |
+
|
822 |
+
if num_sample <= 0:
|
823 |
+
return msa[0][None, :]
|
824 |
+
elif num_sample > 20000:
|
825 |
+
num_sample = 20000
|
826 |
+
|
827 |
+
indices = np.random.choice(
|
828 |
+
msa.shape[0] - 1, size=num_sample, replace=False) + 1
|
829 |
+
indices = np.pad(indices, [1, 0], 'constant') # add the sequence back in
|
830 |
+
return msa[indices]
|
831 |
+
|
832 |
+
def collate_fn(self, batch):
|
833 |
+
msa, dist_bins, omega_bins, theta_bins, phi_bins = tuple(zip(*batch))
|
834 |
+
# features = pad_sequences([self.featurize(msa_) for msa_ in msa], 0)
|
835 |
+
msa1hot = pad_sequences(
|
836 |
+
[F.one_hot(torch.LongTensor(msa_), 21) for msa_ in msa], 0, torch.float)
|
837 |
+
# input_mask = torch.FloatTensor(pad_sequences(input_mask, 0))
|
838 |
+
dist_bins = torch.LongTensor(pad_sequences(dist_bins, -1))
|
839 |
+
omega_bins = torch.LongTensor(pad_sequences(omega_bins, 0))
|
840 |
+
theta_bins = torch.LongTensor(pad_sequences(theta_bins, 0))
|
841 |
+
phi_bins = torch.LongTensor(pad_sequences(phi_bins, 0))
|
842 |
+
|
843 |
+
return {'msa1hot': msa1hot,
|
844 |
+
# 'input_mask': input_mask,
|
845 |
+
'dist': dist_bins,
|
846 |
+
'omega': omega_bins,
|
847 |
+
'theta': theta_bins,
|
848 |
+
'phi': phi_bins}
|
849 |
+
|
850 |
+
def featurize(self, msa):
|
851 |
+
msa = torch.LongTensor(msa)
|
852 |
+
msa1hot = F.one_hot(msa, 21).float()
|
853 |
+
|
854 |
+
seqlen = msa1hot.size(1)
|
855 |
+
|
856 |
+
weights = self.reweight(msa1hot)
|
857 |
+
features_1d = self.extract_features_1d(msa1hot, weights)
|
858 |
+
features_2d = self.extract_features_2d(msa1hot, weights)
|
859 |
+
|
860 |
+
features = torch.cat((
|
861 |
+
features_1d.unsqueeze(1).repeat(1, seqlen, 1),
|
862 |
+
features_1d.unsqueeze(0).repeat(seqlen, 1, 1),
|
863 |
+
features_2d), -1)
|
864 |
+
|
865 |
+
features = features.permute(2, 0, 1)
|
866 |
+
|
867 |
+
return features
|
868 |
+
|
869 |
+
def reweight(self, msa1hot):
|
870 |
+
# Reweight
|
871 |
+
seqlen = msa1hot.size(1)
|
872 |
+
id_min = seqlen * self.msa_cutoff
|
873 |
+
id_mtx = torch.tensordot(msa1hot, msa1hot, [[1, 2], [1, 2]])
|
874 |
+
id_mask = id_mtx > id_min
|
875 |
+
weights = 1.0 / id_mask.float().sum(-1)
|
876 |
+
return weights
|
877 |
+
|
878 |
+
def extract_features_1d(self, msa1hot, weights):
|
879 |
+
# 1D Features
|
880 |
+
seqlen = msa1hot.size(1)
|
881 |
+
f1d_seq = msa1hot[0, :, :20]
|
882 |
+
|
883 |
+
# msa2pssm
|
884 |
+
beff = weights.sum()
|
885 |
+
f_i = (weights[:, None, None] * msa1hot).sum(0) / beff + 1e-9
|
886 |
+
h_i = (-f_i * f_i.log()).sum(1, keepdims=True)
|
887 |
+
f1d_pssm = torch.cat((f_i, h_i), dim=1)
|
888 |
+
|
889 |
+
f1d = torch.cat((f1d_seq, f1d_pssm), dim=1)
|
890 |
+
f1d = f1d.view(seqlen, 42)
|
891 |
+
return f1d
|
892 |
+
|
893 |
+
def extract_features_2d(self, msa1hot, weights):
|
894 |
+
# 2D Features
|
895 |
+
num_alignments = msa1hot.size(0)
|
896 |
+
seqlen = msa1hot.size(1)
|
897 |
+
num_symbols = 21
|
898 |
+
if num_alignments == 1:
|
899 |
+
# No alignments, predict from sequence alone
|
900 |
+
f2d_dca = torch.zeros(seqlen, seqlen, 442, dtype=torch.float)
|
901 |
+
else:
|
902 |
+
# fast_dca
|
903 |
+
|
904 |
+
# covariance
|
905 |
+
x = msa1hot.view(num_alignments, seqlen * num_symbols)
|
906 |
+
num_points = weights.sum() - weights.mean().sqrt()
|
907 |
+
mean = (x * weights[:, None]).sum(0, keepdims=True) / num_points
|
908 |
+
x = (x - mean) * weights[:, None].sqrt()
|
909 |
+
cov = torch.matmul(x.transpose(-1, -2), x) / num_points
|
910 |
+
|
911 |
+
# inverse covariance
|
912 |
+
reg = torch.eye(seqlen * num_symbols) * self.penalty_coeff / weights.sum().sqrt()
|
913 |
+
cov_reg = cov + reg
|
914 |
+
inv_cov = torch.inverse(cov_reg)
|
915 |
+
|
916 |
+
x1 = inv_cov.view(seqlen, num_symbols, seqlen, num_symbols)
|
917 |
+
x2 = x1.permute(0, 2, 1, 3)
|
918 |
+
features = x2.reshape(seqlen, seqlen, num_symbols * num_symbols)
|
919 |
+
|
920 |
+
x3 = (x1[:, :-1, :, :-1] ** 2).sum((1, 3)).sqrt() * (1 - torch.eye(seqlen))
|
921 |
+
apc = x3.sum(0, keepdims=True) * x3.sum(1, keepdims=True) / x3.sum()
|
922 |
+
contacts = (x3 - apc) * (1 - torch.eye(seqlen))
|
923 |
+
|
924 |
+
f2d_dca = torch.cat([features, contacts[:, :, None]], axis=2)
|
925 |
+
|
926 |
+
return f2d_dca
|
tape/errors.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
class EarlyStopping(Exception):
|
2 |
+
"""Raised when stopping training b/c no improvement in validation loss"""
|
3 |
+
pass
|
tape/main.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import warnings
|
6 |
+
import inspect
|
7 |
+
|
8 |
+
|
9 |
+
try:
|
10 |
+
import apex # noqa: F401
|
11 |
+
APEX_FOUND = True
|
12 |
+
except ImportError:
|
13 |
+
APEX_FOUND = False
|
14 |
+
|
15 |
+
from .registry import registry
|
16 |
+
from . import training
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
CallbackList = typing.Sequence[typing.Callable]
|
20 |
+
OutputDict = typing.Dict[str, typing.List[typing.Any]]
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
warnings.filterwarnings( # Ignore pytorch warning about loss gathering
|
25 |
+
'ignore', message='Was asked to gather along dimension 0', module='torch.nn.parallel')
|
26 |
+
|
27 |
+
|
28 |
+
def create_base_parser() -> argparse.ArgumentParser:
|
29 |
+
parser = argparse.ArgumentParser(description='Parent parser for tape functions',
|
30 |
+
add_help=False)
|
31 |
+
parser.add_argument('model_type', help='Base model class to run')
|
32 |
+
parser.add_argument('--model_config_file', default=None, type=utils.check_is_file,
|
33 |
+
help='Config file for model')
|
34 |
+
parser.add_argument('--vocab_file', default=None,
|
35 |
+
help='Pretrained tokenizer vocab file')
|
36 |
+
parser.add_argument('--output_dir', default='./results', type=str)
|
37 |
+
parser.add_argument('--no_cuda', action='store_true', help='CPU-only flag')
|
38 |
+
parser.add_argument('--seed', default=42, type=int, help='Random seed to use')
|
39 |
+
parser.add_argument('--local_rank', type=int, default=-1,
|
40 |
+
help='Local rank of process in distributed training. '
|
41 |
+
'Set by launch script.')
|
42 |
+
parser.add_argument('--tokenizer', choices=['iupac', 'unirep'],
|
43 |
+
default='iupac', help='Tokenizes to use on the amino acid sequences')
|
44 |
+
parser.add_argument('--num_workers', default=8, type=int,
|
45 |
+
help='Number of workers to use for multi-threaded data loading')
|
46 |
+
parser.add_argument('--log_level', default=logging.INFO,
|
47 |
+
choices=['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR',
|
48 |
+
logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR],
|
49 |
+
help="log level for the experiment")
|
50 |
+
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
|
51 |
+
|
52 |
+
return parser
|
53 |
+
|
54 |
+
|
55 |
+
def create_train_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
56 |
+
parser = argparse.ArgumentParser(description='Run Training on the TAPE datasets',
|
57 |
+
parents=[base_parser])
|
58 |
+
parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
|
59 |
+
help='TAPE Task to train/eval on')
|
60 |
+
parser.add_argument('--learning_rate', default=1e-4, type=float,
|
61 |
+
help='Learning rate')
|
62 |
+
parser.add_argument('--batch_size', default=1024, type=int,
|
63 |
+
help='Batch size')
|
64 |
+
parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
|
65 |
+
help='Directory from which to load task data')
|
66 |
+
parser.add_argument('--num_train_epochs', default=10, type=int,
|
67 |
+
help='Number of training epochs')
|
68 |
+
parser.add_argument('--num_steps_per_epoch', default=-1, type=int,
|
69 |
+
help='Number of steps per epoch')
|
70 |
+
parser.add_argument('--num_log_iter', default=20, type=int,
|
71 |
+
help='Number of training steps per log iteration')
|
72 |
+
parser.add_argument('--fp16', action='store_true', help='Whether to use fp16 weights')
|
73 |
+
parser.add_argument('--warmup_steps', default=10000, type=int,
|
74 |
+
help='Number of learning rate warmup steps')
|
75 |
+
parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
|
76 |
+
help='Number of forward passes to make for each backwards pass')
|
77 |
+
parser.add_argument('--loss_scale', default=0, type=int,
|
78 |
+
help='Loss scaling. Only used during fp16 training.')
|
79 |
+
parser.add_argument('--max_grad_norm', default=1.0, type=float,
|
80 |
+
help='Maximum gradient norm')
|
81 |
+
parser.add_argument('--exp_name', default=None, type=str,
|
82 |
+
help='Name to give to this experiment')
|
83 |
+
parser.add_argument('--from_pretrained', default=None, type=str,
|
84 |
+
help='Directory containing config and pretrained model weights')
|
85 |
+
parser.add_argument('--log_dir', default='./logs', type=str)
|
86 |
+
parser.add_argument('--eval_freq', type=int, default=1,
|
87 |
+
help="Frequency of eval pass. A value <= 0 means the eval pass is "
|
88 |
+
"not run")
|
89 |
+
parser.add_argument('--save_freq', default='improvement', type=utils.int_or_str,
|
90 |
+
help="How often to save the model during training. Either an integer "
|
91 |
+
"frequency or the string 'improvement'")
|
92 |
+
parser.add_argument('--patience', default=-1, type=int,
|
93 |
+
help="How many epochs without improvement to wait before ending "
|
94 |
+
"training")
|
95 |
+
parser.add_argument('--resume_from_checkpoint', action='store_true',
|
96 |
+
help="whether to resume training from the checkpoint")
|
97 |
+
parser.add_argument('--val_check_frac', default=1.0, type=float,
|
98 |
+
help="Fraction of validation to check")
|
99 |
+
return parser
|
100 |
+
|
101 |
+
|
102 |
+
def create_eval_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
103 |
+
parser = argparse.ArgumentParser(description='Run Eval on the TAPE Datasets',
|
104 |
+
parents=[base_parser])
|
105 |
+
parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
|
106 |
+
help='TAPE Task to train/eval on')
|
107 |
+
parser.add_argument('from_pretrained', type=str,
|
108 |
+
help='Directory containing config and pretrained model weights')
|
109 |
+
parser.add_argument('--batch_size', default=1024, type=int,
|
110 |
+
help='Batch size')
|
111 |
+
parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
|
112 |
+
help='Directory from which to load task data')
|
113 |
+
parser.add_argument('--metrics', default=[],
|
114 |
+
help=f'Metrics to run on the result. '
|
115 |
+
f'Choices: {list(registry.metric_name_mapping.keys())}',
|
116 |
+
nargs='*')
|
117 |
+
parser.add_argument('--split', default='test', type=str,
|
118 |
+
help='Which split to run on')
|
119 |
+
return parser
|
120 |
+
|
121 |
+
|
122 |
+
def create_embed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
123 |
+
parser = argparse.ArgumentParser(
|
124 |
+
description='Embed a set of proteins with a pretrained model',
|
125 |
+
parents=[base_parser])
|
126 |
+
parser.add_argument('data_file', type=str,
|
127 |
+
help='File containing set of proteins to embed')
|
128 |
+
parser.add_argument('out_file', type=str,
|
129 |
+
help='Name of output file')
|
130 |
+
parser.add_argument('from_pretrained', type=str,
|
131 |
+
help='Directory containing config and pretrained model weights')
|
132 |
+
parser.add_argument('--batch_size', default=1024, type=int,
|
133 |
+
help='Batch size')
|
134 |
+
parser.add_argument('--full_sequence_embed', action='store_true',
|
135 |
+
help='If true, saves an embedding at every amino acid position '
|
136 |
+
'in the sequence. Note that this can take a large amount '
|
137 |
+
'of disk space.')
|
138 |
+
parser.set_defaults(task='embed')
|
139 |
+
return parser
|
140 |
+
|
141 |
+
|
142 |
+
def create_distributed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
143 |
+
parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
|
144 |
+
# typing.Optional arguments for the launch helper
|
145 |
+
parser.add_argument("--nnodes", type=int, default=1,
|
146 |
+
help="The number of nodes to use for distributed "
|
147 |
+
"training")
|
148 |
+
parser.add_argument("--node_rank", type=int, default=0,
|
149 |
+
help="The rank of the node for multi-node distributed "
|
150 |
+
"training")
|
151 |
+
parser.add_argument("--nproc_per_node", type=int, default=1,
|
152 |
+
help="The number of processes to launch on each node, "
|
153 |
+
"for GPU training, this is recommended to be set "
|
154 |
+
"to the number of GPUs in your system so that "
|
155 |
+
"each process can be bound to a single GPU.")
|
156 |
+
parser.add_argument("--master_addr", default="127.0.0.1", type=str,
|
157 |
+
help="Master node (rank 0)'s address, should be either "
|
158 |
+
"the IP address or the hostname of node 0, for "
|
159 |
+
"single node multi-proc training, the "
|
160 |
+
"--master_addr can simply be 127.0.0.1")
|
161 |
+
parser.add_argument("--master_port", default=29500, type=int,
|
162 |
+
help="Master node (rank 0)'s free port that needs to "
|
163 |
+
"be used for communciation during distributed "
|
164 |
+
"training")
|
165 |
+
return parser
|
166 |
+
|
167 |
+
|
168 |
+
def create_model_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
169 |
+
parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
|
170 |
+
parser.add_argument('--model_args', nargs=argparse.REMAINDER, default=None)
|
171 |
+
return parser
|
172 |
+
|
173 |
+
def run_train(args: typing.Optional[argparse.Namespace] = None, env=None) -> None:
|
174 |
+
if env is not None:
|
175 |
+
os.environ = env
|
176 |
+
|
177 |
+
if args is None:
|
178 |
+
base_parser = create_base_parser()
|
179 |
+
train_parser = create_train_parser(base_parser)
|
180 |
+
model_parser = create_model_parser(train_parser)
|
181 |
+
args = model_parser.parse_args()
|
182 |
+
|
183 |
+
if args.gradient_accumulation_steps < 1:
|
184 |
+
raise ValueError(
|
185 |
+
f"Invalid gradient_accumulation_steps parameter: "
|
186 |
+
f"{args.gradient_accumulation_steps}, should be >= 1")
|
187 |
+
|
188 |
+
if (args.fp16 or args.local_rank != -1) and not APEX_FOUND:
|
189 |
+
raise ImportError(
|
190 |
+
"Please install apex from https://www.github.com/nvidia/apex "
|
191 |
+
"to use distributed and fp16 training.")
|
192 |
+
|
193 |
+
arg_dict = vars(args)
|
194 |
+
arg_names = inspect.getfullargspec(training.run_train).args
|
195 |
+
|
196 |
+
missing = set(arg_names) - set(arg_dict.keys())
|
197 |
+
if missing:
|
198 |
+
raise RuntimeError(f"Missing arguments: {missing}")
|
199 |
+
train_args = {name: arg_dict[name] for name in arg_names}
|
200 |
+
|
201 |
+
training.run_train(**train_args)
|
202 |
+
|
203 |
+
|
204 |
+
def run_eval(args: typing.Optional[argparse.Namespace] = None) -> typing.Dict[str, float]:
|
205 |
+
if args is None:
|
206 |
+
base_parser = create_base_parser()
|
207 |
+
parser = create_eval_parser(base_parser)
|
208 |
+
parser = create_model_parser(parser)
|
209 |
+
args = parser.parse_args()
|
210 |
+
|
211 |
+
if args.from_pretrained is None:
|
212 |
+
raise ValueError("Must specify pretrained model")
|
213 |
+
if args.local_rank != -1:
|
214 |
+
raise ValueError("TAPE does not support distributed validation pass")
|
215 |
+
|
216 |
+
arg_dict = vars(args)
|
217 |
+
arg_names = inspect.getfullargspec(training.run_eval).args
|
218 |
+
|
219 |
+
missing = set(arg_names) - set(arg_dict.keys())
|
220 |
+
if missing:
|
221 |
+
raise RuntimeError(f"Missing arguments: {missing}")
|
222 |
+
eval_args = {name: arg_dict[name] for name in arg_names}
|
223 |
+
|
224 |
+
return training.run_eval(**eval_args)
|
225 |
+
|
226 |
+
|
227 |
+
def run_embed(args: typing.Optional[argparse.Namespace] = None) -> None:
|
228 |
+
if args is None:
|
229 |
+
base_parser = create_base_parser()
|
230 |
+
parser = create_embed_parser(base_parser)
|
231 |
+
parser = create_model_parser(parser)
|
232 |
+
args = parser.parse_args()
|
233 |
+
if args.from_pretrained is None:
|
234 |
+
raise ValueError("Must specify pretrained model")
|
235 |
+
if args.local_rank != -1:
|
236 |
+
raise ValueError("TAPE does not support distributed validation pass")
|
237 |
+
|
238 |
+
arg_dict = vars(args)
|
239 |
+
arg_names = inspect.getfullargspec(training.run_embed).args
|
240 |
+
|
241 |
+
missing = set(arg_names) - set(arg_dict.keys())
|
242 |
+
if missing:
|
243 |
+
raise RuntimeError(f"Missing arguments: {missing}")
|
244 |
+
embed_args = {name: arg_dict[name] for name in arg_names}
|
245 |
+
|
246 |
+
training.run_embed(**embed_args)
|
247 |
+
|
248 |
+
|
249 |
+
def run_train_distributed(args: typing.Optional[argparse.Namespace] = None) -> None:
|
250 |
+
"""Runs distributed training via multiprocessing.
|
251 |
+
"""
|
252 |
+
if args is None:
|
253 |
+
base_parser = create_base_parser()
|
254 |
+
distributed_parser = create_distributed_parser(base_parser)
|
255 |
+
distributed_train_parser = create_train_parser(distributed_parser)
|
256 |
+
parser = create_model_parser(distributed_train_parser)
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
# Define the experiment name here, instead of dealing with barriers and communication
|
260 |
+
# when getting the experiment name
|
261 |
+
exp_name = utils.get_expname(args.exp_name, args.task, args.model_type)
|
262 |
+
args.exp_name = exp_name
|
263 |
+
utils.launch_process_group(
|
264 |
+
run_train, args, args.nproc_per_node, args.nnodes,
|
265 |
+
args.node_rank, args.master_addr, args.master_port)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == '__main__':
|
269 |
+
run_train_distributed()
|
tape/metrics.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Union
|
2 |
+
import numpy as np
|
3 |
+
import scipy.stats
|
4 |
+
|
5 |
+
from .registry import registry
|
6 |
+
|
7 |
+
|
8 |
+
@registry.register_metric('mse')
|
9 |
+
def mean_squared_error(target: Sequence[float],
|
10 |
+
prediction: Sequence[float]) -> float:
|
11 |
+
target_array = np.asarray(target)
|
12 |
+
prediction_array = np.asarray(prediction)
|
13 |
+
return np.mean(np.square(target_array - prediction_array))
|
14 |
+
|
15 |
+
|
16 |
+
@registry.register_metric('mae')
|
17 |
+
def mean_absolute_error(target: Sequence[float],
|
18 |
+
prediction: Sequence[float]) -> float:
|
19 |
+
target_array = np.asarray(target)
|
20 |
+
prediction_array = np.asarray(prediction)
|
21 |
+
return np.mean(np.abs(target_array - prediction_array))
|
22 |
+
|
23 |
+
|
24 |
+
@registry.register_metric('spearmanr')
|
25 |
+
def spearmanr(target: Sequence[float],
|
26 |
+
prediction: Sequence[float]) -> float:
|
27 |
+
target_array = np.asarray(target)
|
28 |
+
prediction_array = np.asarray(prediction)
|
29 |
+
return scipy.stats.spearmanr(target_array, prediction_array).correlation
|
30 |
+
|
31 |
+
|
32 |
+
@registry.register_metric('accuracy')
|
33 |
+
def accuracy(target: Union[Sequence[int], Sequence[Sequence[int]]],
|
34 |
+
prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float:
|
35 |
+
if isinstance(target[0], int):
|
36 |
+
# non-sequence case
|
37 |
+
return np.mean(np.asarray(target) == np.asarray(prediction).argmax(-1))
|
38 |
+
else:
|
39 |
+
correct = 0
|
40 |
+
total = 0
|
41 |
+
for label, score in zip(target, prediction):
|
42 |
+
label_array = np.asarray(label)
|
43 |
+
pred_array = np.asarray(score).argmax(-1)
|
44 |
+
mask = label_array != -1
|
45 |
+
is_correct = label_array[mask] == pred_array[mask]
|
46 |
+
correct += is_correct.sum()
|
47 |
+
total += is_correct.size
|
48 |
+
return correct / total
|
tape/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .modeling_utils import ProteinConfig # noqa: F401
|
2 |
+
# from .modeling_utils import ProteinModel # noqa: F401
|
3 |
+
|
4 |
+
# from .modeling_bert import ProteinBertModel # noqa: F401
|
5 |
+
# from .modeling_bert import ProteinBertForMaskedLM # noqa: F401
|
6 |
+
# from .modeling_bert import ProteinBertForValuePrediction # noqa: F401
|
7 |
+
# from .modeling_bert import ProteinBertForSequenceClassification # noqa: F401
|
8 |
+
# from .modeling_bert import ProteinBertForSequenceToSequenceClassification # noqa: F401
|
9 |
+
# # TODO: ProteinBertForContactPrediction
|
10 |
+
# from .modeling_resnet import ProteinResNetModel # noqa: F401
|
11 |
+
# from .modeling_resnet import ProteinResNetForMaskedLM # noqa: F401
|
12 |
+
# from .modeling_resnet import ProteinResNetForValuePrediction # noqa: F401
|
13 |
+
# from .modeling_resnet import ProteinResNetForSequenceClassification # noqa: F401
|
14 |
+
# from .modeling_resnet import ProteinResNetForSequenceToSequenceClassification # noqa: F401
|
15 |
+
# # TODO: ProteinResNetForContactPrediction
|
16 |
+
# # TODO: ProteinLSTM*
|
17 |
+
# from .modeling_unirep import UniRepModel # noqa: F401
|
18 |
+
# from .modeling_unirep import UniRepForLM # noqa: F401
|
19 |
+
# from .modeling_unirep import UniRepForValuePrediction # noqa: F401
|
20 |
+
# from .modeling_unirep import UniRepForSequenceClassification # noqa: F401
|
21 |
+
# from .modeling_unirep import UniRepForSequenceToSequenceClassification # noqa: F401
|
22 |
+
# # TODO: UniRepForContactPrediction
|
23 |
+
# # TODO: Bepler*
|
24 |
+
# from .modeling_onehot import OneHotModel # noqa: F401
|
25 |
+
# from .modeling_onehot import OneHotForValuePrediction # noqa: F401
|
26 |
+
# from .modeling_onehot import OneHotForSequenceClassification # noqa: F401
|
27 |
+
# from .modeling_onehot import OneHotForSequenceToSequenceClassification # noqa: F401
|
28 |
+
# TODO: OneHotForContactPrediction
|
tape/models/file_utils.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the huggingface transformers library at
|
4 |
+
https://github.com/huggingface/transformers, which in turn is adapted from the AllenNLP
|
5 |
+
library at https://github.com/allenai/allennlp
|
6 |
+
Copyright by the AllenNLP authors.
|
7 |
+
Note - this file goes to effort to support Python 2, but the rest of this repository does not.
|
8 |
+
"""
|
9 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
10 |
+
|
11 |
+
import typing
|
12 |
+
import sys
|
13 |
+
import json
|
14 |
+
import logging
|
15 |
+
import os
|
16 |
+
import tempfile
|
17 |
+
import fnmatch
|
18 |
+
from io import open
|
19 |
+
|
20 |
+
import boto3
|
21 |
+
import requests
|
22 |
+
from botocore.exceptions import ClientError
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from contextlib import contextmanager
|
26 |
+
from functools import partial, wraps
|
27 |
+
from hashlib import sha256
|
28 |
+
|
29 |
+
from filelock import FileLock
|
30 |
+
# from tqdm.auto import tqdm
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
try:
|
36 |
+
from torch.hub import _get_torch_home
|
37 |
+
torch_cache_home = _get_torch_home()
|
38 |
+
except ImportError:
|
39 |
+
torch_cache_home = os.path.expanduser(
|
40 |
+
os.getenv('TORCH_HOME', os.path.join(
|
41 |
+
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
42 |
+
default_cache_path = os.path.join(torch_cache_home, 'protein_models')
|
43 |
+
|
44 |
+
try:
|
45 |
+
from urllib.parse import urlparse
|
46 |
+
except ImportError:
|
47 |
+
from urlparse import urlparse # type: ignore
|
48 |
+
|
49 |
+
try:
|
50 |
+
from pathlib import Path
|
51 |
+
PYTORCH_PRETRAINED_BERT_CACHE: typing.Union[str, Path] = Path(
|
52 |
+
os.getenv('PROTEIN_MODELS_CACHE', os.getenv(
|
53 |
+
'PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
54 |
+
except (AttributeError, ImportError):
|
55 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PROTEIN_MODELS_CACHE',
|
56 |
+
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
57 |
+
default_cache_path))
|
58 |
+
|
59 |
+
PROTEIN_MODELS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
60 |
+
|
61 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
62 |
+
|
63 |
+
|
64 |
+
def get_cache():
|
65 |
+
return PROTEIN_MODELS_CACHE
|
66 |
+
|
67 |
+
|
68 |
+
def get_etag(url):
|
69 |
+
# Get eTag to add to filename, if it exists.
|
70 |
+
if url.startswith("s3://"):
|
71 |
+
etag = s3_etag(url)
|
72 |
+
else:
|
73 |
+
try:
|
74 |
+
response = requests.head(url, allow_redirects=True)
|
75 |
+
if response.status_code != 200:
|
76 |
+
etag = None
|
77 |
+
else:
|
78 |
+
etag = response.headers.get("ETag")
|
79 |
+
except EnvironmentError:
|
80 |
+
etag = None
|
81 |
+
|
82 |
+
if sys.version_info[0] == 2 and etag is not None:
|
83 |
+
etag = etag.decode('utf-8')
|
84 |
+
|
85 |
+
return etag
|
86 |
+
|
87 |
+
|
88 |
+
def url_to_filename(url, etag=None):
|
89 |
+
"""
|
90 |
+
Convert `url` into a hashed filename in a repeatable way.
|
91 |
+
If `etag` is specified, append its hash to the url's, delimited
|
92 |
+
by a period.
|
93 |
+
"""
|
94 |
+
url_bytes = url.encode('utf-8')
|
95 |
+
url_hash = sha256(url_bytes)
|
96 |
+
filename = url_hash.hexdigest()
|
97 |
+
|
98 |
+
if etag:
|
99 |
+
etag_bytes = etag.encode('utf-8')
|
100 |
+
etag_hash = sha256(etag_bytes)
|
101 |
+
filename += '.' + etag_hash.hexdigest()
|
102 |
+
|
103 |
+
return filename
|
104 |
+
|
105 |
+
|
106 |
+
def filename_to_url(filename, cache_dir=None):
|
107 |
+
"""
|
108 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
109 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
110 |
+
"""
|
111 |
+
if cache_dir is None:
|
112 |
+
cache_dir = PROTEIN_MODELS_CACHE
|
113 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
114 |
+
cache_dir = str(cache_dir)
|
115 |
+
|
116 |
+
cache_path = os.path.join(cache_dir, filename)
|
117 |
+
if not os.path.exists(cache_path):
|
118 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
119 |
+
|
120 |
+
meta_path = cache_path + '.json'
|
121 |
+
if not os.path.exists(meta_path):
|
122 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
123 |
+
|
124 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
125 |
+
metadata = json.load(meta_file)
|
126 |
+
url = metadata['url']
|
127 |
+
etag = metadata['etag']
|
128 |
+
|
129 |
+
return url, etag
|
130 |
+
|
131 |
+
|
132 |
+
def cached_path(url_or_filename, force_download=False, cache_dir=None):
|
133 |
+
"""
|
134 |
+
Given something that might be a URL (or might be a local path),
|
135 |
+
determine which. If it's a URL, download the file and cache it, and
|
136 |
+
return the path to the cached file. If it's already a local path,
|
137 |
+
make sure the file exists and then return the path.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
cache_dir: specify a cache directory to save the file to
|
141 |
+
(overwrite the default cache dir).
|
142 |
+
force_download: if True, re-dowload the file even if it's
|
143 |
+
already cached in the cache dir.
|
144 |
+
"""
|
145 |
+
if cache_dir is None:
|
146 |
+
cache_dir = PROTEIN_MODELS_CACHE
|
147 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
148 |
+
url_or_filename = str(url_or_filename)
|
149 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
150 |
+
cache_dir = str(cache_dir)
|
151 |
+
|
152 |
+
parsed = urlparse(url_or_filename)
|
153 |
+
|
154 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
155 |
+
# URL, so get it from the cache (downloading if necessary)
|
156 |
+
output_path = get_from_cache(url_or_filename, cache_dir, force_download)
|
157 |
+
elif os.path.exists(url_or_filename):
|
158 |
+
# File, and it exists.
|
159 |
+
output_path = url_or_filename
|
160 |
+
elif parsed.scheme == '':
|
161 |
+
# File, but it doesn't exist.
|
162 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
163 |
+
else:
|
164 |
+
# Something unknown
|
165 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(
|
166 |
+
url_or_filename))
|
167 |
+
|
168 |
+
return output_path
|
169 |
+
|
170 |
+
|
171 |
+
def split_s3_path(url):
|
172 |
+
"""Split a full s3 path into the bucket name and path."""
|
173 |
+
parsed = urlparse(url)
|
174 |
+
if not parsed.netloc or not parsed.path:
|
175 |
+
raise ValueError("bad s3 path {}".format(url))
|
176 |
+
bucket_name = parsed.netloc
|
177 |
+
s3_path = parsed.path
|
178 |
+
# Remove '/' at beginning of path.
|
179 |
+
if s3_path.startswith("/"):
|
180 |
+
s3_path = s3_path[1:]
|
181 |
+
return bucket_name, s3_path
|
182 |
+
|
183 |
+
|
184 |
+
def s3_request(func):
|
185 |
+
"""
|
186 |
+
Wrapper function for s3 requests in order to create more helpful error
|
187 |
+
messages.
|
188 |
+
"""
|
189 |
+
|
190 |
+
@wraps(func)
|
191 |
+
def wrapper(url, *args, **kwargs):
|
192 |
+
try:
|
193 |
+
return func(url, *args, **kwargs)
|
194 |
+
except ClientError as exc:
|
195 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
196 |
+
raise EnvironmentError("file {} not found".format(url))
|
197 |
+
else:
|
198 |
+
raise
|
199 |
+
|
200 |
+
return wrapper
|
201 |
+
|
202 |
+
|
203 |
+
@s3_request
|
204 |
+
def s3_etag(url):
|
205 |
+
"""Check ETag on S3 object."""
|
206 |
+
s3_resource = boto3.resource("s3")
|
207 |
+
bucket_name, s3_path = split_s3_path(url)
|
208 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
209 |
+
return s3_object.e_tag
|
210 |
+
|
211 |
+
|
212 |
+
@s3_request
|
213 |
+
def s3_get(url, temp_file):
|
214 |
+
"""Pull a file directly from S3."""
|
215 |
+
s3_resource = boto3.resource("s3")
|
216 |
+
bucket_name, s3_path = split_s3_path(url)
|
217 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
218 |
+
|
219 |
+
|
220 |
+
def http_get(url, temp_file):
|
221 |
+
req = requests.get(url, stream=True)
|
222 |
+
content_length = req.headers.get('Content-Length')
|
223 |
+
total = int(content_length) if content_length is not None else None
|
224 |
+
progress = tqdm(unit="B", total=total)
|
225 |
+
for chunk in req.iter_content(chunk_size=1024):
|
226 |
+
if chunk: # filter out keep-alive new chunks
|
227 |
+
progress.update(len(chunk))
|
228 |
+
temp_file.write(chunk)
|
229 |
+
progress.close()
|
230 |
+
|
231 |
+
|
232 |
+
def get_from_cache(url, cache_dir=None, force_download=False, resume_download=False):
|
233 |
+
"""
|
234 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
235 |
+
If it's not there, download it. Then return the path to the cached file.
|
236 |
+
"""
|
237 |
+
if cache_dir is None:
|
238 |
+
cache_dir = PROTEIN_MODELS_CACHE
|
239 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
240 |
+
cache_dir = str(cache_dir)
|
241 |
+
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
242 |
+
cache_dir = str(cache_dir)
|
243 |
+
|
244 |
+
if not os.path.exists(cache_dir):
|
245 |
+
os.makedirs(cache_dir)
|
246 |
+
|
247 |
+
# Get eTag to add to filename, if it exists.
|
248 |
+
if url.startswith("s3://"):
|
249 |
+
etag = s3_etag(url)
|
250 |
+
else:
|
251 |
+
try:
|
252 |
+
response = requests.head(url, allow_redirects=True)
|
253 |
+
if response.status_code != 200:
|
254 |
+
etag = None
|
255 |
+
else:
|
256 |
+
etag = response.headers.get("ETag")
|
257 |
+
except EnvironmentError:
|
258 |
+
etag = None
|
259 |
+
|
260 |
+
if sys.version_info[0] == 2 and etag is not None:
|
261 |
+
etag = etag.decode('utf-8')
|
262 |
+
filename = url_to_filename(url, etag)
|
263 |
+
|
264 |
+
# get cache path to put the file
|
265 |
+
cache_path = os.path.join(cache_dir, filename)
|
266 |
+
|
267 |
+
if os.path.exists(cache_path) and etag is None:
|
268 |
+
return cache_path
|
269 |
+
|
270 |
+
# If we don't have a connection (etag is None) and can't identify the file
|
271 |
+
# try to get the last downloaded one
|
272 |
+
if not os.path.exists(cache_path) and etag is None:
|
273 |
+
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
274 |
+
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
275 |
+
if matching_files:
|
276 |
+
cache_path = os.path.join(cache_dir, matching_files[-1])
|
277 |
+
|
278 |
+
# From now on, etag is not None
|
279 |
+
if os.path.exists(cache_path) and not force_download:
|
280 |
+
return cache_path
|
281 |
+
|
282 |
+
# Prevent parallel downloads of the same file with a lock.
|
283 |
+
lock_path = cache_path + ".lock"
|
284 |
+
with FileLock(lock_path):
|
285 |
+
|
286 |
+
# If the download just completed while the lock was activated.
|
287 |
+
if os.path.exists(cache_path) and not force_download:
|
288 |
+
# Even if returning early like here, the lock will be released.
|
289 |
+
return cache_path
|
290 |
+
|
291 |
+
if resume_download:
|
292 |
+
incomplete_path = cache_path + ".incomplete"
|
293 |
+
|
294 |
+
@contextmanager
|
295 |
+
def _resumable_file_manager():
|
296 |
+
with open(incomplete_path, "a+b") as f:
|
297 |
+
yield f
|
298 |
+
|
299 |
+
temp_file_manager = _resumable_file_manager
|
300 |
+
else:
|
301 |
+
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir,
|
302 |
+
delete=False)
|
303 |
+
# Download to temporary file, then copy to cache dir once finished.
|
304 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
305 |
+
with temp_file_manager() as temp_file:
|
306 |
+
logger.info("%s not in cache or force_download=True, download to %s",
|
307 |
+
url, temp_file.name)
|
308 |
+
|
309 |
+
http_get(url, temp_file)
|
310 |
+
|
311 |
+
logger.info("storing %s in cache at %s", url, cache_path)
|
312 |
+
os.replace(temp_file.name, cache_path)
|
313 |
+
|
314 |
+
logger.info("creating metadata file for %s", cache_path)
|
315 |
+
meta = {"url": url, "etag": etag}
|
316 |
+
meta_path = cache_path + ".json"
|
317 |
+
with open(meta_path, "w") as meta_file:
|
318 |
+
json.dump(meta, meta_file)
|
319 |
+
'''
|
320 |
+
if not os.path.exists(cache_path):
|
321 |
+
# Download to temporary file, then copy to cache dir once finished.
|
322 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
323 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
324 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
325 |
+
|
326 |
+
# GET file object
|
327 |
+
if url.startswith("s3://"):
|
328 |
+
s3_get(url, temp_file)
|
329 |
+
else:
|
330 |
+
http_get(url, temp_file)
|
331 |
+
|
332 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
333 |
+
temp_file.flush()
|
334 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
335 |
+
temp_file.seek(0)
|
336 |
+
|
337 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
338 |
+
with open(cache_path, 'wb') as cache_file:
|
339 |
+
shutil.copyfileobj(temp_file, cache_file)
|
340 |
+
|
341 |
+
logger.info("creating metadata file for %s", cache_path)
|
342 |
+
meta = {'url': url, 'etag': etag}
|
343 |
+
meta_path = cache_path + '.json'
|
344 |
+
with open(meta_path, 'w') as meta_file:
|
345 |
+
output_string = json.dumps(meta)
|
346 |
+
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
347 |
+
# The beauty of python 2
|
348 |
+
output_string = unicode(output_string, 'utf-8') # noqa: F821
|
349 |
+
meta_file.write(output_string)
|
350 |
+
|
351 |
+
logger.info("removing temp file %s", temp_file.name)
|
352 |
+
'''
|
353 |
+
return cache_path
|
tape/models/modeling_autoencoder.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .modeling_utils import ProteinConfig
|
8 |
+
from .modeling_utils import ProteinModel
|
9 |
+
from .modeling_utils import get_activation_fn
|
10 |
+
from .modeling_utils import MLMHead
|
11 |
+
from .modeling_utils import LayerNorm
|
12 |
+
from .modeling_utils import ValuePredictionHead
|
13 |
+
from .modeling_utils import SequenceClassificationHead
|
14 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
15 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
16 |
+
from ..registry import registry
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
21 |
+
RESNET_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
22 |
+
|
23 |
+
|
24 |
+
class ProteinAEConfig(ProteinConfig):
|
25 |
+
pretrained_config_archive_map = RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
vocab_size: int = 30,
|
29 |
+
hidden_size: int = 512,
|
30 |
+
num_hidden_layers: int = 30,
|
31 |
+
hidden_act: str = "gelu",
|
32 |
+
hidden_dropout_prob: float = 0.1,
|
33 |
+
initializer_range: float = 0.02,
|
34 |
+
layer_norm_eps: float = 1e-12,
|
35 |
+
temporal_pooling: str = 'attention',
|
36 |
+
freeze_embedding: bool = False,
|
37 |
+
max_size: int = 3000,
|
38 |
+
latent_size: int = 1024,
|
39 |
+
**kwargs):
|
40 |
+
super().__init__(**kwargs)
|
41 |
+
self.vocab_size = vocab_size
|
42 |
+
self.num_hidden_layers = num_hidden_layers
|
43 |
+
self.hidden_size = hidden_size
|
44 |
+
self.hidden_act = hidden_act
|
45 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
46 |
+
self.initializer_range = initializer_range
|
47 |
+
self.layer_norm_eps = layer_norm_eps
|
48 |
+
self.temporal_pooling = temporal_pooling
|
49 |
+
self.freeze_embedding = freeze_embedding
|
50 |
+
self.max_size = max_size
|
51 |
+
self.latent_size = latent_size
|
52 |
+
|
53 |
+
|
54 |
+
class MaskedConv1d(nn.Conv1d):
|
55 |
+
|
56 |
+
def forward(self, x, input_mask=None):
|
57 |
+
if input_mask is not None:
|
58 |
+
x = x * input_mask
|
59 |
+
return super().forward(x)
|
60 |
+
|
61 |
+
|
62 |
+
class ProteinResNetLayerNorm(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self, config):
|
65 |
+
super().__init__()
|
66 |
+
self.norm = LayerNorm(config.hidden_size)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.norm(x.transpose(1, 2)).transpose(1, 2)
|
70 |
+
|
71 |
+
|
72 |
+
class ProteinResNetBlock(nn.Module):
|
73 |
+
|
74 |
+
def __init__(self, config):
|
75 |
+
super().__init__()
|
76 |
+
self.conv1 = MaskedConv1d(
|
77 |
+
config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
|
78 |
+
# self.bn1 = nn.BatchNorm1d(config.hidden_size)
|
79 |
+
self.bn1 = ProteinResNetLayerNorm(config)
|
80 |
+
self.conv2 = MaskedConv1d(
|
81 |
+
config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
|
82 |
+
# self.bn2 = nn.BatchNorm1d(config.hidden_size)
|
83 |
+
self.bn2 = ProteinResNetLayerNorm(config)
|
84 |
+
self.activation_fn = get_activation_fn(config.hidden_act)
|
85 |
+
|
86 |
+
def forward(self, x, input_mask=None):
|
87 |
+
identity = x
|
88 |
+
|
89 |
+
out = self.conv1(x, input_mask)
|
90 |
+
out = self.bn1(out)
|
91 |
+
out = self.activation_fn(out)
|
92 |
+
|
93 |
+
out = self.conv2(out, input_mask)
|
94 |
+
out = self.bn2(out)
|
95 |
+
|
96 |
+
out += identity
|
97 |
+
out = self.activation_fn(out)
|
98 |
+
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
class ProteinResNetEmbeddings(nn.Module):
|
103 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
104 |
+
"""
|
105 |
+
def __init__(self, config):
|
106 |
+
super().__init__()
|
107 |
+
embed_dim = config.hidden_size
|
108 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, embed_dim, padding_idx=0)
|
109 |
+
inverse_frequency = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim))
|
110 |
+
self.register_buffer('inverse_frequency', inverse_frequency)
|
111 |
+
|
112 |
+
self.layer_norm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
113 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
114 |
+
|
115 |
+
def forward(self, input_ids):
|
116 |
+
words_embeddings = self.word_embeddings(input_ids)
|
117 |
+
|
118 |
+
seq_length = input_ids.size(1)
|
119 |
+
position_ids = torch.arange(
|
120 |
+
seq_length - 1, -1, -1.0,
|
121 |
+
dtype=words_embeddings.dtype,
|
122 |
+
device=words_embeddings.device)
|
123 |
+
sinusoidal_input = torch.ger(position_ids, self.inverse_frequency)
|
124 |
+
position_embeddings = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1)
|
125 |
+
position_embeddings = position_embeddings.unsqueeze(0)
|
126 |
+
|
127 |
+
embeddings = words_embeddings + position_embeddings
|
128 |
+
embeddings = self.layer_norm(embeddings)
|
129 |
+
embeddings = self.dropout(embeddings)
|
130 |
+
return embeddings
|
131 |
+
|
132 |
+
|
133 |
+
class ResNetEncoder(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, config):
|
136 |
+
super().__init__()
|
137 |
+
self.config = config
|
138 |
+
self.output_hidden_states = config.output_hidden_states
|
139 |
+
self.encoder = nn.ModuleList(
|
140 |
+
[ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
|
141 |
+
|
142 |
+
self.decoder = nn.ModuleList(
|
143 |
+
[ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
|
144 |
+
|
145 |
+
self.bottleneck1 = nn.Linear(93*config.hidden_size, config.latent_size)
|
146 |
+
self.bottleneck2 = nn.Linear(config.latent_size, 94*config.hidden_size)
|
147 |
+
|
148 |
+
def forward(self, hidden_states, input_mask=None):
|
149 |
+
for i, layer_module in enumerate(self.encoder):
|
150 |
+
hidden_states = layer_module(hidden_states)
|
151 |
+
if i != 0 and i % 5 == 0:
|
152 |
+
hidden_states = nn.functional.avg_pool1d(hidden_states, 2, stride=2)
|
153 |
+
|
154 |
+
bs = hidden_states.shape[0]
|
155 |
+
latents = self.bottleneck1(hidden_states.reshape(bs, -1))
|
156 |
+
hidden_states = self.bottleneck2(latents).reshape(bs, -1, 94)
|
157 |
+
|
158 |
+
|
159 |
+
for i, layer_module in enumerate(self.decoder):
|
160 |
+
if i != 0 and i % 5 == 0:
|
161 |
+
hidden_states = nn.functional.interpolate(hidden_states, scale_factor=2)
|
162 |
+
hidden_states = layer_module(hidden_states)
|
163 |
+
|
164 |
+
hidden_states = hidden_states[:,:,:self.config.max_size]
|
165 |
+
outputs = (hidden_states, latents)
|
166 |
+
|
167 |
+
return outputs
|
168 |
+
|
169 |
+
|
170 |
+
class ProteinAEAbstractModel(ProteinModel):
|
171 |
+
""" An abstract class to handle weights initialization and
|
172 |
+
a simple interface for dowloading and loading pretrained models.
|
173 |
+
"""
|
174 |
+
config_class = ProteinAEConfig
|
175 |
+
base_model_prefix = "ae"
|
176 |
+
|
177 |
+
def __init__(self, config):
|
178 |
+
super().__init__(config)
|
179 |
+
|
180 |
+
def _init_weights(self, module):
|
181 |
+
""" Initialize the weights """
|
182 |
+
if isinstance(module, nn.Embedding):
|
183 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
184 |
+
elif isinstance(module, nn.Linear):
|
185 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
186 |
+
if module.bias is not None:
|
187 |
+
module.bias.data.zero_()
|
188 |
+
elif isinstance(module, nn.Conv1d):
|
189 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
190 |
+
if module.bias is not None:
|
191 |
+
module.bias.data.zero_()
|
192 |
+
|
193 |
+
|
194 |
+
@registry.register_task_model('embed', 'autoencoder')
|
195 |
+
class ProteinResNetModel(ProteinAEAbstractModel):
|
196 |
+
|
197 |
+
def __init__(self, config):
|
198 |
+
super().__init__(config)
|
199 |
+
|
200 |
+
self.embeddings = ProteinResNetEmbeddings(config)
|
201 |
+
self.encoder = ResNetEncoder(config)
|
202 |
+
|
203 |
+
self.init_weights()
|
204 |
+
|
205 |
+
def forward(self,
|
206 |
+
input_ids,
|
207 |
+
input_mask=None):
|
208 |
+
pre_pad_shape = input_ids.shape[1]
|
209 |
+
if pre_pad_shape >= self.config.max_size:
|
210 |
+
input_ids = input_ids[:,:self.config.max_size]
|
211 |
+
if not input_mask is None:
|
212 |
+
input_mask = input_mask[:,:self.config.max_size]
|
213 |
+
else:
|
214 |
+
input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
|
215 |
+
if not input_mask is None:
|
216 |
+
input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
|
217 |
+
assert input_ids.shape[1] == self.config.max_size
|
218 |
+
|
219 |
+
if input_mask is not None and torch.any(input_mask != 1):
|
220 |
+
extended_input_mask = input_mask.unsqueeze(2)
|
221 |
+
# fp16 compatibility
|
222 |
+
extended_input_mask = extended_input_mask.to(
|
223 |
+
dtype=next(self.parameters()).dtype)
|
224 |
+
else:
|
225 |
+
extended_input_mask = None
|
226 |
+
|
227 |
+
embedding_output = self.embeddings(input_ids)
|
228 |
+
embedding_output = embedding_output.transpose(1, 2)
|
229 |
+
if extended_input_mask is not None:
|
230 |
+
extended_input_mask = extended_input_mask.transpose(1, 2)
|
231 |
+
sequence_output, pooled_output = self.encoder(embedding_output, extended_input_mask)
|
232 |
+
sequence_output = sequence_output.transpose(1, 2).contiguous()
|
233 |
+
return sequence_output, pooled_output
|
234 |
+
|
235 |
+
@registry.register_task_model('beta_lactamase', 'autoencoder')
|
236 |
+
@registry.register_task_model('language_modeling', 'autoencoder')
|
237 |
+
class ProteinResNetForMaskedLM(ProteinAEAbstractModel):
|
238 |
+
|
239 |
+
def __init__(self, config):
|
240 |
+
super().__init__(config)
|
241 |
+
|
242 |
+
self.resnet = ProteinResNetModel(config)
|
243 |
+
self.mlm = MLMHead(
|
244 |
+
config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
|
245 |
+
ignore_index=-1)
|
246 |
+
|
247 |
+
self.init_weights()
|
248 |
+
self.tie_weights()
|
249 |
+
|
250 |
+
def tie_weights(self):
|
251 |
+
""" Make sure we are sharing the input and output embeddings.
|
252 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
253 |
+
"""
|
254 |
+
self._tie_or_clone_weights(self.mlm.decoder,
|
255 |
+
self.resnet.embeddings.word_embeddings)
|
256 |
+
|
257 |
+
def forward(self,
|
258 |
+
input_ids,
|
259 |
+
input_mask=None,
|
260 |
+
targets=None):
|
261 |
+
pre_pad_shape = input_ids.shape[1]
|
262 |
+
if targets is not None:
|
263 |
+
targets = targets[:,:self.config.max_size]
|
264 |
+
|
265 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
266 |
+
outputs = self.mlm(outputs[0][:,:pre_pad_shape,:], targets) + (outputs[1],)
|
267 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
268 |
+
return outputs
|
269 |
+
|
270 |
+
|
271 |
+
@registry.register_task_model('fluorescence', 'autoencoder')
|
272 |
+
@registry.register_task_model('stability', 'autoencoder')
|
273 |
+
class ProteinResNetForValuePrediction(ProteinAEAbstractModel):
|
274 |
+
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__(config)
|
277 |
+
|
278 |
+
self.resnet = ProteinResNetModel(config)
|
279 |
+
self.predict = ValuePredictionHead(config.hidden_size)
|
280 |
+
self.freeze_embedding = config.freeze_embedding
|
281 |
+
self.init_weights()
|
282 |
+
|
283 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
284 |
+
if self.freeze_embedding:
|
285 |
+
self.resnet.train(False)
|
286 |
+
|
287 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
288 |
+
|
289 |
+
sequence_output, pooled_output = outputs[:2]
|
290 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
291 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
292 |
+
return outputs
|
293 |
+
|
294 |
+
|
295 |
+
@registry.register_task_model('remote_homology', 'autoencoder')
|
296 |
+
class ProteinResNetForSequenceClassification(ProteinAEAbstractModel):
|
297 |
+
|
298 |
+
def __init__(self, config):
|
299 |
+
super().__init__(config)
|
300 |
+
|
301 |
+
self.resnet = ProteinResNetModel(config)
|
302 |
+
self.classify = SequenceClassificationHead(config.hidden_size, config.num_labels)
|
303 |
+
self.freeze_embedding = config.freeze_embedding
|
304 |
+
|
305 |
+
self.init_weights()
|
306 |
+
|
307 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
308 |
+
if self.freeze_embedding:
|
309 |
+
self.resnet.train(False)
|
310 |
+
|
311 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
312 |
+
|
313 |
+
sequence_output, pooled_output = outputs[:2]
|
314 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
315 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
316 |
+
return outputs
|
tape/models/modeling_bert.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
# Modified by Roshan Rao
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""PyTorch BERT model. """
|
18 |
+
|
19 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
from torch.utils.checkpoint import checkpoint
|
27 |
+
|
28 |
+
from .modeling_utils import ProteinConfig
|
29 |
+
from .modeling_utils import ProteinModel
|
30 |
+
from .modeling_utils import prune_linear_layer
|
31 |
+
from .modeling_utils import get_activation_fn
|
32 |
+
from .modeling_utils import LayerNorm
|
33 |
+
from .modeling_utils import MLMHead
|
34 |
+
from .modeling_utils import ValuePredictionHead
|
35 |
+
from .modeling_utils import SequenceClassificationHead
|
36 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
37 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
38 |
+
from ..registry import registry
|
39 |
+
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
|
43 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
44 |
+
'bert-base': URL_PREFIX + "bert-base-pytorch_model.bin",
|
45 |
+
}
|
46 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
47 |
+
'bert-base': URL_PREFIX + "bert-base-config.json"
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
class ProteinBertConfig(ProteinConfig):
|
52 |
+
r"""
|
53 |
+
:class:`~pytorch_transformers.ProteinBertConfig` is the configuration class to store the
|
54 |
+
configuration of a `ProteinBertModel`.
|
55 |
+
|
56 |
+
|
57 |
+
Arguments:
|
58 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in
|
59 |
+
`ProteinBertModel`.
|
60 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
61 |
+
num_hidden_layers: Number of hidden layers in the ProteinBert encoder.
|
62 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
63 |
+
the ProteinBert encoder.
|
64 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
65 |
+
layer in the ProteinBert encoder.
|
66 |
+
hidden_act: The non-linear activation function (function or string) in the
|
67 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
68 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
69 |
+
layers in the embeddings, encoder, and pooler.
|
70 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
71 |
+
probabilities.
|
72 |
+
max_position_embeddings: The maximum sequence length that this model might
|
73 |
+
ever be used with. Typically set this to something large just in case
|
74 |
+
(e.g., 512 or 1024 or 2048).
|
75 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
76 |
+
`ProteinBertModel`.
|
77 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
78 |
+
initializing all weight matrices.
|
79 |
+
layer_norm_eps: The epsilon used by LayerNorm.
|
80 |
+
"""
|
81 |
+
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
82 |
+
|
83 |
+
def __init__(self,
|
84 |
+
vocab_size: int = 30,
|
85 |
+
hidden_size: int = 768,
|
86 |
+
num_hidden_layers: int = 12,
|
87 |
+
num_attention_heads: int = 12,
|
88 |
+
intermediate_size: int = 3072,
|
89 |
+
hidden_act: str = "gelu",
|
90 |
+
hidden_dropout_prob: float = 0.1,
|
91 |
+
attention_probs_dropout_prob: float = 0.1,
|
92 |
+
max_position_embeddings: int = 8096,
|
93 |
+
type_vocab_size: int = 2,
|
94 |
+
initializer_range: float = 0.02,
|
95 |
+
layer_norm_eps: float = 1e-12,
|
96 |
+
temporal_pooling: str = 'attention',
|
97 |
+
freeze_embedding: bool = False,
|
98 |
+
**kwargs):
|
99 |
+
super().__init__(**kwargs)
|
100 |
+
self.vocab_size = vocab_size
|
101 |
+
self.hidden_size = hidden_size
|
102 |
+
self.num_hidden_layers = num_hidden_layers
|
103 |
+
self.num_attention_heads = num_attention_heads
|
104 |
+
self.hidden_act = hidden_act
|
105 |
+
self.intermediate_size = intermediate_size
|
106 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
107 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
108 |
+
self.max_position_embeddings = max_position_embeddings
|
109 |
+
self.type_vocab_size = type_vocab_size
|
110 |
+
self.initializer_range = initializer_range
|
111 |
+
self.layer_norm_eps = layer_norm_eps
|
112 |
+
self.temporal_pooling = temporal_pooling
|
113 |
+
self.freeze_embedding = freeze_embedding
|
114 |
+
|
115 |
+
|
116 |
+
class ProteinBertEmbeddings(nn.Module):
|
117 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
118 |
+
"""
|
119 |
+
def __init__(self, config):
|
120 |
+
super().__init__()
|
121 |
+
self.word_embeddings = nn.Embedding(
|
122 |
+
config.vocab_size, config.hidden_size, padding_idx=0)
|
123 |
+
self.position_embeddings = nn.Embedding(
|
124 |
+
config.max_position_embeddings, config.hidden_size)
|
125 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
126 |
+
|
127 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be
|
128 |
+
# able to load any TensorFlow checkpoint file
|
129 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
130 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
131 |
+
|
132 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
133 |
+
seq_length = input_ids.size(1)
|
134 |
+
if position_ids is None:
|
135 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
136 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
137 |
+
if token_type_ids is None:
|
138 |
+
token_type_ids = torch.zeros_like(input_ids)
|
139 |
+
|
140 |
+
words_embeddings = self.word_embeddings(input_ids)
|
141 |
+
position_embeddings = self.position_embeddings(position_ids)
|
142 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
143 |
+
|
144 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
145 |
+
embeddings = self.LayerNorm(embeddings)
|
146 |
+
embeddings = self.dropout(embeddings)
|
147 |
+
return embeddings
|
148 |
+
|
149 |
+
|
150 |
+
class ProteinBertSelfAttention(nn.Module):
|
151 |
+
def __init__(self, config):
|
152 |
+
super().__init__()
|
153 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
154 |
+
raise ValueError(
|
155 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
156 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
157 |
+
self.output_attentions = config.output_attentions
|
158 |
+
|
159 |
+
self.num_attention_heads = config.num_attention_heads
|
160 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
161 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
162 |
+
|
163 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
164 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
165 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
166 |
+
|
167 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
168 |
+
|
169 |
+
def transpose_for_scores(self, x):
|
170 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
171 |
+
x = x.view(*new_x_shape)
|
172 |
+
return x.permute(0, 2, 1, 3)
|
173 |
+
|
174 |
+
def forward(self, hidden_states, attention_mask):
|
175 |
+
mixed_query_layer = self.query(hidden_states)
|
176 |
+
mixed_key_layer = self.key(hidden_states)
|
177 |
+
mixed_value_layer = self.value(hidden_states)
|
178 |
+
|
179 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
180 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
181 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
182 |
+
|
183 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
184 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
185 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
186 |
+
# Apply the attention mask is (precomputed for all layers in
|
187 |
+
# ProteinBertModel forward() function)
|
188 |
+
attention_scores = attention_scores + attention_mask
|
189 |
+
|
190 |
+
# Normalize the attention scores to probabilities.
|
191 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
192 |
+
|
193 |
+
# This is actually dropping out entire tokens to attend to, which might
|
194 |
+
# seem a bit unusual, but is taken from the original ProteinBert paper.
|
195 |
+
attention_probs = self.dropout(attention_probs)
|
196 |
+
|
197 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
198 |
+
|
199 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
200 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
201 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
202 |
+
|
203 |
+
outputs = (context_layer, attention_probs) \
|
204 |
+
if self.output_attentions else (context_layer,)
|
205 |
+
return outputs
|
206 |
+
|
207 |
+
|
208 |
+
class ProteinBertSelfOutput(nn.Module):
|
209 |
+
def __init__(self, config):
|
210 |
+
super().__init__()
|
211 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
212 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
213 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
214 |
+
|
215 |
+
def forward(self, hidden_states, input_tensor):
|
216 |
+
hidden_states = self.dense(hidden_states)
|
217 |
+
hidden_states = self.dropout(hidden_states)
|
218 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
219 |
+
return hidden_states
|
220 |
+
|
221 |
+
|
222 |
+
class ProteinBertAttention(nn.Module):
|
223 |
+
def __init__(self, config):
|
224 |
+
super().__init__()
|
225 |
+
self.self = ProteinBertSelfAttention(config)
|
226 |
+
self.output = ProteinBertSelfOutput(config)
|
227 |
+
|
228 |
+
def prune_heads(self, heads):
|
229 |
+
if len(heads) == 0:
|
230 |
+
return
|
231 |
+
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
232 |
+
for head in heads:
|
233 |
+
mask[head] = 0
|
234 |
+
mask = mask.view(-1).contiguous().eq(1)
|
235 |
+
index = torch.arange(len(mask))[mask].long()
|
236 |
+
# Prune linear layers
|
237 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
238 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
239 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
240 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
241 |
+
# Update hyper params
|
242 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
243 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
244 |
+
|
245 |
+
def forward(self, input_tensor, attention_mask):
|
246 |
+
self_outputs = self.self(input_tensor, attention_mask)
|
247 |
+
attention_output = self.output(self_outputs[0], input_tensor)
|
248 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
249 |
+
return outputs
|
250 |
+
|
251 |
+
|
252 |
+
class ProteinBertIntermediate(nn.Module):
|
253 |
+
def __init__(self, config):
|
254 |
+
super().__init__()
|
255 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
256 |
+
if isinstance(config.hidden_act, str):
|
257 |
+
self.intermediate_act_fn = get_activation_fn(config.hidden_act)
|
258 |
+
else:
|
259 |
+
self.intermediate_act_fn = config.hidden_act
|
260 |
+
|
261 |
+
def forward(self, hidden_states):
|
262 |
+
hidden_states = self.dense(hidden_states)
|
263 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
264 |
+
return hidden_states
|
265 |
+
|
266 |
+
|
267 |
+
class ProteinBertOutput(nn.Module):
|
268 |
+
def __init__(self, config):
|
269 |
+
super().__init__()
|
270 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
271 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
272 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
273 |
+
|
274 |
+
def forward(self, hidden_states, input_tensor):
|
275 |
+
hidden_states = self.dense(hidden_states)
|
276 |
+
hidden_states = self.dropout(hidden_states)
|
277 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
278 |
+
return hidden_states
|
279 |
+
|
280 |
+
|
281 |
+
class ProteinBertLayer(nn.Module):
|
282 |
+
def __init__(self, config):
|
283 |
+
super().__init__()
|
284 |
+
self.attention = ProteinBertAttention(config)
|
285 |
+
self.intermediate = ProteinBertIntermediate(config)
|
286 |
+
self.output = ProteinBertOutput(config)
|
287 |
+
|
288 |
+
def forward(self, hidden_states, attention_mask):
|
289 |
+
attention_outputs = self.attention(hidden_states, attention_mask)
|
290 |
+
attention_output = attention_outputs[0]
|
291 |
+
intermediate_output = self.intermediate(attention_output)
|
292 |
+
layer_output = self.output(intermediate_output, attention_output)
|
293 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
294 |
+
return outputs
|
295 |
+
|
296 |
+
|
297 |
+
class ProteinBertEncoder(nn.Module):
|
298 |
+
def __init__(self, config):
|
299 |
+
super().__init__()
|
300 |
+
self.output_attentions = config.output_attentions
|
301 |
+
self.output_hidden_states = config.output_hidden_states
|
302 |
+
self.layer = nn.ModuleList(
|
303 |
+
[ProteinBertLayer(config) for _ in range(config.num_hidden_layers)])
|
304 |
+
|
305 |
+
def run_function(self, start, chunk_size):
|
306 |
+
def custom_forward(hidden_states, attention_mask):
|
307 |
+
all_hidden_states = ()
|
308 |
+
all_attentions = ()
|
309 |
+
chunk_slice = slice(start, start + chunk_size)
|
310 |
+
for layer in self.layer[chunk_slice]:
|
311 |
+
if self.output_hidden_states:
|
312 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
313 |
+
layer_outputs = layer(hidden_states, attention_mask)
|
314 |
+
hidden_states = layer_outputs[0]
|
315 |
+
|
316 |
+
if self.output_attentions:
|
317 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
318 |
+
|
319 |
+
if self.output_hidden_states:
|
320 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
321 |
+
outputs = (hidden_states,)
|
322 |
+
if self.output_hidden_states:
|
323 |
+
outputs = outputs + (all_hidden_states,)
|
324 |
+
if self.output_attentions:
|
325 |
+
outputs = outputs + (all_attentions,)
|
326 |
+
return outputs
|
327 |
+
|
328 |
+
return custom_forward
|
329 |
+
|
330 |
+
def forward(self, hidden_states, attention_mask, chunks=None):
|
331 |
+
all_hidden_states = ()
|
332 |
+
all_attentions = ()
|
333 |
+
|
334 |
+
if chunks is not None:
|
335 |
+
assert isinstance(chunks, int)
|
336 |
+
chunk_size = (len(self.layer) + chunks - 1) // chunks
|
337 |
+
for start in range(0, len(self.layer), chunk_size):
|
338 |
+
outputs = checkpoint(self.run_function(start, chunk_size),
|
339 |
+
hidden_states, attention_mask)
|
340 |
+
if self.output_hidden_states:
|
341 |
+
all_hidden_states = all_hidden_states + outputs[1]
|
342 |
+
if self.output_attentions:
|
343 |
+
all_attentions = all_attentions + outputs[-1]
|
344 |
+
hidden_states = outputs[0]
|
345 |
+
else:
|
346 |
+
for i, layer_module in enumerate(self.layer):
|
347 |
+
if self.output_hidden_states:
|
348 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
349 |
+
|
350 |
+
layer_outputs = layer_module(hidden_states, attention_mask)
|
351 |
+
hidden_states = layer_outputs[0]
|
352 |
+
|
353 |
+
if self.output_attentions:
|
354 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
355 |
+
|
356 |
+
# Add last layer
|
357 |
+
if self.output_hidden_states:
|
358 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
359 |
+
|
360 |
+
outputs = (hidden_states,)
|
361 |
+
if self.output_hidden_states:
|
362 |
+
outputs = outputs + (all_hidden_states,)
|
363 |
+
if self.output_attentions:
|
364 |
+
outputs = outputs + (all_attentions,)
|
365 |
+
return outputs # outputs, (hidden states), (attentions)
|
366 |
+
|
367 |
+
|
368 |
+
class ProteinBertPooler(nn.Module):
|
369 |
+
def __init__(self, config):
|
370 |
+
super().__init__()
|
371 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
372 |
+
self.activation = nn.Tanh()
|
373 |
+
self.temporal_pooling = config.temporal_pooling
|
374 |
+
self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
375 |
+
self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
376 |
+
self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
|
377 |
+
|
378 |
+
def forward(self, hidden_states):
|
379 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
380 |
+
# to the first token.
|
381 |
+
if self.temporal_pooling == 'mean':
|
382 |
+
return hidden_states.mean(dim=1)
|
383 |
+
if self.temporal_pooling == 'max':
|
384 |
+
return hidden_states.max(dim=1)
|
385 |
+
if self.temporal_pooling == 'concat':
|
386 |
+
_temp = hidden_states.reshape(hidden_states.shape[0], -1)
|
387 |
+
return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
|
388 |
+
if self.temporal_pooling == 'topmax':
|
389 |
+
val, _ = torch.topk(hidden_states, k=5, dim=1)
|
390 |
+
return val.mean(dim=1)
|
391 |
+
if self.temporal_pooling == 'light_attention':
|
392 |
+
_temp = hidden_states.permute(0,2,1)
|
393 |
+
a = self._la_w1(_temp).softmax(dim=-1)
|
394 |
+
v = self._la_w2(_temp)
|
395 |
+
v_max = v.max(dim=-1).values
|
396 |
+
v_sum = (a * v).sum(dim=-1)
|
397 |
+
return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
|
398 |
+
|
399 |
+
first_token_tensor = hidden_states[:, 0]
|
400 |
+
pooled_output = self.dense(first_token_tensor)
|
401 |
+
pooled_output = self.activation(pooled_output)
|
402 |
+
return pooled_output
|
403 |
+
|
404 |
+
|
405 |
+
class ProteinBertAbstractModel(ProteinModel):
|
406 |
+
""" An abstract class to handle weights initialization and
|
407 |
+
a simple interface for dowloading and loading pretrained models.
|
408 |
+
"""
|
409 |
+
config_class = ProteinBertConfig
|
410 |
+
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
411 |
+
base_model_prefix = "bert"
|
412 |
+
|
413 |
+
def _init_weights(self, module):
|
414 |
+
""" Initialize the weights """
|
415 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
416 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
417 |
+
elif isinstance(module, LayerNorm):
|
418 |
+
module.bias.data.zero_()
|
419 |
+
module.weight.data.fill_(1.0)
|
420 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
421 |
+
module.bias.data.zero_()
|
422 |
+
|
423 |
+
|
424 |
+
@registry.register_task_model('embed', 'transformer')
|
425 |
+
class ProteinBertModel(ProteinBertAbstractModel):
|
426 |
+
|
427 |
+
def __init__(self, config):
|
428 |
+
super().__init__(config)
|
429 |
+
|
430 |
+
self.embeddings = ProteinBertEmbeddings(config)
|
431 |
+
self.encoder = ProteinBertEncoder(config)
|
432 |
+
self.pooler = ProteinBertPooler(config)
|
433 |
+
|
434 |
+
self.init_weights()
|
435 |
+
|
436 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
437 |
+
old_embeddings = self.embeddings.word_embeddings
|
438 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
439 |
+
self.embeddings.word_embeddings = new_embeddings
|
440 |
+
return self.embeddings.word_embeddings
|
441 |
+
|
442 |
+
def _prune_heads(self, heads_to_prune):
|
443 |
+
""" Prunes heads of the model.
|
444 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
445 |
+
See base class ProteinModel
|
446 |
+
"""
|
447 |
+
for layer, heads in heads_to_prune.items():
|
448 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
449 |
+
|
450 |
+
def forward(self,
|
451 |
+
input_ids,
|
452 |
+
input_mask=None):
|
453 |
+
if input_mask is None:
|
454 |
+
input_mask = torch.ones_like(input_ids)
|
455 |
+
|
456 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
457 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
458 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
459 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
460 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
461 |
+
extended_attention_mask = input_mask.unsqueeze(1).unsqueeze(2)
|
462 |
+
|
463 |
+
# Since input_mask is 1.0 for positions we want to attend and 0.0 for
|
464 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
465 |
+
# positions we want to attend and -10000.0 for masked positions.
|
466 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
467 |
+
# effectively the same as removing these entirely.
|
468 |
+
extended_attention_mask = extended_attention_mask.to(
|
469 |
+
dtype=torch.float32) # fp16 compatibility
|
470 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
471 |
+
|
472 |
+
embedding_output = self.embeddings(input_ids)
|
473 |
+
encoder_outputs = self.encoder(embedding_output,
|
474 |
+
extended_attention_mask,
|
475 |
+
chunks=None)
|
476 |
+
sequence_output = encoder_outputs[0]
|
477 |
+
pooled_output = self.pooler(sequence_output)
|
478 |
+
|
479 |
+
# add hidden_states and attentions if they are here
|
480 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
481 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
482 |
+
|
483 |
+
|
484 |
+
@registry.register_task_model('masked_language_modeling', 'transformer')
|
485 |
+
class ProteinBertForMaskedLM(ProteinBertAbstractModel):
|
486 |
+
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__(config)
|
489 |
+
|
490 |
+
self.bert = ProteinBertModel(config)
|
491 |
+
self.mlm = MLMHead(
|
492 |
+
config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
|
493 |
+
ignore_index=-1)
|
494 |
+
|
495 |
+
self.init_weights()
|
496 |
+
self.tie_weights()
|
497 |
+
|
498 |
+
def tie_weights(self):
|
499 |
+
""" Make sure we are sharing the input and output embeddings.
|
500 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
501 |
+
"""
|
502 |
+
self._tie_or_clone_weights(self.mlm.decoder,
|
503 |
+
self.bert.embeddings.word_embeddings)
|
504 |
+
|
505 |
+
def forward(self,
|
506 |
+
input_ids,
|
507 |
+
input_mask=None,
|
508 |
+
targets=None):
|
509 |
+
|
510 |
+
outputs = self.bert(input_ids, input_mask=input_mask)
|
511 |
+
|
512 |
+
sequence_output, pooled_output = outputs[:2]
|
513 |
+
# add hidden states and attention if they are here
|
514 |
+
outputs = self.mlm(sequence_output, targets) + outputs[:2]
|
515 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
516 |
+
return outputs
|
517 |
+
|
518 |
+
|
519 |
+
@registry.register_task_model('fluorescence', 'transformer')
|
520 |
+
@registry.register_task_model('stability', 'transformer')
|
521 |
+
class ProteinBertForValuePrediction(ProteinBertAbstractModel):
|
522 |
+
|
523 |
+
def __init__(self, config):
|
524 |
+
super().__init__(config)
|
525 |
+
|
526 |
+
self.bert = ProteinBertModel(config)
|
527 |
+
self.predict = ValuePredictionHead(config.hidden_size)
|
528 |
+
self.freeze_embedding = config.freeze_embedding
|
529 |
+
self.init_weights()
|
530 |
+
|
531 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
532 |
+
if self.freeze_embedding:
|
533 |
+
self.bert.train(False)
|
534 |
+
outputs = self.bert(input_ids, input_mask=input_mask)
|
535 |
+
|
536 |
+
sequence_output, pooled_output = outputs[:2]
|
537 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
538 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
539 |
+
return outputs
|
540 |
+
|
541 |
+
|
542 |
+
@registry.register_task_model('remote_homology', 'transformer')
|
543 |
+
class ProteinBertForSequenceClassification(ProteinBertAbstractModel):
|
544 |
+
|
545 |
+
def __init__(self, config):
|
546 |
+
super().__init__(config)
|
547 |
+
|
548 |
+
self.bert = ProteinBertModel(config)
|
549 |
+
self.classify = SequenceClassificationHead(
|
550 |
+
config.hidden_size, config.num_labels)
|
551 |
+
self.freeze_embedding = config.freeze_embedding
|
552 |
+
self.init_weights()
|
553 |
+
|
554 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
555 |
+
if self.freeze_embedding:
|
556 |
+
self.bert.train(False)
|
557 |
+
outputs = self.bert(input_ids, input_mask=input_mask)
|
558 |
+
|
559 |
+
sequence_output, pooled_output = outputs[:2]
|
560 |
+
|
561 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
562 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
563 |
+
return outputs
|
564 |
+
|
565 |
+
|
566 |
+
@registry.register_task_model('secondary_structure', 'transformer')
|
567 |
+
class ProteinBertForSequenceToSequenceClassification(ProteinBertAbstractModel):
|
568 |
+
|
569 |
+
def __init__(self, config):
|
570 |
+
super().__init__(config)
|
571 |
+
|
572 |
+
self.bert = ProteinBertModel(config)
|
573 |
+
self.classify = SequenceToSequenceClassificationHead(
|
574 |
+
config.hidden_size, config.num_labels, ignore_index=-1)
|
575 |
+
|
576 |
+
self.init_weights()
|
577 |
+
|
578 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
579 |
+
|
580 |
+
outputs = self.bert(input_ids, input_mask=input_mask)
|
581 |
+
|
582 |
+
sequence_output, pooled_output = outputs[:2]
|
583 |
+
outputs = self.classify(sequence_output, targets) + outputs[2:]
|
584 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
585 |
+
return outputs
|
586 |
+
|
587 |
+
|
588 |
+
@registry.register_task_model('contact_prediction', 'transformer')
|
589 |
+
class ProteinBertForContactPrediction(ProteinBertAbstractModel):
|
590 |
+
|
591 |
+
def __init__(self, config):
|
592 |
+
super().__init__(config)
|
593 |
+
|
594 |
+
self.bert = ProteinBertModel(config)
|
595 |
+
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
|
596 |
+
|
597 |
+
self.init_weights()
|
598 |
+
|
599 |
+
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
|
600 |
+
|
601 |
+
outputs = self.bert(input_ids, input_mask=input_mask)
|
602 |
+
|
603 |
+
sequence_output, pooled_output = outputs[:2]
|
604 |
+
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
|
605 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
606 |
+
return outputs
|
tape/models/modeling_bottleneck.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from tape import ProteinModel, ProteinConfig
|
5 |
+
from tape.models.modeling_utils import SequenceToSequenceClassificationHead
|
6 |
+
from tape.registry import registry
|
7 |
+
from .modeling_utils import LayerNorm, MLMHead
|
8 |
+
from .modeling_bert import ProteinBertModel, ProteinBertConfig
|
9 |
+
from .modeling_lstm import ProteinLSTMModel, ProteinLSTMConfig
|
10 |
+
from .modeling_resnet import ProteinResNetModel, ProteinResNetConfig
|
11 |
+
|
12 |
+
|
13 |
+
class BottleneckConfig(ProteinConfig):
|
14 |
+
def __init__(self,
|
15 |
+
hidden_size: int = 1024,
|
16 |
+
max_size: int = 300,
|
17 |
+
backend_name: str = 'resnet',
|
18 |
+
**kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.hidden_size = hidden_size
|
21 |
+
self.max_size = max_size
|
22 |
+
self.backend_name = backend_name
|
23 |
+
|
24 |
+
|
25 |
+
class BottleneckAbstractModel(ProteinModel):
|
26 |
+
""" All your models will inherit from this one - it's used to define the
|
27 |
+
config_class of the model set and also to define the base_model_prefix.
|
28 |
+
This is used to allow easy loading/saving into different models.
|
29 |
+
"""
|
30 |
+
config_class = BottleneckConfig
|
31 |
+
base_model_prefix = 'bottleneck'
|
32 |
+
|
33 |
+
def __init__(self, config):
|
34 |
+
super().__init__(config)
|
35 |
+
|
36 |
+
def _init_weights(self, module):
|
37 |
+
""" Initialize the weights """
|
38 |
+
if isinstance(module, nn.Embedding):
|
39 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
40 |
+
elif isinstance(module, nn.Linear):
|
41 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
42 |
+
if module.bias is not None:
|
43 |
+
module.bias.data.zero_()
|
44 |
+
elif isinstance(module, LayerNorm):
|
45 |
+
module.bias.data.zero_()
|
46 |
+
module.weight.data.fill_(1.0)
|
47 |
+
elif isinstance(module, nn.Conv1d):
|
48 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
49 |
+
if module.bias is not None:
|
50 |
+
module.bias.data.zero_()
|
51 |
+
# elif isinstance(module, ProteinResNetBlock):
|
52 |
+
# nn.init.constant_(module.bn2.weight, 0)
|
53 |
+
|
54 |
+
@registry.register_task_model('embed', 'bottleneck')
|
55 |
+
class ProteinBottleneckModel(BottleneckAbstractModel):
|
56 |
+
|
57 |
+
def __init__(self, config):
|
58 |
+
super().__init__(config)
|
59 |
+
if config.backend_name == 'resnet':
|
60 |
+
config = ProteinResNetConfig()
|
61 |
+
self.backbone1 = ProteinResNetModel(config)
|
62 |
+
elif config.backend_name == 'transformer':
|
63 |
+
config = ProteinBertConfig()
|
64 |
+
self.backbone1 = ProteinBertModel(config)
|
65 |
+
elif config.backend_name == 'lstm':
|
66 |
+
config = ProteinLSTMConfig(hidden_size=256)
|
67 |
+
self.backbone1 = ProteinLSTMModel(config)
|
68 |
+
config.hidden_size = config.hidden_size * 2
|
69 |
+
else:
|
70 |
+
raise ValueError('Somethings wrong')
|
71 |
+
self.linear1 = nn.Linear(self.config.max_size*config.hidden_size, self.config.hidden_size)
|
72 |
+
self.linear2 = nn.Linear(self.config.hidden_size, self.config.max_size*config.hidden_size)
|
73 |
+
|
74 |
+
def forward(self, input_ids, input_mask=None):
|
75 |
+
pre_pad_shape = input_ids.shape[1]
|
76 |
+
if pre_pad_shape >= self.config.max_size:
|
77 |
+
input_ids = input_ids[:,:self.config.max_size]
|
78 |
+
if not input_mask is None:
|
79 |
+
input_mask = input_mask[:,:self.config.max_size]
|
80 |
+
else:
|
81 |
+
input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
|
82 |
+
if not input_mask is None:
|
83 |
+
input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
|
84 |
+
assert input_ids.shape[1] == self.config.max_size
|
85 |
+
|
86 |
+
output = self.backbone1(input_ids, input_mask)
|
87 |
+
sequence_output = output[0]
|
88 |
+
pre_shape = sequence_output.shape
|
89 |
+
embeddings = self.linear1(sequence_output.reshape(sequence_output.shape[0], -1))
|
90 |
+
sequence_output = self.linear2(embeddings).reshape(*pre_shape)
|
91 |
+
sequence_output = sequence_output[:,:pre_pad_shape]
|
92 |
+
outputs = (sequence_output, embeddings) + output[2:]
|
93 |
+
return outputs
|
94 |
+
|
95 |
+
@registry.register_task_model('beta_lactamase', 'bottleneck')
|
96 |
+
@registry.register_task_model('masked_language_modeling', 'bottleneck')
|
97 |
+
@registry.register_task_model('language_modeling', 'bottleneck')
|
98 |
+
class ProteinBottleneckForPretraining(BottleneckAbstractModel):
|
99 |
+
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__(config)
|
102 |
+
self.backbone1 = ProteinBottleneckModel(config)
|
103 |
+
|
104 |
+
if config.backend_name == 'resnet':
|
105 |
+
config = ProteinResNetConfig()
|
106 |
+
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
|
107 |
+
config.layer_norm_eps, ignore_index=-1)
|
108 |
+
elif config.backend_name == 'transformer':
|
109 |
+
config = ProteinBertConfig()
|
110 |
+
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
|
111 |
+
config.layer_norm_eps, ignore_index=-1)
|
112 |
+
elif config.backend_name == 'lstm':
|
113 |
+
config = ProteinLSTMConfig(hidden_size=256)
|
114 |
+
self.backbone2 = nn.Linear(config.hidden_size, config.vocab_size)
|
115 |
+
config.hidden_size = config.hidden_size * 2
|
116 |
+
else:
|
117 |
+
raise ValueError('Somethings wrong')
|
118 |
+
|
119 |
+
def forward(self,
|
120 |
+
input_ids,
|
121 |
+
input_mask=None,
|
122 |
+
targets=None):
|
123 |
+
if input_ids.shape[1]>self.config.max_size:
|
124 |
+
targets = targets[:,:self.config.max_size]
|
125 |
+
|
126 |
+
outputs = self.backbone1(input_ids, input_mask)
|
127 |
+
sequence_output = outputs[0]
|
128 |
+
if self.config.backend_name == 'resnet' or self.config.backend_name == 'transformer':
|
129 |
+
outputs = self.backbone2(sequence_output, targets) + outputs[2:]
|
130 |
+
elif self.config.backend_name == 'lstm':
|
131 |
+
sequence_output, pooled_output = outputs[:2]
|
132 |
+
|
133 |
+
forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
|
134 |
+
forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
|
135 |
+
reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
|
136 |
+
prediction_scores = \
|
137 |
+
self.backbone2(forward_prediction) + self.backbone2(reverse_prediction)
|
138 |
+
prediction_scores = prediction_scores.contiguous()
|
139 |
+
|
140 |
+
# add hidden states and if they are here
|
141 |
+
outputs = (prediction_scores,) + outputs[2:]
|
142 |
+
|
143 |
+
if targets is not None:
|
144 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
145 |
+
lm_loss = loss_fct(
|
146 |
+
prediction_scores.view(-1, 30), targets.view(-1))
|
147 |
+
outputs = (lm_loss,) + outputs
|
148 |
+
|
149 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
150 |
+
return outputs
|
tape/models/modeling_lstm.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import typing
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .modeling_utils import ProteinConfig
|
8 |
+
from .modeling_utils import ProteinModel
|
9 |
+
from .modeling_utils import ValuePredictionHead
|
10 |
+
from .modeling_utils import SequenceClassificationHead
|
11 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
12 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
13 |
+
from ..registry import registry
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
|
19 |
+
LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
20 |
+
LSTM_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
21 |
+
|
22 |
+
|
23 |
+
class ProteinLSTMConfig(ProteinConfig):
|
24 |
+
pretrained_config_archive_map = LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
vocab_size: int = 30,
|
28 |
+
input_size: int = 128,
|
29 |
+
hidden_size: int = 1024,
|
30 |
+
num_hidden_layers: int = 3,
|
31 |
+
hidden_dropout_prob: float = 0.1,
|
32 |
+
initializer_range: float = 0.02,
|
33 |
+
temporal_pooling: str = 'attention',
|
34 |
+
freeze_embedding: bool = False,
|
35 |
+
**kwargs):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.vocab_size = vocab_size
|
38 |
+
self.input_size = input_size
|
39 |
+
self.hidden_size = hidden_size
|
40 |
+
self.num_hidden_layers = num_hidden_layers
|
41 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
42 |
+
self.initializer_range = initializer_range
|
43 |
+
self.temporal_pooling = temporal_pooling
|
44 |
+
self.freeze_embedding = freeze_embedding
|
45 |
+
|
46 |
+
|
47 |
+
class ProteinLSTMLayer(nn.Module):
|
48 |
+
|
49 |
+
def __init__(self, input_size: int, hidden_size: int, dropout: float = 0.):
|
50 |
+
super().__init__()
|
51 |
+
self.dropout = nn.Dropout(dropout)
|
52 |
+
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
53 |
+
|
54 |
+
def forward(self, inputs):
|
55 |
+
inputs = self.dropout(inputs)
|
56 |
+
self.lstm.flatten_parameters()
|
57 |
+
return self.lstm(inputs)
|
58 |
+
|
59 |
+
|
60 |
+
class ProteinLSTMPooler(nn.Module):
|
61 |
+
def __init__(self, config):
|
62 |
+
super().__init__()
|
63 |
+
self.scalar_reweighting = nn.Linear(2 * config.num_hidden_layers, 1)
|
64 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
65 |
+
self.activation = nn.Tanh()
|
66 |
+
self.temporal_pooling = config.temporal_pooling
|
67 |
+
self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
68 |
+
self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
69 |
+
self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
|
70 |
+
|
71 |
+
def forward(self, hidden_states):
|
72 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
73 |
+
# to the first token.
|
74 |
+
if self.temporal_pooling == 'mean':
|
75 |
+
return hidden_states.mean(dim=1)
|
76 |
+
if self.temporal_pooling == 'max':
|
77 |
+
return hidden_states.max(dim=1)
|
78 |
+
if self.temporal_pooling == 'concat':
|
79 |
+
_temp = hidden_states.reshape(hidden_states.shape[0], -1)
|
80 |
+
return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
|
81 |
+
if self.temporal_pooling == 'topmax':
|
82 |
+
val, _ = torch.topk(hidden_states, k=5, dim=1)
|
83 |
+
return val.mean(dim=1)
|
84 |
+
if self.temporal_pooling == 'light_attention':
|
85 |
+
_temp = hidden_states.permute(0,2,1)
|
86 |
+
a = self._la_w1(_temp).softmax(dim=-1)
|
87 |
+
v = self._la_w2(_temp)
|
88 |
+
v_max = v.max(dim=-1).values
|
89 |
+
v_sum = (a * v).sum(dim=-1)
|
90 |
+
return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
|
91 |
+
|
92 |
+
pooled_output = self.scalar_reweighting(hidden_states).squeeze(2)
|
93 |
+
pooled_output = self.dense(pooled_output)
|
94 |
+
pooled_output = self.activation(pooled_output)
|
95 |
+
return pooled_output
|
96 |
+
|
97 |
+
|
98 |
+
class ProteinLSTMEncoder(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, config: ProteinLSTMConfig):
|
101 |
+
super().__init__()
|
102 |
+
forward_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)]
|
103 |
+
reverse_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)]
|
104 |
+
for _ in range(config.num_hidden_layers - 1):
|
105 |
+
forward_lstm.append(ProteinLSTMLayer(
|
106 |
+
config.hidden_size, config.hidden_size, config.hidden_dropout_prob))
|
107 |
+
reverse_lstm.append(ProteinLSTMLayer(
|
108 |
+
config.hidden_size, config.hidden_size, config.hidden_dropout_prob))
|
109 |
+
self.forward_lstm = nn.ModuleList(forward_lstm)
|
110 |
+
self.reverse_lstm = nn.ModuleList(reverse_lstm)
|
111 |
+
self.output_hidden_states = config.output_hidden_states
|
112 |
+
|
113 |
+
def forward(self, inputs, input_mask=None):
|
114 |
+
all_forward_pooled = ()
|
115 |
+
all_reverse_pooled = ()
|
116 |
+
all_hidden_states = (inputs,)
|
117 |
+
forward_output = inputs
|
118 |
+
for layer in self.forward_lstm:
|
119 |
+
forward_output, forward_pooled = layer(forward_output)
|
120 |
+
all_forward_pooled = all_forward_pooled + (forward_pooled[0],)
|
121 |
+
all_hidden_states = all_hidden_states + (forward_output,)
|
122 |
+
|
123 |
+
reversed_sequence = self.reverse_sequence(inputs, input_mask)
|
124 |
+
reverse_output = reversed_sequence
|
125 |
+
for layer in self.reverse_lstm:
|
126 |
+
reverse_output, reverse_pooled = layer(reverse_output)
|
127 |
+
all_reverse_pooled = all_reverse_pooled + (reverse_pooled[0],)
|
128 |
+
all_hidden_states = all_hidden_states + (reverse_output,)
|
129 |
+
reverse_output = self.reverse_sequence(reverse_output, input_mask)
|
130 |
+
|
131 |
+
output = torch.cat((forward_output, reverse_output), dim=2)
|
132 |
+
|
133 |
+
pooled = all_forward_pooled + all_reverse_pooled
|
134 |
+
pooled = torch.stack(pooled, 3).squeeze(0)
|
135 |
+
outputs = (output, pooled)
|
136 |
+
if self.output_hidden_states:
|
137 |
+
outputs = outputs + (all_hidden_states,)
|
138 |
+
|
139 |
+
return outputs # sequence_embedding, pooled_embedding, (hidden_states)
|
140 |
+
|
141 |
+
def reverse_sequence(self, sequence, input_mask):
|
142 |
+
if input_mask is None:
|
143 |
+
idx = torch.arange(sequence.size(1) - 1, -1, -1)
|
144 |
+
reversed_sequence = sequence.index_select(1, idx, device=sequence.device)
|
145 |
+
else:
|
146 |
+
sequence_lengths = input_mask.sum(1)
|
147 |
+
reversed_sequence = []
|
148 |
+
for seq, seqlen in zip(sequence, sequence_lengths):
|
149 |
+
idx = torch.arange(seqlen - 1, -1, -1, device=seq.device)
|
150 |
+
seq = seq.index_select(0, idx)
|
151 |
+
seq = F.pad(seq, [0, 0, 0, sequence.size(1) - seqlen])
|
152 |
+
reversed_sequence.append(seq)
|
153 |
+
reversed_sequence = torch.stack(reversed_sequence, 0)
|
154 |
+
return reversed_sequence
|
155 |
+
|
156 |
+
|
157 |
+
class ProteinLSTMAbstractModel(ProteinModel):
|
158 |
+
|
159 |
+
config_class = ProteinLSTMConfig
|
160 |
+
pretrained_model_archive_map = LSTM_PRETRAINED_MODEL_ARCHIVE_MAP
|
161 |
+
base_model_prefix = "lstm"
|
162 |
+
|
163 |
+
def _init_weights(self, module):
|
164 |
+
""" Initialize the weights """
|
165 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
166 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
167 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
168 |
+
module.bias.data.zero_()
|
169 |
+
|
170 |
+
|
171 |
+
@registry.register_task_model('embed', 'lstm')
|
172 |
+
class ProteinLSTMModel(ProteinLSTMAbstractModel):
|
173 |
+
|
174 |
+
def __init__(self, config: ProteinLSTMConfig):
|
175 |
+
super().__init__(config)
|
176 |
+
self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size)
|
177 |
+
self.encoder = ProteinLSTMEncoder(config)
|
178 |
+
self.pooler = ProteinLSTMPooler(config)
|
179 |
+
self.output_hidden_states = config.output_hidden_states
|
180 |
+
self.init_weights()
|
181 |
+
|
182 |
+
def forward(self, input_ids, input_mask=None):
|
183 |
+
if input_mask is None:
|
184 |
+
input_mask = torch.ones_like(input_ids)
|
185 |
+
|
186 |
+
# fp16 compatibility
|
187 |
+
embedding_output = self.embed_matrix(input_ids)
|
188 |
+
outputs = self.encoder(embedding_output, input_mask=input_mask)
|
189 |
+
sequence_output = outputs[0]
|
190 |
+
pooled_outputs = self.pooler(outputs[1])
|
191 |
+
|
192 |
+
outputs = (sequence_output, pooled_outputs) + outputs[2:]
|
193 |
+
return outputs # sequence_output, pooled_output, (hidden_states)
|
194 |
+
|
195 |
+
|
196 |
+
@registry.register_task_model('language_modeling', 'lstm')
|
197 |
+
class ProteinLSTMForLM(ProteinLSTMAbstractModel):
|
198 |
+
|
199 |
+
def __init__(self, config):
|
200 |
+
super().__init__(config)
|
201 |
+
|
202 |
+
self.lstm = ProteinLSTMModel(config)
|
203 |
+
self.feedforward = nn.Linear(config.hidden_size, config.vocab_size)
|
204 |
+
|
205 |
+
self.init_weights()
|
206 |
+
|
207 |
+
def forward(self,
|
208 |
+
input_ids,
|
209 |
+
input_mask=None,
|
210 |
+
targets=None):
|
211 |
+
|
212 |
+
outputs = self.lstm(input_ids, input_mask=input_mask)
|
213 |
+
|
214 |
+
sequence_output, pooled_output = outputs[:2]
|
215 |
+
|
216 |
+
forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
|
217 |
+
forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
|
218 |
+
reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
|
219 |
+
prediction_scores = \
|
220 |
+
self.feedforward(forward_prediction) + self.feedforward(reverse_prediction)
|
221 |
+
prediction_scores = prediction_scores.contiguous()
|
222 |
+
|
223 |
+
# add hidden states and if they are here
|
224 |
+
outputs = (prediction_scores,) + outputs[:2]
|
225 |
+
|
226 |
+
if targets is not None:
|
227 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
228 |
+
lm_loss = loss_fct(
|
229 |
+
prediction_scores.view(-1, self.config.vocab_size), targets.view(-1))
|
230 |
+
outputs = (lm_loss,) + outputs
|
231 |
+
|
232 |
+
# (loss), prediction_scores, seq_relationship_score, (hidden_states)
|
233 |
+
return outputs
|
234 |
+
|
235 |
+
|
236 |
+
@registry.register_task_model('fluorescence', 'lstm')
|
237 |
+
@registry.register_task_model('stability', 'lstm')
|
238 |
+
class ProteinLSTMForValuePrediction(ProteinLSTMAbstractModel):
|
239 |
+
|
240 |
+
def __init__(self, config):
|
241 |
+
super().__init__(config)
|
242 |
+
|
243 |
+
self.lstm = ProteinLSTMModel(config)
|
244 |
+
self.predict = ValuePredictionHead(config.hidden_size)
|
245 |
+
self.freeze_embedding = config.freeze_embedding
|
246 |
+
self.init_weights()
|
247 |
+
|
248 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
249 |
+
if self.freeze_embedding:
|
250 |
+
self.lstm.train(False)
|
251 |
+
|
252 |
+
outputs = self.lstm(input_ids, input_mask=input_mask)
|
253 |
+
|
254 |
+
sequence_output, pooled_output = outputs[:2]
|
255 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
256 |
+
# (loss), prediction_scores, (hidden_states)
|
257 |
+
return outputs
|
258 |
+
|
259 |
+
|
260 |
+
@registry.register_task_model('remote_homology', 'lstm')
|
261 |
+
class ProteinLSTMForSequenceClassification(ProteinLSTMAbstractModel):
|
262 |
+
|
263 |
+
def __init__(self, config):
|
264 |
+
super().__init__(config)
|
265 |
+
|
266 |
+
self.lstm = ProteinLSTMModel(config)
|
267 |
+
self.classify = SequenceClassificationHead(
|
268 |
+
config.hidden_size, config.num_labels)
|
269 |
+
self.freeze_embedding = config.freeze_embedding
|
270 |
+
self.init_weights()
|
271 |
+
|
272 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
273 |
+
if self.freeze_embedding:
|
274 |
+
self.lstm.train(False)
|
275 |
+
|
276 |
+
outputs = self.lstm(input_ids, input_mask=input_mask)
|
277 |
+
|
278 |
+
sequence_output, pooled_output = outputs[:2]
|
279 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
280 |
+
# (loss), prediction_scores, (hidden_states)
|
281 |
+
return outputs
|
282 |
+
|
283 |
+
|
284 |
+
@registry.register_task_model('secondary_structure', 'lstm')
|
285 |
+
class ProteinLSTMForSequenceToSequenceClassification(ProteinLSTMAbstractModel):
|
286 |
+
|
287 |
+
def __init__(self, config):
|
288 |
+
super().__init__(config)
|
289 |
+
|
290 |
+
self.lstm = ProteinLSTMModel(config)
|
291 |
+
self.classify = SequenceToSequenceClassificationHead(
|
292 |
+
config.hidden_size * 2, config.num_labels, ignore_index=-1)
|
293 |
+
|
294 |
+
self.init_weights()
|
295 |
+
|
296 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
297 |
+
|
298 |
+
outputs = self.lstm(input_ids, input_mask=input_mask)
|
299 |
+
|
300 |
+
sequence_output, pooled_output = outputs[:2]
|
301 |
+
amino_acid_class_scores = self.classify(sequence_output.contiguous())
|
302 |
+
|
303 |
+
# add hidden states and if they are here
|
304 |
+
outputs = (amino_acid_class_scores,) + outputs[2:]
|
305 |
+
|
306 |
+
if targets is not None:
|
307 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
308 |
+
classification_loss = loss_fct(
|
309 |
+
amino_acid_class_scores.view(-1, self.config.num_labels),
|
310 |
+
targets.view(-1))
|
311 |
+
outputs = (classification_loss,) + outputs
|
312 |
+
|
313 |
+
# (loss), prediction_scores, seq_relationship_score, (hidden_states)
|
314 |
+
return outputs
|
315 |
+
|
316 |
+
|
317 |
+
@registry.register_task_model('contact_prediction', 'lstm')
|
318 |
+
class ProteinLSTMForContactPrediction(ProteinLSTMAbstractModel):
|
319 |
+
|
320 |
+
def __init__(self, config):
|
321 |
+
super().__init__(config)
|
322 |
+
|
323 |
+
self.lstm = ProteinLSTMModel(config)
|
324 |
+
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
|
325 |
+
|
326 |
+
self.init_weights()
|
327 |
+
|
328 |
+
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
|
329 |
+
|
330 |
+
outputs = self.lstm(input_ids, input_mask=input_mask)
|
331 |
+
|
332 |
+
sequence_output, pooled_output = outputs[:2]
|
333 |
+
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
|
334 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
335 |
+
return outputs
|
tape/models/modeling_onehot.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import typing
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .modeling_utils import ProteinConfig
|
8 |
+
from .modeling_utils import ProteinModel
|
9 |
+
from .modeling_utils import ValuePredictionHead
|
10 |
+
from .modeling_utils import SequenceClassificationHead
|
11 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
12 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
13 |
+
from ..registry import registry
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class ProteinOneHotConfig(ProteinConfig):
|
19 |
+
pretrained_config_archive_map: typing.Dict[str, str] = {}
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
vocab_size: int,
|
23 |
+
initializer_range: float = 0.02,
|
24 |
+
use_evolutionary: bool = False,
|
25 |
+
**kwargs):
|
26 |
+
super().__init__(**kwargs)
|
27 |
+
self.vocab_size = vocab_size
|
28 |
+
self.use_evolutionary = use_evolutionary
|
29 |
+
self.initializer_range = initializer_range
|
30 |
+
|
31 |
+
|
32 |
+
class ProteinOneHotAbstractModel(ProteinModel):
|
33 |
+
|
34 |
+
config_class = ProteinOneHotConfig
|
35 |
+
pretrained_model_archive_map: typing.Dict[str, str] = {}
|
36 |
+
base_model_prefix = "onehot"
|
37 |
+
|
38 |
+
def _init_weights(self, module):
|
39 |
+
""" Initialize the weights """
|
40 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
41 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
42 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
43 |
+
module.bias.data.zero_()
|
44 |
+
|
45 |
+
|
46 |
+
class ProteinOneHotModel(ProteinOneHotAbstractModel):
|
47 |
+
|
48 |
+
def __init__(self, config: ProteinOneHotConfig):
|
49 |
+
super().__init__(config)
|
50 |
+
self.vocab_size = config.vocab_size
|
51 |
+
|
52 |
+
# Note: this exists *solely* for fp16 support
|
53 |
+
# There doesn't seem to be an easier way to check whether to use fp16 or fp32 training
|
54 |
+
buffer = torch.tensor([0.])
|
55 |
+
self.register_buffer('_buffer', buffer)
|
56 |
+
|
57 |
+
def forward(self, input_ids, input_mask=None):
|
58 |
+
if input_mask is None:
|
59 |
+
input_mask = torch.ones_like(input_ids)
|
60 |
+
|
61 |
+
sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size)
|
62 |
+
# fp16 compatibility
|
63 |
+
sequence_output = sequence_output.type_as(self._buffer)
|
64 |
+
input_mask = input_mask.unsqueeze(2).type_as(sequence_output)
|
65 |
+
# just a bag-of-words for amino acids
|
66 |
+
pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1)
|
67 |
+
|
68 |
+
outputs = (sequence_output, pooled_outputs)
|
69 |
+
return outputs
|
70 |
+
|
71 |
+
|
72 |
+
@registry.register_task_model('fluorescence', 'onehot')
|
73 |
+
@registry.register_task_model('stability', 'onehot')
|
74 |
+
class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel):
|
75 |
+
|
76 |
+
def __init__(self, config):
|
77 |
+
super().__init__(config)
|
78 |
+
|
79 |
+
self.onehot = ProteinOneHotModel(config)
|
80 |
+
self.predict = ValuePredictionHead(config.vocab_size)
|
81 |
+
|
82 |
+
self.init_weights()
|
83 |
+
|
84 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
85 |
+
|
86 |
+
outputs = self.onehot(input_ids, input_mask=input_mask)
|
87 |
+
|
88 |
+
sequence_output, pooled_output = outputs[:2]
|
89 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
90 |
+
# (loss), prediction_scores, (hidden_states)
|
91 |
+
return outputs
|
92 |
+
|
93 |
+
|
94 |
+
@registry.register_task_model('remote_homology', 'onehot')
|
95 |
+
class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel):
|
96 |
+
|
97 |
+
def __init__(self, config):
|
98 |
+
super().__init__(config)
|
99 |
+
|
100 |
+
self.onehot = ProteinOneHotModel(config)
|
101 |
+
self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels)
|
102 |
+
|
103 |
+
self.init_weights()
|
104 |
+
|
105 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
106 |
+
|
107 |
+
outputs = self.onehot(input_ids, input_mask=input_mask)
|
108 |
+
|
109 |
+
sequence_output, pooled_output = outputs[:2]
|
110 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
111 |
+
# (loss), prediction_scores, (hidden_states)
|
112 |
+
return outputs
|
113 |
+
|
114 |
+
|
115 |
+
@registry.register_task_model('secondary_structure', 'onehot')
|
116 |
+
class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel):
|
117 |
+
|
118 |
+
def __init__(self, config):
|
119 |
+
super().__init__(config)
|
120 |
+
|
121 |
+
self.onehot = ProteinOneHotModel(config)
|
122 |
+
self.classify = SequenceToSequenceClassificationHead(
|
123 |
+
config.vocab_size, config.num_labels, ignore_index=-1)
|
124 |
+
|
125 |
+
self.init_weights()
|
126 |
+
|
127 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
128 |
+
|
129 |
+
outputs = self.onehot(input_ids, input_mask=input_mask)
|
130 |
+
|
131 |
+
sequence_output, pooled_output = outputs[:2]
|
132 |
+
outputs = self.classify(sequence_output, targets) + outputs[2:]
|
133 |
+
# (loss), prediction_scores, (hidden_states)
|
134 |
+
return outputs
|
135 |
+
|
136 |
+
|
137 |
+
@registry.register_task_model('contact_prediction', 'onehot')
|
138 |
+
class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel):
|
139 |
+
|
140 |
+
def __init__(self, config):
|
141 |
+
super().__init__(config)
|
142 |
+
|
143 |
+
self.onehot = ProteinOneHotModel(config)
|
144 |
+
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
|
145 |
+
|
146 |
+
self.init_weights()
|
147 |
+
|
148 |
+
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
|
149 |
+
|
150 |
+
outputs = self.onehot(input_ids, input_mask=input_mask)
|
151 |
+
|
152 |
+
sequence_output, pooled_output = outputs[:2]
|
153 |
+
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
|
154 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
155 |
+
return outputs
|
tape/models/modeling_resnet.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .modeling_utils import ProteinConfig
|
7 |
+
from .modeling_utils import ProteinModel
|
8 |
+
from .modeling_utils import get_activation_fn
|
9 |
+
from .modeling_utils import MLMHead
|
10 |
+
from .modeling_utils import LayerNorm
|
11 |
+
from .modeling_utils import ValuePredictionHead
|
12 |
+
from .modeling_utils import SequenceClassificationHead
|
13 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
14 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
15 |
+
from ..registry import registry
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
20 |
+
RESNET_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
|
21 |
+
|
22 |
+
|
23 |
+
class ProteinResNetConfig(ProteinConfig):
|
24 |
+
pretrained_config_archive_map = RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
vocab_size: int = 30,
|
28 |
+
hidden_size: int = 512,
|
29 |
+
num_hidden_layers: int = 30,
|
30 |
+
hidden_act: str = "gelu",
|
31 |
+
hidden_dropout_prob: float = 0.1,
|
32 |
+
initializer_range: float = 0.02,
|
33 |
+
layer_norm_eps: float = 1e-12,
|
34 |
+
temporal_pooling: str = 'attention',
|
35 |
+
freeze_embedding: bool = False,
|
36 |
+
**kwargs):
|
37 |
+
super().__init__(**kwargs)
|
38 |
+
self.vocab_size = vocab_size
|
39 |
+
self.num_hidden_layers = num_hidden_layers
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.hidden_act = hidden_act
|
42 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
43 |
+
self.initializer_range = initializer_range
|
44 |
+
self.layer_norm_eps = layer_norm_eps
|
45 |
+
self.temporal_pooling = temporal_pooling
|
46 |
+
self.freeze_embedding = freeze_embedding
|
47 |
+
|
48 |
+
|
49 |
+
class MaskedConv1d(nn.Conv1d):
|
50 |
+
|
51 |
+
def forward(self, x, input_mask=None):
|
52 |
+
if input_mask is not None:
|
53 |
+
x = x * input_mask
|
54 |
+
return super().forward(x)
|
55 |
+
|
56 |
+
|
57 |
+
class ProteinResNetLayerNorm(nn.Module):
|
58 |
+
|
59 |
+
def __init__(self, config):
|
60 |
+
super().__init__()
|
61 |
+
self.norm = LayerNorm(config.hidden_size)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.norm(x.transpose(1, 2)).transpose(1, 2)
|
65 |
+
|
66 |
+
|
67 |
+
class ProteinResNetBlock(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self, config):
|
70 |
+
super().__init__()
|
71 |
+
self.conv1 = MaskedConv1d(
|
72 |
+
config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
|
73 |
+
# self.bn1 = nn.BatchNorm1d(config.hidden_size)
|
74 |
+
self.bn1 = ProteinResNetLayerNorm(config)
|
75 |
+
self.conv2 = MaskedConv1d(
|
76 |
+
config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
|
77 |
+
# self.bn2 = nn.BatchNorm1d(config.hidden_size)
|
78 |
+
self.bn2 = ProteinResNetLayerNorm(config)
|
79 |
+
self.activation_fn = get_activation_fn(config.hidden_act)
|
80 |
+
|
81 |
+
def forward(self, x, input_mask=None):
|
82 |
+
identity = x
|
83 |
+
|
84 |
+
out = self.conv1(x, input_mask)
|
85 |
+
out = self.bn1(out)
|
86 |
+
out = self.activation_fn(out)
|
87 |
+
|
88 |
+
out = self.conv2(out, input_mask)
|
89 |
+
out = self.bn2(out)
|
90 |
+
|
91 |
+
out += identity
|
92 |
+
out = self.activation_fn(out)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class ProteinResNetEmbeddings(nn.Module):
|
98 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
99 |
+
"""
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
embed_dim = config.hidden_size
|
103 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, embed_dim, padding_idx=0)
|
104 |
+
inverse_frequency = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim))
|
105 |
+
self.register_buffer('inverse_frequency', inverse_frequency)
|
106 |
+
|
107 |
+
self.layer_norm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
108 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
109 |
+
|
110 |
+
def forward(self, input_ids):
|
111 |
+
words_embeddings = self.word_embeddings(input_ids)
|
112 |
+
|
113 |
+
seq_length = input_ids.size(1)
|
114 |
+
position_ids = torch.arange(
|
115 |
+
seq_length - 1, -1, -1.0,
|
116 |
+
dtype=words_embeddings.dtype,
|
117 |
+
device=words_embeddings.device)
|
118 |
+
sinusoidal_input = torch.ger(position_ids, self.inverse_frequency)
|
119 |
+
position_embeddings = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1)
|
120 |
+
position_embeddings = position_embeddings.unsqueeze(0)
|
121 |
+
|
122 |
+
embeddings = words_embeddings + position_embeddings
|
123 |
+
embeddings = self.layer_norm(embeddings)
|
124 |
+
embeddings = self.dropout(embeddings)
|
125 |
+
return embeddings
|
126 |
+
|
127 |
+
|
128 |
+
class ProteinResNetPooler(nn.Module):
|
129 |
+
def __init__(self, config):
|
130 |
+
super().__init__()
|
131 |
+
self.attention_weights = nn.Linear(config.hidden_size, 1)
|
132 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
133 |
+
self.activation = nn.Tanh()
|
134 |
+
self.temporal_pooling = config.temporal_pooling
|
135 |
+
self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
136 |
+
self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
|
137 |
+
self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
|
138 |
+
|
139 |
+
def forward(self, hidden_states, mask=None):
|
140 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
141 |
+
# to the first token.
|
142 |
+
if self.temporal_pooling == 'mean':
|
143 |
+
return hidden_states.mean(dim=1)
|
144 |
+
if self.temporal_pooling == 'max':
|
145 |
+
return hidden_states.max(dim=1)
|
146 |
+
if self.temporal_pooling == 'concat':
|
147 |
+
_temp = hidden_states.reshape(hidden_states.shape[0], -1)
|
148 |
+
return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
|
149 |
+
if self.temporal_pooling == 'meanmax':
|
150 |
+
_mean = hidden_states.mean(dim=1)
|
151 |
+
_max = hidden_states.max(dim=1)
|
152 |
+
return torch.cat([_mean, _max])
|
153 |
+
if self.temporal_pooling == 'topmax':
|
154 |
+
val, _ = torch.topk(hidden_states, k=5, dim=1)
|
155 |
+
return val.mean(dim=1)
|
156 |
+
if self.temporal_pooling == 'light_attention':
|
157 |
+
_temp = hidden_states.permute(0,2,1)
|
158 |
+
a = self._la_w1(_temp).softmax(dim=-1)
|
159 |
+
v = self._la_w2(_temp)
|
160 |
+
v_max = v.max(dim=-1).values
|
161 |
+
v_sum = (a * v).sum(dim=-1)
|
162 |
+
return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
|
163 |
+
|
164 |
+
attention_scores = self.attention_weights(hidden_states)
|
165 |
+
if mask is not None:
|
166 |
+
attention_scores += -10000. * (1 - mask)
|
167 |
+
attention_weights = torch.softmax(attention_scores, -1)
|
168 |
+
weighted_mean_embedding = torch.matmul(
|
169 |
+
hidden_states.transpose(1, 2), attention_weights).squeeze(2)
|
170 |
+
pooled_output = self.dense(weighted_mean_embedding)
|
171 |
+
pooled_output = self.activation(pooled_output)
|
172 |
+
return pooled_output
|
173 |
+
|
174 |
+
|
175 |
+
class ResNetEncoder(nn.Module):
|
176 |
+
|
177 |
+
def __init__(self, config):
|
178 |
+
super().__init__()
|
179 |
+
self.output_hidden_states = config.output_hidden_states
|
180 |
+
self.layer = nn.ModuleList(
|
181 |
+
[ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
|
182 |
+
|
183 |
+
def forward(self, hidden_states, input_mask=None):
|
184 |
+
all_hidden_states = ()
|
185 |
+
for layer_module in self.layer:
|
186 |
+
if self.output_hidden_states:
|
187 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
188 |
+
hidden_states = layer_module(hidden_states, input_mask)
|
189 |
+
|
190 |
+
if self.output_hidden_states:
|
191 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
192 |
+
|
193 |
+
outputs = (hidden_states,)
|
194 |
+
if self.output_hidden_states:
|
195 |
+
outputs = outputs + (all_hidden_states,)
|
196 |
+
|
197 |
+
return outputs
|
198 |
+
|
199 |
+
|
200 |
+
class ProteinResNetAbstractModel(ProteinModel):
|
201 |
+
""" An abstract class to handle weights initialization and
|
202 |
+
a simple interface for dowloading and loading pretrained models.
|
203 |
+
"""
|
204 |
+
config_class = ProteinResNetConfig
|
205 |
+
pretrained_model_archive_map = RESNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
206 |
+
base_model_prefix = "resnet"
|
207 |
+
|
208 |
+
def __init__(self, config):
|
209 |
+
super().__init__(config)
|
210 |
+
|
211 |
+
def _init_weights(self, module):
|
212 |
+
""" Initialize the weights """
|
213 |
+
if isinstance(module, nn.Embedding):
|
214 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
215 |
+
elif isinstance(module, nn.Linear):
|
216 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
217 |
+
if module.bias is not None:
|
218 |
+
module.bias.data.zero_()
|
219 |
+
elif isinstance(module, nn.Conv1d):
|
220 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
221 |
+
if module.bias is not None:
|
222 |
+
module.bias.data.zero_()
|
223 |
+
# elif isinstance(module, ProteinResNetBlock):
|
224 |
+
# nn.init.constant_(module.bn2.weight, 0)
|
225 |
+
|
226 |
+
|
227 |
+
@registry.register_task_model('embed', 'resnet')
|
228 |
+
class ProteinResNetModel(ProteinResNetAbstractModel):
|
229 |
+
|
230 |
+
def __init__(self, config):
|
231 |
+
super().__init__(config)
|
232 |
+
|
233 |
+
self.embeddings = ProteinResNetEmbeddings(config)
|
234 |
+
self.encoder = ResNetEncoder(config)
|
235 |
+
self.pooler = ProteinResNetPooler(config)
|
236 |
+
|
237 |
+
self.init_weights()
|
238 |
+
|
239 |
+
def forward(self,
|
240 |
+
input_ids,
|
241 |
+
input_mask=None):
|
242 |
+
if input_mask is not None and torch.any(input_mask != 1):
|
243 |
+
extended_input_mask = input_mask.unsqueeze(2)
|
244 |
+
# fp16 compatibility
|
245 |
+
extended_input_mask = extended_input_mask.to(
|
246 |
+
dtype=next(self.parameters()).dtype)
|
247 |
+
else:
|
248 |
+
extended_input_mask = None
|
249 |
+
|
250 |
+
embedding_output = self.embeddings(input_ids)
|
251 |
+
embedding_output = embedding_output.transpose(1, 2)
|
252 |
+
if extended_input_mask is not None:
|
253 |
+
extended_input_mask = extended_input_mask.transpose(1, 2)
|
254 |
+
encoder_outputs = self.encoder(embedding_output, extended_input_mask)
|
255 |
+
sequence_output = encoder_outputs[0]
|
256 |
+
sequence_output = sequence_output.transpose(1, 2).contiguous()
|
257 |
+
# sequence_output = encoder_outputs[0]
|
258 |
+
if extended_input_mask is not None:
|
259 |
+
extended_input_mask = extended_input_mask.transpose(1, 2)
|
260 |
+
pooled_output = self.pooler(sequence_output, extended_input_mask)
|
261 |
+
|
262 |
+
# add hidden_states and attentions if they are here
|
263 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
264 |
+
return outputs # sequence_output, pooled_output, (hidden_states)
|
265 |
+
|
266 |
+
|
267 |
+
@registry.register_task_model('masked_language_modeling', 'resnet')
|
268 |
+
class ProteinResNetForMaskedLM(ProteinResNetAbstractModel):
|
269 |
+
|
270 |
+
def __init__(self, config):
|
271 |
+
super().__init__(config)
|
272 |
+
|
273 |
+
self.resnet = ProteinResNetModel(config)
|
274 |
+
self.mlm = MLMHead(
|
275 |
+
config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
|
276 |
+
ignore_index=-1)
|
277 |
+
|
278 |
+
self.init_weights()
|
279 |
+
self.tie_weights()
|
280 |
+
|
281 |
+
def tie_weights(self):
|
282 |
+
""" Make sure we are sharing the input and output embeddings.
|
283 |
+
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
284 |
+
"""
|
285 |
+
self._tie_or_clone_weights(self.mlm.decoder,
|
286 |
+
self.resnet.embeddings.word_embeddings)
|
287 |
+
|
288 |
+
def forward(self,
|
289 |
+
input_ids,
|
290 |
+
input_mask=None,
|
291 |
+
targets=None):
|
292 |
+
|
293 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
294 |
+
|
295 |
+
sequence_output, pooled_output = outputs[:2]
|
296 |
+
outputs = self.mlm(sequence_output, targets) + outputs[:2]
|
297 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
298 |
+
return outputs
|
299 |
+
|
300 |
+
|
301 |
+
@registry.register_task_model('fluorescence', 'resnet')
|
302 |
+
@registry.register_task_model('stability', 'resnet')
|
303 |
+
class ProteinResNetForValuePrediction(ProteinResNetAbstractModel):
|
304 |
+
|
305 |
+
def __init__(self, config):
|
306 |
+
super().__init__(config)
|
307 |
+
|
308 |
+
self.resnet = ProteinResNetModel(config)
|
309 |
+
self.predict = ValuePredictionHead(config.hidden_size)
|
310 |
+
self.freeze_embedding = config.freeze_embedding
|
311 |
+
self.init_weights()
|
312 |
+
|
313 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
314 |
+
if self.freeze_embedding:
|
315 |
+
self.resnet.train(False)
|
316 |
+
|
317 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
318 |
+
|
319 |
+
sequence_output, pooled_output = outputs[:2]
|
320 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
321 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
322 |
+
return outputs
|
323 |
+
|
324 |
+
|
325 |
+
@registry.register_task_model('remote_homology', 'resnet')
|
326 |
+
class ProteinResNetForSequenceClassification(ProteinResNetAbstractModel):
|
327 |
+
|
328 |
+
def __init__(self, config):
|
329 |
+
super().__init__(config)
|
330 |
+
|
331 |
+
self.resnet = ProteinResNetModel(config)
|
332 |
+
self.classify = SequenceClassificationHead(config.hidden_size, config.num_labels)
|
333 |
+
self.freeze_embedding = config.freeze_embedding
|
334 |
+
|
335 |
+
self.init_weights()
|
336 |
+
|
337 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
338 |
+
if self.freeze_embedding:
|
339 |
+
self.resnet.train(False)
|
340 |
+
|
341 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
342 |
+
|
343 |
+
sequence_output, pooled_output = outputs[:2]
|
344 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
345 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
346 |
+
return outputs
|
347 |
+
|
348 |
+
|
349 |
+
@registry.register_task_model('secondary_structure', 'resnet')
|
350 |
+
class ProteinResNetForSequenceToSequenceClassification(ProteinResNetAbstractModel):
|
351 |
+
|
352 |
+
def __init__(self, config):
|
353 |
+
super().__init__(config)
|
354 |
+
|
355 |
+
self.resnet = ProteinResNetModel(config)
|
356 |
+
self.classify = SequenceToSequenceClassificationHead(
|
357 |
+
config.hidden_size, config.num_labels, ignore_index=-1)
|
358 |
+
|
359 |
+
self.init_weights()
|
360 |
+
|
361 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
362 |
+
|
363 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
364 |
+
|
365 |
+
sequence_output, pooled_output = outputs[:2]
|
366 |
+
outputs = self.classify(sequence_output, targets) + outputs[2:]
|
367 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
368 |
+
return outputs
|
369 |
+
|
370 |
+
|
371 |
+
@registry.register_task_model('contact_prediction', 'resnet')
|
372 |
+
class ProteinResNetForContactPrediction(ProteinResNetAbstractModel):
|
373 |
+
|
374 |
+
def __init__(self, config):
|
375 |
+
super().__init__(config)
|
376 |
+
|
377 |
+
self.resnet = ProteinResNetModel(config)
|
378 |
+
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
|
379 |
+
|
380 |
+
self.init_weights()
|
381 |
+
|
382 |
+
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
|
383 |
+
|
384 |
+
outputs = self.resnet(input_ids, input_mask=input_mask)
|
385 |
+
|
386 |
+
sequence_output, pooled_output = outputs[:2]
|
387 |
+
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
|
388 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
389 |
+
return outputs
|
tape/models/modeling_trrosetta.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ..registry import registry
|
5 |
+
from .modeling_utils import ProteinConfig
|
6 |
+
from .modeling_utils import ProteinModel
|
7 |
+
|
8 |
+
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
|
9 |
+
TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
10 |
+
'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin",
|
11 |
+
'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin",
|
12 |
+
'xac': URL_PREFIX + "trRosetta-xac-pytorch_model.bin",
|
13 |
+
'xad': URL_PREFIX + "trRosetta-xad-pytorch_model.bin",
|
14 |
+
'xae': URL_PREFIX + "trRosetta-xae-pytorch_model.bin",
|
15 |
+
}
|
16 |
+
TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
17 |
+
'xaa': URL_PREFIX + "trRosetta-xaa-config.json",
|
18 |
+
'xab': URL_PREFIX + "trRosetta-xab-config.json",
|
19 |
+
'xac': URL_PREFIX + "trRosetta-xac-config.json",
|
20 |
+
'xad': URL_PREFIX + "trRosetta-xad-config.json",
|
21 |
+
'xae': URL_PREFIX + "trRosetta-xae-config.json",
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class TRRosettaConfig(ProteinConfig):
|
26 |
+
|
27 |
+
pretrained_config_archive_map = TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
num_features: int = 64,
|
31 |
+
kernel_size: int = 3,
|
32 |
+
num_layers: int = 61,
|
33 |
+
dropout: float = 0.15,
|
34 |
+
msa_cutoff: float = 0.8,
|
35 |
+
penalty_coeff: float = 4.5,
|
36 |
+
initializer_range: float = 0.02,
|
37 |
+
**kwargs):
|
38 |
+
super().__init__(**kwargs)
|
39 |
+
self.num_features = num_features
|
40 |
+
self.kernel_size = kernel_size
|
41 |
+
self.num_layers = num_layers
|
42 |
+
self.dropout = dropout
|
43 |
+
self.msa_cutoff = msa_cutoff
|
44 |
+
self.penalty_coeff = penalty_coeff
|
45 |
+
self.initializer_range = initializer_range
|
46 |
+
|
47 |
+
|
48 |
+
class MSAFeatureExtractor(nn.Module):
|
49 |
+
|
50 |
+
def __init__(self, config: TRRosettaConfig):
|
51 |
+
super().__init__()
|
52 |
+
self.msa_cutoff = config.msa_cutoff
|
53 |
+
self.penalty_coeff = config.penalty_coeff
|
54 |
+
|
55 |
+
def forward(self, msa1hot):
|
56 |
+
# Convert to float, then potentially back to half
|
57 |
+
# These transforms aren't well suited to half-precision
|
58 |
+
initial_type = msa1hot.dtype
|
59 |
+
|
60 |
+
msa1hot = msa1hot.float()
|
61 |
+
seqlen = msa1hot.size(2)
|
62 |
+
|
63 |
+
weights = self.reweight(msa1hot)
|
64 |
+
features_1d = self.extract_features_1d(msa1hot, weights)
|
65 |
+
features_2d = self.extract_features_2d(msa1hot, weights)
|
66 |
+
|
67 |
+
left = features_1d.unsqueeze(2).repeat(1, 1, seqlen, 1)
|
68 |
+
right = features_1d.unsqueeze(1).repeat(1, seqlen, 1, 1)
|
69 |
+
features = torch.cat((left, right, features_2d), -1)
|
70 |
+
features = features.type(initial_type)
|
71 |
+
features = features.permute(0, 3, 1, 2)
|
72 |
+
features = features.contiguous()
|
73 |
+
return features
|
74 |
+
|
75 |
+
def reweight(self, msa1hot, eps=1e-9):
|
76 |
+
# Reweight
|
77 |
+
seqlen = msa1hot.size(2)
|
78 |
+
id_min = seqlen * self.msa_cutoff
|
79 |
+
id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
|
80 |
+
id_mask = id_mtx > id_min
|
81 |
+
weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
|
82 |
+
return weights
|
83 |
+
|
84 |
+
def extract_features_1d(self, msa1hot, weights):
|
85 |
+
# 1D Features
|
86 |
+
f1d_seq = msa1hot[:, 0, :, :20]
|
87 |
+
batch_size = msa1hot.size(0)
|
88 |
+
seqlen = msa1hot.size(2)
|
89 |
+
|
90 |
+
# msa2pssm
|
91 |
+
beff = weights.sum()
|
92 |
+
f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9
|
93 |
+
h_i = (-f_i * f_i.log()).sum(2, keepdims=True)
|
94 |
+
f1d_pssm = torch.cat((f_i, h_i), dim=2)
|
95 |
+
f1d = torch.cat((f1d_seq, f1d_pssm), dim=2)
|
96 |
+
f1d = f1d.view(batch_size, seqlen, 42)
|
97 |
+
return f1d
|
98 |
+
|
99 |
+
def extract_features_2d(self, msa1hot, weights):
|
100 |
+
# 2D Features
|
101 |
+
batch_size = msa1hot.size(0)
|
102 |
+
num_alignments = msa1hot.size(1)
|
103 |
+
seqlen = msa1hot.size(2)
|
104 |
+
num_symbols = 21
|
105 |
+
|
106 |
+
if num_alignments == 1:
|
107 |
+
# No alignments, predict from sequence alone
|
108 |
+
f2d_dca = torch.zeros(
|
109 |
+
batch_size, seqlen, seqlen, 442,
|
110 |
+
dtype=torch.float,
|
111 |
+
device=msa1hot.device)
|
112 |
+
return f2d_dca
|
113 |
+
|
114 |
+
# compute fast_dca
|
115 |
+
# covariance
|
116 |
+
x = msa1hot.view(batch_size, num_alignments, seqlen * num_symbols)
|
117 |
+
num_points = weights.sum(1) - weights.mean(1).sqrt()
|
118 |
+
mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None]
|
119 |
+
x = (x - mean) * weights[:, :, None].sqrt()
|
120 |
+
cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None]
|
121 |
+
|
122 |
+
# inverse covariance
|
123 |
+
reg = torch.eye(seqlen * num_symbols,
|
124 |
+
device=weights.device,
|
125 |
+
dtype=weights.dtype)[None]
|
126 |
+
reg = reg * self.penalty_coeff / weights.sum(1, keepdims=True).sqrt().unsqueeze(2)
|
127 |
+
cov_reg = cov + reg
|
128 |
+
inv_cov = torch.stack([torch.inverse(cr) for cr in cov_reg.unbind(0)], 0)
|
129 |
+
|
130 |
+
x1 = inv_cov.view(batch_size, seqlen, num_symbols, seqlen, num_symbols)
|
131 |
+
x2 = x1.permute(0, 1, 3, 2, 4)
|
132 |
+
features = x2.reshape(batch_size, seqlen, seqlen, num_symbols * num_symbols)
|
133 |
+
|
134 |
+
x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * (
|
135 |
+
1 - torch.eye(seqlen, device=weights.device, dtype=weights.dtype)[None])
|
136 |
+
apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum(
|
137 |
+
(1, 2), keepdims=True)
|
138 |
+
contacts = (x3 - apc) * (1 - torch.eye(
|
139 |
+
seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0))
|
140 |
+
|
141 |
+
f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3)
|
142 |
+
return f2d_dca
|
143 |
+
|
144 |
+
@property
|
145 |
+
def feature_size(self) -> int:
|
146 |
+
return 526
|
147 |
+
|
148 |
+
|
149 |
+
class DilatedResidualBlock(nn.Module):
|
150 |
+
|
151 |
+
def __init__(self, num_features: int, kernel_size: int, dilation: int, dropout: float):
|
152 |
+
super().__init__()
|
153 |
+
padding = self._get_padding(kernel_size, dilation)
|
154 |
+
self.conv1 = nn.Conv2d(
|
155 |
+
num_features, num_features, kernel_size, padding=padding, dilation=dilation)
|
156 |
+
self.norm1 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
|
157 |
+
self.actv1 = nn.ELU(inplace=True)
|
158 |
+
self.dropout = nn.Dropout(dropout)
|
159 |
+
self.conv2 = nn.Conv2d(
|
160 |
+
num_features, num_features, kernel_size, padding=padding, dilation=dilation)
|
161 |
+
self.norm2 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
|
162 |
+
self.actv2 = nn.ELU(inplace=True)
|
163 |
+
self.apply(self._init_weights)
|
164 |
+
nn.init.constant_(self.norm2.weight, 0)
|
165 |
+
|
166 |
+
def _get_padding(self, kernel_size: int, dilation: int) -> int:
|
167 |
+
return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
|
168 |
+
|
169 |
+
def _init_weights(self, module):
|
170 |
+
""" Initialize the weights """
|
171 |
+
if isinstance(module, nn.Conv2d):
|
172 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
173 |
+
if module.bias is not None:
|
174 |
+
module.bias.data.zero_()
|
175 |
+
|
176 |
+
# elif isinstance(module, DilatedResidualBlock):
|
177 |
+
# nn.init.constant_(module.norm2.weight, 0)
|
178 |
+
|
179 |
+
def forward(self, features):
|
180 |
+
shortcut = features
|
181 |
+
features = self.conv1(features)
|
182 |
+
features = self.norm1(features)
|
183 |
+
features = self.actv1(features)
|
184 |
+
features = self.dropout(features)
|
185 |
+
features = self.conv2(features)
|
186 |
+
features = self.norm2(features)
|
187 |
+
features = self.actv2(features + shortcut)
|
188 |
+
return features
|
189 |
+
|
190 |
+
|
191 |
+
class TRRosettaAbstractModel(ProteinModel):
|
192 |
+
|
193 |
+
config_class = TRRosettaConfig
|
194 |
+
base_model_prefix = 'trrosetta'
|
195 |
+
pretrained_model_archive_map = TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
196 |
+
|
197 |
+
def __init__(self, config: TRRosettaConfig):
|
198 |
+
super().__init__(config)
|
199 |
+
|
200 |
+
def _init_weights(self, module):
|
201 |
+
""" Initialize the weights """
|
202 |
+
if isinstance(module, nn.Linear):
|
203 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
204 |
+
if module.bias is not None:
|
205 |
+
module.bias.data.zero_()
|
206 |
+
elif isinstance(module, nn.Conv2d):
|
207 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
208 |
+
if module.bias is not None:
|
209 |
+
module.bias.data.zero_()
|
210 |
+
elif isinstance(module, DilatedResidualBlock):
|
211 |
+
nn.init.constant_(module.norm2.weight, 0)
|
212 |
+
|
213 |
+
|
214 |
+
class TRRosettaPredictor(TRRosettaAbstractModel):
|
215 |
+
|
216 |
+
def __init__(self, config: TRRosettaConfig):
|
217 |
+
super().__init__(config)
|
218 |
+
layers = [
|
219 |
+
nn.Conv2d(526, config.num_features, 1),
|
220 |
+
nn.InstanceNorm2d(config.num_features, affine=True, eps=1e-6),
|
221 |
+
nn.ELU(),
|
222 |
+
nn.Dropout(config.dropout)]
|
223 |
+
|
224 |
+
dilation = 1
|
225 |
+
for _ in range(config.num_layers):
|
226 |
+
block = DilatedResidualBlock(
|
227 |
+
config.num_features, config.kernel_size, dilation, config.dropout)
|
228 |
+
layers.append(block)
|
229 |
+
|
230 |
+
dilation *= 2
|
231 |
+
if dilation > 16:
|
232 |
+
dilation = 1
|
233 |
+
|
234 |
+
self.resnet = nn.Sequential(*layers)
|
235 |
+
self.predict_theta = nn.Conv2d(config.num_features, 25, 1)
|
236 |
+
self.predict_phi = nn.Conv2d(config.num_features, 13, 1)
|
237 |
+
self.predict_dist = nn.Conv2d(config.num_features, 37, 1)
|
238 |
+
self.predict_bb = nn.Conv2d(config.num_features, 3, 1)
|
239 |
+
self.predict_omega = nn.Conv2d(config.num_features, 25, 1)
|
240 |
+
|
241 |
+
self.init_weights()
|
242 |
+
|
243 |
+
def init_weights(self):
|
244 |
+
self.apply(self._init_weights)
|
245 |
+
nn.init.constant_(self.predict_theta.weight, 0)
|
246 |
+
nn.init.constant_(self.predict_phi.weight, 0)
|
247 |
+
nn.init.constant_(self.predict_dist.weight, 0)
|
248 |
+
nn.init.constant_(self.predict_bb.weight, 0)
|
249 |
+
nn.init.constant_(self.predict_omega.weight, 0)
|
250 |
+
|
251 |
+
def forward(self,
|
252 |
+
features,
|
253 |
+
theta=None,
|
254 |
+
phi=None,
|
255 |
+
dist=None,
|
256 |
+
omega=None):
|
257 |
+
batch_size = features.size(0)
|
258 |
+
seqlen = features.size(2)
|
259 |
+
embedding = self.resnet(features)
|
260 |
+
|
261 |
+
# anglegrams for theta
|
262 |
+
logits_theta = self.predict_theta(embedding)
|
263 |
+
|
264 |
+
# anglegrams for phi
|
265 |
+
logits_phi = self.predict_phi(embedding)
|
266 |
+
|
267 |
+
# symmetrize
|
268 |
+
sym_embedding = 0.5 * (embedding + embedding.transpose(-1, -2))
|
269 |
+
|
270 |
+
# distograms
|
271 |
+
logits_dist = self.predict_dist(sym_embedding)
|
272 |
+
|
273 |
+
# beta-strand pairings (not used)
|
274 |
+
# logits_bb = self.predict_bb(sym_embedding)
|
275 |
+
|
276 |
+
# anglegrams for omega
|
277 |
+
logits_omega = self.predict_omega(sym_embedding)
|
278 |
+
|
279 |
+
logits_dist = logits_dist.permute(0, 2, 3, 1).contiguous()
|
280 |
+
logits_theta = logits_theta.permute(0, 2, 3, 1).contiguous()
|
281 |
+
logits_omega = logits_omega.permute(0, 2, 3, 1).contiguous()
|
282 |
+
logits_phi = logits_phi.permute(0, 2, 3, 1).contiguous()
|
283 |
+
|
284 |
+
probs = {}
|
285 |
+
probs['p_dist'] = nn.Softmax(-1)(logits_dist)
|
286 |
+
probs['p_theta'] = nn.Softmax(-1)(logits_theta)
|
287 |
+
probs['p_omega'] = nn.Softmax(-1)(logits_omega)
|
288 |
+
probs['p_phi'] = nn.Softmax(-1)(logits_phi)
|
289 |
+
outputs = (probs,)
|
290 |
+
|
291 |
+
metrics = {}
|
292 |
+
total_loss = 0
|
293 |
+
|
294 |
+
if dist is not None:
|
295 |
+
logits_dist = logits_dist.reshape(batch_size * seqlen * seqlen, 37)
|
296 |
+
loss_dist = nn.CrossEntropyLoss(ignore_index=-1)(logits_dist, dist.view(-1))
|
297 |
+
metrics['dist'] = loss_dist
|
298 |
+
total_loss += loss_dist
|
299 |
+
if theta is not None:
|
300 |
+
logits_theta = logits_theta.reshape(batch_size * seqlen * seqlen, 25)
|
301 |
+
loss_theta = nn.CrossEntropyLoss(ignore_index=0)(logits_theta, theta.view(-1))
|
302 |
+
metrics['theta'] = loss_theta
|
303 |
+
total_loss += loss_theta
|
304 |
+
if omega is not None:
|
305 |
+
logits_omega = logits_omega.reshape(batch_size * seqlen * seqlen, 25)
|
306 |
+
loss_omega = nn.CrossEntropyLoss(ignore_index=0)(logits_omega, omega.view(-1))
|
307 |
+
metrics['omega'] = loss_omega
|
308 |
+
total_loss += loss_omega
|
309 |
+
if phi is not None:
|
310 |
+
logits_phi = logits_phi.reshape(batch_size * seqlen * seqlen, 13)
|
311 |
+
loss_phi = nn.CrossEntropyLoss(ignore_index=0)(logits_phi, phi.view(-1))
|
312 |
+
metrics['phi'] = loss_phi
|
313 |
+
total_loss += loss_phi
|
314 |
+
|
315 |
+
if len(metrics) > 0:
|
316 |
+
outputs = ((total_loss, metrics),) + outputs
|
317 |
+
|
318 |
+
return outputs
|
319 |
+
|
320 |
+
|
321 |
+
@registry.register_task_model('trrosetta', 'trrosetta')
|
322 |
+
class TRRosetta(TRRosettaAbstractModel):
|
323 |
+
|
324 |
+
def __init__(self, config: TRRosettaConfig):
|
325 |
+
super().__init__(config)
|
326 |
+
self.extract_features = MSAFeatureExtractor(config)
|
327 |
+
self.trrosetta = TRRosettaPredictor(config)
|
328 |
+
|
329 |
+
def forward(self,
|
330 |
+
msa1hot,
|
331 |
+
theta=None,
|
332 |
+
phi=None,
|
333 |
+
dist=None,
|
334 |
+
omega=None):
|
335 |
+
features = self.extract_features(msa1hot)
|
336 |
+
return self.trrosetta(features, theta, phi, dist, omega)
|
tape/models/modeling_unirep.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import typing
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
|
7 |
+
from .modeling_utils import ProteinConfig
|
8 |
+
from .modeling_utils import ProteinModel
|
9 |
+
from .modeling_utils import ValuePredictionHead
|
10 |
+
from .modeling_utils import SequenceClassificationHead
|
11 |
+
from .modeling_utils import SequenceToSequenceClassificationHead
|
12 |
+
from .modeling_utils import PairwiseContactPredictionHead
|
13 |
+
from ..registry import registry
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
|
19 |
+
UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {
|
20 |
+
'babbler-1900': URL_PREFIX + 'unirep-base-config.json'}
|
21 |
+
UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {
|
22 |
+
'babbler-1900': URL_PREFIX + 'unirep-base-pytorch_model.bin'}
|
23 |
+
|
24 |
+
|
25 |
+
class UniRepConfig(ProteinConfig):
|
26 |
+
pretrained_config_archive_map = UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
vocab_size: int = 26,
|
30 |
+
input_size: int = 10,
|
31 |
+
hidden_size: int = 1900,
|
32 |
+
hidden_dropout_prob: float = 0.1,
|
33 |
+
layer_norm_eps: float = 1e-12,
|
34 |
+
initializer_range: float = 0.02,
|
35 |
+
**kwargs):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.vocab_size = vocab_size
|
38 |
+
self.input_size = input_size
|
39 |
+
self.hidden_size = hidden_size
|
40 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
41 |
+
self.layer_norm_eps = layer_norm_eps
|
42 |
+
self.initializer_range = initializer_range
|
43 |
+
|
44 |
+
|
45 |
+
class mLSTMCell(nn.Module):
|
46 |
+
def __init__(self, config):
|
47 |
+
super().__init__()
|
48 |
+
project_size = config.hidden_size * 4
|
49 |
+
self.wmx = weight_norm(
|
50 |
+
nn.Linear(config.input_size, config.hidden_size, bias=False))
|
51 |
+
self.wmh = weight_norm(
|
52 |
+
nn.Linear(config.hidden_size, config.hidden_size, bias=False))
|
53 |
+
self.wx = weight_norm(
|
54 |
+
nn.Linear(config.input_size, project_size, bias=False))
|
55 |
+
self.wh = weight_norm(
|
56 |
+
nn.Linear(config.hidden_size, project_size, bias=True))
|
57 |
+
|
58 |
+
def forward(self, inputs, state):
|
59 |
+
h_prev, c_prev = state
|
60 |
+
m = self.wmx(inputs) * self.wmh(h_prev)
|
61 |
+
z = self.wx(inputs) + self.wh(m)
|
62 |
+
i, f, o, u = torch.chunk(z, 4, 1)
|
63 |
+
i = torch.sigmoid(i)
|
64 |
+
f = torch.sigmoid(f)
|
65 |
+
o = torch.sigmoid(o)
|
66 |
+
u = torch.tanh(u)
|
67 |
+
c = f * c_prev + i * u
|
68 |
+
h = o * torch.tanh(c)
|
69 |
+
|
70 |
+
return h, c
|
71 |
+
|
72 |
+
|
73 |
+
class mLSTM(nn.Module):
|
74 |
+
|
75 |
+
def __init__(self, config):
|
76 |
+
super().__init__()
|
77 |
+
self.mlstm_cell = mLSTMCell(config)
|
78 |
+
self.hidden_size = config.hidden_size
|
79 |
+
|
80 |
+
def forward(self, inputs, state=None, mask=None):
|
81 |
+
batch_size = inputs.size(0)
|
82 |
+
seqlen = inputs.size(1)
|
83 |
+
|
84 |
+
if mask is None:
|
85 |
+
mask = torch.ones(batch_size, seqlen, 1, dtype=inputs.dtype, device=inputs.device)
|
86 |
+
elif mask.dim() == 2:
|
87 |
+
mask = mask.unsqueeze(2)
|
88 |
+
|
89 |
+
if state is None:
|
90 |
+
zeros = torch.zeros(batch_size, self.hidden_size,
|
91 |
+
dtype=inputs.dtype, device=inputs.device)
|
92 |
+
state = (zeros, zeros)
|
93 |
+
|
94 |
+
steps = []
|
95 |
+
for seq in range(seqlen):
|
96 |
+
prev = state
|
97 |
+
seq_input = inputs[:, seq, :]
|
98 |
+
hx, cx = self.mlstm_cell(seq_input, state)
|
99 |
+
seqmask = mask[:, seq]
|
100 |
+
hx = seqmask * hx + (1 - seqmask) * prev[0]
|
101 |
+
cx = seqmask * cx + (1 - seqmask) * prev[1]
|
102 |
+
state = (hx, cx)
|
103 |
+
steps.append(hx)
|
104 |
+
|
105 |
+
return torch.stack(steps, 1), (hx, cx)
|
106 |
+
|
107 |
+
|
108 |
+
class UniRepAbstractModel(ProteinModel):
|
109 |
+
|
110 |
+
config_class = UniRepConfig
|
111 |
+
pretrained_model_archive_map = UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP
|
112 |
+
base_model_prefix = "unirep"
|
113 |
+
|
114 |
+
def _init_weights(self, module):
|
115 |
+
""" Initialize the weights """
|
116 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
117 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
118 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
119 |
+
module.bias.data.zero_()
|
120 |
+
|
121 |
+
|
122 |
+
@registry.register_task_model('embed', 'unirep')
|
123 |
+
class UniRepModel(UniRepAbstractModel):
|
124 |
+
|
125 |
+
def __init__(self, config: UniRepConfig):
|
126 |
+
super().__init__(config)
|
127 |
+
self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size)
|
128 |
+
self.encoder = mLSTM(config)
|
129 |
+
self.output_hidden_states = config.output_hidden_states
|
130 |
+
self.init_weights()
|
131 |
+
|
132 |
+
def forward(self, input_ids, input_mask=None):
|
133 |
+
if input_mask is None:
|
134 |
+
input_mask = torch.ones_like(input_ids)
|
135 |
+
|
136 |
+
# fp16 compatibility
|
137 |
+
input_mask = input_mask.to(dtype=next(self.parameters()).dtype)
|
138 |
+
embedding_output = self.embed_matrix(input_ids)
|
139 |
+
|
140 |
+
encoder_outputs = self.encoder(embedding_output, mask=input_mask)
|
141 |
+
sequence_output = encoder_outputs[0]
|
142 |
+
hidden_states = encoder_outputs[1]
|
143 |
+
pooled_outputs = torch.cat(hidden_states, 1)
|
144 |
+
|
145 |
+
outputs = (sequence_output, pooled_outputs)
|
146 |
+
return outputs
|
147 |
+
|
148 |
+
|
149 |
+
@registry.register_task_model('language_modeling', 'unirep')
|
150 |
+
class UniRepForLM(UniRepAbstractModel):
|
151 |
+
# TODO: Fix this for UniRep - UniRep changes the size of the targets
|
152 |
+
|
153 |
+
def __init__(self, config):
|
154 |
+
super().__init__(config)
|
155 |
+
|
156 |
+
self.unirep = UniRepModel(config)
|
157 |
+
self.feedforward = nn.Linear(config.hidden_size, config.vocab_size - 1)
|
158 |
+
|
159 |
+
self.init_weights()
|
160 |
+
|
161 |
+
def forward(self,
|
162 |
+
input_ids,
|
163 |
+
input_mask=None,
|
164 |
+
targets=None):
|
165 |
+
|
166 |
+
outputs = self.unirep(input_ids, input_mask=input_mask)
|
167 |
+
|
168 |
+
sequence_output, pooled_output = outputs[:2]
|
169 |
+
prediction_scores = self.feedforward(sequence_output)
|
170 |
+
|
171 |
+
# add hidden states and if they are here
|
172 |
+
outputs = (prediction_scores,) + outputs[2:]
|
173 |
+
|
174 |
+
if targets is not None:
|
175 |
+
targets = targets[:, 1:]
|
176 |
+
prediction_scores = prediction_scores[:, :-1]
|
177 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
178 |
+
lm_loss = loss_fct(
|
179 |
+
prediction_scores.view(-1, self.config.vocab_size), targets.view(-1))
|
180 |
+
outputs = (lm_loss,) + outputs
|
181 |
+
|
182 |
+
# (loss), prediction_scores, (hidden_states)
|
183 |
+
return outputs
|
184 |
+
|
185 |
+
|
186 |
+
@registry.register_task_model('fluorescence', 'unirep')
|
187 |
+
@registry.register_task_model('stability', 'unirep')
|
188 |
+
class UniRepForValuePrediction(UniRepAbstractModel):
|
189 |
+
|
190 |
+
def __init__(self, config):
|
191 |
+
super().__init__(config)
|
192 |
+
|
193 |
+
self.unirep = UniRepModel(config)
|
194 |
+
self.predict = ValuePredictionHead(config.hidden_size * 2)
|
195 |
+
|
196 |
+
self.init_weights()
|
197 |
+
|
198 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
199 |
+
|
200 |
+
outputs = self.unirep(input_ids, input_mask=input_mask)
|
201 |
+
|
202 |
+
sequence_output, pooled_output = outputs[:2]
|
203 |
+
outputs = self.predict(pooled_output, targets) + outputs[2:]
|
204 |
+
# (loss), prediction_scores, (hidden_states)
|
205 |
+
return outputs
|
206 |
+
|
207 |
+
|
208 |
+
@registry.register_task_model('remote_homology', 'unirep')
|
209 |
+
class UniRepForSequenceClassification(UniRepAbstractModel):
|
210 |
+
|
211 |
+
def __init__(self, config):
|
212 |
+
super().__init__(config)
|
213 |
+
|
214 |
+
self.unirep = UniRepModel(config)
|
215 |
+
self.classify = SequenceClassificationHead(
|
216 |
+
config.hidden_size * 2, config.num_labels)
|
217 |
+
|
218 |
+
self.init_weights()
|
219 |
+
|
220 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
221 |
+
|
222 |
+
outputs = self.unirep(input_ids, input_mask=input_mask)
|
223 |
+
|
224 |
+
sequence_output, pooled_output = outputs[:2]
|
225 |
+
outputs = self.classify(pooled_output, targets) + outputs[2:]
|
226 |
+
# (loss), prediction_scores, (hidden_states)
|
227 |
+
return outputs
|
228 |
+
|
229 |
+
|
230 |
+
@registry.register_task_model('secondary_structure', 'unirep')
|
231 |
+
class UniRepForSequenceToSequenceClassification(UniRepAbstractModel):
|
232 |
+
|
233 |
+
def __init__(self, config):
|
234 |
+
super().__init__(config)
|
235 |
+
|
236 |
+
self.unirep = UniRepModel(config)
|
237 |
+
self.classify = SequenceToSequenceClassificationHead(
|
238 |
+
config.hidden_size, config.num_labels, ignore_index=-1)
|
239 |
+
|
240 |
+
self.init_weights()
|
241 |
+
|
242 |
+
def forward(self, input_ids, input_mask=None, targets=None):
|
243 |
+
|
244 |
+
outputs = self.unirep(input_ids, input_mask=input_mask)
|
245 |
+
|
246 |
+
sequence_output, pooled_output = outputs[:2]
|
247 |
+
outputs = self.classify(sequence_output, targets) + outputs[2:]
|
248 |
+
# (loss), prediction_scores, (hidden_states)
|
249 |
+
return outputs
|
250 |
+
|
251 |
+
|
252 |
+
@registry.register_task_model('contact_prediction', 'unirep')
|
253 |
+
class UniRepForContactPrediction(UniRepAbstractModel):
|
254 |
+
|
255 |
+
def __init__(self, config):
|
256 |
+
super().__init__(config)
|
257 |
+
|
258 |
+
self.unirep = UniRepModel(config)
|
259 |
+
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
|
260 |
+
|
261 |
+
self.init_weights()
|
262 |
+
|
263 |
+
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
|
264 |
+
|
265 |
+
outputs = self.unirep(input_ids, input_mask=input_mask)
|
266 |
+
|
267 |
+
sequence_output, pooled_output = outputs[:2]
|
268 |
+
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
|
269 |
+
# (loss), prediction_scores, (hidden_states), (attentions)
|
270 |
+
return outputs
|
tape/models/modeling_utils.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
# Modified by Roshan Rao
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""PyTorch Protein models."""
|
18 |
+
from __future__ import (absolute_import, division, print_function,
|
19 |
+
unicode_literals)
|
20 |
+
import typing
|
21 |
+
import copy
|
22 |
+
import json
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
from io import open
|
26 |
+
import math
|
27 |
+
from torch.nn.utils.weight_norm import weight_norm
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from torch import nn
|
31 |
+
import torch.nn.functional as F
|
32 |
+
|
33 |
+
from .file_utils import cached_path
|
34 |
+
|
35 |
+
CONFIG_NAME = "config.json"
|
36 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
class ProteinConfig(object):
|
42 |
+
""" Base class for all configuration classes.
|
43 |
+
Handles a few parameters common to all models' configurations as well as methods
|
44 |
+
for loading/downloading/saving configurations.
|
45 |
+
|
46 |
+
Class attributes (overridden by derived classes):
|
47 |
+
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names`
|
48 |
+
(string) as keys and `url` (string) of associated pretrained model
|
49 |
+
configurations as values.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
``finetuning_task``: string, default `None`. Name of the task used to fine-tune
|
53 |
+
the model.
|
54 |
+
``num_labels``: integer, default `2`. Number of classes to use when the model is
|
55 |
+
a classification model (sequences/tokens)
|
56 |
+
``output_attentions``: boolean, default `False`. Should the model returns
|
57 |
+
attentions weights.
|
58 |
+
``output_hidden_states``: string, default `False`. Should the model returns all
|
59 |
+
hidden-states.
|
60 |
+
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
61 |
+
"""
|
62 |
+
pretrained_config_archive_map: typing.Dict[str, str] = {}
|
63 |
+
|
64 |
+
def __init__(self, **kwargs):
|
65 |
+
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
66 |
+
self.num_labels = kwargs.pop('num_labels', 2)
|
67 |
+
self.output_attentions = kwargs.pop('output_attentions', False)
|
68 |
+
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
69 |
+
self.torchscript = kwargs.pop('torchscript', False)
|
70 |
+
|
71 |
+
def save_pretrained(self, save_directory):
|
72 |
+
""" Save a configuration object to the directory `save_directory`, so that it
|
73 |
+
can be re-loaded using the :func:`~ProteinConfig.from_pretrained`
|
74 |
+
class method.
|
75 |
+
"""
|
76 |
+
assert os.path.isdir(save_directory), "Saving path should be a directory where the " \
|
77 |
+
"model and configuration can be saved"
|
78 |
+
|
79 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
80 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
81 |
+
|
82 |
+
self.to_json_file(output_config_file)
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
86 |
+
r""" Instantiate a :class:`~ProteinConfig`
|
87 |
+
(or a derived class) from a pre-trained model configuration.
|
88 |
+
|
89 |
+
Parameters:
|
90 |
+
pretrained_model_name_or_path: either:
|
91 |
+
|
92 |
+
- a string with the `shortcut name` of a pre-trained model configuration to
|
93 |
+
load from cache or download, e.g.: ``bert-base-uncased``.
|
94 |
+
- a path to a `directory` containing a configuration file saved using the
|
95 |
+
:func:`~ProteinConfig.save_pretrained` method,
|
96 |
+
e.g.: ``./my_model_directory/``.
|
97 |
+
- a path or url to a saved configuration JSON `file`,
|
98 |
+
e.g.: ``./my_model_directory/configuration.json``.
|
99 |
+
|
100 |
+
cache_dir: (`optional`) string:
|
101 |
+
Path to a directory in which a downloaded pre-trained model
|
102 |
+
configuration should be cached if the standard cache should not be used.
|
103 |
+
|
104 |
+
kwargs: (`optional`) dict:
|
105 |
+
key/value pairs with which to update the configuration object after loading.
|
106 |
+
|
107 |
+
- The values in kwargs of any keys which are configuration attributes will
|
108 |
+
be used to override the loaded values.
|
109 |
+
- Behavior concerning key/value pairs whose keys are *not* configuration
|
110 |
+
attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
111 |
+
|
112 |
+
return_unused_kwargs: (`optional`) bool:
|
113 |
+
|
114 |
+
- If False, then this function returns just the final configuration object.
|
115 |
+
- If True, then this functions returns a tuple `(config, unused_kwargs)`
|
116 |
+
where `unused_kwargs` is a dictionary consisting of the key/value pairs
|
117 |
+
whose keys are not configuration attributes: ie the part of kwargs which
|
118 |
+
has not been used to update `config` and is otherwise ignored.
|
119 |
+
|
120 |
+
Examples::
|
121 |
+
|
122 |
+
# We can't instantiate directly the base class `ProteinConfig` so let's
|
123 |
+
show the examples on a derived class: ProteinBertConfig
|
124 |
+
# Download configuration from S3 and cache.
|
125 |
+
config = ProteinBertConfig.from_pretrained('bert-base-uncased')
|
126 |
+
# E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
127 |
+
config = ProteinBertConfig.from_pretrained('./test/saved_model/')
|
128 |
+
config = ProteinBertConfig.from_pretrained(
|
129 |
+
'./test/saved_model/my_configuration.json')
|
130 |
+
config = ProteinBertConfig.from_pretrained(
|
131 |
+
'bert-base-uncased', output_attention=True, foo=False)
|
132 |
+
assert config.output_attention == True
|
133 |
+
config, unused_kwargs = BertConfig.from_pretrained(
|
134 |
+
'bert-base-uncased', output_attention=True,
|
135 |
+
foo=False, return_unused_kwargs=True)
|
136 |
+
assert config.output_attention == True
|
137 |
+
assert unused_kwargs == {'foo': False}
|
138 |
+
|
139 |
+
"""
|
140 |
+
cache_dir = kwargs.pop('cache_dir', None)
|
141 |
+
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
142 |
+
|
143 |
+
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
144 |
+
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
145 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
146 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
147 |
+
else:
|
148 |
+
config_file = pretrained_model_name_or_path
|
149 |
+
# redirect to the cache, if necessary
|
150 |
+
try:
|
151 |
+
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
152 |
+
except EnvironmentError:
|
153 |
+
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
154 |
+
logger.error("Couldn't reach server at '{}' to download pretrained model "
|
155 |
+
"configuration file.".format(config_file))
|
156 |
+
else:
|
157 |
+
logger.error(
|
158 |
+
"Model name '{}' was not found in model name list ({}). "
|
159 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
160 |
+
"associated to this path or url.".format(
|
161 |
+
pretrained_model_name_or_path,
|
162 |
+
', '.join(cls.pretrained_config_archive_map.keys()),
|
163 |
+
config_file))
|
164 |
+
return None
|
165 |
+
if resolved_config_file == config_file:
|
166 |
+
logger.info("loading configuration file {}".format(config_file))
|
167 |
+
else:
|
168 |
+
logger.info("loading configuration file {} from cache at {}".format(
|
169 |
+
config_file, resolved_config_file))
|
170 |
+
|
171 |
+
# Load config
|
172 |
+
config = cls.from_json_file(resolved_config_file)
|
173 |
+
|
174 |
+
# Update config with kwargs if needed
|
175 |
+
to_remove = []
|
176 |
+
for key, value in kwargs.items():
|
177 |
+
if hasattr(config, key):
|
178 |
+
setattr(config, key, value)
|
179 |
+
to_remove.append(key)
|
180 |
+
for key in to_remove:
|
181 |
+
kwargs.pop(key, None)
|
182 |
+
|
183 |
+
logger.info("Model config %s", config)
|
184 |
+
if return_unused_kwargs:
|
185 |
+
return config, kwargs
|
186 |
+
else:
|
187 |
+
return config
|
188 |
+
|
189 |
+
@classmethod
|
190 |
+
def from_dict(cls, json_object):
|
191 |
+
"""Constructs a `Config` from a Python dictionary of parameters."""
|
192 |
+
config = cls(vocab_size_or_config_json_file=-1)
|
193 |
+
for key, value in json_object.items():
|
194 |
+
config.__dict__[key] = value
|
195 |
+
return config
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def from_json_file(cls, json_file):
|
199 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
200 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
201 |
+
text = reader.read()
|
202 |
+
return cls.from_dict(json.loads(text))
|
203 |
+
|
204 |
+
def __eq__(self, other):
|
205 |
+
return self.__dict__ == other.__dict__
|
206 |
+
|
207 |
+
def __repr__(self):
|
208 |
+
return str(self.to_json_string())
|
209 |
+
|
210 |
+
def to_dict(self):
|
211 |
+
"""Serializes this instance to a Python dictionary."""
|
212 |
+
output = copy.deepcopy(self.__dict__)
|
213 |
+
return output
|
214 |
+
|
215 |
+
def to_json_string(self):
|
216 |
+
"""Serializes this instance to a JSON string."""
|
217 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
218 |
+
|
219 |
+
def to_json_file(self, json_file_path):
|
220 |
+
""" Save this instance to a json file."""
|
221 |
+
with open(json_file_path, "w", encoding='utf-8') as writer:
|
222 |
+
writer.write(self.to_json_string())
|
223 |
+
|
224 |
+
|
225 |
+
class ProteinModel(nn.Module):
|
226 |
+
r""" Base class for all models.
|
227 |
+
|
228 |
+
:class:`~ProteinModel` takes care of storing the configuration of
|
229 |
+
the models and handles methods for loading/downloading/saving models as well as a
|
230 |
+
few methods commons to all models to (i) resize the input embeddings and (ii) prune
|
231 |
+
heads in the self-attention heads.
|
232 |
+
|
233 |
+
Class attributes (overridden by derived classes):
|
234 |
+
- ``config_class``: a class derived from :class:`~ProteinConfig`
|
235 |
+
to use as configuration class for this model architecture.
|
236 |
+
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names`
|
237 |
+
(string) as keys and `url` (string) of associated pretrained weights as values.
|
238 |
+
|
239 |
+
- ``base_model_prefix``: a string indicating the attribute associated to the
|
240 |
+
base model in derived classes of the same architecture adding modules on top
|
241 |
+
of the base model.
|
242 |
+
"""
|
243 |
+
config_class: typing.Type[ProteinConfig] = ProteinConfig
|
244 |
+
pretrained_model_archive_map: typing.Dict[str, str] = {}
|
245 |
+
base_model_prefix = ""
|
246 |
+
|
247 |
+
def __init__(self, config, *inputs, **kwargs):
|
248 |
+
super().__init__()
|
249 |
+
if not isinstance(config, ProteinConfig):
|
250 |
+
raise ValueError(
|
251 |
+
"Parameter config in `{}(config)` should be an instance of class "
|
252 |
+
"`ProteinConfig`. To create a model from a pretrained model use "
|
253 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
254 |
+
self.__class__.__name__, self.__class__.__name__
|
255 |
+
))
|
256 |
+
# Save config in model
|
257 |
+
self.config = config
|
258 |
+
|
259 |
+
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
260 |
+
""" Build a resized Embedding Module from a provided token Embedding Module.
|
261 |
+
Increasing the size will add newly initialized vectors at the end
|
262 |
+
Reducing the size will remove vectors from the end
|
263 |
+
|
264 |
+
Args:
|
265 |
+
new_num_tokens: (`optional`) int
|
266 |
+
New number of tokens in the embedding matrix.
|
267 |
+
Increasing the size will add newly initialized vectors at the end
|
268 |
+
Reducing the size will remove vectors from the end
|
269 |
+
If not provided or None: return the provided token Embedding Module.
|
270 |
+
Return: ``torch.nn.Embeddings``
|
271 |
+
Pointer to the resized Embedding Module or the old Embedding Module if
|
272 |
+
new_num_tokens is None
|
273 |
+
"""
|
274 |
+
if new_num_tokens is None:
|
275 |
+
return old_embeddings
|
276 |
+
|
277 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
278 |
+
if old_num_tokens == new_num_tokens:
|
279 |
+
return old_embeddings
|
280 |
+
|
281 |
+
# Build new embeddings
|
282 |
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
283 |
+
new_embeddings.to(old_embeddings.weight.device)
|
284 |
+
|
285 |
+
# initialize all new embeddings (in particular added tokens)
|
286 |
+
self.init_weights(new_embeddings)
|
287 |
+
|
288 |
+
# Copy word embeddings from the previous weights
|
289 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
290 |
+
new_embeddings.weight.data[:num_tokens_to_copy, :] = \
|
291 |
+
old_embeddings.weight.data[:num_tokens_to_copy, :]
|
292 |
+
|
293 |
+
return new_embeddings
|
294 |
+
|
295 |
+
def _tie_or_clone_weights(self, first_module, second_module):
|
296 |
+
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
297 |
+
"""
|
298 |
+
if self.config.torchscript:
|
299 |
+
first_module.weight = nn.Parameter(second_module.weight.clone())
|
300 |
+
else:
|
301 |
+
first_module.weight = second_module.weight
|
302 |
+
|
303 |
+
def resize_token_embeddings(self, new_num_tokens=None):
|
304 |
+
""" Resize input token embeddings matrix of the model if
|
305 |
+
new_num_tokens != config.vocab_size. Take care of tying weights embeddings
|
306 |
+
afterwards if the model class has a `tie_weights()` method.
|
307 |
+
|
308 |
+
Arguments:
|
309 |
+
|
310 |
+
new_num_tokens: (`optional`) int:
|
311 |
+
New number of tokens in the embedding matrix. Increasing the size will add
|
312 |
+
newly initialized vectors at the end. Reducing the size will remove vectors
|
313 |
+
from the end. If not provided or None: does nothing and just returns a
|
314 |
+
pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
315 |
+
|
316 |
+
Return: ``torch.nn.Embeddings``
|
317 |
+
Pointer to the input tokens Embeddings Module of the model
|
318 |
+
"""
|
319 |
+
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
320 |
+
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
321 |
+
if new_num_tokens is None:
|
322 |
+
return model_embeds
|
323 |
+
|
324 |
+
# Update base model and current model config
|
325 |
+
self.config.vocab_size = new_num_tokens
|
326 |
+
base_model.vocab_size = new_num_tokens
|
327 |
+
|
328 |
+
# Tie weights again if needed
|
329 |
+
if hasattr(self, 'tie_weights'):
|
330 |
+
self.tie_weights()
|
331 |
+
|
332 |
+
return model_embeds
|
333 |
+
|
334 |
+
def init_weights(self):
|
335 |
+
""" Initialize and prunes weights if needed. """
|
336 |
+
# Initialize weights
|
337 |
+
self.apply(self._init_weights)
|
338 |
+
|
339 |
+
# Prune heads if needed
|
340 |
+
if getattr(self.config, 'pruned_heads', False):
|
341 |
+
self.prune_heads(self.config.pruned_heads)
|
342 |
+
|
343 |
+
def prune_heads(self, heads_to_prune):
|
344 |
+
""" Prunes heads of the base model.
|
345 |
+
|
346 |
+
Arguments:
|
347 |
+
|
348 |
+
heads_to_prune: dict with keys being selected layer indices (`int`) and
|
349 |
+
associated values being the list of heads to prune in said layer
|
350 |
+
(list of `int`).
|
351 |
+
"""
|
352 |
+
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
353 |
+
base_model._prune_heads(heads_to_prune)
|
354 |
+
|
355 |
+
def save_pretrained(self, save_directory):
|
356 |
+
""" Save a model and its configuration file to a directory, so that it
|
357 |
+
can be re-loaded using the `:func:`~ProteinModel.from_pretrained`
|
358 |
+
` class method.
|
359 |
+
"""
|
360 |
+
assert os.path.isdir(save_directory), "Saving path should be a directory where "\
|
361 |
+
"the model and configuration can be saved"
|
362 |
+
|
363 |
+
# Only save the model it-self if we are using distributed training
|
364 |
+
model_to_save = self.module if hasattr(self, 'module') else self
|
365 |
+
|
366 |
+
# Save configuration file
|
367 |
+
model_to_save.config.save_pretrained(save_directory)
|
368 |
+
|
369 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
370 |
+
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
371 |
+
|
372 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
373 |
+
|
374 |
+
@classmethod
|
375 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
376 |
+
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
377 |
+
|
378 |
+
The model is set in evaluation mode by default using ``model.eval()``
|
379 |
+
(Dropout modules are deactivated)
|
380 |
+
To train the model, you should first set it back in training mode with ``model.train()``
|
381 |
+
|
382 |
+
The warning ``Weights from XXX not initialized from pretrained model`` means that
|
383 |
+
the weights of XXX do not come pre-trained with the rest of the model.
|
384 |
+
It is up to you to train those weights with a downstream fine-tuning task.
|
385 |
+
|
386 |
+
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used
|
387 |
+
by YYY, therefore those weights are discarded.
|
388 |
+
|
389 |
+
Parameters:
|
390 |
+
pretrained_model_name_or_path: either:
|
391 |
+
|
392 |
+
- a string with the `shortcut name` of a pre-trained model to load from cache
|
393 |
+
or download, e.g.: ``bert-base-uncased``.
|
394 |
+
- a path to a `directory` containing model weights saved using
|
395 |
+
:func:`~ProteinModel.save_pretrained`,
|
396 |
+
e.g.: ``./my_model_directory/``.
|
397 |
+
|
398 |
+
model_args: (`optional`) Sequence of positional arguments:
|
399 |
+
All remaning positional arguments will be passed to the underlying model's
|
400 |
+
``__init__`` method
|
401 |
+
|
402 |
+
config: (`optional`) instance of a class derived from
|
403 |
+
:class:`~ProteinConfig`: Configuration for the model to
|
404 |
+
use instead of an automatically loaded configuation. Configuration can be
|
405 |
+
automatically loaded when:
|
406 |
+
|
407 |
+
- the model is a model provided by the library (loaded with the
|
408 |
+
``shortcut-name`` string of a pretrained model), or
|
409 |
+
- the model was saved using
|
410 |
+
:func:`~ProteinModel.save_pretrained` and is reloaded
|
411 |
+
by suppling the save directory.
|
412 |
+
- the model is loaded by suppling a local directory as
|
413 |
+
``pretrained_model_name_or_path`` and a configuration JSON file named
|
414 |
+
`config.json` is found in the directory.
|
415 |
+
|
416 |
+
state_dict: (`optional`) dict:
|
417 |
+
an optional state dictionnary for the model to use instead of a state
|
418 |
+
dictionary loaded from saved weights file. This option can be used if you
|
419 |
+
want to create a model from a pretrained configuration but load your own
|
420 |
+
weights. In this case though, you should check if using
|
421 |
+
:func:`~ProteinModel.save_pretrained` and
|
422 |
+
:func:`~ProteinModel.from_pretrained` is not a
|
423 |
+
simpler option.
|
424 |
+
|
425 |
+
cache_dir: (`optional`) string:
|
426 |
+
Path to a directory in which a downloaded pre-trained model
|
427 |
+
configuration should be cached if the standard cache should not be used.
|
428 |
+
|
429 |
+
force_download: (`optional`) boolean, default False:
|
430 |
+
Force to (re-)download the model weights and configuration files and override
|
431 |
+
the cached versions if they exists.
|
432 |
+
|
433 |
+
resume_download: (`optional`) boolean, default False:
|
434 |
+
Do not delete incompletely recieved file. Attempt to resume the download if
|
435 |
+
such a file exists.
|
436 |
+
|
437 |
+
output_loading_info: (`optional`) boolean:
|
438 |
+
Set to ``True`` to also return a dictionnary containing missing keys,
|
439 |
+
unexpected keys and error messages.
|
440 |
+
|
441 |
+
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
442 |
+
Can be used to update the configuration object (after it being loaded) and
|
443 |
+
initiate the model. (e.g. ``output_attention=True``). Behave differently
|
444 |
+
depending on whether a `config` is provided or automatically loaded:
|
445 |
+
|
446 |
+
- If a configuration is provided with ``config``, ``**kwarg
|
447 |
+
directly passed to the underlying model's ``__init__`` method (we assume
|
448 |
+
all relevant updates to the configuration have already been done)
|
449 |
+
- If a configuration is not provided, ``kwargs`` will be first passed to the
|
450 |
+
configuration class initialization function
|
451 |
+
(:func:`~ProteinConfig.from_pretrained`). Each key of
|
452 |
+
``kwargs`` that corresponds to a configuration attribute will be used to
|
453 |
+
override said attribute with the supplied ``kwargs`` value. Remaining keys
|
454 |
+
that do not correspond to any configuration attribute will be passed to the
|
455 |
+
underlying model's ``__init__`` function.
|
456 |
+
|
457 |
+
Examples::
|
458 |
+
|
459 |
+
# Download model and configuration from S3 and cache.
|
460 |
+
model = ProteinBertModel.from_pretrained('bert-base-uncased')
|
461 |
+
# E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
462 |
+
model = ProteinBertModel.from_pretrained('./test/saved_model/')
|
463 |
+
# Update configuration during loading
|
464 |
+
model = ProteinBertModel.from_pretrained('bert-base-uncased', output_attention=True)
|
465 |
+
assert model.config.output_attention == True
|
466 |
+
|
467 |
+
"""
|
468 |
+
config = kwargs.pop('config', None)
|
469 |
+
state_dict = kwargs.pop('state_dict', None)
|
470 |
+
cache_dir = kwargs.pop('cache_dir', None)
|
471 |
+
output_loading_info = kwargs.pop('output_loading_info', False)
|
472 |
+
|
473 |
+
force_download = kwargs.pop("force_download", False)
|
474 |
+
kwargs.pop("resume_download", False)
|
475 |
+
|
476 |
+
# Load config
|
477 |
+
if config is None:
|
478 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
479 |
+
pretrained_model_name_or_path, *model_args,
|
480 |
+
cache_dir=cache_dir, return_unused_kwargs=True,
|
481 |
+
# force_download=force_download,
|
482 |
+
# resume_download=resume_download,
|
483 |
+
**kwargs
|
484 |
+
)
|
485 |
+
else:
|
486 |
+
model_kwargs = kwargs
|
487 |
+
|
488 |
+
# Load model
|
489 |
+
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
490 |
+
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
491 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
492 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
493 |
+
else:
|
494 |
+
archive_file = pretrained_model_name_or_path
|
495 |
+
# redirect to the cache, if necessary
|
496 |
+
try:
|
497 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir,
|
498 |
+
force_download=force_download)
|
499 |
+
except EnvironmentError:
|
500 |
+
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
501 |
+
logger.error(
|
502 |
+
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
503 |
+
archive_file))
|
504 |
+
else:
|
505 |
+
logger.error(
|
506 |
+
"Model name '{}' was not found in model name list ({}). "
|
507 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
508 |
+
"associated to this path or url.".format(
|
509 |
+
pretrained_model_name_or_path,
|
510 |
+
', '.join(cls.pretrained_model_archive_map.keys()),
|
511 |
+
archive_file))
|
512 |
+
return None
|
513 |
+
if resolved_archive_file == archive_file:
|
514 |
+
logger.info("loading weights file {}".format(archive_file))
|
515 |
+
else:
|
516 |
+
logger.info("loading weights file {} from cache at {}".format(
|
517 |
+
archive_file, resolved_archive_file))
|
518 |
+
|
519 |
+
# Instantiate model.
|
520 |
+
model = cls(config, *model_args, **model_kwargs)
|
521 |
+
|
522 |
+
if state_dict is None:
|
523 |
+
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
524 |
+
|
525 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
526 |
+
old_keys = []
|
527 |
+
new_keys = []
|
528 |
+
for key in state_dict.keys():
|
529 |
+
new_key = None
|
530 |
+
if 'gamma' in key:
|
531 |
+
new_key = key.replace('gamma', 'weight')
|
532 |
+
if 'beta' in key:
|
533 |
+
new_key = key.replace('beta', 'bias')
|
534 |
+
if new_key:
|
535 |
+
old_keys.append(key)
|
536 |
+
new_keys.append(new_key)
|
537 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
538 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
539 |
+
|
540 |
+
# Load from a PyTorch state_dict
|
541 |
+
missing_keys = []
|
542 |
+
unexpected_keys = []
|
543 |
+
error_msgs = []
|
544 |
+
# copy state_dict so _load_from_state_dict can modify it
|
545 |
+
metadata = getattr(state_dict, '_metadata', None)
|
546 |
+
state_dict = state_dict.copy()
|
547 |
+
if metadata is not None:
|
548 |
+
state_dict._metadata = metadata
|
549 |
+
|
550 |
+
def load(module, prefix=''):
|
551 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
552 |
+
module._load_from_state_dict(
|
553 |
+
state_dict, prefix, local_metadata, True, missing_keys,
|
554 |
+
unexpected_keys, error_msgs)
|
555 |
+
for name, child in module._modules.items():
|
556 |
+
if child is not None:
|
557 |
+
load(child, prefix + name + '.')
|
558 |
+
|
559 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
560 |
+
start_prefix = ''
|
561 |
+
model_to_load = model
|
562 |
+
if cls.base_model_prefix not in (None, ''):
|
563 |
+
if not hasattr(model, cls.base_model_prefix) and \
|
564 |
+
any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
565 |
+
start_prefix = cls.base_model_prefix + '.'
|
566 |
+
if hasattr(model, cls.base_model_prefix) and \
|
567 |
+
not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
568 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
569 |
+
|
570 |
+
load(model_to_load, prefix=start_prefix)
|
571 |
+
if len(missing_keys) > 0:
|
572 |
+
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
573 |
+
model.__class__.__name__, missing_keys))
|
574 |
+
if len(unexpected_keys) > 0:
|
575 |
+
logger.info("Weights from pretrained model not used in {}: {}".format(
|
576 |
+
model.__class__.__name__, unexpected_keys))
|
577 |
+
if len(error_msgs) > 0:
|
578 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
579 |
+
model.__class__.__name__, "\n\t".join(error_msgs)))
|
580 |
+
|
581 |
+
if hasattr(model, 'tie_weights'):
|
582 |
+
model.tie_weights() # make sure word embedding weights are still tied
|
583 |
+
|
584 |
+
# Set model in evaluation mode to desactivate DropOut modules by default
|
585 |
+
model.eval()
|
586 |
+
|
587 |
+
if output_loading_info:
|
588 |
+
loading_info = {
|
589 |
+
"missing_keys": missing_keys,
|
590 |
+
"unexpected_keys": unexpected_keys,
|
591 |
+
"error_msgs": error_msgs}
|
592 |
+
return model, loading_info
|
593 |
+
|
594 |
+
return model
|
595 |
+
|
596 |
+
|
597 |
+
def prune_linear_layer(layer, index, dim=0):
|
598 |
+
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
599 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
600 |
+
Used to remove heads.
|
601 |
+
"""
|
602 |
+
index = index.to(layer.weight.device)
|
603 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
604 |
+
if layer.bias is not None:
|
605 |
+
if dim == 1:
|
606 |
+
b = layer.bias.clone().detach()
|
607 |
+
else:
|
608 |
+
b = layer.bias[index].clone().detach()
|
609 |
+
new_size = list(layer.weight.size())
|
610 |
+
new_size[dim] = len(index)
|
611 |
+
new_layer = nn.Linear(
|
612 |
+
new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
613 |
+
new_layer.weight.requires_grad = False
|
614 |
+
new_layer.weight.copy_(W.contiguous())
|
615 |
+
new_layer.weight.requires_grad = True
|
616 |
+
if layer.bias is not None:
|
617 |
+
new_layer.bias.requires_grad = False
|
618 |
+
new_layer.bias.copy_(b.contiguous())
|
619 |
+
new_layer.bias.requires_grad = True
|
620 |
+
return new_layer
|
621 |
+
|
622 |
+
|
623 |
+
def accuracy(logits, labels, ignore_index: int = -100):
|
624 |
+
with torch.no_grad():
|
625 |
+
valid_mask = (labels != ignore_index)
|
626 |
+
predictions = logits.float().argmax(-1)
|
627 |
+
correct = (predictions == labels) * valid_mask
|
628 |
+
return correct.sum().float() / valid_mask.sum().float()
|
629 |
+
|
630 |
+
|
631 |
+
def gelu(x):
|
632 |
+
"""Implementation of the gelu activation function.
|
633 |
+
For information: OpenAI GPT's gelu is slightly different
|
634 |
+
(and gives slightly different results):
|
635 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
636 |
+
Also see https://arxiv.org/abs/1606.08415
|
637 |
+
"""
|
638 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
639 |
+
|
640 |
+
|
641 |
+
def swish(x):
|
642 |
+
return x * torch.sigmoid(x)
|
643 |
+
|
644 |
+
|
645 |
+
def get_activation_fn(name: str) -> typing.Callable:
|
646 |
+
if name == 'gelu':
|
647 |
+
return gelu
|
648 |
+
elif name == 'relu':
|
649 |
+
return torch.nn.functional.relu
|
650 |
+
elif name == 'swish':
|
651 |
+
return swish
|
652 |
+
else:
|
653 |
+
raise ValueError(f"Unrecognized activation fn: {name}")
|
654 |
+
|
655 |
+
|
656 |
+
try:
|
657 |
+
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm # type: ignore
|
658 |
+
except (ImportError, AttributeError):
|
659 |
+
logger.info("Better speed can be achieved with apex installed from "
|
660 |
+
"https://www.github.com/nvidia/apex .")
|
661 |
+
|
662 |
+
class LayerNorm(nn.Module): # type: ignore
|
663 |
+
def __init__(self, hidden_size, eps=1e-12):
|
664 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
665 |
+
"""
|
666 |
+
super().__init__()
|
667 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
668 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
669 |
+
self.variance_epsilon = eps
|
670 |
+
|
671 |
+
def forward(self, x):
|
672 |
+
u = x.mean(-1, keepdim=True)
|
673 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
674 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
675 |
+
return self.weight * x + self.bias
|
676 |
+
|
677 |
+
|
678 |
+
class SimpleMLP(nn.Module):
|
679 |
+
|
680 |
+
def __init__(self,
|
681 |
+
in_dim: int,
|
682 |
+
hid_dim: int,
|
683 |
+
out_dim: int,
|
684 |
+
dropout: float = 0.):
|
685 |
+
super().__init__()
|
686 |
+
self.main = nn.Sequential(
|
687 |
+
weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
|
688 |
+
nn.ReLU(),
|
689 |
+
nn.Dropout(dropout, inplace=True),
|
690 |
+
weight_norm(nn.Linear(hid_dim, out_dim), dim=None))
|
691 |
+
|
692 |
+
def forward(self, x):
|
693 |
+
return self.main(x)
|
694 |
+
|
695 |
+
|
696 |
+
class SimpleConv(nn.Module):
|
697 |
+
|
698 |
+
def __init__(self,
|
699 |
+
in_dim: int,
|
700 |
+
hid_dim: int,
|
701 |
+
out_dim: int,
|
702 |
+
dropout: float = 0.):
|
703 |
+
super().__init__()
|
704 |
+
self.main = nn.Sequential(
|
705 |
+
nn.BatchNorm1d(in_dim), # Added this
|
706 |
+
weight_norm(nn.Conv1d(in_dim, hid_dim, 5, padding=2), dim=None),
|
707 |
+
nn.ReLU(),
|
708 |
+
nn.Dropout(dropout, inplace=True),
|
709 |
+
weight_norm(nn.Conv1d(hid_dim, out_dim, 3, padding=1), dim=None))
|
710 |
+
|
711 |
+
def forward(self, x):
|
712 |
+
x = x.transpose(1, 2)
|
713 |
+
x = self.main(x)
|
714 |
+
x = x.transpose(1, 2).contiguous()
|
715 |
+
return x
|
716 |
+
|
717 |
+
|
718 |
+
class Accuracy(nn.Module):
|
719 |
+
|
720 |
+
def __init__(self, ignore_index: int = -100):
|
721 |
+
super().__init__()
|
722 |
+
self.ignore_index = ignore_index
|
723 |
+
|
724 |
+
def forward(self, inputs, target):
|
725 |
+
return accuracy(inputs, target, self.ignore_index)
|
726 |
+
|
727 |
+
|
728 |
+
class PredictionHeadTransform(nn.Module):
|
729 |
+
|
730 |
+
def __init__(self,
|
731 |
+
hidden_size: int,
|
732 |
+
hidden_act: typing.Union[str, typing.Callable] = 'gelu',
|
733 |
+
layer_norm_eps: float = 1e-12):
|
734 |
+
super().__init__()
|
735 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
736 |
+
if isinstance(hidden_act, str):
|
737 |
+
self.transform_act_fn = get_activation_fn(hidden_act)
|
738 |
+
else:
|
739 |
+
self.transform_act_fn = hidden_act
|
740 |
+
self.LayerNorm = LayerNorm(hidden_size, eps=layer_norm_eps)
|
741 |
+
|
742 |
+
def forward(self, hidden_states):
|
743 |
+
hidden_states = self.dense(hidden_states)
|
744 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
745 |
+
hidden_states = self.LayerNorm(hidden_states)
|
746 |
+
return hidden_states
|
747 |
+
|
748 |
+
|
749 |
+
class MLMHead(nn.Module):
|
750 |
+
|
751 |
+
def __init__(self,
|
752 |
+
hidden_size: int,
|
753 |
+
vocab_size: int,
|
754 |
+
hidden_act: typing.Union[str, typing.Callable] = 'gelu',
|
755 |
+
layer_norm_eps: float = 1e-12,
|
756 |
+
ignore_index: int = -100):
|
757 |
+
super().__init__()
|
758 |
+
self.transform = PredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps)
|
759 |
+
|
760 |
+
# The output weights are the same as the input embeddings, but there is
|
761 |
+
# an output-only bias for each token.
|
762 |
+
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
|
763 |
+
self.bias = nn.Parameter(data=torch.zeros(vocab_size)) # type: ignore
|
764 |
+
self.vocab_size = vocab_size
|
765 |
+
self._ignore_index = ignore_index
|
766 |
+
|
767 |
+
def forward(self, hidden_states, targets=None):
|
768 |
+
hidden_states = self.transform(hidden_states)
|
769 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
770 |
+
outputs = (hidden_states,)
|
771 |
+
if targets is not None:
|
772 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
|
773 |
+
masked_lm_loss = loss_fct(
|
774 |
+
hidden_states.reshape(-1, self.vocab_size), targets.reshape(-1))
|
775 |
+
metrics = {'perplexity': torch.exp(masked_lm_loss)}
|
776 |
+
loss_and_metrics = (masked_lm_loss, metrics)
|
777 |
+
outputs = (loss_and_metrics,) + outputs
|
778 |
+
return outputs # (loss), prediction_scores
|
779 |
+
|
780 |
+
|
781 |
+
class ValuePredictionHead(nn.Module):
|
782 |
+
def __init__(self, hidden_size: int, dropout: float = 0.):
|
783 |
+
super().__init__()
|
784 |
+
self.value_prediction = SimpleMLP(hidden_size, 512, 1, dropout)
|
785 |
+
|
786 |
+
def forward(self, pooled_output, targets=None):
|
787 |
+
value_pred = self.value_prediction(pooled_output)
|
788 |
+
outputs = (value_pred,)
|
789 |
+
|
790 |
+
if targets is not None:
|
791 |
+
loss_fct = nn.MSELoss()
|
792 |
+
value_pred_loss = loss_fct(value_pred, targets)
|
793 |
+
outputs = (value_pred_loss,) + outputs
|
794 |
+
return outputs # (loss), value_prediction
|
795 |
+
|
796 |
+
|
797 |
+
class SequenceClassificationHead(nn.Module):
|
798 |
+
def __init__(self, hidden_size: int, num_labels: int):
|
799 |
+
super().__init__()
|
800 |
+
self.classify = SimpleMLP(hidden_size, 512, num_labels)
|
801 |
+
|
802 |
+
def forward(self, pooled_output, targets=None):
|
803 |
+
logits = self.classify(pooled_output)
|
804 |
+
outputs = (logits,)
|
805 |
+
|
806 |
+
if targets is not None:
|
807 |
+
loss_fct = nn.CrossEntropyLoss()
|
808 |
+
classification_loss = loss_fct(logits, targets)
|
809 |
+
metrics = {'accuracy': accuracy(logits, targets)}
|
810 |
+
loss_and_metrics = (classification_loss, metrics)
|
811 |
+
outputs = (loss_and_metrics,) + outputs
|
812 |
+
|
813 |
+
return outputs # (loss), logits
|
814 |
+
|
815 |
+
|
816 |
+
class SequenceToSequenceClassificationHead(nn.Module):
|
817 |
+
|
818 |
+
def __init__(self,
|
819 |
+
hidden_size: int,
|
820 |
+
num_labels: int,
|
821 |
+
ignore_index: int = -100):
|
822 |
+
super().__init__()
|
823 |
+
self.classify = SimpleConv(
|
824 |
+
hidden_size, 512, num_labels)
|
825 |
+
self.num_labels = num_labels
|
826 |
+
self._ignore_index = ignore_index
|
827 |
+
|
828 |
+
def forward(self, sequence_output, targets=None):
|
829 |
+
sequence_logits = self.classify(sequence_output)
|
830 |
+
outputs = (sequence_logits,)
|
831 |
+
if targets is not None:
|
832 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
|
833 |
+
classification_loss = loss_fct(
|
834 |
+
sequence_logits.view(-1, self.num_labels), targets.view(-1))
|
835 |
+
acc_fct = Accuracy(ignore_index=self._ignore_index)
|
836 |
+
metrics = {'accuracy':
|
837 |
+
acc_fct(sequence_logits.view(-1, self.num_labels), targets.view(-1))}
|
838 |
+
loss_and_metrics = (classification_loss, metrics)
|
839 |
+
outputs = (loss_and_metrics,) + outputs
|
840 |
+
return outputs # (loss), sequence_logits
|
841 |
+
|
842 |
+
|
843 |
+
class PairwiseContactPredictionHead(nn.Module):
|
844 |
+
|
845 |
+
def __init__(self, hidden_size: int, ignore_index=-100):
|
846 |
+
super().__init__()
|
847 |
+
self.predict = nn.Sequential(
|
848 |
+
nn.Dropout(), nn.Linear(2 * hidden_size, 2))
|
849 |
+
self._ignore_index = ignore_index
|
850 |
+
|
851 |
+
def forward(self, inputs, sequence_lengths, targets=None):
|
852 |
+
prod = inputs[:, :, None, :] * inputs[:, None, :, :]
|
853 |
+
diff = inputs[:, :, None, :] - inputs[:, None, :, :]
|
854 |
+
pairwise_features = torch.cat((prod, diff), -1)
|
855 |
+
prediction = self.predict(pairwise_features)
|
856 |
+
prediction = (prediction + prediction.transpose(1, 2)) / 2
|
857 |
+
prediction = prediction[:, 1:-1, 1:-1].contiguous() # remove start/stop tokens
|
858 |
+
outputs = (prediction,)
|
859 |
+
|
860 |
+
if targets is not None:
|
861 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
|
862 |
+
contact_loss = loss_fct(
|
863 |
+
prediction.view(-1, 2), targets.view(-1))
|
864 |
+
metrics = {'precision_at_l5':
|
865 |
+
self.compute_precision_at_l5(sequence_lengths, prediction, targets)}
|
866 |
+
loss_and_metrics = (contact_loss, metrics)
|
867 |
+
outputs = (loss_and_metrics,) + outputs
|
868 |
+
|
869 |
+
return outputs
|
870 |
+
|
871 |
+
def compute_precision_at_l5(self, sequence_lengths, prediction, labels):
|
872 |
+
with torch.no_grad():
|
873 |
+
valid_mask = labels != self._ignore_index
|
874 |
+
seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device)
|
875 |
+
x_ind, y_ind = torch.meshgrid(seqpos, seqpos)
|
876 |
+
valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0)
|
877 |
+
probs = F.softmax(prediction, 3)[:, :, :, 1]
|
878 |
+
valid_mask = valid_mask.type_as(probs)
|
879 |
+
correct = 0
|
880 |
+
total = 0
|
881 |
+
for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask):
|
882 |
+
masked_prob = (prob * mask).view(-1)
|
883 |
+
most_likely = masked_prob.topk(length // 5, sorted=False)
|
884 |
+
selected = label.view(-1).gather(0, most_likely.indices)
|
885 |
+
correct += selected.sum().float()
|
886 |
+
total += selected.numel()
|
887 |
+
return correct / total
|
tape/optimization.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Modifications by Roshan Rao
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch optimization for BERT model."""
|
17 |
+
|
18 |
+
import logging
|
19 |
+
import math
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch.optim import Optimizer # type: ignore
|
23 |
+
from torch.optim.lr_scheduler import LambdaLR
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class ConstantLRSchedule(LambdaLR):
|
29 |
+
""" Constant learning rate schedule.
|
30 |
+
"""
|
31 |
+
def __init__(self, optimizer, last_epoch=-1):
|
32 |
+
super(ConstantLRSchedule, self).__init__(
|
33 |
+
optimizer, lambda _: 1.0, last_epoch=last_epoch)
|
34 |
+
|
35 |
+
|
36 |
+
class WarmupConstantSchedule(LambdaLR):
|
37 |
+
""" Linear warmup and then constant.
|
38 |
+
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps`
|
39 |
+
training steps. Keeps learning rate schedule equal to 1. after warmup_steps.
|
40 |
+
"""
|
41 |
+
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
42 |
+
self.warmup_steps = warmup_steps
|
43 |
+
super(WarmupConstantSchedule, self).__init__(
|
44 |
+
optimizer, self.lr_lambda, last_epoch=last_epoch)
|
45 |
+
|
46 |
+
def lr_lambda(self, step):
|
47 |
+
if step < self.warmup_steps:
|
48 |
+
return float(step) / float(max(1.0, self.warmup_steps))
|
49 |
+
return 1.
|
50 |
+
|
51 |
+
|
52 |
+
class WarmupLinearSchedule(LambdaLR):
|
53 |
+
""" Linear warmup and then linear decay.
|
54 |
+
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
55 |
+
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps`
|
56 |
+
steps.
|
57 |
+
"""
|
58 |
+
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
59 |
+
self.warmup_steps = warmup_steps
|
60 |
+
self.t_total = t_total
|
61 |
+
super(WarmupLinearSchedule, self).__init__(
|
62 |
+
optimizer, self.lr_lambda, last_epoch=last_epoch)
|
63 |
+
|
64 |
+
def lr_lambda(self, step):
|
65 |
+
if step < self.warmup_steps:
|
66 |
+
return float(step) / float(max(1, self.warmup_steps))
|
67 |
+
return max(0.0, float(self.t_total - step) / float(
|
68 |
+
max(1.0, self.t_total - self.warmup_steps)))
|
69 |
+
|
70 |
+
|
71 |
+
class WarmupCosineSchedule(LambdaLR):
|
72 |
+
""" Linear warmup and then cosine decay.
|
73 |
+
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
74 |
+
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps
|
75 |
+
following a cosine curve. If `cycles` (default=0.5) is different from default, learning
|
76 |
+
rate follows cosine function after warmup.
|
77 |
+
"""
|
78 |
+
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
79 |
+
self.warmup_steps = warmup_steps
|
80 |
+
self.t_total = t_total
|
81 |
+
self.cycles = cycles
|
82 |
+
super(WarmupCosineSchedule, self).__init__(
|
83 |
+
optimizer, self.lr_lambda, last_epoch=last_epoch)
|
84 |
+
|
85 |
+
def lr_lambda(self, step):
|
86 |
+
if step < self.warmup_steps:
|
87 |
+
return float(step) / float(max(1.0, self.warmup_steps))
|
88 |
+
# progress after warmup
|
89 |
+
progress = float(step - self.warmup_steps) / float(
|
90 |
+
max(1, self.t_total - self.warmup_steps))
|
91 |
+
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|
92 |
+
|
93 |
+
|
94 |
+
class WarmupCosineWithHardRestartsSchedule(LambdaLR):
|
95 |
+
""" Linear warmup and then cosine cycles with hard restarts.
|
96 |
+
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
97 |
+
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times
|
98 |
+
a cosine decaying learning rate (with hard restarts).
|
99 |
+
"""
|
100 |
+
def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
|
101 |
+
self.warmup_steps = warmup_steps
|
102 |
+
self.t_total = t_total
|
103 |
+
self.cycles = cycles
|
104 |
+
super(WarmupCosineWithHardRestartsSchedule, self).__init__(
|
105 |
+
optimizer, self.lr_lambda, last_epoch=last_epoch)
|
106 |
+
|
107 |
+
def lr_lambda(self, step):
|
108 |
+
if step < self.warmup_steps:
|
109 |
+
return float(step) / float(max(1, self.warmup_steps))
|
110 |
+
# progress after warmup
|
111 |
+
progress = float(step - self.warmup_steps) / float(
|
112 |
+
max(1, self.t_total - self.warmup_steps))
|
113 |
+
if progress >= 1.0:
|
114 |
+
return 0.0
|
115 |
+
return max(0.0, 0.5 * (1. + math.cos(
|
116 |
+
math.pi * ((float(self.cycles) * progress) % 1.0))))
|
117 |
+
|
118 |
+
|
119 |
+
class AdamW(Optimizer):
|
120 |
+
""" Implements Adam algorithm with weight decay fix.
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
lr (float): learning rate. Default 1e-3.
|
124 |
+
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
|
125 |
+
eps (float): Adams epsilon. Default: 1e-6
|
126 |
+
weight_decay (float): Weight decay. Default: 0.0
|
127 |
+
correct_bias (bool): can be set to False to avoid correcting bias in Adam
|
128 |
+
(e.g. like in Bert TF repository). Default True.
|
129 |
+
"""
|
130 |
+
def __init__(self,
|
131 |
+
params,
|
132 |
+
lr=1e-3,
|
133 |
+
betas=(0.9, 0.999),
|
134 |
+
eps=1e-6,
|
135 |
+
weight_decay=0.0,
|
136 |
+
correct_bias=True):
|
137 |
+
if lr < 0.0:
|
138 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
139 |
+
if not 0.0 <= betas[0] < 1.0:
|
140 |
+
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
141 |
+
if not 0.0 <= betas[1] < 1.0:
|
142 |
+
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
143 |
+
if not 0.0 <= eps:
|
144 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
145 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
146 |
+
correct_bias=correct_bias)
|
147 |
+
super(AdamW, self).__init__(params, defaults)
|
148 |
+
|
149 |
+
def step(self, closure=None):
|
150 |
+
"""Performs a single optimization step.
|
151 |
+
|
152 |
+
Arguments:
|
153 |
+
closure (callable, optional): A closure that reevaluates the model
|
154 |
+
and returns the loss.
|
155 |
+
"""
|
156 |
+
loss = None
|
157 |
+
if closure is not None:
|
158 |
+
loss = closure()
|
159 |
+
|
160 |
+
for group in self.param_groups:
|
161 |
+
for p in group['params']:
|
162 |
+
if p.grad is None:
|
163 |
+
continue
|
164 |
+
grad = p.grad.data
|
165 |
+
if grad.is_sparse:
|
166 |
+
raise RuntimeError('Adam does not support sparse gradients, '
|
167 |
+
'please consider SparseAdam instead')
|
168 |
+
|
169 |
+
state = self.state[p]
|
170 |
+
|
171 |
+
# State initialization
|
172 |
+
if len(state) == 0:
|
173 |
+
state['step'] = 0
|
174 |
+
# Exponential moving average of gradient values
|
175 |
+
state['exp_avg'] = torch.zeros_like(p.data)
|
176 |
+
# Exponential moving average of squared gradient values
|
177 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
178 |
+
|
179 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
180 |
+
beta1, beta2 = group['betas']
|
181 |
+
|
182 |
+
state['step'] += 1
|
183 |
+
|
184 |
+
# Decay the first and second moment running average coefficient
|
185 |
+
# In-place operations to update the averages at the same time
|
186 |
+
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
|
187 |
+
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
188 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
189 |
+
|
190 |
+
step_size = group['lr']
|
191 |
+
if group['correct_bias']: # No bias correction for Bert
|
192 |
+
bias_correction1 = 1.0 - beta1 ** state['step']
|
193 |
+
bias_correction2 = 1.0 - beta2 ** state['step']
|
194 |
+
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
195 |
+
|
196 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
197 |
+
|
198 |
+
# Just adding the square of the weights to the loss function is *not*
|
199 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
200 |
+
# since that will interact with the m and v parameters in strange ways.
|
201 |
+
#
|
202 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
203 |
+
# with the m/v parameters. This is equivalent to adding the square
|
204 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
205 |
+
# Add weight decay at the end (fixed version)
|
206 |
+
if group['weight_decay'] > 0.0:
|
207 |
+
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
208 |
+
|
209 |
+
return loss
|
tape/registry.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Type, Callable, Optional, Union
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from .models.modeling_utils import ProteinModel
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
PathType = Union[str, Path]
|
7 |
+
|
8 |
+
|
9 |
+
def convert_model_args(model_args):
|
10 |
+
d = {}
|
11 |
+
for e in model_args:
|
12 |
+
k, v = e.split("=")
|
13 |
+
try:
|
14 |
+
v = int(v)
|
15 |
+
except:
|
16 |
+
try:
|
17 |
+
v = float(v)
|
18 |
+
except:
|
19 |
+
v = str(v)
|
20 |
+
d[k] = v
|
21 |
+
return d
|
22 |
+
|
23 |
+
|
24 |
+
class TAPETaskSpec:
|
25 |
+
"""
|
26 |
+
Attributes
|
27 |
+
----------
|
28 |
+
name (str):
|
29 |
+
The name of the TAPE task
|
30 |
+
dataset (Type[Dataset]):
|
31 |
+
The dataset used in the TAPE task
|
32 |
+
num_labels (int):
|
33 |
+
number of labels used if this is a classification task
|
34 |
+
models (Dict[str, ProteinModel]):
|
35 |
+
The set of models that can be used for this task. Default: {}.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
name: str,
|
40 |
+
dataset: Type[Dataset],
|
41 |
+
num_labels: int = -1,
|
42 |
+
models: Optional[Dict[str, Type[ProteinModel]]] = None):
|
43 |
+
self.name = name
|
44 |
+
self.dataset = dataset
|
45 |
+
self.num_labels = num_labels
|
46 |
+
self.models = models if models is not None else {}
|
47 |
+
|
48 |
+
def register_model(self, model_name: str, model_cls: Optional[Type[ProteinModel]] = None):
|
49 |
+
if model_cls is not None:
|
50 |
+
if model_name in self.models:
|
51 |
+
raise KeyError(
|
52 |
+
f"A model with name '{model_name}' is already registered for this task")
|
53 |
+
self.models[model_name] = model_cls
|
54 |
+
return model_cls
|
55 |
+
else:
|
56 |
+
return lambda model_cls: self.register_model(model_name, model_cls)
|
57 |
+
|
58 |
+
def get_model(self, model_name: str) -> Type[ProteinModel]:
|
59 |
+
return self.models[model_name]
|
60 |
+
|
61 |
+
|
62 |
+
class Registry:
|
63 |
+
r"""Class for registry object which acts as the
|
64 |
+
central repository for TAPE."""
|
65 |
+
|
66 |
+
task_name_mapping: Dict[str, TAPETaskSpec] = {}
|
67 |
+
metric_name_mapping: Dict[str, Callable] = {}
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def register_task(cls,
|
71 |
+
task_name: str,
|
72 |
+
num_labels: int = -1,
|
73 |
+
dataset: Optional[Type[Dataset]] = None,
|
74 |
+
models: Optional[Dict[str, Type[ProteinModel]]] = None):
|
75 |
+
""" Register a a new TAPE task. This creates a new TAPETaskSpec.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
|
79 |
+
task_name (str): The name of the TAPE task.
|
80 |
+
num_labels (int): Number of labels used if this is a classification task. If this
|
81 |
+
is not a classification task, simply leave the default as -1.
|
82 |
+
dataset (Type[Dataset]): The dataset used in the TAPE task.
|
83 |
+
models (Optional[Dict[str, ProteinModel]]): The set of models that can be used for
|
84 |
+
this task. If you do not pass this argument, you can register models to the task
|
85 |
+
later by using `registry.register_task_model`. Default: {}.
|
86 |
+
|
87 |
+
Examples:
|
88 |
+
|
89 |
+
There are two ways of registering a new task. First, one can define the task by simply
|
90 |
+
declaring all the components, and then calling the register method, like so:
|
91 |
+
|
92 |
+
class SecondaryStructureDataset(Dataset):
|
93 |
+
...
|
94 |
+
|
95 |
+
class ProteinBertForSequenceToSequenceClassification():
|
96 |
+
...
|
97 |
+
|
98 |
+
registry.register_task(
|
99 |
+
'secondary_structure', 3, SecondaryStructureDataset,
|
100 |
+
{'transformer': ProteinBertForSequenceToSequenceClassification})
|
101 |
+
|
102 |
+
This will register a new task, 'secondary_structure', with a single model. More models
|
103 |
+
can be added with `registry.register_task_model`. Alternatively, this can be used as a
|
104 |
+
decorator:
|
105 |
+
|
106 |
+
@registry.regsiter_task('secondary_structure', 3)
|
107 |
+
class SecondaryStructureDataset(Dataset):
|
108 |
+
...
|
109 |
+
|
110 |
+
@registry.register_task_model('secondary_structure', 'transformer')
|
111 |
+
class ProteinBertForSequenceToSequenceClassification():
|
112 |
+
...
|
113 |
+
|
114 |
+
These two pieces of code are exactly equivalent, in terms of the resulting registry
|
115 |
+
state.
|
116 |
+
|
117 |
+
"""
|
118 |
+
if dataset is not None:
|
119 |
+
if models is None:
|
120 |
+
models = {}
|
121 |
+
task_spec = TAPETaskSpec(task_name, dataset, num_labels, models)
|
122 |
+
return cls.register_task_spec(task_name, task_spec).dataset
|
123 |
+
else:
|
124 |
+
return lambda dataset: cls.register_task(task_name, num_labels, dataset, models)
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def register_task_spec(cls, task_name: str, task_spec: Optional[TAPETaskSpec] = None):
|
128 |
+
""" Registers a task_spec directly. If you find it easier to actually create a
|
129 |
+
TAPETaskSpec manually, and then register it, feel free to use this method,
|
130 |
+
but otherwise it is likely easier to use `registry.register_task`.
|
131 |
+
"""
|
132 |
+
if task_spec is not None:
|
133 |
+
if task_name in cls.task_name_mapping:
|
134 |
+
raise KeyError(f"A task with name '{task_name}' is already registered")
|
135 |
+
cls.task_name_mapping[task_name] = task_spec
|
136 |
+
return task_spec
|
137 |
+
else:
|
138 |
+
return lambda task_spec: cls.register_task_spec(task_name, task_spec)
|
139 |
+
|
140 |
+
@classmethod
|
141 |
+
def register_task_model(cls,
|
142 |
+
task_name: str,
|
143 |
+
model_name: str,
|
144 |
+
model_cls: Optional[Type[ProteinModel]] = None):
|
145 |
+
r"""Register a specific model to a task with the provided model name.
|
146 |
+
The task must already be in the registry - you cannot register a
|
147 |
+
model to an unregistered task.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
task_name (str): Name of task to which to register the model.
|
151 |
+
model_name (str): Name of model to use when registering task, this
|
152 |
+
is the name that you will use to refer to the model on the
|
153 |
+
command line.
|
154 |
+
model_cls (Type[ProteinModel]): The model to register.
|
155 |
+
|
156 |
+
Examples:
|
157 |
+
|
158 |
+
As with `registry.register_task`, this can both be used as a regular
|
159 |
+
python function, and as a decorator. For example this:
|
160 |
+
|
161 |
+
class ProteinBertForSequenceToSequenceClassification():
|
162 |
+
...
|
163 |
+
registry.register_task_model(
|
164 |
+
'secondary_structure', 'transformer',
|
165 |
+
ProteinBertForSequenceToSequenceClassification)
|
166 |
+
|
167 |
+
and as a decorator:
|
168 |
+
|
169 |
+
@registry.register_task_model('secondary_structure', 'transformer')
|
170 |
+
class ProteinBertForSequenceToSequenceClassification():
|
171 |
+
...
|
172 |
+
|
173 |
+
are both equivalent.
|
174 |
+
"""
|
175 |
+
if task_name not in cls.task_name_mapping:
|
176 |
+
raise KeyError(
|
177 |
+
f"Tried to register a task model for an unregistered task: {task_name}. "
|
178 |
+
f"Make sure to register the task {task_name} first.")
|
179 |
+
return cls.task_name_mapping[task_name].register_model(model_name, model_cls)
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def register_metric(cls, name: str) -> Callable[[Callable], Callable]:
|
183 |
+
r"""Register a metric to registry with key 'name'
|
184 |
+
|
185 |
+
Args:
|
186 |
+
name: Key with which the metric will be registered.
|
187 |
+
|
188 |
+
Usage::
|
189 |
+
from tape.registry import registry
|
190 |
+
|
191 |
+
@registry.register_metric('mse')
|
192 |
+
def mean_squred_error(inputs, outputs):
|
193 |
+
...
|
194 |
+
"""
|
195 |
+
|
196 |
+
def wrap(fn: Callable) -> Callable:
|
197 |
+
assert callable(fn), "All metrics must be callable"
|
198 |
+
cls.metric_name_mapping[name] = fn
|
199 |
+
return fn
|
200 |
+
|
201 |
+
return wrap
|
202 |
+
|
203 |
+
@classmethod
|
204 |
+
def get_task_spec(cls, name: str) -> TAPETaskSpec:
|
205 |
+
return cls.task_name_mapping[name]
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def get_metric(cls, name: str) -> Callable:
|
209 |
+
return cls.metric_name_mapping[name]
|
210 |
+
|
211 |
+
@classmethod
|
212 |
+
def get_task_model(cls,
|
213 |
+
model_name: str,
|
214 |
+
task_name: str,
|
215 |
+
config_file: Optional[PathType] = None,
|
216 |
+
load_dir: Optional[PathType] = None,
|
217 |
+
model_args = None) -> ProteinModel:
|
218 |
+
""" Create a TAPE task model, either from scratch or from a pretrained model.
|
219 |
+
This is mostly a helper function that evaluates the if statements in a
|
220 |
+
sensible order if you pass all three of the arguments.
|
221 |
+
Args:
|
222 |
+
model_name (str): Which type of model to create (e.g. transformer, unirep, ...)
|
223 |
+
task_name (str): The TAPE task for which to create a model
|
224 |
+
config_file (str, optional): A json config file that specifies hyperparameters
|
225 |
+
load_dir (str, optional): A save directory for a pretrained model
|
226 |
+
Returns:
|
227 |
+
model (ProteinModel): A TAPE task model
|
228 |
+
"""
|
229 |
+
task_spec = registry.get_task_spec(task_name)
|
230 |
+
model_cls = task_spec.get_model(model_name)
|
231 |
+
|
232 |
+
if load_dir is not None:
|
233 |
+
model = model_cls.from_pretrained(load_dir, num_labels=task_spec.num_labels)
|
234 |
+
else:
|
235 |
+
config_class = model_cls.config_class
|
236 |
+
if config_file is not None:
|
237 |
+
config = config_class.from_json_file(config_file)
|
238 |
+
else:
|
239 |
+
config = config_class()
|
240 |
+
|
241 |
+
if model_args:
|
242 |
+
model_args = convert_model_args(model_args)
|
243 |
+
for k,v in model_args.items():
|
244 |
+
if k in config.__dict__ and type(config.__dict__[k])==type(v):
|
245 |
+
setattr(config, k, v)
|
246 |
+
else:
|
247 |
+
raise ValueError(f"model arg {k} not in config or of the same type as default")
|
248 |
+
|
249 |
+
config.num_labels = task_spec.num_labels
|
250 |
+
model = model_cls(config)
|
251 |
+
return model
|
252 |
+
|
253 |
+
|
254 |
+
registry = Registry()
|
tape/tokenizers.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
IUPAC_CODES = OrderedDict([
|
9 |
+
('Ala', 'A'),
|
10 |
+
('Asx', 'B'),
|
11 |
+
('Cys', 'C'),
|
12 |
+
('Asp', 'D'),
|
13 |
+
('Glu', 'E'),
|
14 |
+
('Phe', 'F'),
|
15 |
+
('Gly', 'G'),
|
16 |
+
('His', 'H'),
|
17 |
+
('Ile', 'I'),
|
18 |
+
('Lys', 'K'),
|
19 |
+
('Leu', 'L'),
|
20 |
+
('Met', 'M'),
|
21 |
+
('Asn', 'N'),
|
22 |
+
('Pro', 'P'),
|
23 |
+
('Gln', 'Q'),
|
24 |
+
('Arg', 'R'),
|
25 |
+
('Ser', 'S'),
|
26 |
+
('Thr', 'T'),
|
27 |
+
('Sec', 'U'),
|
28 |
+
('Val', 'V'),
|
29 |
+
('Trp', 'W'),
|
30 |
+
('Xaa', 'X'),
|
31 |
+
('Tyr', 'Y'),
|
32 |
+
('Glx', 'Z')])
|
33 |
+
|
34 |
+
IUPAC_VOCAB = OrderedDict([
|
35 |
+
("<pad>", 0),
|
36 |
+
("<mask>", 1),
|
37 |
+
("<cls>", 2),
|
38 |
+
("<sep>", 3),
|
39 |
+
("<unk>", 4),
|
40 |
+
("A", 5),
|
41 |
+
("B", 6),
|
42 |
+
("C", 7),
|
43 |
+
("D", 8),
|
44 |
+
("E", 9),
|
45 |
+
("F", 10),
|
46 |
+
("G", 11),
|
47 |
+
("H", 12),
|
48 |
+
("I", 13),
|
49 |
+
("K", 14),
|
50 |
+
("L", 15),
|
51 |
+
("M", 16),
|
52 |
+
("N", 17),
|
53 |
+
("O", 18),
|
54 |
+
("P", 19),
|
55 |
+
("Q", 20),
|
56 |
+
("R", 21),
|
57 |
+
("S", 22),
|
58 |
+
("T", 23),
|
59 |
+
("U", 24),
|
60 |
+
("V", 25),
|
61 |
+
("W", 26),
|
62 |
+
("X", 27),
|
63 |
+
("Y", 28),
|
64 |
+
("Z", 29)])
|
65 |
+
|
66 |
+
UNIREP_VOCAB = OrderedDict([
|
67 |
+
("<pad>", 0),
|
68 |
+
("M", 1),
|
69 |
+
("R", 2),
|
70 |
+
("H", 3),
|
71 |
+
("K", 4),
|
72 |
+
("D", 5),
|
73 |
+
("E", 6),
|
74 |
+
("S", 7),
|
75 |
+
("T", 8),
|
76 |
+
("N", 9),
|
77 |
+
("Q", 10),
|
78 |
+
("C", 11),
|
79 |
+
("U", 12),
|
80 |
+
("G", 13),
|
81 |
+
("P", 14),
|
82 |
+
("A", 15),
|
83 |
+
("V", 16),
|
84 |
+
("I", 17),
|
85 |
+
("F", 18),
|
86 |
+
("Y", 19),
|
87 |
+
("W", 20),
|
88 |
+
("L", 21),
|
89 |
+
("O", 22),
|
90 |
+
("X", 23),
|
91 |
+
("Z", 23),
|
92 |
+
("B", 23),
|
93 |
+
("J", 23),
|
94 |
+
("<cls>", 24),
|
95 |
+
("<sep>", 25)])
|
96 |
+
|
97 |
+
|
98 |
+
class TAPETokenizer():
|
99 |
+
r"""TAPE Tokenizer. Can use different vocabs depending on the model.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, vocab: str = 'iupac'):
|
103 |
+
if vocab == 'iupac':
|
104 |
+
self.vocab = IUPAC_VOCAB
|
105 |
+
elif vocab == 'unirep':
|
106 |
+
self.vocab = UNIREP_VOCAB
|
107 |
+
self.tokens = list(self.vocab.keys())
|
108 |
+
self._vocab_type = vocab
|
109 |
+
assert self.start_token in self.vocab and self.stop_token in self.vocab
|
110 |
+
|
111 |
+
@property
|
112 |
+
def vocab_size(self) -> int:
|
113 |
+
return len(self.vocab)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def start_token(self) -> str:
|
117 |
+
return "<cls>"
|
118 |
+
|
119 |
+
@property
|
120 |
+
def stop_token(self) -> str:
|
121 |
+
return "<sep>"
|
122 |
+
|
123 |
+
@property
|
124 |
+
def mask_token(self) -> str:
|
125 |
+
if "<mask>" in self.vocab:
|
126 |
+
return "<mask>"
|
127 |
+
else:
|
128 |
+
raise RuntimeError(f"{self._vocab_type} vocab does not support masking")
|
129 |
+
|
130 |
+
def tokenize(self, text: str) -> List[str]:
|
131 |
+
return [x for x in text]
|
132 |
+
|
133 |
+
def convert_token_to_id(self, token: str) -> int:
|
134 |
+
""" Converts a token (str/unicode) in an id using the vocab. """
|
135 |
+
try:
|
136 |
+
return self.vocab[token]
|
137 |
+
except KeyError:
|
138 |
+
raise KeyError(f"Unrecognized token: '{token}'")
|
139 |
+
|
140 |
+
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
|
141 |
+
return [self.convert_token_to_id(token) for token in tokens]
|
142 |
+
|
143 |
+
def convert_id_to_token(self, index: int) -> str:
|
144 |
+
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
145 |
+
try:
|
146 |
+
return self.tokens[index]
|
147 |
+
except IndexError:
|
148 |
+
raise IndexError(f"Unrecognized index: '{index}'")
|
149 |
+
|
150 |
+
def convert_ids_to_tokens(self, indices: List[int]) -> List[str]:
|
151 |
+
return [self.convert_id_to_token(id_) for id_ in indices]
|
152 |
+
|
153 |
+
def convert_tokens_to_string(self, tokens: str) -> str:
|
154 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
155 |
+
return ''.join(tokens)
|
156 |
+
|
157 |
+
def add_special_tokens(self, token_ids: List[str]) -> List[str]:
|
158 |
+
"""
|
159 |
+
Adds special tokens to the a sequence for sequence classification tasks.
|
160 |
+
A BERT sequence has the following format: [CLS] X [SEP]
|
161 |
+
"""
|
162 |
+
cls_token = [self.start_token]
|
163 |
+
sep_token = [self.stop_token]
|
164 |
+
return cls_token + token_ids + sep_token
|
165 |
+
|
166 |
+
def encode(self, text: str) -> np.ndarray:
|
167 |
+
tokens = self.tokenize(text)
|
168 |
+
tokens = self.add_special_tokens(tokens)
|
169 |
+
token_ids = self.convert_tokens_to_ids(tokens)
|
170 |
+
return np.array(token_ids, np.int64)
|
171 |
+
|
172 |
+
@classmethod
|
173 |
+
def from_pretrained(cls, **kwargs):
|
174 |
+
return cls()
|
tape/training.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from timeit import default_timer as timer
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import inspect
|
8 |
+
import pickle as pkl
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.optim as optim
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from .optimization import WarmupLinearSchedule
|
16 |
+
|
17 |
+
from . import utils
|
18 |
+
from . import errors
|
19 |
+
from . import visualization
|
20 |
+
from .registry import registry
|
21 |
+
from .models.modeling_utils import ProteinModel
|
22 |
+
|
23 |
+
try:
|
24 |
+
from apex import amp
|
25 |
+
import amp_C
|
26 |
+
import apex_C
|
27 |
+
from apex.amp import _amp_state
|
28 |
+
from apex.parallel.distributed import flat_dist_call
|
29 |
+
from apex.parallel.distributed import DistributedDataParallel as DDP
|
30 |
+
APEX_FOUND = True
|
31 |
+
except ImportError:
|
32 |
+
APEX_FOUND = False
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
MetricsDict = typing.Dict[str, float]
|
37 |
+
LossAndMetrics = typing.Tuple[float, MetricsDict]
|
38 |
+
OutputDict = typing.Dict[str, typing.Any]
|
39 |
+
|
40 |
+
|
41 |
+
class ForwardRunner:
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
model: ProteinModel,
|
45 |
+
device: torch.device = torch.device('cuda:0'),
|
46 |
+
n_gpu: int = 1,
|
47 |
+
fp16: bool = False,
|
48 |
+
local_rank: int = -1):
|
49 |
+
|
50 |
+
self.model = model
|
51 |
+
self.device = device
|
52 |
+
self.n_gpu = n_gpu
|
53 |
+
self.fp16 = fp16
|
54 |
+
self.local_rank = local_rank
|
55 |
+
|
56 |
+
forward_arg_keys = inspect.getfullargspec(model.forward).args
|
57 |
+
forward_arg_keys = forward_arg_keys[1:] # remove self argument
|
58 |
+
self._forward_arg_keys = forward_arg_keys
|
59 |
+
assert 'input_ids' in self._forward_arg_keys
|
60 |
+
|
61 |
+
def initialize_distributed_model(self):
|
62 |
+
if self.local_rank != -1:
|
63 |
+
if not self.fp16:
|
64 |
+
self.model = DDP(self.model)
|
65 |
+
else:
|
66 |
+
flat_dist_call([param.data for param in self.model.parameters()],
|
67 |
+
torch.distributed.broadcast, (0,))
|
68 |
+
elif self.n_gpu > 1:
|
69 |
+
self.model = nn.DataParallel(self.model)
|
70 |
+
|
71 |
+
def forward(self,
|
72 |
+
batch: typing.Dict[str, torch.Tensor],
|
73 |
+
return_outputs: bool = False,
|
74 |
+
no_loss: bool = False):
|
75 |
+
# Filter out batch items that aren't used in this model
|
76 |
+
# Requires that dataset keys match the forward args of the model
|
77 |
+
# Useful if some elements of the data are only used by certain models
|
78 |
+
# e.g. PSSMs / MSAs and other evolutionary data
|
79 |
+
batch = {name: tensor for name, tensor in batch.items()
|
80 |
+
if name in self._forward_arg_keys}
|
81 |
+
if self.device.type == 'cuda':
|
82 |
+
batch = {name: tensor.cuda(device=self.device, non_blocking=True)
|
83 |
+
for name, tensor in batch.items()}
|
84 |
+
|
85 |
+
outputs = self.model(**batch)
|
86 |
+
|
87 |
+
if no_loss:
|
88 |
+
return outputs
|
89 |
+
|
90 |
+
if isinstance(outputs[0], tuple):
|
91 |
+
# model also returned metrics
|
92 |
+
loss, metrics = outputs[0]
|
93 |
+
else:
|
94 |
+
# no metrics
|
95 |
+
loss = outputs[0]
|
96 |
+
metrics = {}
|
97 |
+
|
98 |
+
if self.n_gpu > 1: # pytorch DataDistributed doesn't mean scalars
|
99 |
+
loss = loss.mean()
|
100 |
+
metrics = {name: metric.mean() for name, metric in metrics.items()}
|
101 |
+
|
102 |
+
if return_outputs:
|
103 |
+
return loss, metrics, outputs
|
104 |
+
else:
|
105 |
+
return loss, metrics
|
106 |
+
|
107 |
+
def train(self):
|
108 |
+
self.model.train()
|
109 |
+
return self
|
110 |
+
|
111 |
+
def eval(self):
|
112 |
+
self.model.eval()
|
113 |
+
return self
|
114 |
+
|
115 |
+
|
116 |
+
class BackwardRunner(ForwardRunner):
|
117 |
+
|
118 |
+
def __init__(self,
|
119 |
+
model: ProteinModel,
|
120 |
+
optimizer: optim.Optimizer, # type: ignore
|
121 |
+
gradient_accumulation_steps: int = 1,
|
122 |
+
device: torch.device = torch.device('cuda:0'),
|
123 |
+
n_gpu: int = 1,
|
124 |
+
fp16: bool = False,
|
125 |
+
local_rank: int = -1,
|
126 |
+
max_grad_norm: float = 1.0,
|
127 |
+
warmup_steps: int = 0,
|
128 |
+
num_train_optimization_steps: int = 1000000):
|
129 |
+
|
130 |
+
super().__init__(model, device, n_gpu, fp16, local_rank)
|
131 |
+
self.optimizer = optimizer
|
132 |
+
self.max_grad_norm = max_grad_norm
|
133 |
+
self._global_step = 0
|
134 |
+
self._local_rank = local_rank
|
135 |
+
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
|
136 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
137 |
+
self._delay_accumulation = fp16 and local_rank != -1
|
138 |
+
|
139 |
+
self.scheduler = WarmupLinearSchedule(
|
140 |
+
self.optimizer, warmup_steps, num_train_optimization_steps)
|
141 |
+
|
142 |
+
def initialize_fp16(self):
|
143 |
+
if self.fp16:
|
144 |
+
self.model, self.optimizer = amp.initialize(
|
145 |
+
self.model, self.optimizer, opt_level="O2", loss_scale="dynamic",
|
146 |
+
master_weights=True)
|
147 |
+
_amp_state.loss_scalers[0]._loss_scale = 2 ** 20
|
148 |
+
|
149 |
+
def resume_from_checkpoint(self, checkpoint_dir: str) -> int:
|
150 |
+
checkpoint = torch.load(
|
151 |
+
os.path.join(checkpoint_dir, 'checkpoint.bin'), map_location=self.device)
|
152 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
153 |
+
if self.fp16:
|
154 |
+
self.optimizer._lazy_init_maybe_master_weights()
|
155 |
+
self.optimizer._amp_stash.lazy_init_called = True
|
156 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
157 |
+
for param, saved in zip(
|
158 |
+
amp.master_params(self.optimizer), checkpoint['master params']):
|
159 |
+
param.data.copy_(saved.data)
|
160 |
+
amp.load_state_dict(checkpoint['amp'])
|
161 |
+
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
162 |
+
start_epoch = checkpoint['epoch'] + 1
|
163 |
+
return start_epoch
|
164 |
+
|
165 |
+
def save_state(self, save_directory: typing.Union[str, Path], epoch_id: int):
|
166 |
+
save_directory = Path(save_directory)
|
167 |
+
if not save_directory.exists():
|
168 |
+
save_directory.mkdir()
|
169 |
+
else:
|
170 |
+
assert save_directory.is_dir(), "Save path should be a directory"
|
171 |
+
model_to_save = getattr(self.model, 'module', self.model)
|
172 |
+
model_to_save.save_pretrained(save_directory)
|
173 |
+
optimizer_state: typing.Dict[str, typing.Any] = {
|
174 |
+
'optimizer': self.optimizer.state_dict(),
|
175 |
+
'scheduler': self.scheduler.state_dict(),
|
176 |
+
'epoch': epoch_id}
|
177 |
+
if APEX_FOUND:
|
178 |
+
optimizer_state['master params'] = list(amp.master_params(self.optimizer))
|
179 |
+
try:
|
180 |
+
optimizer_state['amp'] = amp.state_dict()
|
181 |
+
except AttributeError:
|
182 |
+
pass
|
183 |
+
torch.save(optimizer_state, save_directory / 'checkpoint.bin')
|
184 |
+
|
185 |
+
def backward(self, loss) -> None:
|
186 |
+
if not self._delay_accumulation:
|
187 |
+
loss = loss / self.gradient_accumulation_steps
|
188 |
+
if self.fp16:
|
189 |
+
with amp.scale_loss(loss, self.optimizer,
|
190 |
+
delay_overflow_check=self._delay_accumulation) as scaled_loss:
|
191 |
+
scaled_loss.backward()
|
192 |
+
else:
|
193 |
+
loss.backward()
|
194 |
+
|
195 |
+
def step(self) -> None:
|
196 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
197 |
+
if self._local_rank == -1:
|
198 |
+
self._step()
|
199 |
+
elif not self.fp16:
|
200 |
+
# TODO: Can you do this allreduce after accumulation also?
|
201 |
+
self._step()
|
202 |
+
else:
|
203 |
+
self._step_distributed_fp16()
|
204 |
+
|
205 |
+
def _step(self) -> None:
|
206 |
+
self.optimizer.step()
|
207 |
+
if self.scheduler is not None:
|
208 |
+
self.scheduler.step() # type: ignore
|
209 |
+
self._global_step += 1
|
210 |
+
|
211 |
+
def _step_distributed_fp16(self) -> None:
|
212 |
+
# manually allreduce gradients after all accumulation steps
|
213 |
+
# check for Inf/NaN
|
214 |
+
# 1. allocate an uninitialized buffer for flattened gradient
|
215 |
+
scaler = _amp_state.loss_scalers[0]
|
216 |
+
master_grads = [p.grad for p in amp.master_params(self.optimizer) if p.grad is not None]
|
217 |
+
flat_grad_size = sum(p.numel() for p in master_grads)
|
218 |
+
# allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else \
|
219 |
+
# torch.float32
|
220 |
+
allreduce_dtype = torch.float16
|
221 |
+
flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
|
222 |
+
# 2. combine unflattening and predivision of unscaled 'raw' gradient
|
223 |
+
allreduced_views = apex_C.unflatten(flat_raw, master_grads)
|
224 |
+
self._overflow_buf.zero_()
|
225 |
+
amp_C.multi_tensor_scale(
|
226 |
+
65536,
|
227 |
+
self._overflow_buf,
|
228 |
+
[master_grads, allreduced_views],
|
229 |
+
scaler.loss_scale() / (
|
230 |
+
torch.distributed.get_world_size() * self.gradient_accumulation_steps))
|
231 |
+
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient
|
232 |
+
torch.distributed.all_reduce(flat_raw)
|
233 |
+
# 4. combine unscaling and unflattening of allreduced gradient
|
234 |
+
self._overflow_buf.zero_()
|
235 |
+
amp_C.multi_tensor_scale(
|
236 |
+
65536,
|
237 |
+
self._overflow_buf,
|
238 |
+
[allreduced_views, master_grads],
|
239 |
+
1. / scaler.loss_scale())
|
240 |
+
# 5. update loss scale
|
241 |
+
scaler = _amp_state.loss_scalers[0]
|
242 |
+
old_overflow_buf = scaler._overflow_buf
|
243 |
+
scaler._overflow_buf = self._overflow_buf
|
244 |
+
had_overflow = scaler.update_scale()
|
245 |
+
scaler._overfloat_buf = old_overflow_buf
|
246 |
+
# 6. call optimizer step function
|
247 |
+
if had_overflow == 0:
|
248 |
+
self._step()
|
249 |
+
else:
|
250 |
+
# Overflow detected, print message and clear gradients
|
251 |
+
logger.info(f"Gradient overflow. Skipping step, reducing loss scale to "
|
252 |
+
f"{scaler.loss_scale()}")
|
253 |
+
if _amp_state.opt_properties.master_weights:
|
254 |
+
for param in self.optimizer._amp_stash.all_fp32_from_fp16_params:
|
255 |
+
param.grad = None
|
256 |
+
for param in self.model.parameters():
|
257 |
+
param.grad = None
|
258 |
+
|
259 |
+
@property
|
260 |
+
def global_step(self) -> int:
|
261 |
+
return self._global_step
|
262 |
+
|
263 |
+
|
264 |
+
def run_train_epoch(epoch_id: int,
|
265 |
+
train_loader: DataLoader,
|
266 |
+
runner: BackwardRunner,
|
267 |
+
viz: typing.Optional[visualization.TAPEVisualizer] = None,
|
268 |
+
num_log_iter: int = 20,
|
269 |
+
gradient_accumulation_steps: int = 1,
|
270 |
+
num_steps_per_epoch: int = -1) -> LossAndMetrics:
|
271 |
+
if viz is None:
|
272 |
+
viz = visualization.DummyVisualizer()
|
273 |
+
smoothing = 1 - 1 / num_log_iter
|
274 |
+
accumulator = utils.MetricsAccumulator(smoothing)
|
275 |
+
|
276 |
+
torch.set_grad_enabled(True)
|
277 |
+
runner.train()
|
278 |
+
|
279 |
+
def make_log_str(step: int, time: float) -> str:
|
280 |
+
ep_percent = epoch_id + step / len(train_loader)
|
281 |
+
if runner.scheduler is not None:
|
282 |
+
curr_lr = runner.scheduler.get_lr()[0] # type: ignore
|
283 |
+
else:
|
284 |
+
curr_lr = runner.optimizer.param_groups[0]['lr']
|
285 |
+
|
286 |
+
print_str = []
|
287 |
+
print_str.append(f"[Ep: {ep_percent:.2f}]")
|
288 |
+
print_str.append(f"[Iter: {runner.global_step}]")
|
289 |
+
print_str.append(f"[Time: {time:5.2f}s]")
|
290 |
+
print_str.append(f"[Loss: {accumulator.loss():.5g}]")
|
291 |
+
|
292 |
+
for name, value in accumulator.metrics().items():
|
293 |
+
print_str.append(f"[{name.capitalize()}: {value:.5g}]")
|
294 |
+
|
295 |
+
print_str.append(f"[LR: {curr_lr:.5g}]")
|
296 |
+
return ''.join(print_str)
|
297 |
+
|
298 |
+
start_t = timer()
|
299 |
+
for step, batch in enumerate(train_loader):
|
300 |
+
loss, metrics = runner.forward(batch) # type: ignore
|
301 |
+
runner.backward(loss)
|
302 |
+
accumulator.update(loss, metrics, step=False)
|
303 |
+
if (step + 1) % gradient_accumulation_steps == 0:
|
304 |
+
runner.step()
|
305 |
+
viz.log_metrics(accumulator.step(), "train", runner.global_step)
|
306 |
+
if runner.global_step % num_log_iter == 0:
|
307 |
+
end_t = timer()
|
308 |
+
logger.info(make_log_str(step, end_t - start_t))
|
309 |
+
start_t = end_t
|
310 |
+
if num_steps_per_epoch != -1 and (step + 1) > num_steps_per_epoch:
|
311 |
+
break
|
312 |
+
|
313 |
+
final_print_str = f"Train: [Loss: {accumulator.final_loss():.5g}]"
|
314 |
+
for name, value in accumulator.final_metrics().items():
|
315 |
+
final_print_str += f"[{name.capitalize()}: {value:.5g}]"
|
316 |
+
logger.info(final_print_str)
|
317 |
+
return accumulator.final_loss(), accumulator.final_metrics()
|
318 |
+
|
319 |
+
|
320 |
+
def run_valid_epoch(epoch_id: int,
|
321 |
+
valid_loader: DataLoader,
|
322 |
+
runner: ForwardRunner,
|
323 |
+
viz: typing.Optional[visualization.TAPEVisualizer] = None,
|
324 |
+
is_master: bool = True,
|
325 |
+
val_check_frac: float = 1.0) -> typing.Tuple[float, typing.Dict[str, float]]:
|
326 |
+
num_batches = len(valid_loader)
|
327 |
+
num_batches_to_run = int(num_batches * val_check_frac)
|
328 |
+
accumulator = utils.MetricsAccumulator()
|
329 |
+
|
330 |
+
torch.set_grad_enabled(False)
|
331 |
+
runner.eval()
|
332 |
+
|
333 |
+
for idx, batch in enumerate(tqdm(valid_loader, desc='Running Eval', total=num_batches_to_run,
|
334 |
+
disable=not is_master, leave=False)):
|
335 |
+
loss, metrics = runner.forward(batch) # type: ignore
|
336 |
+
accumulator.update(loss, metrics)
|
337 |
+
if idx>num_batches_to_run:
|
338 |
+
break
|
339 |
+
|
340 |
+
# Reduce loss across all processes if multiprocessing
|
341 |
+
eval_loss = utils.reduce_scalar(accumulator.final_loss())
|
342 |
+
metrics = {name: utils.reduce_scalar(value)
|
343 |
+
for name, value in accumulator.final_metrics().items()}
|
344 |
+
|
345 |
+
print_str = f"Evaluation: [Loss: {eval_loss:.5g}]"
|
346 |
+
for name, value in metrics.items():
|
347 |
+
print_str += f"[{name.capitalize()}: {value:.5g}]"
|
348 |
+
|
349 |
+
metrics['loss'] = eval_loss
|
350 |
+
if viz is not None:
|
351 |
+
viz.log_metrics(metrics, "val", getattr(runner, 'global_step', epoch_id))
|
352 |
+
|
353 |
+
logger.info(print_str)
|
354 |
+
|
355 |
+
return eval_loss, metrics
|
356 |
+
|
357 |
+
|
358 |
+
def _get_outputs_to_save(batch, outputs):
|
359 |
+
targets = batch['targets'].cpu().numpy()
|
360 |
+
outputs = outputs.cpu().numpy()
|
361 |
+
protein_length = batch['protein_length'].sum(1).cpu().numpy()
|
362 |
+
|
363 |
+
reshaped_output = []
|
364 |
+
for target, output, plength in zip(targets, outputs, protein_length):
|
365 |
+
output_slices = tuple(slice(1, plength - 1) if dim == protein_length.max() else
|
366 |
+
slice(0, dim) for dim in output.shape)
|
367 |
+
output = output[output_slices]
|
368 |
+
target = target[output_slices]
|
369 |
+
|
370 |
+
reshaped_output.append((target, output))
|
371 |
+
reshaped_output
|
372 |
+
|
373 |
+
|
374 |
+
def run_eval_epoch(eval_loader: DataLoader,
|
375 |
+
runner: ForwardRunner,
|
376 |
+
is_master: bool = True) -> typing.List[typing.Dict[str, typing.Any]]:
|
377 |
+
torch.set_grad_enabled(False)
|
378 |
+
runner.eval()
|
379 |
+
|
380 |
+
save_outputs = []
|
381 |
+
|
382 |
+
for batch in tqdm(eval_loader, desc='Evaluation', total=len(eval_loader),
|
383 |
+
disable=not is_master):
|
384 |
+
loss, metrics, outputs = runner.forward(batch, return_outputs=True) # type: ignore
|
385 |
+
predictions = outputs[1].cpu().numpy()
|
386 |
+
targets = batch['targets'].cpu().numpy()
|
387 |
+
for pred, target in zip(predictions, targets):
|
388 |
+
save_outputs.append({'prediction': pred, 'target': target})
|
389 |
+
|
390 |
+
return save_outputs
|
391 |
+
|
392 |
+
|
393 |
+
def run_train(model_type: str,
|
394 |
+
task: str,
|
395 |
+
learning_rate: float = 1e-4,
|
396 |
+
batch_size: int = 1024,
|
397 |
+
num_train_epochs: int = 10,
|
398 |
+
num_log_iter: int = 20,
|
399 |
+
fp16: bool = False,
|
400 |
+
warmup_steps: int = 10000,
|
401 |
+
gradient_accumulation_steps: int = 1,
|
402 |
+
loss_scale: int = 0,
|
403 |
+
max_grad_norm: float = 1.0,
|
404 |
+
exp_name: typing.Optional[str] = None,
|
405 |
+
from_pretrained: typing.Optional[str] = None,
|
406 |
+
log_dir: str = './logs',
|
407 |
+
eval_freq: int = 1,
|
408 |
+
save_freq: typing.Union[int, str] = 1,
|
409 |
+
model_config_file: typing.Optional[str] = None,
|
410 |
+
data_dir: str = './data',
|
411 |
+
output_dir: str = './results',
|
412 |
+
no_cuda: bool = False,
|
413 |
+
seed: int = 42,
|
414 |
+
local_rank: int = -1,
|
415 |
+
tokenizer: str = 'iupac',
|
416 |
+
num_workers: int = 8,
|
417 |
+
debug: bool = False,
|
418 |
+
log_level: typing.Union[str, int] = logging.INFO,
|
419 |
+
patience: int = -1,
|
420 |
+
resume_from_checkpoint: bool = False,
|
421 |
+
model_args = None,
|
422 |
+
num_steps_per_epoch: int = -1,
|
423 |
+
val_check_frac: float = 1.0) -> None:
|
424 |
+
|
425 |
+
# SETUP AND LOGGING CODE #
|
426 |
+
input_args = locals()
|
427 |
+
device, n_gpu, is_master = utils.setup_distributed(
|
428 |
+
local_rank, no_cuda)
|
429 |
+
|
430 |
+
exp_dir = utils.get_expname(exp_name, task, model_type)
|
431 |
+
save_path = Path(output_dir) / exp_dir
|
432 |
+
|
433 |
+
if is_master:
|
434 |
+
# save all the hidden parameters.
|
435 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
436 |
+
with (save_path / 'args.json').open('w') as f:
|
437 |
+
json.dump(input_args, f)
|
438 |
+
|
439 |
+
utils.barrier_if_distributed()
|
440 |
+
utils.setup_logging(local_rank, save_path, log_level)
|
441 |
+
utils.set_random_seeds(seed, n_gpu)
|
442 |
+
|
443 |
+
train_dataset = utils.setup_dataset(task, data_dir, 'train', tokenizer)
|
444 |
+
valid_dataset = utils.setup_dataset(task, data_dir, 'valid', tokenizer)
|
445 |
+
train_loader = utils.setup_loader(
|
446 |
+
train_dataset, batch_size, local_rank, n_gpu,
|
447 |
+
gradient_accumulation_steps, num_workers)
|
448 |
+
valid_loader = utils.setup_loader(
|
449 |
+
valid_dataset, batch_size, local_rank, n_gpu,
|
450 |
+
gradient_accumulation_steps, num_workers)
|
451 |
+
|
452 |
+
num_train_optimization_steps = utils.get_num_train_optimization_steps(
|
453 |
+
train_dataset, batch_size, num_train_epochs)
|
454 |
+
|
455 |
+
model = registry.get_task_model(model_type, task, model_config_file, from_pretrained, model_args)
|
456 |
+
model = model.to(device)
|
457 |
+
optimizer = utils.setup_optimizer(model, learning_rate)
|
458 |
+
viz = visualization.get(log_dir, exp_dir, local_rank, debug=debug)
|
459 |
+
viz.log_config(input_args)
|
460 |
+
viz.log_config(model.config.to_dict())
|
461 |
+
viz.watch(model)
|
462 |
+
|
463 |
+
logger.info(
|
464 |
+
f"device: {device} "
|
465 |
+
f"n_gpu: {n_gpu}, "
|
466 |
+
f"distributed_training: {local_rank != -1}, "
|
467 |
+
f"16-bits training: {fp16}")
|
468 |
+
|
469 |
+
runner = BackwardRunner(
|
470 |
+
model, optimizer, gradient_accumulation_steps, device, n_gpu,
|
471 |
+
fp16, local_rank, max_grad_norm, warmup_steps, num_train_optimization_steps)
|
472 |
+
|
473 |
+
runner.initialize_fp16()
|
474 |
+
if resume_from_checkpoint:
|
475 |
+
assert from_pretrained is not None
|
476 |
+
start_epoch = runner.resume_from_checkpoint(from_pretrained)
|
477 |
+
else:
|
478 |
+
start_epoch = 0
|
479 |
+
runner.initialize_distributed_model()
|
480 |
+
|
481 |
+
num_train_optimization_steps = utils.get_num_train_optimization_steps(
|
482 |
+
train_dataset, batch_size, num_train_epochs)
|
483 |
+
is_master = local_rank in (-1, 0)
|
484 |
+
|
485 |
+
if isinstance(save_freq, str) and save_freq != 'improvement':
|
486 |
+
raise ValueError(
|
487 |
+
f"Only recongized string value for save_freq is 'improvement'"
|
488 |
+
f", received: {save_freq}")
|
489 |
+
|
490 |
+
if save_freq == 'improvement' and eval_freq <= 0:
|
491 |
+
raise ValueError("Cannot set save_freq to 'improvement' and eval_freq < 0")
|
492 |
+
|
493 |
+
num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
494 |
+
logger.info("***** Running training *****")
|
495 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
496 |
+
logger.info(" Batch size = %d", batch_size)
|
497 |
+
logger.info(" Num epochs = %d", num_train_epochs)
|
498 |
+
logger.info(" Num train steps = %d", num_train_optimization_steps)
|
499 |
+
logger.info(" Num parameters = %d", num_trainable_parameters)
|
500 |
+
|
501 |
+
best_val_loss = float('inf')
|
502 |
+
num_evals_no_improvement = 0
|
503 |
+
|
504 |
+
def do_save(epoch_id: int, num_evals_no_improvement: int) -> bool:
|
505 |
+
if not is_master:
|
506 |
+
return False
|
507 |
+
if isinstance(save_freq, int):
|
508 |
+
return ((epoch_id + 1) % save_freq == 0) or ((epoch_id + 1) == num_train_epochs)
|
509 |
+
else:
|
510 |
+
return num_evals_no_improvement == 0
|
511 |
+
|
512 |
+
utils.barrier_if_distributed()
|
513 |
+
|
514 |
+
# ACTUAL TRAIN/EVAL LOOP #
|
515 |
+
with utils.wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation_steps):
|
516 |
+
for epoch_id in range(start_epoch, num_train_epochs):
|
517 |
+
run_train_epoch(epoch_id, train_loader, runner,
|
518 |
+
viz, num_log_iter, gradient_accumulation_steps, num_steps_per_epoch)
|
519 |
+
if eval_freq > 0 and (epoch_id + 1) % eval_freq == 0:
|
520 |
+
val_loss, _ = run_valid_epoch(epoch_id, valid_loader, runner, viz, is_master, val_check_frac)
|
521 |
+
if val_loss < best_val_loss:
|
522 |
+
best_val_loss = val_loss
|
523 |
+
num_evals_no_improvement = 0
|
524 |
+
else:
|
525 |
+
num_evals_no_improvement += 1
|
526 |
+
|
527 |
+
# Save trained model
|
528 |
+
if do_save(epoch_id, num_evals_no_improvement):
|
529 |
+
logger.info("** ** * Saving trained model ** ** * ")
|
530 |
+
# Only save the model itself
|
531 |
+
runner.save_state(save_path, epoch_id)
|
532 |
+
logger.info(f"Saving model checkpoint to {save_path}")
|
533 |
+
|
534 |
+
utils.barrier_if_distributed()
|
535 |
+
if patience > 0 and num_evals_no_improvement >= patience:
|
536 |
+
logger.info(f"Finished training at epoch {epoch_id} because no "
|
537 |
+
f"improvement for {num_evals_no_improvement} epochs.")
|
538 |
+
logger.log(35, f"Best Val Loss: {best_val_loss}")
|
539 |
+
if local_rank != -1:
|
540 |
+
# If you're distributed, raise this error. It sends a signal to
|
541 |
+
# the master process which lets it kill other processes and terminate
|
542 |
+
# without actually reporting an error. See utils/distributed_utils.py
|
543 |
+
# for the signal handling code.
|
544 |
+
raise errors.EarlyStopping
|
545 |
+
else:
|
546 |
+
break
|
547 |
+
logger.info(f"Finished training after {num_train_epochs} epochs.")
|
548 |
+
if best_val_loss != float('inf'):
|
549 |
+
logger.log(35, f"Best Val Loss: {best_val_loss}")
|
550 |
+
|
551 |
+
|
552 |
+
def run_eval(model_type: str,
|
553 |
+
task: str,
|
554 |
+
from_pretrained: str,
|
555 |
+
split: str = 'test',
|
556 |
+
batch_size: int = 1024,
|
557 |
+
model_config_file: typing.Optional[str] = None,
|
558 |
+
data_dir: str = './data',
|
559 |
+
no_cuda: bool = False,
|
560 |
+
seed: int = 42,
|
561 |
+
tokenizer: str = 'iupac',
|
562 |
+
num_workers: int = 8,
|
563 |
+
debug: bool = False,
|
564 |
+
metrics: typing.Tuple[str, ...] = (),
|
565 |
+
log_level: typing.Union[str, int] = logging.INFO) -> typing.Dict[str, float]:
|
566 |
+
|
567 |
+
local_rank = -1 # TAPE does not support torch.distributed.launch for evaluation
|
568 |
+
device, n_gpu, is_master = utils.setup_distributed(local_rank, no_cuda)
|
569 |
+
utils.setup_logging(local_rank, save_path=None, log_level=log_level)
|
570 |
+
utils.set_random_seeds(seed, n_gpu)
|
571 |
+
|
572 |
+
pretrained_dir = Path(from_pretrained)
|
573 |
+
|
574 |
+
logger.info(
|
575 |
+
f"device: {device} "
|
576 |
+
f"n_gpu: {n_gpu}")
|
577 |
+
|
578 |
+
model = registry.get_task_model(model_type, task, model_config_file, from_pretrained)
|
579 |
+
model = model.to(device)
|
580 |
+
|
581 |
+
runner = ForwardRunner(model, device, n_gpu)
|
582 |
+
runner.initialize_distributed_model()
|
583 |
+
valid_dataset = utils.setup_dataset(task, data_dir, split, tokenizer)
|
584 |
+
valid_loader = utils.setup_loader(
|
585 |
+
valid_dataset, batch_size, local_rank, n_gpu,
|
586 |
+
1, num_workers)
|
587 |
+
|
588 |
+
metric_functions = [registry.get_metric(name) for name in metrics]
|
589 |
+
save_outputs = run_eval_epoch(valid_loader, runner, is_master)
|
590 |
+
target = [el['target'] for el in save_outputs]
|
591 |
+
prediction = [el['prediction'] for el in save_outputs]
|
592 |
+
|
593 |
+
metrics_to_save = {name: metric(target, prediction)
|
594 |
+
for name, metric in zip(metrics, metric_functions)}
|
595 |
+
logger.info(''.join(f'{name}: {val}' for name, val in metrics_to_save.items()))
|
596 |
+
|
597 |
+
with (pretrained_dir / 'results.pkl').open('wb') as f:
|
598 |
+
pkl.dump((metrics_to_save, save_outputs), f)
|
599 |
+
|
600 |
+
return metrics_to_save
|
601 |
+
|
602 |
+
|
603 |
+
def run_embed(model_type: str,
|
604 |
+
data_file: str,
|
605 |
+
out_file: str,
|
606 |
+
from_pretrained: str,
|
607 |
+
batch_size: int = 1024,
|
608 |
+
model_config_file: typing.Optional[str] = None,
|
609 |
+
full_sequence_embed: bool = False,
|
610 |
+
no_cuda: bool = False,
|
611 |
+
seed: int = 42,
|
612 |
+
tokenizer: str = 'iupac',
|
613 |
+
num_workers: int = 8,
|
614 |
+
log_level: typing.Union[str, int] = logging.INFO) -> None:
|
615 |
+
|
616 |
+
local_rank = -1 # TAPE does not support torch.distributed.launch for embedding
|
617 |
+
device, n_gpu, is_master = utils.setup_distributed(local_rank, no_cuda)
|
618 |
+
utils.setup_logging(local_rank, save_path=None, log_level=log_level)
|
619 |
+
utils.set_random_seeds(seed, n_gpu)
|
620 |
+
|
621 |
+
logger.info(
|
622 |
+
f"device: {device} "
|
623 |
+
f"n_gpu: {n_gpu}")
|
624 |
+
|
625 |
+
task_spec = registry.get_task_spec('embed')
|
626 |
+
model = registry.get_task_model(
|
627 |
+
model_type, task_spec.name, model_config_file, from_pretrained)
|
628 |
+
model = model.to(device)
|
629 |
+
runner = ForwardRunner(model, device, n_gpu)
|
630 |
+
runner.initialize_distributed_model()
|
631 |
+
runner.eval()
|
632 |
+
torch.set_grad_enabled(False)
|
633 |
+
|
634 |
+
dataset = task_spec.dataset(data_file, tokenizer=tokenizer) # type: ignore
|
635 |
+
valid_loader = utils.setup_loader(dataset, batch_size, local_rank, n_gpu, 1, num_workers)
|
636 |
+
|
637 |
+
with utils.IncrementalNPZ(out_file) as npzfile:
|
638 |
+
with utils.wrap_cuda_oom_error(local_rank, batch_size, n_gpu):
|
639 |
+
for batch in tqdm(valid_loader, total=len(valid_loader)):
|
640 |
+
outputs = runner.forward(batch, no_loss=True)
|
641 |
+
ids = batch['ids']
|
642 |
+
sequence_embed = outputs[0]
|
643 |
+
pooled_embed = outputs[1]
|
644 |
+
sequence_lengths = batch['input_mask'].sum(1)
|
645 |
+
sequence_embed = sequence_embed.cpu().numpy()
|
646 |
+
pooled_embed = pooled_embed.cpu().numpy()
|
647 |
+
sequence_lengths = sequence_lengths.cpu().numpy()
|
648 |
+
|
649 |
+
for seqembed, poolembed, length, protein_id in zip(
|
650 |
+
sequence_embed, pooled_embed, sequence_lengths, ids):
|
651 |
+
seqembed = seqembed[:length]
|
652 |
+
arrays = {'pooled': poolembed}
|
653 |
+
if not full_sequence_embed:
|
654 |
+
# avgpool across the sequence
|
655 |
+
arrays['avg'] = seqembed.mean(0)
|
656 |
+
else:
|
657 |
+
arrays['seq'] = seqembed
|
658 |
+
to_save = {protein_id: arrays}
|
659 |
+
npzfile.savez(**to_save)
|
tape/utils/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import int_or_str # noqa: F401
|
2 |
+
from .utils import check_is_file # noqa: F401
|
3 |
+
from .utils import check_is_dir # noqa: F401
|
4 |
+
from .utils import path_to_datetime # noqa: F401
|
5 |
+
from .utils import get_expname # noqa: F401
|
6 |
+
from .utils import get_effective_num_gpus # noqa: F401
|
7 |
+
from .utils import get_effective_batch_size # noqa: F401
|
8 |
+
from .utils import get_num_train_optimization_steps # noqa: F401
|
9 |
+
from .utils import set_random_seeds # noqa: F401
|
10 |
+
from .utils import MetricsAccumulator # noqa: F401
|
11 |
+
from .utils import wrap_cuda_oom_error # noqa: F401
|
12 |
+
from .utils import write_lmdb # noqa: F401
|
13 |
+
from .utils import IncrementalNPZ # noqa: F401
|
14 |
+
|
15 |
+
from .setup_utils import setup_logging # noqa: F401
|
16 |
+
from .setup_utils import setup_optimizer # noqa: F401
|
17 |
+
from .setup_utils import setup_dataset # noqa: F401
|
18 |
+
from .setup_utils import setup_loader # noqa: F401
|
19 |
+
from .setup_utils import setup_distributed # noqa: F401
|
20 |
+
|
21 |
+
from .distributed_utils import barrier_if_distributed # noqa: F401
|
22 |
+
from .distributed_utils import reduce_scalar # noqa: F401
|
23 |
+
from .distributed_utils import launch_process_group # noqa: F401
|
tape/utils/_sampler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of a bucketed data sampler from PyTorch-NLP.
|
2 |
+
Modified by Roshan Rao.
|
3 |
+
|
4 |
+
See https://github.com/PetrochukM/PyTorch-NLP/
|
5 |
+
"""
|
6 |
+
import typing
|
7 |
+
import math
|
8 |
+
import operator
|
9 |
+
from torch.utils.data.sampler import Sampler
|
10 |
+
from torch.utils.data.sampler import BatchSampler
|
11 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
12 |
+
|
13 |
+
|
14 |
+
class SortedSampler(Sampler):
|
15 |
+
""" Samples elements sequentially, always in the same order.
|
16 |
+
Args:
|
17 |
+
data (iterable): Iterable data.
|
18 |
+
sort_key (callable): Specifies a function of one argument that is used to extract a
|
19 |
+
numerical comparison key from each list element.
|
20 |
+
Example:
|
21 |
+
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
22 |
+
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
dataset,
|
27 |
+
sort_key: typing.Callable[[int], typing.Any],
|
28 |
+
indices: typing.Optional[typing.Iterable[int]] = None):
|
29 |
+
super().__init__(dataset)
|
30 |
+
self.dataset = dataset
|
31 |
+
self.sort_key = sort_key
|
32 |
+
if indices is None:
|
33 |
+
sort_keys = map(sort_key, dataset)
|
34 |
+
else:
|
35 |
+
sort_keys = ((i, sort_key(dataset[i])) for i in indices)
|
36 |
+
self.sorted_indices = [i for i, _ in sorted(sort_keys, key=operator.itemgetter(1))]
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
return iter(self.sorted_indices)
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.dataset)
|
43 |
+
|
44 |
+
|
45 |
+
class BucketBatchSampler(BatchSampler):
|
46 |
+
""" `BucketBatchSampler` toggles between `sampler` batches and sorted batches.
|
47 |
+
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
|
48 |
+
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted
|
49 |
+
and vice versa. Provides ~10-25 percent speedup.
|
50 |
+
|
51 |
+
Background:
|
52 |
+
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular
|
53 |
+
libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together
|
54 |
+
examples with a similar size length to reduce the padding required for each batch
|
55 |
+
while maintaining some noise through bucketing.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
sampler (torch.data.utils.sampler.Sampler):
|
59 |
+
batch_size (int): Size of mini-batch.
|
60 |
+
drop_last (bool): If `True` the sampler will drop the last batch if its size
|
61 |
+
would be less than `batch_size`.
|
62 |
+
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
63 |
+
bucket_size_multiplier (int, optional): Buckets are of size
|
64 |
+
`batch_size * bucket_size_multiplier`.
|
65 |
+
Example:
|
66 |
+
>>> from torch.utils.data.sampler import SequentialSampler
|
67 |
+
>>> sampler = SequentialSampler(list(range(10)))
|
68 |
+
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
|
69 |
+
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
|
70 |
+
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
|
71 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self,
|
75 |
+
sampler,
|
76 |
+
batch_size,
|
77 |
+
drop_last,
|
78 |
+
sort_key,
|
79 |
+
dataset,
|
80 |
+
bucket_size_multiplier=100):
|
81 |
+
super().__init__(sampler, batch_size, drop_last)
|
82 |
+
self.sort_key = sort_key
|
83 |
+
self.dataset = dataset
|
84 |
+
self.bucket_sampler = BatchSampler(
|
85 |
+
sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
|
86 |
+
|
87 |
+
def __iter__(self):
|
88 |
+
for bucket in self.bucket_sampler:
|
89 |
+
sorted_sampler = SortedSampler(self.dataset, self.sort_key, indices=bucket)
|
90 |
+
for batch in SubsetRandomSampler(
|
91 |
+
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
92 |
+
yield batch
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
if self.drop_last:
|
96 |
+
return len(self.sampler) // self.batch_size
|
97 |
+
else:
|
98 |
+
return math.ceil(len(self.sampler) / self.batch_size)
|
tape/utils/distributed_utils.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import multiprocessing as mp
|
5 |
+
import sys
|
6 |
+
import signal
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore
|
11 |
+
|
12 |
+
from ..errors import EarlyStopping
|
13 |
+
|
14 |
+
|
15 |
+
def reduce_scalar(scalar: float) -> float:
|
16 |
+
if dist.is_available() and dist.is_initialized():
|
17 |
+
float_tensor = torch.cuda.FloatTensor([scalar]) # type: ignore
|
18 |
+
dist.all_reduce(float_tensor)
|
19 |
+
float_tensor /= dist.get_world_size()
|
20 |
+
scalar = float_tensor.item()
|
21 |
+
return scalar
|
22 |
+
|
23 |
+
|
24 |
+
def barrier_if_distributed() -> None:
|
25 |
+
"""Raises a barrier if in a distributed context, otherwise does nothing."""
|
26 |
+
if dist.is_available() and dist.is_initialized():
|
27 |
+
dist.barrier()
|
28 |
+
|
29 |
+
|
30 |
+
def _wrap(fn, kwargs, error_queue):
|
31 |
+
# prctl(2) is a Linux specific system call.
|
32 |
+
# On other systems the following function call has no effect.
|
33 |
+
# This is set to ensure that non-daemonic child processes can
|
34 |
+
# terminate if their parent terminates before they do.
|
35 |
+
_prctl_pr_set_pdeathsig(signal.SIGINT)
|
36 |
+
|
37 |
+
try:
|
38 |
+
fn(**kwargs)
|
39 |
+
except KeyboardInterrupt:
|
40 |
+
pass # SIGINT; Killed by parent, do nothing
|
41 |
+
except EarlyStopping:
|
42 |
+
sys.exit(signal.SIGUSR1) # tape early stop exception
|
43 |
+
except Exception:
|
44 |
+
# Propagate exception to parent process, keeping original traceback
|
45 |
+
import traceback
|
46 |
+
error_queue.put(traceback.format_exc())
|
47 |
+
sys.exit(1)
|
48 |
+
|
49 |
+
|
50 |
+
class ProcessContext:
|
51 |
+
def __init__(self, processes, error_queues):
|
52 |
+
self.error_queues = error_queues
|
53 |
+
self.processes = processes
|
54 |
+
self.sentinels = {
|
55 |
+
process.sentinel: index
|
56 |
+
for index, process in enumerate(processes)
|
57 |
+
}
|
58 |
+
|
59 |
+
def pids(self):
|
60 |
+
return [int(process.pid) for process in self.processes]
|
61 |
+
|
62 |
+
def join(self, timeout=None):
|
63 |
+
r"""
|
64 |
+
Tries to join one or more processes in this process context.
|
65 |
+
If one of them exited with a non-zero exit status, this function
|
66 |
+
kills the remaining processes and raises an exception with the cause
|
67 |
+
of the first process exiting.
|
68 |
+
|
69 |
+
Returns ``True`` if all processes have been joined successfully,
|
70 |
+
``False`` if there are more processes that need to be joined.
|
71 |
+
|
72 |
+
Arguments:
|
73 |
+
timeout (float): Wait this long before giving up on waiting.
|
74 |
+
"""
|
75 |
+
# Ensure this function can be called even when we're done.
|
76 |
+
if len(self.sentinels) == 0:
|
77 |
+
return True
|
78 |
+
|
79 |
+
# Wait for any process to fail or all of them to succeed.
|
80 |
+
ready = mp.connection.wait(
|
81 |
+
self.sentinels.keys(),
|
82 |
+
timeout=timeout,
|
83 |
+
)
|
84 |
+
error_index = None
|
85 |
+
for sentinel in ready:
|
86 |
+
index = self.sentinels.pop(sentinel)
|
87 |
+
process = self.processes[index]
|
88 |
+
process.join()
|
89 |
+
if process.exitcode != 0:
|
90 |
+
error_index = index
|
91 |
+
break
|
92 |
+
# Return if there was no error.
|
93 |
+
if error_index is None:
|
94 |
+
# Return whether or not all processes have been joined.
|
95 |
+
return len(self.sentinels) == 0
|
96 |
+
# Assume failure. Terminate processes that are still alive.
|
97 |
+
for process in self.processes:
|
98 |
+
if process.is_alive():
|
99 |
+
process.terminate()
|
100 |
+
process.join()
|
101 |
+
|
102 |
+
# There won't be an error on the queue if the process crashed.
|
103 |
+
if self.error_queues[error_index].empty():
|
104 |
+
exitcode = self.processes[error_index].exitcode
|
105 |
+
if exitcode == signal.SIGUSR1:
|
106 |
+
return True
|
107 |
+
elif exitcode < 0:
|
108 |
+
name = signal.Signals(-exitcode).name
|
109 |
+
raise Exception(
|
110 |
+
"process %d terminated with signal %s" %
|
111 |
+
(error_index, name)
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
raise Exception(
|
115 |
+
"process %d terminated with exit code %d" %
|
116 |
+
(error_index, exitcode)
|
117 |
+
)
|
118 |
+
|
119 |
+
original_trace = self.error_queues[error_index].get()
|
120 |
+
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
121 |
+
msg += original_trace
|
122 |
+
raise Exception(msg)
|
123 |
+
|
124 |
+
|
125 |
+
def launch_process_group(func: typing.Callable,
|
126 |
+
args: argparse.Namespace,
|
127 |
+
num_processes: int,
|
128 |
+
num_nodes: int = 1,
|
129 |
+
node_rank: int = 0,
|
130 |
+
master_addr: str = "127.0.0.1",
|
131 |
+
master_port: int = 29500,
|
132 |
+
join: bool = True,
|
133 |
+
daemon: bool = False):
|
134 |
+
# world size in terms of number of processes
|
135 |
+
dist_world_size = num_processes * num_nodes
|
136 |
+
|
137 |
+
# set PyTorch distributed related environmental variables
|
138 |
+
current_env = os.environ.copy()
|
139 |
+
current_env["MASTER_ADDR"] = master_addr
|
140 |
+
current_env["MASTER_PORT"] = str(master_port)
|
141 |
+
current_env["WORLD_SIZE"] = str(dist_world_size)
|
142 |
+
if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1:
|
143 |
+
current_env["OMP_NUM_THREADS"] = str(4)
|
144 |
+
|
145 |
+
error_queues = []
|
146 |
+
processes = []
|
147 |
+
|
148 |
+
for local_rank in range(num_processes):
|
149 |
+
# each process's rank
|
150 |
+
dist_rank = num_processes * node_rank + local_rank
|
151 |
+
current_env["RANK"] = str(dist_rank)
|
152 |
+
current_env["LOCAL_RANK"] = str(local_rank)
|
153 |
+
args.local_rank = local_rank
|
154 |
+
|
155 |
+
error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue()
|
156 |
+
kwargs = {'args': args, 'env': current_env}
|
157 |
+
process = mp.Process(
|
158 |
+
target=_wrap,
|
159 |
+
args=(func, kwargs, error_queue),
|
160 |
+
daemon=daemon)
|
161 |
+
process.start()
|
162 |
+
error_queues.append(error_queue)
|
163 |
+
processes.append(process)
|
164 |
+
|
165 |
+
process_context = ProcessContext(processes, error_queues)
|
166 |
+
if not join:
|
167 |
+
return process_context
|
168 |
+
|
169 |
+
while not process_context.join():
|
170 |
+
pass
|
tape/utils/setup_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions to help setup the model, optimizer, distributed compute, etc.
|
2 |
+
"""
|
3 |
+
import typing
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.utils.data import DataLoader, RandomSampler, Dataset
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from ..optimization import AdamW
|
13 |
+
|
14 |
+
from ..registry import registry
|
15 |
+
|
16 |
+
from .utils import get_effective_batch_size
|
17 |
+
from ._sampler import BucketBatchSampler
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def setup_logging(local_rank: int,
|
23 |
+
save_path: typing.Optional[Path] = None,
|
24 |
+
log_level: typing.Union[str, int] = None) -> None:
|
25 |
+
if log_level is None:
|
26 |
+
level = logging.INFO
|
27 |
+
elif isinstance(log_level, str):
|
28 |
+
level = getattr(logging, log_level.upper())
|
29 |
+
elif isinstance(log_level, int):
|
30 |
+
level = log_level
|
31 |
+
|
32 |
+
if local_rank not in (-1, 0):
|
33 |
+
level = max(level, logging.WARN)
|
34 |
+
|
35 |
+
root_logger = logging.getLogger()
|
36 |
+
root_logger.setLevel(level)
|
37 |
+
|
38 |
+
formatter = logging.Formatter(
|
39 |
+
"%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
40 |
+
datefmt="%y/%m/%d %H:%M:%S")
|
41 |
+
|
42 |
+
if not root_logger.hasHandlers():
|
43 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
44 |
+
console_handler.setLevel(level)
|
45 |
+
console_handler.setFormatter(formatter)
|
46 |
+
root_logger.addHandler(console_handler)
|
47 |
+
|
48 |
+
if save_path is not None:
|
49 |
+
file_handler = logging.FileHandler(save_path / 'log')
|
50 |
+
file_handler.setLevel(level)
|
51 |
+
file_handler.setFormatter(formatter)
|
52 |
+
root_logger.addHandler(file_handler)
|
53 |
+
|
54 |
+
|
55 |
+
def setup_optimizer(model,
|
56 |
+
learning_rate: float):
|
57 |
+
"""Create the AdamW optimizer for the given model with the specified learning rate. Based on
|
58 |
+
creation in the pytorch_transformers repository.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
model (PreTrainedModel): The model for which to create an optimizer
|
62 |
+
learning_rate (float): Default learning rate to use when creating the optimizer
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
optimizer (AdamW): An AdamW optimizer
|
66 |
+
|
67 |
+
"""
|
68 |
+
param_optimizer = list(model.named_parameters())
|
69 |
+
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
|
70 |
+
optimizer_grouped_parameters = [
|
71 |
+
{
|
72 |
+
"params": [
|
73 |
+
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
|
74 |
+
],
|
75 |
+
"weight_decay": 0.01,
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"params": [
|
79 |
+
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
|
80 |
+
],
|
81 |
+
"weight_decay": 0.0,
|
82 |
+
},
|
83 |
+
]
|
84 |
+
|
85 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
|
86 |
+
return optimizer
|
87 |
+
|
88 |
+
|
89 |
+
def setup_dataset(task: str,
|
90 |
+
data_dir: typing.Union[str, Path],
|
91 |
+
split: str,
|
92 |
+
tokenizer: str) -> Dataset:
|
93 |
+
task_spec = registry.get_task_spec(task)
|
94 |
+
return task_spec.dataset(data_dir, split, tokenizer) # type: ignore
|
95 |
+
|
96 |
+
|
97 |
+
def setup_loader(dataset: Dataset,
|
98 |
+
batch_size: int,
|
99 |
+
local_rank: int,
|
100 |
+
n_gpu: int,
|
101 |
+
gradient_accumulation_steps: int,
|
102 |
+
num_workers: int) -> DataLoader:
|
103 |
+
sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
|
104 |
+
batch_size = get_effective_batch_size(
|
105 |
+
batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
|
106 |
+
# WARNING: this will fail if the primary sequence is not the first thing the dataset returns
|
107 |
+
batch_sampler = BucketBatchSampler(
|
108 |
+
sampler, batch_size, False, lambda x: len(x[0]), dataset)
|
109 |
+
|
110 |
+
loader = DataLoader(
|
111 |
+
dataset,
|
112 |
+
num_workers=num_workers,
|
113 |
+
collate_fn=dataset.collate_fn, # type: ignore
|
114 |
+
batch_sampler=batch_sampler)
|
115 |
+
|
116 |
+
return loader
|
117 |
+
|
118 |
+
|
119 |
+
def setup_distributed(local_rank: int,
|
120 |
+
no_cuda: bool) -> typing.Tuple[torch.device, int, bool]:
|
121 |
+
if local_rank != -1 and not no_cuda:
|
122 |
+
torch.cuda.set_device(local_rank)
|
123 |
+
device: torch.device = torch.device("cuda", local_rank)
|
124 |
+
n_gpu = 1
|
125 |
+
dist.init_process_group(backend="nccl")
|
126 |
+
elif not torch.cuda.is_available() or no_cuda:
|
127 |
+
device = torch.device("cpu")
|
128 |
+
n_gpu = 1
|
129 |
+
else:
|
130 |
+
device = torch.device("cuda")
|
131 |
+
n_gpu = torch.cuda.device_count()
|
132 |
+
|
133 |
+
is_master = local_rank in (-1, 0)
|
134 |
+
|
135 |
+
return device, n_gpu, is_master
|
tape/utils/utils.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
import logging
|
5 |
+
from time import strftime, gmtime
|
6 |
+
from datetime import datetime
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import contextlib
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
FloatOrTensor = typing.Union[float, torch.Tensor]
|
19 |
+
|
20 |
+
|
21 |
+
def int_or_str(arg: str) -> typing.Union[int, str]:
|
22 |
+
try:
|
23 |
+
return int(arg)
|
24 |
+
except ValueError:
|
25 |
+
return arg
|
26 |
+
|
27 |
+
|
28 |
+
def check_is_file(file_path: str) -> str:
|
29 |
+
if file_path is None or os.path.isfile(file_path):
|
30 |
+
return file_path
|
31 |
+
else:
|
32 |
+
raise argparse.ArgumentTypeError(f"File path: {file_path} is not a valid file")
|
33 |
+
|
34 |
+
|
35 |
+
def check_is_dir(dir_path: str) -> str:
|
36 |
+
if dir_path is None or os.path.isdir(dir_path):
|
37 |
+
return dir_path
|
38 |
+
else:
|
39 |
+
raise argparse.ArgumentTypeError(f"Directory path: {dir_path} is not a valid directory")
|
40 |
+
|
41 |
+
|
42 |
+
def path_to_datetime(path: Path) -> datetime:
|
43 |
+
name = path.name
|
44 |
+
datetime_string = name.split('_')[0]
|
45 |
+
try:
|
46 |
+
year, month, day, hour, minute, second = datetime_string.split('-')
|
47 |
+
except ValueError:
|
48 |
+
try:
|
49 |
+
# Deprecated datetime strings
|
50 |
+
year, month, day, time_str = datetime_string.split('-')
|
51 |
+
hour, minute, second = time_str.split(':')
|
52 |
+
except ValueError:
|
53 |
+
return datetime(1, 1, 1)
|
54 |
+
|
55 |
+
pathdatetime = datetime(
|
56 |
+
int(year), int(month), int(day), int(hour), int(minute), int(second))
|
57 |
+
return pathdatetime
|
58 |
+
|
59 |
+
|
60 |
+
def get_expname(exp_name: typing.Optional[str],
|
61 |
+
task: typing.Optional[str] = None,
|
62 |
+
model_type: typing.Optional[str] = None) -> str:
|
63 |
+
if exp_name is None:
|
64 |
+
time_stamp = strftime("%y-%m-%d-%H-%M-%S", gmtime())
|
65 |
+
exp_name = f"{task}_{model_type}_{time_stamp}_{random.randint(0, int(1e6)):0>6d}"
|
66 |
+
return exp_name
|
67 |
+
|
68 |
+
|
69 |
+
def set_random_seeds(seed: int, n_gpu: int) -> None:
|
70 |
+
random.seed(seed)
|
71 |
+
np.random.seed(seed)
|
72 |
+
torch.manual_seed(seed)
|
73 |
+
if n_gpu > 0:
|
74 |
+
torch.cuda.manual_seed_all(seed) # type: ignore
|
75 |
+
|
76 |
+
|
77 |
+
def get_effective_num_gpus(local_rank: int, n_gpu: int) -> int:
|
78 |
+
if local_rank == -1:
|
79 |
+
num_gpus = n_gpu
|
80 |
+
else:
|
81 |
+
num_gpus = dist.get_world_size()
|
82 |
+
return num_gpus
|
83 |
+
|
84 |
+
|
85 |
+
def get_effective_batch_size(batch_size: int,
|
86 |
+
local_rank: int,
|
87 |
+
n_gpu: int,
|
88 |
+
gradient_accumulation_steps: int = 1) -> int:
|
89 |
+
eff_batch_size = float(batch_size)
|
90 |
+
eff_batch_size /= gradient_accumulation_steps
|
91 |
+
eff_batch_size /= get_effective_num_gpus(local_rank, n_gpu)
|
92 |
+
return int(eff_batch_size)
|
93 |
+
|
94 |
+
|
95 |
+
def get_num_train_optimization_steps(dataset: Dataset,
|
96 |
+
batch_size: int,
|
97 |
+
num_train_epochs: int) -> int:
|
98 |
+
return int(len(dataset) / batch_size * num_train_epochs)
|
99 |
+
|
100 |
+
|
101 |
+
class MetricsAccumulator:
|
102 |
+
|
103 |
+
def __init__(self, smoothing: float = 0.95):
|
104 |
+
self._loss_tmp = 0.
|
105 |
+
self._smoothloss: typing.Optional[float] = None
|
106 |
+
self._totalloss = 0.
|
107 |
+
self._metricstmp: typing.Dict[str, float] = defaultdict(lambda: 0.0)
|
108 |
+
self._smoothmetrics: typing.Dict[str, float] = {}
|
109 |
+
self._totalmetrics: typing.Dict[str, float] = defaultdict(lambda: 0.0)
|
110 |
+
|
111 |
+
self._nacc_steps = 0
|
112 |
+
self._nupdates = 0
|
113 |
+
self._smoothing = smoothing
|
114 |
+
|
115 |
+
def update(self,
|
116 |
+
loss: FloatOrTensor,
|
117 |
+
metrics: typing.Dict[str, FloatOrTensor],
|
118 |
+
step: bool = True) -> None:
|
119 |
+
if isinstance(loss, torch.Tensor):
|
120 |
+
loss = loss.item()
|
121 |
+
|
122 |
+
self._loss_tmp += loss
|
123 |
+
for name, value in metrics.items():
|
124 |
+
if isinstance(value, torch.Tensor):
|
125 |
+
value = value.item()
|
126 |
+
self._metricstmp[name] += value
|
127 |
+
self._nacc_steps += 1
|
128 |
+
|
129 |
+
if step:
|
130 |
+
self.step()
|
131 |
+
|
132 |
+
def step(self) -> typing.Dict[str, float]:
|
133 |
+
loss_tmp = self._loss_tmp / self._nacc_steps
|
134 |
+
metricstmp = {name: value / self._nacc_steps
|
135 |
+
for name, value in self._metricstmp.items()}
|
136 |
+
|
137 |
+
if self._smoothloss is None:
|
138 |
+
self._smoothloss = loss_tmp
|
139 |
+
else:
|
140 |
+
self._smoothloss *= self._smoothing
|
141 |
+
self._smoothloss += (1 - self._smoothing) * loss_tmp
|
142 |
+
self._totalloss += loss_tmp
|
143 |
+
|
144 |
+
for name, value in metricstmp.items():
|
145 |
+
if name in self._smoothmetrics:
|
146 |
+
currvalue = self._smoothmetrics[name]
|
147 |
+
newvalue = currvalue * self._smoothing + value * (1 - self._smoothing)
|
148 |
+
else:
|
149 |
+
newvalue = value
|
150 |
+
|
151 |
+
self._smoothmetrics[name] = newvalue
|
152 |
+
self._totalmetrics[name] += value
|
153 |
+
|
154 |
+
self._nupdates += 1
|
155 |
+
|
156 |
+
self._nacc_steps = 0
|
157 |
+
self._loss_tmp = 0
|
158 |
+
self._metricstmp = defaultdict(lambda: 0.0)
|
159 |
+
|
160 |
+
metricstmp['loss'] = loss_tmp
|
161 |
+
return metricstmp
|
162 |
+
|
163 |
+
def loss(self) -> float:
|
164 |
+
if self._smoothloss is None:
|
165 |
+
raise RuntimeError("Trying to get the loss without any updates")
|
166 |
+
return self._smoothloss
|
167 |
+
|
168 |
+
def metrics(self) -> typing.Dict[str, float]:
|
169 |
+
if self._nupdates == 0:
|
170 |
+
raise RuntimeError("Trying to get metrics without any updates")
|
171 |
+
return dict(self._smoothmetrics)
|
172 |
+
|
173 |
+
def final_loss(self) -> float:
|
174 |
+
return self._totalloss / self._nupdates
|
175 |
+
|
176 |
+
def final_metrics(self) -> typing.Dict[str, float]:
|
177 |
+
return {name: value / self._nupdates
|
178 |
+
for name, value in self._totalmetrics.items()}
|
179 |
+
|
180 |
+
|
181 |
+
class wrap_cuda_oom_error(contextlib.ContextDecorator):
|
182 |
+
"""A context manager that wraps the Cuda OOM message so that you get some more helpful
|
183 |
+
context as to what you can/should change. Can also be used as a decorator.
|
184 |
+
|
185 |
+
Examples:
|
186 |
+
1) As a context manager:
|
187 |
+
|
188 |
+
with wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation):
|
189 |
+
loss = model.forward(batch)
|
190 |
+
loss.backward()
|
191 |
+
optimizer.step()
|
192 |
+
optimizer.zero_grad
|
193 |
+
|
194 |
+
2) As a decorator:
|
195 |
+
|
196 |
+
@wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation)
|
197 |
+
def run_train_epoch(args):
|
198 |
+
...
|
199 |
+
<code to run training epoch>
|
200 |
+
...
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self,
|
204 |
+
local_rank: int,
|
205 |
+
batch_size: int,
|
206 |
+
n_gpu: int = 1,
|
207 |
+
gradient_accumulation_steps: typing.Optional[int] = None):
|
208 |
+
self._local_rank = local_rank
|
209 |
+
self._batch_size = batch_size
|
210 |
+
self._n_gpu = n_gpu
|
211 |
+
self._gradient_accumulation_steps = gradient_accumulation_steps
|
212 |
+
|
213 |
+
def __enter__(self):
|
214 |
+
return self
|
215 |
+
|
216 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
217 |
+
exc_args = exc_value.args if exc_value is not None else None
|
218 |
+
if exc_args and 'CUDA out of memory' in exc_args[0]:
|
219 |
+
eff_ngpu = get_effective_num_gpus(self._local_rank, self._n_gpu)
|
220 |
+
if self._gradient_accumulation_steps is not None:
|
221 |
+
eff_batch_size = get_effective_batch_size(
|
222 |
+
self._batch_size, self._local_rank, self._n_gpu,
|
223 |
+
self._gradient_accumulation_steps)
|
224 |
+
message = (f"CUDA out of memory. Reduce batch size or increase "
|
225 |
+
f"gradient_accumulation_steps to divide each batch over more "
|
226 |
+
f"forward passes.\n\n"
|
227 |
+
f"\tHyperparameters:\n"
|
228 |
+
f"\t\tbatch_size per backward-pass: {self._batch_size}\n"
|
229 |
+
f"\t\tgradient_accumulation_steps: "
|
230 |
+
f"{self._gradient_accumulation_steps}\n"
|
231 |
+
f"\t\tn_gpu: {eff_ngpu}\n"
|
232 |
+
f"\t\tbatch_size per (gpu * forward-pass): "
|
233 |
+
f"{eff_batch_size}")
|
234 |
+
else:
|
235 |
+
eff_batch_size = get_effective_batch_size(
|
236 |
+
self._batch_size, self._local_rank, self._n_gpu)
|
237 |
+
message = (f"CUDA out of memory. Reduce batch size to fit each "
|
238 |
+
f"iteration in memory.\n\n"
|
239 |
+
f"\tHyperparameters:\n"
|
240 |
+
f"\t\tbatch_size per forward-pass: {self._batch_size}\n"
|
241 |
+
f"\t\tn_gpu: {eff_ngpu}\n"
|
242 |
+
f"\t\tbatch_size per (gpu * forward-pass): "
|
243 |
+
f"{eff_batch_size}")
|
244 |
+
raise RuntimeError(message)
|
245 |
+
return False
|
246 |
+
|
247 |
+
|
248 |
+
def write_lmdb(filename: str, iterable: typing.Iterable, map_size: int = 2 ** 20):
|
249 |
+
"""Utility for writing a dataset to an LMDB file.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
filename (str): Output filename to write to
|
253 |
+
iterable (Iterable): An iterable dataset to write to. Entries must be pickleable.
|
254 |
+
map_size (int, optional): Maximum allowable size of database in bytes. Required by LMDB.
|
255 |
+
You will likely have to increase this. Default: 1MB.
|
256 |
+
"""
|
257 |
+
import lmdb
|
258 |
+
import pickle as pkl
|
259 |
+
env = lmdb.open(filename, map_size=map_size)
|
260 |
+
|
261 |
+
with env.begin(write=True) as txn:
|
262 |
+
for i, entry in enumerate(iterable):
|
263 |
+
txn.put(str(i).encode(), pkl.dumps(entry))
|
264 |
+
txn.put(b'num_examples', pkl.dumps(i + 1))
|
265 |
+
env.close()
|
266 |
+
|
267 |
+
|
268 |
+
class IncrementalNPZ(object):
|
269 |
+
# Modified npz that allows incremental saving, from https://stackoverflow.com/questions/22712292/how-to-use-numpy-savez-in-a-loop-for-save-more-than-one-array # noqa: E501
|
270 |
+
def __init__(self, file):
|
271 |
+
import tempfile
|
272 |
+
import zipfile
|
273 |
+
import os
|
274 |
+
|
275 |
+
if isinstance(file, str):
|
276 |
+
if not file.endswith('.npz'):
|
277 |
+
file = file + '.npz'
|
278 |
+
|
279 |
+
compression = zipfile.ZIP_STORED
|
280 |
+
|
281 |
+
zipfile = self.zipfile_factory(file, mode="a", compression=compression)
|
282 |
+
|
283 |
+
# Stage arrays in a temporary file on disk, before writing to zip.
|
284 |
+
fd, tmpfile = tempfile.mkstemp(suffix='-numpy.npy')
|
285 |
+
os.close(fd)
|
286 |
+
|
287 |
+
self.tmpfile = tmpfile
|
288 |
+
self.zip = zipfile
|
289 |
+
self._i = 0
|
290 |
+
|
291 |
+
def zipfile_factory(self, *args, **kwargs):
|
292 |
+
import zipfile
|
293 |
+
import sys
|
294 |
+
if sys.version_info >= (2, 5):
|
295 |
+
kwargs['allowZip64'] = True
|
296 |
+
return zipfile.ZipFile(*args, **kwargs)
|
297 |
+
|
298 |
+
def savez(self, *args, **kwds):
|
299 |
+
import os
|
300 |
+
import numpy.lib.format as fmt
|
301 |
+
|
302 |
+
namedict = kwds
|
303 |
+
for val in args:
|
304 |
+
key = 'arr_%d' % self._i
|
305 |
+
if key in namedict.keys():
|
306 |
+
raise ValueError("Cannot use un-named variables and keyword %s" % key)
|
307 |
+
namedict[key] = val
|
308 |
+
self._i += 1
|
309 |
+
|
310 |
+
try:
|
311 |
+
for key, val in namedict.items():
|
312 |
+
fname = key + '.npy'
|
313 |
+
fid = open(self.tmpfile, 'wb')
|
314 |
+
with open(self.tmpfile, 'wb') as fid:
|
315 |
+
fmt.write_array(fid, np.asanyarray(val), allow_pickle=True)
|
316 |
+
self.zip.write(self.tmpfile, arcname=fname)
|
317 |
+
finally:
|
318 |
+
os.remove(self.tmpfile)
|
319 |
+
|
320 |
+
def close(self):
|
321 |
+
self.zip.close()
|
322 |
+
|
323 |
+
def __enter__(self):
|
324 |
+
return self
|
325 |
+
|
326 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
327 |
+
self.close()
|
tape/visualization.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from pathlib import Path
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from tensorboardX import SummaryWriter
|
9 |
+
|
10 |
+
try:
|
11 |
+
import wandb
|
12 |
+
WANDB_FOUND = True
|
13 |
+
except ImportError:
|
14 |
+
WANDB_FOUND = False
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class TAPEVisualizer(ABC):
|
20 |
+
"""Base class for visualization in TAPE"""
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def watch(self, model: nn.Module) -> None:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def log_metrics(self,
|
36 |
+
metrics_dict: typing.Dict[str, float],
|
37 |
+
split: str,
|
38 |
+
step: int):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
|
42 |
+
class DummyVisualizer(TAPEVisualizer):
|
43 |
+
"""Dummy class that doesn't do anything. Used for non-master branches."""
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
log_dir: typing.Union[str, Path] = '',
|
47 |
+
exp_name: str = '',
|
48 |
+
debug: bool = False):
|
49 |
+
pass
|
50 |
+
|
51 |
+
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
|
52 |
+
pass
|
53 |
+
|
54 |
+
def watch(self, model: nn.Module) -> None:
|
55 |
+
pass
|
56 |
+
|
57 |
+
def log_metrics(self,
|
58 |
+
metrics_dict: typing.Dict[str, float],
|
59 |
+
split: str,
|
60 |
+
step: int):
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
class TBVisualizer(TAPEVisualizer):
|
65 |
+
|
66 |
+
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
|
67 |
+
log_dir = Path(log_dir) / exp_name
|
68 |
+
logger.info(f"tensorboard file at: {log_dir}")
|
69 |
+
self.logger = SummaryWriter(log_dir=str(log_dir))
|
70 |
+
|
71 |
+
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
|
72 |
+
logger.warn("Cannot log config when using a TBVisualizer. "
|
73 |
+
"Configure wandb for this functionality")
|
74 |
+
|
75 |
+
def watch(self, model: nn.Module) -> None:
|
76 |
+
logger.warn("Cannot watch models when using a TBVisualizer. "
|
77 |
+
"Configure wandb for this functionality")
|
78 |
+
|
79 |
+
def log_metrics(self,
|
80 |
+
metrics_dict: typing.Dict[str, float],
|
81 |
+
split: str,
|
82 |
+
step: int):
|
83 |
+
for name, value in metrics_dict.items():
|
84 |
+
self.logger.add_scalar(split + "/" + name, value, step)
|
85 |
+
|
86 |
+
|
87 |
+
class WandBVisualizer(TAPEVisualizer):
|
88 |
+
|
89 |
+
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
|
90 |
+
if not WANDB_FOUND:
|
91 |
+
raise ImportError("wandb module not available")
|
92 |
+
#if debug:
|
93 |
+
# os.environ['WANDB_MODE'] = 'dryrun'
|
94 |
+
#if 'WANDB_PROJECT' not in os.environ:
|
95 |
+
# # Want the user to set the WANDB_PROJECT.
|
96 |
+
# logger.warning("WANDB_PROJECT environment variable not found, "
|
97 |
+
# "not logging to app.wandb.ai")
|
98 |
+
# os.environ['WANDB_MODE'] = 'dryrun'
|
99 |
+
wandb.init(dir=log_dir, name=exp_name)
|
100 |
+
|
101 |
+
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
|
102 |
+
wandb.config.update(config)
|
103 |
+
|
104 |
+
def watch(self, model: nn.Module):
|
105 |
+
wandb.watch(model)
|
106 |
+
|
107 |
+
def log_metrics(self,
|
108 |
+
metrics_dict: typing.Dict[str, float],
|
109 |
+
split: str,
|
110 |
+
step: int):
|
111 |
+
wandb.log({f"{split.capitalize()} {name.capitalize()}": value
|
112 |
+
for name, value in metrics_dict.items()}, step=step)
|
113 |
+
|
114 |
+
|
115 |
+
def get(log_dir: typing.Union[str, Path],
|
116 |
+
exp_name: str,
|
117 |
+
local_rank: int,
|
118 |
+
debug: bool = False) -> TAPEVisualizer:
|
119 |
+
if local_rank not in (-1, 0):
|
120 |
+
return DummyVisualizer(log_dir, exp_name, debug)
|
121 |
+
elif WANDB_FOUND:
|
122 |
+
return WandBVisualizer(log_dir, exp_name, debug)
|
123 |
+
else:
|
124 |
+
return TBVisualizer(log_dir, exp_name, debug)
|