Sifal commited on
Commit
64cb177
·
verified ·
1 Parent(s): 32552bc

Update automodel.py

Browse files
Files changed (1) hide show
  1. automodel.py +8 -2
automodel.py CHANGED
@@ -41,11 +41,14 @@ class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
41
  # this gets a fresh init model
42
  model = cls(config, *inputs, **kwargs)
43
 
 
 
 
44
  # Download the model file
45
  archive_file = hf_hub_download(
46
  repo_id=pretrained_checkpoint,
47
  filename="model.safetensors",
48
- **kwargs
49
  )
50
 
51
  # Load the state_dict
@@ -120,11 +123,14 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
120
  # this gets a fresh init model
121
  model = cls(config, *inputs, **kwargs)
122
 
 
 
 
123
  # Download the model file
124
  archive_file = hf_hub_download(
125
  repo_id=pretrained_checkpoint,
126
  filename="model.safetensors",
127
- **kwargs
128
  )
129
 
130
  # Load the state_dict
 
41
  # this gets a fresh init model
42
  model = cls(config, *inputs, **kwargs)
43
 
44
+ hf_kwargs = kwargs.copy()
45
+ hf_kwargs.pop('_from_auto', None) # Remove '_from_auto' if present
46
+
47
  # Download the model file
48
  archive_file = hf_hub_download(
49
  repo_id=pretrained_checkpoint,
50
  filename="model.safetensors",
51
+ **hf_kwargs # Pass filtered kwargs
52
  )
53
 
54
  # Load the state_dict
 
123
  # this gets a fresh init model
124
  model = cls(config, *inputs, **kwargs)
125
 
126
+ hf_kwargs = kwargs.copy()
127
+ hf_kwargs.pop('_from_auto', None) # Remove '_from_auto' if present
128
+
129
  # Download the model file
130
  archive_file = hf_hub_download(
131
  repo_id=pretrained_checkpoint,
132
  filename="model.safetensors",
133
+ **hf_kwargs # Pass filtered kwargs
134
  )
135
 
136
  # Load the state_dict