Sifal commited on
Commit
4932186
·
verified ·
1 Parent(s): 46cd48a

clarify classifier warning

Browse files
Files changed (1) hide show
  1. automodel.py +14 -4
automodel.py CHANGED
@@ -137,13 +137,23 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
137
 
138
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
139
 
140
- if len(missing_keys) > 0:
 
 
 
 
 
 
 
 
 
 
 
141
  logger.warning(
142
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
 
143
  )
144
 
145
- logger.warning(f"the number of which is equal to {len(missing_keys)}")
146
-
147
  if len(unexpected_keys) > 0:
148
  logger.warning(
149
  f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
 
137
 
138
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
139
 
140
+ # Calculate classifier parameters
141
+ num_classifier_params = config.hidden_size * config.num_labels + config.num_labels
142
+ classifier_keys = {"classifier.weight", "classifier.bias"}
143
+
144
+ # Check if only the classification layer is missing
145
+ if set(missing_keys) == classifier_keys:
146
+ print(
147
+ f"Checkpoint does not contain the classification layer "
148
+ f"({config.hidden_size}x{config.num_labels} + {config.num_labels} = {num_classifier_params} params). "
149
+ "It will be randomly initialized."
150
+ )
151
+ elif len(missing_keys) > 0:
152
  logger.warning(
153
+ f"Checkpoint is missing {len(missing_keys)} parameters, including possibly critical ones: "
154
+ f"{', '.join(missing_keys)}"
155
  )
156
 
 
 
157
  if len(unexpected_keys) > 0:
158
  logger.warning(
159
  f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",