Spaces:
Sleeping
Sleeping
yinuozhang
commited on
Update modeling_metalatte.py
Browse files- modeling_metalatte.py +21 -1
modeling_metalatte.py
CHANGED
@@ -14,7 +14,7 @@ import gc
|
|
14 |
from torch.optim.lr_scheduler import _LRScheduler
|
15 |
from transformers import EsmModel, PreTrainedModel
|
16 |
from configuration import MetaLATTEConfig
|
17 |
-
|
18 |
seed_everything(42)
|
19 |
|
20 |
class GELU(nn.Module):
|
@@ -218,6 +218,26 @@ class MultitaskProteinModel(PreTrainedModel):
|
|
218 |
|
219 |
# Initialize weights and apply final processing
|
220 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
def forward(self, input_ids, attention_mask=None):
|
223 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
14 |
from torch.optim.lr_scheduler import _LRScheduler
|
15 |
from transformers import EsmModel, PreTrainedModel
|
16 |
from configuration import MetaLATTEConfig
|
17 |
+
from urllib.parse import urljoin
|
18 |
seed_everything(42)
|
19 |
|
20 |
class GELU(nn.Module):
|
|
|
218 |
|
219 |
# Initialize weights and apply final processing
|
220 |
self.post_init()
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
224 |
+
config = kwargs.pop("config", None)
|
225 |
+
if config is None:
|
226 |
+
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
227 |
+
|
228 |
+
model = cls(config)
|
229 |
+
#state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
230 |
+
try:
|
231 |
+
state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
|
232 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
233 |
+
state_dict_url,
|
234 |
+
map_location=torch.device('cpu')
|
235 |
+
)['state_dict']
|
236 |
+
model.load_state_dict(state_dict, strict=False)
|
237 |
+
except Exception as e:
|
238 |
+
raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
|
239 |
+
|
240 |
+
return model
|
241 |
|
242 |
def forward(self, input_ids, attention_mask=None):
|
243 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|