File size: 3,667 Bytes
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from argparse import Namespace
from typing import Optional
import torch


def get_model(
        arch: str,
        patch_size: Optional[int] = None,
        training_method: Optional[str] = None,
        configs: Optional[Namespace] = None,
        **kwargs
):
    if arch == "maskformer":
        assert configs is not None
        from networks.maskformer.maskformer import MaskFormer
        model = MaskFormer(
            n_queries=configs.n_queries,
            n_decoder_layers=configs.n_decoder_layers,
            learnable_pixel_decoder=configs.learnable_pixel_decoder,
            lateral_connection=configs.lateral_connection,
            return_intermediate=configs.loss_every_decoder_layer,
            scale_factor=configs.scale_factor,
            abs_2d_pe_init=configs.abs_2d_pe_init,
            use_binary_classifier=configs.use_binary_classifier,
            arch=configs.arch,
            training_method=configs.training_method,
            patch_size=configs.patch_size
        )

        for n, p in model.encoder.named_parameters():
            p.requires_grad_(True)

    elif "vit" in arch:
        import networks.vision_transformer as vits
        import networks.timm_deit as timm_deit
        if training_method == "dino":
            arch = arch.replace("vit", "deit") if arch.find("small") != -1 else arch
            model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
            load_model(model, arch, patch_size)

        elif training_method == "deit":
            assert patch_size == 16
            model = timm_deit.deit_small_distilled_patch16_224(True)

        elif training_method == "supervised":
            assert patch_size == 16
            state_dict: dict = torch.load(
                "/users/gyungin/selfmask/networks/pretrained/deit_small_patch16_224-cd65a155.pth"
            )["model"]
            for k in list(state_dict.keys()):
                if k in ["head.weight", "head.bias"]:  # classifier head, which is not used in our network
                    state_dict.pop(k)

            model = get_model(arch="vit_small", patch_size=16, training_method="dino")
            model.load_state_dict(state_dict=state_dict, strict=True)

        else:
            raise NotImplementedError
        print(f"{arch}_p{patch_size}_{training_method} is built.")

    elif arch == "resnet50":
        from networks.resnet import ResNet50
        assert training_method in ["mocov2", "swav", "supervised"]
        model = ResNet50(training_method)

    else:
        raise ValueError(f"{arch} is not supported arch. Choose from [maskformer, resnet50, vit, dino]")
    return model


def load_model(model, arch: str, patch_size: int) -> None:
    url = None
    if arch == "deit_small" and patch_size == 16:
        url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
    elif arch == "deit_small" and patch_size == 8:
        # model used for visualizations in our paper
        url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
    elif arch == "vit_base" and patch_size == 16:
        url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
    elif arch == "vit_base" and patch_size == 8:
        url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
    if url is not None:
        print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
        state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
        model.load_state_dict(state_dict, strict=True)
    else:
        print("There is no reference weights available for this model => We use random weights.")