Sifal commited on
Commit
6d6b31c
·
verified ·
1 Parent(s): 7970e16

add bert prefix

Browse files
Files changed (1) hide show
  1. automodel.py +6 -4
automodel.py CHANGED
@@ -52,8 +52,9 @@ class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
52
  # Load the state_dict
53
  state_dict = load_file(archive_file)
54
 
55
- # remove `model` prefix to avoid error
56
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
 
57
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
58
 
59
  if len(missing_keys) > 0:
@@ -131,8 +132,9 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
131
  # Load the state_dict
132
  state_dict = load_file(archive_file)
133
 
134
- # remove `model` prefix to avoid error
135
- consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
 
136
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
137
 
138
  if len(missing_keys) > 0:
 
52
  # Load the state_dict
53
  state_dict = load_file(archive_file)
54
 
55
+ # add missing bert prefix
56
+ state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
57
+
58
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
59
 
60
  if len(missing_keys) > 0:
 
132
  # Load the state_dict
133
  state_dict = load_file(archive_file)
134
 
135
+ # add missing bert prefix
136
+ state_dict = {f'bert.{key}': value for key, value in state_dict.items()}
137
+
138
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
139
 
140
  if len(missing_keys) > 0: