File size: 1,462 Bytes
bb8b1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fd5dd8
bb8b1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import PretrainedConfig, PreTrainedModel
import metl.models as models
from metl.main import *

class METLConfig(PretrainedConfig):
    IDENT_UUID_MAP = IDENT_UUID_MAP
    UUID_URL_MAP = UUID_URL_MAP
    model_type = "METL"

    def __init__(

            self,

            id:str = None,

            **kwargs,

    ):
        self.id = id
        super().__init__(**kwargs)

class METLModel(PreTrainedModel):
    config_class = METLConfig
    def __init__(self, config:METLConfig):
        super().__init__(config)
        self.model = None
        self.encoder = None
        self.config = config
        
    def forward(self, X, pdb_fn=None):
        if pdb_fn:
            return self.model(X, pdb_fn=pdb_fn)
        return self.model(X)
    
    def load_from_uuid(self, id):
        if id:
            id = id.lower()
            assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
            self.config.id = id

        self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])

    def load_from_indent(self, id):
        if id:
            id = id.lower()
            assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
            self.config.id = id

        self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])