MeMDLM / benchmarks /MLM /model.py
sgoel30's picture
Upload 34 files
d8ed92a verified
raw
history blame
2.44 kB
import config
import torch
import torch.nn as nn
from pretrained_models import load_esm2_model
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
class MembraneTokenizer:
def __init__(self, esm_model_path=config.ESM_MODEL_PATH):
self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
def __getattr__(self, name):
return getattr(self.tokenizer, name)
def __call__(self, *args, **kwargs):
return self.tokenizer(*args, **kwargs)
def save_tokenizer(self, save_dir):
self.tokenizer.save_pretrained(save_dir)
def load_tokenizer(self, load_dir):
self.tokenizer.save_pretrained(load_dir)
class MembraneMLM:
def __init__(self, esm_model_path=config.ESM_MODEL_PATH):
self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
def __getattr__(self, name):
return getattr(self.model, name)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def freeze_model(self):
# Disable parameter updates for all layers
for param in self.model.parameters():
param.requires_grad = False
def unfreeze_n_layers(self):
# Count number of encoder layers
model_layers = len(self.model.esm.encoder.layer)
# Enable parameter updates for the last 3 encoder layers
for i, layer in enumerate(self.model.esm.encoder.layer):
if i >= model_layers-config.ESM_LAYERS:
for module in layer.attention.self.key.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.query.modules():
for param in module.parameters():
param.requires_grad = True
for module in layer.attention.self.value.modules():
for param in module.parameters():
param.requires_grad = True
def forward(self, **inputs):
return self.model(**inputs)
def save_model(self, save_dir):
self.model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
def load_model(self, load_dir):
self.model = AutoModel.from_pretrained(load_dir)
self.tokenizer = AutoTokenizer.from_pretrained(load_dir)