Spaces:
Build error
Build error
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" | |
import os | |
import sys | |
import numpy as np | |
import torch | |
try: | |
from urllib import urlretrieve | |
except ImportError: | |
from urllib.request import urlretrieve | |
def load_url(url, model_dir='./pretrained', map_location=None): | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
filename = url.split('/')[-1] | |
cached_file = os.path.join(model_dir, filename) | |
if not os.path.exists(cached_file): | |
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
urlretrieve(url, cached_file) | |
return torch.load(cached_file, map_location=map_location) | |
def color_encode(labelmap, colors, mode='RGB'): | |
labelmap = labelmap.astype('int') | |
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), | |
dtype=np.uint8) | |
for label in np.unique(labelmap): | |
if label < 0: | |
continue | |
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ | |
np.tile(colors[label], | |
(labelmap.shape[0], labelmap.shape[1], 1)) | |
if mode == 'BGR': | |
return labelmap_rgb[:, :, ::-1] | |
else: | |
return labelmap_rgb | |