import os import yaml import torch from transformers import AlbertConfig, AlbertModel class CustomAlbert(AlbertModel): def forward(self, *args, **kwargs): # Call the original forward method outputs = super().forward(*args, **kwargs) # Only return the last_hidden_state return outputs.last_hidden_state def load_plbert(wights_path, config_path): plbert_config = yaml.safe_load(open(config_path)) albert_base_configuration = AlbertConfig(**plbert_config['model_params']) bert = CustomAlbert(albert_base_configuration) state_dict = torch.load(wights_path, map_location='cpu') bert.load_state_dict(state_dict, strict=False) return bert