File size: 2,437 Bytes
d8ed92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import config
import torch
import torch.nn as nn
from pretrained_models import load_esm2_model
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel

class MembraneTokenizer:
    def __init__(self, esm_model_path=config.ESM_MODEL_PATH):
        self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
    
    def __getattr__(self, name):
        return getattr(self.tokenizer, name)

    def __call__(self, *args, **kwargs):
        return self.tokenizer(*args, **kwargs)
    
    def save_tokenizer(self, save_dir):
        self.tokenizer.save_pretrained(save_dir)
    
    def load_tokenizer(self, load_dir):
        self.tokenizer.save_pretrained(load_dir)

class MembraneMLM:
    def __init__(self, esm_model_path=config.ESM_MODEL_PATH):
        self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
    
    def __getattr__(self, name):
        return getattr(self.model, name)

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def freeze_model(self):
       # Disable parameter updates for all layers
        for param in self.model.parameters():
            param.requires_grad = False

    def unfreeze_n_layers(self):
        # Count number of encoder layers
        model_layers = len(self.model.esm.encoder.layer)

        # Enable parameter updates for the last 3 encoder layers
        for i, layer in enumerate(self.model.esm.encoder.layer):
            if i >= model_layers-config.ESM_LAYERS:
                for module in layer.attention.self.key.modules():
                    for param in module.parameters():
                        param.requires_grad = True
                for module in layer.attention.self.query.modules():
                    for param in module.parameters():
                        param.requires_grad = True
                for module in layer.attention.self.value.modules():
                    for param in module.parameters():
                        param.requires_grad = True
    
    def forward(self, **inputs):
        return self.model(**inputs)

    def save_model(self, save_dir):
        self.model.save_pretrained(save_dir)
        self.tokenizer.save_pretrained(save_dir)

    def load_model(self, load_dir):
        self.model = AutoModel.from_pretrained(load_dir)
        self.tokenizer = AutoTokenizer.from_pretrained(load_dir)