ClinicalMosaic / automodel.py
Sifal's picture
rm deprecated assertion
fd870aa verified
raw
history blame
8.5 kB
import logging
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from transformers import BertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .bert_layers_mosa import BertModel
logger = logging.getLogger(__name__)
class ClinicalMosaicForEmbeddingGeneration(BertPreTrainedModel):
def __init__(self, config, **kwargs):
"""
Initializes the BertEmbeddings class.
Args:
config (BertConfig): The configuration for the BERT model.
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
"""
super().__init__(config)
self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
# this resets the weights
self.post_init()
@classmethod
def from_pretrained(
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
):
"""Load from pre-trained."""
# this gets a fresh init model
model = cls(config, *inputs, **kwargs)
# Download the model file
archive_file = hf_hub_download(
repo_id=pretrained_checkpoint,
filename="model.safetensors",
)
# Load the state_dict
state_dict = load_file(archive_file)
# add missing bert prefix
state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
logger.warning(
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
)
logger.warning(f"the number of which is equal to {len(missing_keys)}")
if len(unexpected_keys) > 0:
logger.warning(
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
)
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
return model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
subset_mask: Optional[torch.Tensor] = None,
output_all_encoded_layers: bool = True,
) -> torch.Tensor:
embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
encoder_outputs_all = self.bert.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask,
)
# batch_size, hidden_dim
return encoder_outputs_all
class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
"""Bert Model transformer with a sequence classification/regression head.
This head is just a linear layer on top of the pooled output.
"""
def __init__(self, config, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config, add_pooling_layer=True)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# this resets the weights
self.post_init()
@classmethod
def from_pretrained(
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
):
"""Load from pre-trained."""
# this gets a fresh init model
model = cls(config, *inputs, **kwargs)
# Download the model file
archive_file = hf_hub_download(
repo_id=pretrained_checkpoint,
filename="model.safetensors",
)
# Load the state_dict
state_dict = load_file(archive_file)
# add missing bert prefix
state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
# Calculate classifier parameters
num_classifier_params = config.hidden_size * config.num_labels + config.num_labels
classifier_keys = {"classifier.weight", "classifier.bias", "bert.pooler.dense.weight", "bert.pooler.dense.bias"}
# Check if only the classification layer is missing
if set(missing_keys) == classifier_keys:
print(
f"Checkpoint does not contain the classification layer "
f"({config.hidden_size}x{config.num_labels} + {config.num_labels} = {num_classifier_params} params). "
"It will be randomly initialized."
)
elif len(missing_keys) > 0:
logger.warning(
f"Checkpoint is missing {len(missing_keys)} parameters, including possibly critical ones: "
f"{', '.join(missing_keys)}"
)
if len(unexpected_keys) > 0:
logger.warning(
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
)
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
return model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)