Spaces:
Runtime error
Runtime error
File size: 5,531 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn.functional as F
import numpy as np
from scipy.io import loadmat
def init_spixel_grid(args, b_train=True, ratio = 1, downsize = 16):
curr_img_height = args.crop_size
curr_img_width = args.crop_size
# pixel coord
all_h_coords = np.arange(0, curr_img_height, 1)
all_w_coords = np.arange(0, curr_img_width, 1)
curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))
coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])
all_XY_feat = (torch.from_numpy(
np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda())
return all_XY_feat
def label2one_hot_torch(labels, C=14):
""" Converts an integer label torch.autograd.Variable to a one-hot Variable.
Args:
labels(tensor) : segmentation label
C (integer) : number of classes in labels
Returns:
target (tensor) : one-hot vector of the input label
Shape:
labels: (B, 1, H, W)
target: (B, N, H, W)
"""
b,_, h, w = labels.shape
one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type
return target.type(torch.float32)
colors = loadmat('data/color150.mat')['colors']
colors = np.concatenate((colors, colors, colors, colors))
def unique(ar, return_index=False, return_inverse=False, return_counts=False):
ar = np.asanyarray(ar).flatten()
optional_indices = return_index or return_inverse
optional_returns = optional_indices or return_counts
if ar.size == 0:
if not optional_returns:
ret = ar
else:
ret = (ar,)
if return_index:
ret += (np.empty(0, np.bool),)
if return_inverse:
ret += (np.empty(0, np.bool),)
if return_counts:
ret += (np.empty(0, np.intp),)
return ret
if optional_indices:
perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
aux = ar[perm]
else:
ar.sort()
aux = ar
flag = np.concatenate(([True], aux[1:] != aux[:-1]))
if not optional_returns:
ret = aux[flag]
else:
ret = (aux[flag],)
if return_index:
ret += (perm[flag],)
if return_inverse:
iflag = np.cumsum(flag) - 1
inv_idx = np.empty(ar.shape, dtype=np.intp)
inv_idx[perm] = iflag
ret += (inv_idx,)
if return_counts:
idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
ret += (np.diff(idx),)
return ret
def colorEncode(labelmap, mode='RGB'):
labelmap = labelmap.astype('int')
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
dtype=np.uint8)
for label in unique(labelmap):
if label < 0:
continue
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
np.tile(colors[label],
(labelmap.shape[0], labelmap.shape[1], 1))
if mode == 'BGR':
return labelmap_rgb[:, :, ::-1]
else:
return labelmap_rgb
def get_edges(sp_label, sp_num):
# This function returns a (hw) * (hw) matrix N.
# If Nij = 1, then superpixel i and j are neighbors
# Otherwise Nij = 0.
top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :]
left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:]
top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:]
top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1]
n_affs = []
edge_indices = []
for i in range(sp_label.shape[0]):
# change to torch.ones below to include self-loop in graph
n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda()
# top/bottom
top_i = top[i].squeeze()
x, y = torch.nonzero(top_i, as_tuple = True)
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x+1, y].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# left/right
left_i = left[i].squeeze()
try:
x, y = torch.nonzero(left_i, as_tuple = True)
except:
import pdb; pdb.set_trace()
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x, y+1].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# top left
top_left_i = top_left[i].squeeze()
x, y = torch.nonzero(top_left_i, as_tuple = True)
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x+1, y+1].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# top right
top_right_i = top_right[i].squeeze()
x, y = torch.nonzero(top_right_i, as_tuple = True)
sp1 = sp_label[i, :, x, y+1].squeeze().long()
sp2 = sp_label[i, :, x+1, y].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
n_affs.append(n_aff)
edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True))
edge_indices.append(edge_index.cuda())
return edge_indices
def draw_color_seg(seg):
seg = seg.detach().cpu().numpy()
color_ = []
for i in range(seg.shape[0]):
colori = colorEncode(seg[i].squeeze())
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
color_.append(colori)
color_ = torch.stack(color_)
return color_
|