clarify classifier warning
Browse files- automodel.py +14 -4
automodel.py
CHANGED
@@ -137,13 +137,23 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
|
|
137 |
|
138 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
logger.warning(
|
142 |
-
f"
|
|
|
143 |
)
|
144 |
|
145 |
-
logger.warning(f"the number of which is equal to {len(missing_keys)}")
|
146 |
-
|
147 |
if len(unexpected_keys) > 0:
|
148 |
logger.warning(
|
149 |
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|
|
|
137 |
|
138 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
139 |
|
140 |
+
# Calculate classifier parameters
|
141 |
+
num_classifier_params = config.hidden_size * config.num_labels + config.num_labels
|
142 |
+
classifier_keys = {"classifier.weight", "classifier.bias"}
|
143 |
+
|
144 |
+
# Check if only the classification layer is missing
|
145 |
+
if set(missing_keys) == classifier_keys:
|
146 |
+
print(
|
147 |
+
f"Checkpoint does not contain the classification layer "
|
148 |
+
f"({config.hidden_size}x{config.num_labels} + {config.num_labels} = {num_classifier_params} params). "
|
149 |
+
"It will be randomly initialized."
|
150 |
+
)
|
151 |
+
elif len(missing_keys) > 0:
|
152 |
logger.warning(
|
153 |
+
f"Checkpoint is missing {len(missing_keys)} parameters, including possibly critical ones: "
|
154 |
+
f"{', '.join(missing_keys)}"
|
155 |
)
|
156 |
|
|
|
|
|
157 |
if len(unexpected_keys) > 0:
|
158 |
logger.warning(
|
159 |
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|