|
import config |
|
import torch |
|
import torch.nn as nn |
|
from pretrained_models import load_esm2_model |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
|
|
|
class MembraneTokenizer: |
|
def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
|
self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.tokenizer, name) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.tokenizer(*args, **kwargs) |
|
|
|
def save_tokenizer(self, save_dir): |
|
self.tokenizer.save_pretrained(save_dir) |
|
|
|
def load_tokenizer(self, load_dir): |
|
self.tokenizer.save_pretrained(load_dir) |
|
|
|
class MembraneMLM: |
|
def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
|
self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
|
|
|
def __getattr__(self, name): |
|
return getattr(self.model, name) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def freeze_model(self): |
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_n_layers(self): |
|
|
|
model_layers = len(self.model.esm.encoder.layer) |
|
|
|
|
|
for i, layer in enumerate(self.model.esm.encoder.layer): |
|
if i >= model_layers-config.ESM_LAYERS: |
|
for module in layer.attention.self.key.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.query.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
for module in layer.attention.self.value.modules(): |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
|
|
def forward(self, **inputs): |
|
return self.model(**inputs) |
|
|
|
def save_model(self, save_dir): |
|
self.model.save_pretrained(save_dir) |
|
self.tokenizer.save_pretrained(save_dir) |
|
|
|
def load_model(self, load_dir): |
|
self.model = AutoModel.from_pretrained(load_dir) |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_dir) |