|
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
|
import torch |
|
import os |
|
|
|
class FusOnTokenizer: |
|
""" |
|
FusOnTokenizer class: a wrapper around AutoTokenizer |
|
""" |
|
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'): |
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) |
|
|
|
def __getattr__(self, name): |
|
""" |
|
Delegate attribute access to the underlying tokenizer. |
|
This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer. |
|
""" |
|
return getattr(self.tokenizer, name) |
|
|
|
def __call__(self, *args, **kwargs): |
|
""" |
|
Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method. |
|
""" |
|
return self.tokenizer(*args, **kwargs) |
|
|
|
def save_tokenizer(self, save_directory): |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
def load_tokenizer(self, load_directory): |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_directory) |
|
|
|
class FusOnpLM: |
|
""" |
|
FusOn-pLM class: a wrapper around AutoModelForMaskedLM |
|
""" |
|
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False): |
|
if not(ckpt_path is None): |
|
self.load_model(ckpt_path, mlm_head) |
|
else: |
|
|
|
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path) |
|
self.tokenizer = FusOnTokenizer(pretrained_path) |
|
|
|
self.n_layers = self.count_encoder_layers() |
|
|
|
def __getattr__(self, name): |
|
""" |
|
Delegate attribute access to the underlying model. |
|
This allows calls like .to(), .train(), and .eval() to be forwarded to the model. |
|
""" |
|
return getattr(self.model, name) |
|
|
|
def __call__(self, *args, **kwargs): |
|
""" |
|
Make the FusOnpLM object callable, delegating to the model's __call__ method. |
|
""" |
|
return self.model(*args, **kwargs) |
|
|
|
def freeze_model(self): |
|
""" |
|
Freezes all parameters in the model |
|
""" |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True): |
|
""" |
|
Unfreezes specific parts of the final n layers in the model's encoder. |
|
|
|
Args: |
|
n_unfrozen_layers (int): Number of final layers to unfreeze. |
|
unfreeze_query (bool): Whether to unfreeze the query projections. Default is True. |
|
unfreeze_key (bool): Whether to unfreeze the key projections. Default is True. |
|
unfreeze_value (bool): Whether to unfreeze the value projections. Default is True. |
|
""" |
|
for i, layer in enumerate(self.model.esm.encoder.layer): |
|
if (self.n_layers - i) <= n_unfrozen_layers: |
|
if unfreeze_query: |
|
self._unfreeze_parameters(layer.attention.self.query) |
|
if unfreeze_key: |
|
self._unfreeze_parameters(layer.attention.self.key) |
|
if unfreeze_value: |
|
self._unfreeze_parameters(layer.attention.self.value) |
|
|
|
def _unfreeze_parameters(self, module): |
|
""" |
|
Helper method to unfreeze parameters in a given module. |
|
|
|
Args: |
|
module (nn.Module): The module whose parameters are to be unfrozen. |
|
""" |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
def count_encoder_layers(self): |
|
""" |
|
Count the number of encoder layers in the model. |
|
""" |
|
return len(self.model.esm.encoder.layer) |
|
|
|
def save_model(self, save_directory, optimizer=None): |
|
|
|
self.model.save_pretrained(save_directory) |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
|
|
if optimizer is not None: |
|
optimizer_path = os.path.join(save_directory, "optimizer.pt") |
|
torch.save(optimizer.state_dict(), optimizer_path) |
|
|
|
def load_model(self, load_directory, mlm_head): |
|
|
|
if mlm_head: |
|
self.model = AutoModelForMaskedLM.from_pretrained(load_directory) |
|
else: |
|
|
|
self.model = AutoModel.from_pretrained(load_directory) |
|
self.tokenizer = AutoTokenizer.from_pretrained(load_directory) |
|
|
|
|