selfmask / networks /resnet.py
noelshin's picture
Add application file
35188e4
import os
import torch
import torch.nn as nn
from .resnet_backbone import ResNetBackbone
class ResNet50(nn.Module):
def __init__(
self,
weight_type: str = "supervised",
use_dilated_resnet: bool = True
):
super(ResNet50, self).__init__()
self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
self.n_embs = self.network.num_features
self.use_dilated_resnet = use_dilated_resnet
self._load_pretrained(weight_type)
def _load_pretrained(self, training_method: str) -> None:
curr_state_dict = self.network.state_dict()
if training_method == "mocov2":
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
for k in list(state_dict.keys()):
if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
state_dict.pop(k)
elif training_method == "swav":
state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
for k in list(state_dict.keys()):
if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
state_dict.pop(k)
elif training_method == "supervised":
# Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
# for k in list(curr_state_dict.keys()):
# if k.find("num_batches_tracked") != -1:
# curr_state_dict.pop(k)
# state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
from torchvision.models.resnet import resnet50
resnet50_supervised = resnet50(True, True)
state_dict = resnet50_supervised.state_dict()
for k in list(state_dict.keys()):
if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
state_dict.pop(k)
assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
curr_state_dict[k_curr].copy_(state_dict[k])
print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
return
def forward(self, x):
return self.network(x)
if __name__ == '__main__':
resnet = ResNet50("mocov2")