yinuozhang commited on
Commit
54b0181
·
verified ·
1 Parent(s): 7c0621c

Update modeling_metalatte.py

Browse files
Files changed (1) hide show
  1. modeling_metalatte.py +21 -1
modeling_metalatte.py CHANGED
@@ -14,7 +14,7 @@ import gc
14
  from torch.optim.lr_scheduler import _LRScheduler
15
  from transformers import EsmModel, PreTrainedModel
16
  from configuration import MetaLATTEConfig
17
-
18
  seed_everything(42)
19
 
20
  class GELU(nn.Module):
@@ -218,6 +218,26 @@ class MultitaskProteinModel(PreTrainedModel):
218
 
219
  # Initialize weights and apply final processing
220
  self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def forward(self, input_ids, attention_mask=None):
223
  outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
 
14
  from torch.optim.lr_scheduler import _LRScheduler
15
  from transformers import EsmModel, PreTrainedModel
16
  from configuration import MetaLATTEConfig
17
+ from urllib.parse import urljoin
18
  seed_everything(42)
19
 
20
  class GELU(nn.Module):
 
218
 
219
  # Initialize weights and apply final processing
220
  self.post_init()
221
+
222
+ @classmethod
223
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
224
+ config = kwargs.pop("config", None)
225
+ if config is None:
226
+ config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
227
+
228
+ model = cls(config)
229
+ #state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
230
+ try:
231
+ state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
232
+ state_dict = torch.hub.load_state_dict_from_url(
233
+ state_dict_url,
234
+ map_location=torch.device('cpu')
235
+ )['state_dict']
236
+ model.load_state_dict(state_dict, strict=False)
237
+ except Exception as e:
238
+ raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
239
+
240
+ return model
241
 
242
  def forward(self, input_ids, attention_mask=None):
243
  outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)