File size: 2,890 Bytes
402b3d3 4f1265c 402b3d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import logging
from typing import Optional
import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from transformers import BertPreTrainedModel
from bert_layers_mosa import BertModel
logger = logging.getLogger(__name__)
class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
def __init__(self, config, add_pooling_layer=False):
"""
Initializes the BertEmbeddings class.
Args:
config (BertConfig): The configuration for the BERT model.
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
"""
super().__init__(config)
assert (
config.num_hidden_layers >= config.num_embedding_layers
), "num_hidden_layers should be greater than or equal to num_embedding_layers"
self.config = config
self.strategy = config.strategy
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
# this resets the weights
self.post_init()
@classmethod
def from_pretrained(
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
):
"""Load from pre-trained."""
# this gets a fresh init model
model = cls(config, *inputs, **kwargs)
# thus we need to load the state_dict
state_dict = torch.load(pretrained_checkpoint)
# remove `model` prefix to avoid error
consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
logger.warning(
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
)
logger.warning(f"the number of which is equal to {len(missing_keys)}")
if len(unexpected_keys) > 0:
logger.warning(
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
)
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
return model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
subset_mask: Optional[torch.Tensor] = None,
output_all_encoded_layers: Book = True,
) -> torch.Tensor:
embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
encoder_outputs_all = self.bert.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask,
)
# batch_size, hidden_dim
return encoder_outputs_all |