svincoff's picture
uploaded training code and model weights
9a73cb0
raw
history blame
4.67 kB
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import torch
import os
class FusOnTokenizer:
"""
FusOnTokenizer class: a wrapper around AutoTokenizer
"""
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'):
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
def __getattr__(self, name):
"""
Delegate attribute access to the underlying tokenizer.
This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer.
"""
return getattr(self.tokenizer, name)
def __call__(self, *args, **kwargs):
"""
Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method.
"""
return self.tokenizer(*args, **kwargs)
def save_tokenizer(self, save_directory):
self.tokenizer.save_pretrained(save_directory)
def load_tokenizer(self, load_directory):
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
class FusOnpLM:
"""
FusOn-pLM class: a wrapper around AutoModelForMaskedLM
"""
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False):
if not(ckpt_path is None):
self.load_model(ckpt_path, mlm_head)
else:
# Load the pre-trained model and tokenizer
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
self.tokenizer = FusOnTokenizer(pretrained_path)
self.n_layers = self.count_encoder_layers()
def __getattr__(self, name):
"""
Delegate attribute access to the underlying model.
This allows calls like .to(), .train(), and .eval() to be forwarded to the model.
"""
return getattr(self.model, name)
def __call__(self, *args, **kwargs):
"""
Make the FusOnpLM object callable, delegating to the model's __call__ method.
"""
return self.model(*args, **kwargs)
def freeze_model(self):
"""
Freezes all parameters in the model
"""
for param in self.model.parameters():
param.requires_grad = False
def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
"""
Unfreezes specific parts of the final n layers in the model's encoder.
Args:
n_unfrozen_layers (int): Number of final layers to unfreeze.
unfreeze_query (bool): Whether to unfreeze the query projections. Default is True.
unfreeze_key (bool): Whether to unfreeze the key projections. Default is True.
unfreeze_value (bool): Whether to unfreeze the value projections. Default is True.
"""
for i, layer in enumerate(self.model.esm.encoder.layer):
if (self.n_layers - i) <= n_unfrozen_layers: # Only the last n layers
if unfreeze_query:
self._unfreeze_parameters(layer.attention.self.query)
if unfreeze_key:
self._unfreeze_parameters(layer.attention.self.key)
if unfreeze_value:
self._unfreeze_parameters(layer.attention.self.value)
def _unfreeze_parameters(self, module):
"""
Helper method to unfreeze parameters in a given module.
Args:
module (nn.Module): The module whose parameters are to be unfrozen.
"""
for param in module.parameters():
param.requires_grad = True
def count_encoder_layers(self):
"""
Count the number of encoder layers in the model.
"""
return len(self.model.esm.encoder.layer)
def save_model(self, save_directory, optimizer=None):
# Save the model and tokenizer
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
# If an optimizer is provided, save its state dict
if optimizer is not None:
optimizer_path = os.path.join(save_directory, "optimizer.pt")
torch.save(optimizer.state_dict(), optimizer_path)
def load_model(self, load_directory, mlm_head):
# Load a checkpoint of the model either with or without an MLM head
if mlm_head:
self.model = AutoModelForMaskedLM.from_pretrained(load_directory)
else:
# Load the model and tokenizer from a directory
self.model = AutoModel.from_pretrained(load_directory)
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)