import config import torch import torch.nn as nn from pretrained_models import load_esm2_model from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel class MembraneTokenizer: def __init__(self, esm_model_path=config.ESM_MODEL_PATH): self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) def __getattr__(self, name): return getattr(self.tokenizer, name) def __call__(self, *args, **kwargs): return self.tokenizer(*args, **kwargs) def save_tokenizer(self, save_dir): self.tokenizer.save_pretrained(save_dir) def load_tokenizer(self, load_dir): self.tokenizer.save_pretrained(load_dir) class MembraneMLM: def __init__(self, esm_model_path=config.ESM_MODEL_PATH): self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path) self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) def __getattr__(self, name): return getattr(self.model, name) def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def freeze_model(self): # Disable parameter updates for all layers for param in self.model.parameters(): param.requires_grad = False def unfreeze_n_layers(self): # Count number of encoder layers model_layers = len(self.model.esm.encoder.layer) # Enable parameter updates for the last 3 encoder layers for i, layer in enumerate(self.model.esm.encoder.layer): if i >= model_layers-config.ESM_LAYERS: for module in layer.attention.self.key.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.query.modules(): for param in module.parameters(): param.requires_grad = True for module in layer.attention.self.value.modules(): for param in module.parameters(): param.requires_grad = True def forward(self, **inputs): return self.model(**inputs) def save_model(self, save_dir): self.model.save_pretrained(save_dir) self.tokenizer.save_pretrained(save_dir) def load_model(self, load_dir): self.model = AutoModel.from_pretrained(load_dir) self.tokenizer = AutoTokenizer.from_pretrained(load_dir)