METL / huggingface_wrapper.py
jgpeters's picture
Upload model
3fd5dd8 verified
raw
history blame
1.46 kB
from transformers import PretrainedConfig, PreTrainedModel
import metl.models as models
from metl.main import *
class METLConfig(PretrainedConfig):
IDENT_UUID_MAP = IDENT_UUID_MAP
UUID_URL_MAP = UUID_URL_MAP
model_type = "METL"
def __init__(
self,
id:str = None,
**kwargs,
):
self.id = id
super().__init__(**kwargs)
class METLModel(PreTrainedModel):
config_class = METLConfig
def __init__(self, config:METLConfig):
super().__init__(config)
self.model = None
self.encoder = None
self.config = config
def forward(self, X, pdb_fn=None):
if pdb_fn:
return self.model(X, pdb_fn=pdb_fn)
return self.model(X)
def load_from_uuid(self, id):
if id:
id = id.lower()
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id
self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])
def load_from_indent(self, id):
if id:
id = id.lower()
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id
self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])