Update modeling.py
Browse files- modeling.py +15 -0
modeling.py
CHANGED
@@ -120,6 +120,21 @@ class MARModel(PreTrainedModel):
|
|
120 |
# state_dict = torch.load(safetensors_path, map_location='cpu')
|
121 |
# model.model.load_state_dict(state_dict)
|
122 |
# return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
def save_pretrained(self, save_directory):
|
|
|
120 |
# state_dict = torch.load(safetensors_path, map_location='cpu')
|
121 |
# model.model.load_state_dict(state_dict)
|
122 |
# return model
|
123 |
+
@classmethod
|
124 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
125 |
+
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
126 |
+
|
127 |
+
model = cls(config)
|
128 |
+
|
129 |
+
safetensors_path = hf_hub_download(
|
130 |
+
repo_id=pretrained_model_name_or_path,
|
131 |
+
filename="model.safetensors"
|
132 |
+
)
|
133 |
+
state_dict = load_file(safetensors_path)
|
134 |
+
|
135 |
+
model.model.load_state_dict(state_dict)
|
136 |
+
|
137 |
+
return model
|
138 |
|
139 |
|
140 |
def save_pretrained(self, save_directory):
|