File size: 5,531 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn.functional as F

import numpy as np
from scipy.io import loadmat

def init_spixel_grid(args,  b_train=True, ratio = 1, downsize = 16):
    curr_img_height = args.crop_size
    curr_img_width = args.crop_size

    # pixel coord
    all_h_coords = np.arange(0, curr_img_height, 1)
    all_w_coords = np.arange(0, curr_img_width, 1)
    curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))

    coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])

    all_XY_feat = (torch.from_numpy(
        np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda())

    return  all_XY_feat

def label2one_hot_torch(labels, C=14):
    """ Converts an integer label torch.autograd.Variable to a one-hot Variable.

    Args:
      labels(tensor) : segmentation label
      C (integer) : number of classes in labels

    Returns:
      target (tensor) : one-hot vector of the input label

    Shape:
      labels: (B, 1, H, W)
      target: (B, N, H, W)
    """
    b,_, h, w = labels.shape
    one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
    target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type

    return target.type(torch.float32)

colors = loadmat('data/color150.mat')['colors']
colors = np.concatenate((colors, colors, colors, colors))

def unique(ar, return_index=False, return_inverse=False, return_counts=False):
    ar = np.asanyarray(ar).flatten()

    optional_indices = return_index or return_inverse
    optional_returns = optional_indices or return_counts

    if ar.size == 0:
        if not optional_returns:
            ret = ar
        else:
            ret = (ar,)
            if return_index:
                ret += (np.empty(0, np.bool),)
            if return_inverse:
                ret += (np.empty(0, np.bool),)
            if return_counts:
                ret += (np.empty(0, np.intp),)
        return ret
    if optional_indices:
        perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
        aux = ar[perm]
    else:
        ar.sort()
        aux = ar
    flag = np.concatenate(([True], aux[1:] != aux[:-1]))

    if not optional_returns:
        ret = aux[flag]
    else:
        ret = (aux[flag],)
        if return_index:
            ret += (perm[flag],)
        if return_inverse:
            iflag = np.cumsum(flag) - 1
            inv_idx = np.empty(ar.shape, dtype=np.intp)
            inv_idx[perm] = iflag
            ret += (inv_idx,)
        if return_counts:
            idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
            ret += (np.diff(idx),)
    return ret

def colorEncode(labelmap, mode='RGB'):
    labelmap = labelmap.astype('int')
    labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
                            dtype=np.uint8)
    for label in unique(labelmap):
        if label < 0:
            continue
        labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
            np.tile(colors[label],
                    (labelmap.shape[0], labelmap.shape[1], 1))

    if mode == 'BGR':
        return labelmap_rgb[:, :, ::-1]
    else:
        return labelmap_rgb

def get_edges(sp_label, sp_num):
    # This function returns a (hw) * (hw) matrix N.
    # If Nij = 1, then superpixel i and j are neighbors
    # Otherwise Nij = 0.
    top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :]
    left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:]
    top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:]
    top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1]
    n_affs = []
    edge_indices = []
    for i in range(sp_label.shape[0]):
        # change to torch.ones below to include self-loop in graph
        n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda()
        # top/bottom
        top_i = top[i].squeeze()
        x, y = torch.nonzero(top_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # left/right
        left_i = left[i].squeeze()
        try:
            x, y = torch.nonzero(left_i, as_tuple = True)
        except:
            import pdb; pdb.set_trace()
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top left
        top_left_i = top_left[i].squeeze()
        x, y = torch.nonzero(top_left_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top right
        top_right_i = top_right[i].squeeze()
        x, y = torch.nonzero(top_right_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y+1].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        n_affs.append(n_aff)
        edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True))
        edge_indices.append(edge_index.cuda())
    return edge_indices


def draw_color_seg(seg):
    seg = seg.detach().cpu().numpy()
    color_ = []
    for i in range(seg.shape[0]):
        colori = colorEncode(seg[i].squeeze())
        colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1)
        color_.append(colori)
    color_ = torch.stack(color_)
    return color_