|
from transformers import PretrainedConfig, PreTrainedModel
|
|
import metl.models as models
|
|
from metl.main import *
|
|
|
|
class METLConfig(PretrainedConfig):
|
|
IDENT_UUID_MAP = IDENT_UUID_MAP
|
|
UUID_URL_MAP = UUID_URL_MAP
|
|
model_type = "METL"
|
|
|
|
def __init__(
|
|
self,
|
|
id:str = None,
|
|
**kwargs,
|
|
):
|
|
self.id = id
|
|
super().__init__(**kwargs)
|
|
|
|
class METLModel(PreTrainedModel):
|
|
config_class = METLConfig
|
|
def __init__(self, config:METLConfig):
|
|
super().__init__(config)
|
|
self.model = None
|
|
self.encoder = None
|
|
self.config = config
|
|
|
|
def forward(self, X, pdb_fn=None):
|
|
if pdb_fn:
|
|
return self.model(X, pdb_fn=pdb_fn)
|
|
return self.model(X)
|
|
|
|
def load_from_uuid(self, id):
|
|
if id:
|
|
id = id.lower()
|
|
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
|
|
self.config.id = id
|
|
|
|
self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])
|
|
|
|
def load_from_indent(self, id):
|
|
if id:
|
|
id = id.lower()
|
|
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
|
|
self.config.id = id
|
|
|
|
self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])
|
|
|