namedmask / utils.py
noelshin's picture
fix state dict loading error
6d6f3c6
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
from networks import deeplabv3plus_resnet50
from networks import convert_to_separable_conv, set_bn_momentum
def get_network() -> torch.nn.Module:
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
convert_to_separable_conv(network.classifier)
set_bn_momentum(network.backbone, momentum=0.01)
state_dict = torch.hub.load_state_dict_from_url(
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
network.load_state_dict(state_dict, strict=True)
return network
def colourise_mask(
mask: np.ndarray,
):
assert len(mask.shape) == 2, ValueError(mask.shape)
h, w = mask.shape
grid = np.zeros((h, w, 3), dtype=np.uint8)
unique_labels = set(mask.flatten())
voc2012_palette = {
0: [0, 0, 0],
1: [128, 0, 0],
2: [0, 128, 0],
3: [128, 128, 0],
4: [0, 0, 128],
5: [128, 0, 128],
6: [0, 128, 128],
7: [128, 128, 128],
8: [64, 0, 0],
9: [192, 0, 0],
10: [64, 128, 0],
11: [192, 128, 0],
12: [64, 0, 128],
13: [192, 0, 128],
14: [64, 128, 128],
15: [192, 128, 128],
16: [0, 64, 0],
17: [128, 64, 0],
18: [0, 192, 0],
19: [128, 192, 0],
20: [0, 64, 128],
255: [255, 255, 255]
}
for l in unique_labels:
grid[mask == l] = np.array(voc2012_palette[l])
try:
grid[mask == l] = np.array(voc2012_palette[l])
except IndexError:
raise IndexError(f"No colour is found for a label id: {l}")
return grid