|
from typing import Dict, List, Tuple, Union |
|
import numpy as np |
|
import torch |
|
from networks import deeplabv3plus_resnet50 |
|
from networks import convert_to_separable_conv, set_bn_momentum |
|
|
|
|
|
def get_network() -> torch.nn.Module: |
|
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False) |
|
|
|
convert_to_separable_conv(network.classifier) |
|
set_bn_momentum(network.backbone, momentum=0.01) |
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt", |
|
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
network.load_state_dict(state_dict, strict=True) |
|
return network |
|
|
|
|
|
def colourise_mask( |
|
mask: np.ndarray, |
|
): |
|
assert len(mask.shape) == 2, ValueError(mask.shape) |
|
h, w = mask.shape |
|
grid = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
|
unique_labels = set(mask.flatten()) |
|
|
|
voc2012_palette = { |
|
0: [0, 0, 0], |
|
1: [128, 0, 0], |
|
2: [0, 128, 0], |
|
3: [128, 128, 0], |
|
4: [0, 0, 128], |
|
5: [128, 0, 128], |
|
6: [0, 128, 128], |
|
7: [128, 128, 128], |
|
8: [64, 0, 0], |
|
9: [192, 0, 0], |
|
10: [64, 128, 0], |
|
11: [192, 128, 0], |
|
12: [64, 0, 128], |
|
13: [192, 0, 128], |
|
14: [64, 128, 128], |
|
15: [192, 128, 128], |
|
16: [0, 64, 0], |
|
17: [128, 64, 0], |
|
18: [0, 192, 0], |
|
19: [128, 192, 0], |
|
20: [0, 64, 128], |
|
255: [255, 255, 255] |
|
} |
|
|
|
for l in unique_labels: |
|
grid[mask == l] = np.array(voc2012_palette[l]) |
|
try: |
|
grid[mask == l] = np.array(voc2012_palette[l]) |
|
except IndexError: |
|
raise IndexError(f"No colour is found for a label id: {l}") |
|
return grid |