Spaces:
Sleeping
Sleeping
File size: 6,963 Bytes
c20a1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
import torch.nn.functional as F
import numpy as np
class PathIndex:
def __init__(self, radius, default_size):
self.radius = radius
self.radius_floor = int(np.ceil(radius) - 1)
self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius)
self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size)
def get_search_paths_dst(self, max_radius=5):
coord_indices_by_length = [[] for _ in range(max_radius * 4)]
search_dirs = []
for x in range(1, max_radius):
search_dirs.append((0, x))
for y in range(1, max_radius):
for x in range(-max_radius + 1, max_radius):
if x * x + y * y < max_radius ** 2:
search_dirs.append((y, x))
for dir in search_dirs:
length_sq = dir[0] ** 2 + dir[1] ** 2
path_coords = []
min_y, max_y = sorted((0, dir[0]))
min_x, max_x = sorted((0, dir[1]))
for y in range(min_y, max_y + 1):
for x in range(min_x, max_x + 1):
dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq
if dist_sq < 1:
path_coords.append([y, x])
path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1]))
path_length = len(path_coords)
coord_indices_by_length[path_length].append(path_coords)
path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v]
path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0)
return path_list_by_length, path_destinations
def get_path_indices(self, size):
full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1]))
cropped_height = size[0] - self.radius_floor
cropped_width = size[1] - 2 * self.radius_floor
path_indices = []
for paths in self.search_paths:
path_indices_list = []
for p in paths:
coord_indices_list = []
for dy, dx in p:
coord_indices = full_indices[dy:dy + cropped_height,
self.radius_floor + dx:self.radius_floor + dx + cropped_width]
coord_indices = np.reshape(coord_indices, [-1])
coord_indices_list.append(coord_indices)
path_indices_list.append(coord_indices_list)
path_indices.append(np.array(path_indices_list))
src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1)
dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0)
return path_indices, src_indices, dst_indices
def edge_to_affinity(edge, paths_indices):
aff_list = []
edge = edge.view(edge.size(0), -1)
for i in range(len(paths_indices)):
if isinstance(paths_indices[i], np.ndarray):
paths_indices[i] = torch.from_numpy(paths_indices[i])
paths_indices[i] = paths_indices[i].cuda(non_blocking=True)
for ind in paths_indices:
ind_flat = ind.view(-1)
dist = torch.index_select(edge, dim=-1, index=ind_flat)
dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2))
aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2)
aff_list.append(aff)
aff_cat = torch.cat(aff_list, dim=1)
return aff_cat
def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices):
ind_from = torch.from_numpy(ind_from)
ind_to = torch.from_numpy(ind_to)
affinity_sparse = affinity_sparse.view(-1).cpu()
ind_from = ind_from.repeat(ind_to.size(0)).view(-1)
ind_to = ind_to.view(-1)
indices = torch.stack([ind_from, ind_to])
indices_tp = torch.stack([ind_to, ind_from])
indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()])
affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1),
torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda()
return affinity_dense
def to_transition_matrix(affinity_dense, beta, times):
scaled_affinity = torch.pow(affinity_dense, beta)
trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True)
for _ in range(times):
trans_mat = torch.matmul(trans_mat, trans_mat)
return trans_mat
def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8):
height, width = x.shape[-2:]
hor_padded = width+radius*2
ver_padded = height+radius
path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded))
edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0)
sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0),
path_index.path_indices)
dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices,
path_index.dst_indices, ver_padded * hor_padded)
dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded)
dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius]
dense_aff = dense_aff.reshape(height * width, height * width)
trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times)
x = x.view(-1, height, width) * (1 - edge)
rw = torch.matmul(x.view(-1, height * width), trans_mat)
rw = rw.view(rw.size(0), 1, height, width)
return rw
class GetAffinityLabelFromIndices():
def __init__(self, indices_from, indices_to):
self.indices_from = indices_from
self.indices_to = indices_to
def __call__(self, segm_map):
segm_map_flat = np.reshape(segm_map, -1)
segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0)
segm_label_to = segm_map_flat[self.indices_to]
valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21))
equal_label = np.equal(segm_label_from, segm_label_to)
pos_affinity_label = np.logical_and(equal_label, valid_label)
bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32)
fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32)
neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32)
return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), torch.from_numpy(neg_affinity_label)
|