ClinicalMosaic / bert_embeddings.py
Sifal's picture
fix import
4f1265c verified
raw
history blame
2.89 kB
import logging
from typing import Optional
import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from transformers import BertPreTrainedModel
from bert_layers_mosa import BertModel
logger = logging.getLogger(__name__)
class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
def __init__(self, config, add_pooling_layer=False):
"""
Initializes the BertEmbeddings class.
Args:
config (BertConfig): The configuration for the BERT model.
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
"""
super().__init__(config)
assert (
config.num_hidden_layers >= config.num_embedding_layers
), "num_hidden_layers should be greater than or equal to num_embedding_layers"
self.config = config
self.strategy = config.strategy
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
# this resets the weights
self.post_init()
@classmethod
def from_pretrained(
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
):
"""Load from pre-trained."""
# this gets a fresh init model
model = cls(config, *inputs, **kwargs)
# thus we need to load the state_dict
state_dict = torch.load(pretrained_checkpoint)
# remove `model` prefix to avoid error
consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
logger.warning(
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
)
logger.warning(f"the number of which is equal to {len(missing_keys)}")
if len(unexpected_keys) > 0:
logger.warning(
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
)
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
return model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
subset_mask: Optional[torch.Tensor] = None,
output_all_encoded_layers: Book = True,
) -> torch.Tensor:
embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
encoder_outputs_all = self.bert.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask,
)
# batch_size, hidden_dim
return encoder_outputs_all