jadechoghari commited on
Commit
ac08085
·
verified ·
1 Parent(s): c0c9440

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +15 -0
modeling.py CHANGED
@@ -120,6 +120,21 @@ class MARModel(PreTrainedModel):
120
  # state_dict = torch.load(safetensors_path, map_location='cpu')
121
  # model.model.load_state_dict(state_dict)
122
  # return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  def save_pretrained(self, save_directory):
 
120
  # state_dict = torch.load(safetensors_path, map_location='cpu')
121
  # model.model.load_state_dict(state_dict)
122
  # return model
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
125
+ config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
126
+
127
+ model = cls(config)
128
+
129
+ safetensors_path = hf_hub_download(
130
+ repo_id=pretrained_model_name_or_path,
131
+ filename="model.safetensors"
132
+ )
133
+ state_dict = load_file(safetensors_path)
134
+
135
+ model.model.load_state_dict(state_dict)
136
+
137
+ return model
138
 
139
 
140
  def save_pretrained(self, save_directory):