File size: 1,462 Bytes
bb8b1b9 3fd5dd8 bb8b1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from transformers import PretrainedConfig, PreTrainedModel
import metl.models as models
from metl.main import *
class METLConfig(PretrainedConfig):
IDENT_UUID_MAP = IDENT_UUID_MAP
UUID_URL_MAP = UUID_URL_MAP
model_type = "METL"
def __init__(
self,
id:str = None,
**kwargs,
):
self.id = id
super().__init__(**kwargs)
class METLModel(PreTrainedModel):
config_class = METLConfig
def __init__(self, config:METLConfig):
super().__init__(config)
self.model = None
self.encoder = None
self.config = config
def forward(self, X, pdb_fn=None):
if pdb_fn:
return self.model(X, pdb_fn=pdb_fn)
return self.model(X)
def load_from_uuid(self, id):
if id:
id = id.lower()
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id
self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])
def load_from_indent(self, id):
if id:
id = id.lower()
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id
self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])
|