Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
from typing import Optional | |
import torch | |
def get_model( | |
arch: str, | |
patch_size: Optional[int] = None, | |
training_method: Optional[str] = None, | |
configs: Optional[Namespace] = None, | |
**kwargs | |
): | |
if arch == "maskformer": | |
assert configs is not None | |
from networks.maskformer.maskformer import MaskFormer | |
model = MaskFormer( | |
n_queries=configs.n_queries, | |
n_decoder_layers=configs.n_decoder_layers, | |
learnable_pixel_decoder=configs.learnable_pixel_decoder, | |
lateral_connection=configs.lateral_connection, | |
return_intermediate=configs.loss_every_decoder_layer, | |
scale_factor=configs.scale_factor, | |
abs_2d_pe_init=configs.abs_2d_pe_init, | |
use_binary_classifier=configs.use_binary_classifier, | |
arch=configs.arch, | |
training_method=configs.training_method, | |
patch_size=configs.patch_size | |
) | |
for n, p in model.encoder.named_parameters(): | |
p.requires_grad_(True) | |
elif "vit" in arch: | |
import networks.vision_transformer as vits | |
import networks.timm_deit as timm_deit | |
if training_method == "dino": | |
arch = arch.replace("vit", "deit") if arch.find("small") != -1 else arch | |
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) | |
load_model(model, arch, patch_size) | |
elif training_method == "deit": | |
assert patch_size == 16 | |
model = timm_deit.deit_small_distilled_patch16_224(True) | |
elif training_method == "supervised": | |
assert patch_size == 16 | |
state_dict: dict = torch.load( | |
"/users/gyungin/selfmask/networks/pretrained/deit_small_patch16_224-cd65a155.pth" | |
)["model"] | |
for k in list(state_dict.keys()): | |
if k in ["head.weight", "head.bias"]: # classifier head, which is not used in our network | |
state_dict.pop(k) | |
model = get_model(arch="vit_small", patch_size=16, training_method="dino") | |
model.load_state_dict(state_dict=state_dict, strict=True) | |
else: | |
raise NotImplementedError | |
print(f"{arch}_p{patch_size}_{training_method} is built.") | |
elif arch == "resnet50": | |
from networks.resnet import ResNet50 | |
assert training_method in ["mocov2", "swav", "supervised"] | |
model = ResNet50(training_method) | |
else: | |
raise ValueError(f"{arch} is not supported arch. Choose from [maskformer, resnet50, vit, dino]") | |
return model | |
def load_model(model, arch: str, patch_size: int) -> None: | |
url = None | |
if arch == "deit_small" and patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif arch == "deit_small" and patch_size == 8: | |
# model used for visualizations in our paper | |
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" | |
elif arch == "vit_base" and patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif arch == "vit_base" and patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
if url is not None: | |
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") | |
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
model.load_state_dict(state_dict, strict=True) | |
else: | |
print("There is no reference weights available for this model => We use random weights.") |