add bert prefix
Browse files- automodel.py +6 -4
automodel.py
CHANGED
@@ -52,8 +52,9 @@ class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
|
|
52 |
# Load the state_dict
|
53 |
state_dict = load_file(archive_file)
|
54 |
|
55 |
-
#
|
56 |
-
|
|
|
57 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
58 |
|
59 |
if len(missing_keys) > 0:
|
@@ -131,8 +132,9 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
|
|
131 |
# Load the state_dict
|
132 |
state_dict = load_file(archive_file)
|
133 |
|
134 |
-
#
|
135 |
-
|
|
|
136 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
137 |
|
138 |
if len(missing_keys) > 0:
|
|
|
52 |
# Load the state_dict
|
53 |
state_dict = load_file(archive_file)
|
54 |
|
55 |
+
# add missing bert prefix
|
56 |
+
state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
|
57 |
+
|
58 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
59 |
|
60 |
if len(missing_keys) > 0:
|
|
|
132 |
# Load the state_dict
|
133 |
state_dict = load_file(archive_file)
|
134 |
|
135 |
+
# add missing bert prefix
|
136 |
+
state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
|
137 |
+
|
138 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
139 |
|
140 |
if len(missing_keys) > 0:
|