from transformers import PretrainedConfig, PreTrainedModel import metl.models as models from metl.main import * class METLConfig(PretrainedConfig): IDENT_UUID_MAP = IDENT_UUID_MAP UUID_URL_MAP = UUID_URL_MAP model_type = "METL" def __init__( self, id:str = None, **kwargs, ): self.id = id super().__init__(**kwargs) class METLModel(PreTrainedModel): config_class = METLConfig def __init__(self, config:METLConfig): super().__init__(config) self.model = None self.encoder = None self.config = config def forward(self, X, pdb_fn=None): if pdb_fn: return self.model(X, pdb_fn=pdb_fn) return self.model(X) def load_from_uuid(self, id): if id: id = id.lower() assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP" self.config.id = id self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id]) def load_from_indent(self, id): if id: id = id.lower() assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP" self.config.id = id self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])