File size: 2,551 Bytes
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os

import torch
import torch.nn as nn
from .resnet_backbone import ResNetBackbone


class ResNet50(nn.Module):
    def __init__(
            self,
            weight_type: str = "supervised",
            use_dilated_resnet: bool = True
    ):
        super(ResNet50, self).__init__()
        self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
        self.n_embs = self.network.num_features
        self.use_dilated_resnet = use_dilated_resnet
        self._load_pretrained(weight_type)

    def _load_pretrained(self, training_method: str) -> None:
        curr_state_dict = self.network.state_dict()
        if training_method == "mocov2":
            state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]

            for k in list(state_dict.keys()):
                if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
                    state_dict.pop(k)

        elif training_method == "swav":
            state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
            for k in list(state_dict.keys()):
                if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
                    state_dict.pop(k)

        elif training_method == "supervised":
            # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
            # for k in list(curr_state_dict.keys()):
            #     if k.find("num_batches_tracked") != -1:
            #         curr_state_dict.pop(k)
            # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")

            from torchvision.models.resnet import resnet50
            resnet50_supervised = resnet50(True, True)
            state_dict = resnet50_supervised.state_dict()
            for k in list(state_dict.keys()):
                if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
                    state_dict.pop(k)

        assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
        for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
            curr_state_dict[k_curr].copy_(state_dict[k])
        print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
        return

    def forward(self, x):
        return self.network(x)


if __name__ == '__main__':
    resnet = ResNet50("mocov2")