Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
from .resnet_backbone import ResNetBackbone | |
class ResNet50(nn.Module): | |
def __init__( | |
self, | |
weight_type: str = "supervised", | |
use_dilated_resnet: bool = True | |
): | |
super(ResNet50, self).__init__() | |
self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None) | |
self.n_embs = self.network.num_features | |
self.use_dilated_resnet = use_dilated_resnet | |
self._load_pretrained(weight_type) | |
def _load_pretrained(self, training_method: str) -> None: | |
curr_state_dict = self.network.state_dict() | |
if training_method == "mocov2": | |
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"] | |
for k in list(state_dict.keys()): | |
if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]): | |
state_dict.pop(k) | |
elif training_method == "swav": | |
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar") | |
for k in list(state_dict.keys()): | |
if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]): | |
state_dict.pop(k) | |
elif training_method == "supervised": | |
# Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why. | |
# for k in list(curr_state_dict.keys()): | |
# if k.find("num_batches_tracked") != -1: | |
# curr_state_dict.pop(k) | |
# state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth") | |
from torchvision.models.resnet import resnet50 | |
resnet50_supervised = resnet50(True, True) | |
state_dict = resnet50_supervised.state_dict() | |
for k in list(state_dict.keys()): | |
if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]): | |
state_dict.pop(k) | |
assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}" | |
for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()): | |
curr_state_dict[k_curr].copy_(state_dict[k]) | |
print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.") | |
return | |
def forward(self, x): | |
return self.network(x) | |
if __name__ == '__main__': | |
resnet = ResNet50("mocov2") | |