WSSS_ResNet50 / core /aff_utils.py
kittendev's picture
Upload 176 files
c20a1af verified
import torch
import torch.nn.functional as F
import numpy as np
class PathIndex:
def __init__(self, radius, default_size):
self.radius = radius
self.radius_floor = int(np.ceil(radius) - 1)
self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius)
self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size)
def get_search_paths_dst(self, max_radius=5):
coord_indices_by_length = [[] for _ in range(max_radius * 4)]
search_dirs = []
for x in range(1, max_radius):
search_dirs.append((0, x))
for y in range(1, max_radius):
for x in range(-max_radius + 1, max_radius):
if x * x + y * y < max_radius ** 2:
search_dirs.append((y, x))
for dir in search_dirs:
length_sq = dir[0] ** 2 + dir[1] ** 2
path_coords = []
min_y, max_y = sorted((0, dir[0]))
min_x, max_x = sorted((0, dir[1]))
for y in range(min_y, max_y + 1):
for x in range(min_x, max_x + 1):
dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq
if dist_sq < 1:
path_coords.append([y, x])
path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1]))
path_length = len(path_coords)
coord_indices_by_length[path_length].append(path_coords)
path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v]
path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0)
return path_list_by_length, path_destinations
def get_path_indices(self, size):
full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1]))
cropped_height = size[0] - self.radius_floor
cropped_width = size[1] - 2 * self.radius_floor
path_indices = []
for paths in self.search_paths:
path_indices_list = []
for p in paths:
coord_indices_list = []
for dy, dx in p:
coord_indices = full_indices[dy:dy + cropped_height,
self.radius_floor + dx:self.radius_floor + dx + cropped_width]
coord_indices = np.reshape(coord_indices, [-1])
coord_indices_list.append(coord_indices)
path_indices_list.append(coord_indices_list)
path_indices.append(np.array(path_indices_list))
src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1)
dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0)
return path_indices, src_indices, dst_indices
def edge_to_affinity(edge, paths_indices):
aff_list = []
edge = edge.view(edge.size(0), -1)
for i in range(len(paths_indices)):
if isinstance(paths_indices[i], np.ndarray):
paths_indices[i] = torch.from_numpy(paths_indices[i])
paths_indices[i] = paths_indices[i].cuda(non_blocking=True)
for ind in paths_indices:
ind_flat = ind.view(-1)
dist = torch.index_select(edge, dim=-1, index=ind_flat)
dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2))
aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2)
aff_list.append(aff)
aff_cat = torch.cat(aff_list, dim=1)
return aff_cat
def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices):
ind_from = torch.from_numpy(ind_from)
ind_to = torch.from_numpy(ind_to)
affinity_sparse = affinity_sparse.view(-1).cpu()
ind_from = ind_from.repeat(ind_to.size(0)).view(-1)
ind_to = ind_to.view(-1)
indices = torch.stack([ind_from, ind_to])
indices_tp = torch.stack([ind_to, ind_from])
indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()])
affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1),
torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda()
return affinity_dense
def to_transition_matrix(affinity_dense, beta, times):
scaled_affinity = torch.pow(affinity_dense, beta)
trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True)
for _ in range(times):
trans_mat = torch.matmul(trans_mat, trans_mat)
return trans_mat
def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8):
height, width = x.shape[-2:]
hor_padded = width+radius*2
ver_padded = height+radius
path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded))
edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0)
sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0),
path_index.path_indices)
dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices,
path_index.dst_indices, ver_padded * hor_padded)
dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded)
dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius]
dense_aff = dense_aff.reshape(height * width, height * width)
trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times)
x = x.view(-1, height, width) * (1 - edge)
rw = torch.matmul(x.view(-1, height * width), trans_mat)
rw = rw.view(rw.size(0), 1, height, width)
return rw
class GetAffinityLabelFromIndices():
def __init__(self, indices_from, indices_to):
self.indices_from = indices_from
self.indices_to = indices_to
def __call__(self, segm_map):
segm_map_flat = np.reshape(segm_map, -1)
segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0)
segm_label_to = segm_map_flat[self.indices_to]
valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21))
equal_label = np.equal(segm_label_from, segm_label_to)
pos_affinity_label = np.logical_and(equal_label, valid_label)
bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32)
fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32)
neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32)
return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), torch.from_numpy(neg_affinity_label)