File size: 5,783 Bytes
a352c78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torch.hub
import metl.models as models
from metl.encode import DataEncoder, Encoding
UUID_URL_MAP = {
# global source models
"D72M9aEp": "https://zenodo.org/records/11051645/files/METL-G-20M-1D-D72M9aEp.pt?download=1",
"Nr9zCKpR": "https://zenodo.org/records/11051645/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1",
"auKdzzwX": "https://zenodo.org/records/11051645/files/METL-G-50M-1D-auKdzzwX.pt?download=1",
"6PSAzdfv": "https://zenodo.org/records/11051645/files/METL-G-50M-3D-6PSAzdfv.pt?download=1",
# local source models
"8gMPQJy4": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1",
"Hr4GNHws": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1",
"8iFoiYw2": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1",
"kt5DdWTa": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1",
"DMfkjVzT": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1",
"epegcFiH": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1",
"kS3rUS7h": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1",
"X7w83g6S": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1",
"UKebCQGz": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1",
"2rr8V4th": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1",
"PREhfC22": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1",
"9ASvszux": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1",
"HscFFkAb": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1",
"H48oiNZN": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1",
# metl bind source models
"K6mw24Rg": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1",
"Bo5wn2SG": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1",
# finetuned models from GFP design experiment
"YoQkzoLD": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1",
"PEkeRuxb": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1",
}
IDENT_UUID_MAP = {
# the keys should be all lowercase
"metl-g-20m-1d": "D72M9aEp",
"metl-g-20m-3d": "Nr9zCKpR",
"metl-g-50m-1d": "auKdzzwX",
"metl-g-50m-3d": "6PSAzdfv",
# GFP local source models
"metl-l-2m-1d-gfp": "8gMPQJy4",
"metl-l-2m-3d-gfp": "Hr4GNHws",
# DLG4 local source models
"metl-l-2m-1d-dlg4": "8iFoiYw2",
"metl-l-2m-3d-dlg4": "kt5DdWTa",
# GB1 local source models
"metl-l-2m-1d-gb1": "DMfkjVzT",
"metl-l-2m-3d-gb1": "epegcFiH",
# GRB2 local source models
"metl-l-2m-1d-grb2": "kS3rUS7h",
"metl-l-2m-3d-grb2": "X7w83g6S",
# Pab1 local source models
"metl-l-2m-1d-pab1": "UKebCQGz",
"metl-l-2m-3d-pab1": "2rr8V4th",
# TEM-1 local source models
"metl-l-2m-1d-tem-1": "PREhfC22",
"metl-l-2m-3d-tem-1": "9ASvszux",
# Ube4b local source models
"metl-l-2m-1d-ube4b": "HscFFkAb",
"metl-l-2m-3d-ube4b": "H48oiNZN",
# METL-Bind for GB1
"metl-bind-2m-3d-gb1-standard": "K6mw24Rg",
"metl-bind-2m-3d-gb1-binding": "Bo5wn2SG",
# GFP design models, giving them an ident
"metl-l-2m-1d-gfp-ft-design": "YoQkzoLD",
"metl-l-2m-3d-gfp-ft-design": "PEkeRuxb",
}
def download_checkpoint(uuid):
ckpt = torch.hub.load_state_dict_from_url(UUID_URL_MAP[uuid],
map_location="cpu", file_name=f"{uuid}.pt")
state_dict = ckpt["state_dict"]
hyper_parameters = ckpt["hyper_parameters"]
return state_dict, hyper_parameters
def _get_data_encoding(hparams):
if "encoding" in hparams and hparams["encoding"] == "int_seqs":
encoding = Encoding.INT_SEQS
elif "encoding" in hparams and hparams["encoding"] == "one_hot":
encoding = Encoding.ONE_HOT
elif (("encoding" in hparams and hparams["encoding"] == "auto") or "encoding" not in hparams) and \
hparams["model_name"] in ["transformer_encoder"]:
encoding = Encoding.INT_SEQS
else:
raise ValueError("Detected unsupported encoding in hyperparameters")
return encoding
def load_model_and_data_encoder(state_dict, hparams):
model = models.Model[hparams["model_name"]].cls(**hparams)
model.load_state_dict(state_dict)
data_encoder = DataEncoder(_get_data_encoding(hparams))
return model, data_encoder
def get_from_uuid(uuid):
if uuid in UUID_URL_MAP:
state_dict, hparams = download_checkpoint(uuid)
return load_model_and_data_encoder(state_dict, hparams)
else:
raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP")
def get_from_ident(ident):
ident = ident.lower()
if ident in IDENT_UUID_MAP:
state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
return load_model_and_data_encoder(state_dict, hparams)
else:
raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP")
def get_from_checkpoint(ckpt_fn):
ckpt = torch.load(ckpt_fn, map_location="cpu")
state_dict = ckpt["state_dict"]
hyper_parameters = ckpt["hyper_parameters"]
return load_model_and_data_encoder(state_dict, hyper_parameters)
|