Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
class PathIndex: | |
def __init__(self, radius, default_size): | |
self.radius = radius | |
self.radius_floor = int(np.ceil(radius) - 1) | |
self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius) | |
self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size) | |
def get_search_paths_dst(self, max_radius=5): | |
coord_indices_by_length = [[] for _ in range(max_radius * 4)] | |
search_dirs = [] | |
for x in range(1, max_radius): | |
search_dirs.append((0, x)) | |
for y in range(1, max_radius): | |
for x in range(-max_radius + 1, max_radius): | |
if x * x + y * y < max_radius ** 2: | |
search_dirs.append((y, x)) | |
for dir in search_dirs: | |
length_sq = dir[0] ** 2 + dir[1] ** 2 | |
path_coords = [] | |
min_y, max_y = sorted((0, dir[0])) | |
min_x, max_x = sorted((0, dir[1])) | |
for y in range(min_y, max_y + 1): | |
for x in range(min_x, max_x + 1): | |
dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq | |
if dist_sq < 1: | |
path_coords.append([y, x]) | |
path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1])) | |
path_length = len(path_coords) | |
coord_indices_by_length[path_length].append(path_coords) | |
path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v] | |
path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0) | |
return path_list_by_length, path_destinations | |
def get_path_indices(self, size): | |
full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1])) | |
cropped_height = size[0] - self.radius_floor | |
cropped_width = size[1] - 2 * self.radius_floor | |
path_indices = [] | |
for paths in self.search_paths: | |
path_indices_list = [] | |
for p in paths: | |
coord_indices_list = [] | |
for dy, dx in p: | |
coord_indices = full_indices[dy:dy + cropped_height, | |
self.radius_floor + dx:self.radius_floor + dx + cropped_width] | |
coord_indices = np.reshape(coord_indices, [-1]) | |
coord_indices_list.append(coord_indices) | |
path_indices_list.append(coord_indices_list) | |
path_indices.append(np.array(path_indices_list)) | |
src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1) | |
dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0) | |
return path_indices, src_indices, dst_indices | |
def edge_to_affinity(edge, paths_indices): | |
aff_list = [] | |
edge = edge.view(edge.size(0), -1) | |
for i in range(len(paths_indices)): | |
if isinstance(paths_indices[i], np.ndarray): | |
paths_indices[i] = torch.from_numpy(paths_indices[i]) | |
paths_indices[i] = paths_indices[i].cuda(non_blocking=True) | |
for ind in paths_indices: | |
ind_flat = ind.view(-1) | |
dist = torch.index_select(edge, dim=-1, index=ind_flat) | |
dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) | |
aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) | |
aff_list.append(aff) | |
aff_cat = torch.cat(aff_list, dim=1) | |
return aff_cat | |
def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices): | |
ind_from = torch.from_numpy(ind_from) | |
ind_to = torch.from_numpy(ind_to) | |
affinity_sparse = affinity_sparse.view(-1).cpu() | |
ind_from = ind_from.repeat(ind_to.size(0)).view(-1) | |
ind_to = ind_to.view(-1) | |
indices = torch.stack([ind_from, ind_to]) | |
indices_tp = torch.stack([ind_to, ind_from]) | |
indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()]) | |
affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), | |
torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda() | |
return affinity_dense | |
def to_transition_matrix(affinity_dense, beta, times): | |
scaled_affinity = torch.pow(affinity_dense, beta) | |
trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True) | |
for _ in range(times): | |
trans_mat = torch.matmul(trans_mat, trans_mat) | |
return trans_mat | |
def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8): | |
height, width = x.shape[-2:] | |
hor_padded = width+radius*2 | |
ver_padded = height+radius | |
path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded)) | |
edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0) | |
sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0), | |
path_index.path_indices) | |
dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices, | |
path_index.dst_indices, ver_padded * hor_padded) | |
dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded) | |
dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius] | |
dense_aff = dense_aff.reshape(height * width, height * width) | |
trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times) | |
x = x.view(-1, height, width) * (1 - edge) | |
rw = torch.matmul(x.view(-1, height * width), trans_mat) | |
rw = rw.view(rw.size(0), 1, height, width) | |
return rw | |
class GetAffinityLabelFromIndices(): | |
def __init__(self, indices_from, indices_to): | |
self.indices_from = indices_from | |
self.indices_to = indices_to | |
def __call__(self, segm_map): | |
segm_map_flat = np.reshape(segm_map, -1) | |
segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0) | |
segm_label_to = segm_map_flat[self.indices_to] | |
valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21)) | |
equal_label = np.equal(segm_label_from, segm_label_to) | |
pos_affinity_label = np.logical_and(equal_label, valid_label) | |
bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32) | |
fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32) | |
neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32) | |
return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), torch.from_numpy(neg_affinity_label) | |