Spaces:
Runtime error
Runtime error
File size: 7,561 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import math
import torch
from .pair_wise_distance import PairwiseDistFunction
from .sparse_utils import naive_sparse_bmm
def calc_init_centroid(images, num_spixels_width, num_spixels_height):
"""
calculate initial superpixels
Args:
images: torch.Tensor
A Tensor of shape (B, C, H, W)
spixels_width: int
initial superpixel width
spixels_height: int
initial superpixel height
Return:
centroids: torch.Tensor
A Tensor of shape (B, C, H * W)
init_label_map: torch.Tensor
A Tensor of shape (B, H * W)
num_spixels_width: int
A number of superpixels in each column
num_spixels_height: int
A number of superpixels int each raw
"""
batchsize, channels, height, width = images.shape
device = images.device
centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width))
with torch.no_grad():
num_spixels = num_spixels_width * num_spixels_height
labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids)
init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest")
init_label_map = init_label_map.repeat(batchsize, 1, 1, 1)
init_label_map = init_label_map.reshape(batchsize, -1)
centroids = centroids.reshape(batchsize, channels, -1)
return centroids, init_label_map
@torch.no_grad()
def get_abs_indices(init_label_map, num_spixels_width):
b, n_pixel = init_label_map.shape
device = init_label_map.device
r = torch.arange(-1, 2.0, device=device)
relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)
abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long()
abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long()
abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long()
return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0)
@torch.no_grad()
def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width):
relative_label = affinity_matrix.max(1)[1]
r = torch.arange(-1, 2.0, device=affinity_matrix.device)
relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)
label = init_label_map + relative_spix_indices[relative_label]
return label.long()
@torch.no_grad()
def sparse_ssn_iter(pixel_features, num_spixels, n_iter):
"""
computing assignment iterations with sparse matrix
detailed process is in Algorithm 1, line 2 - 6
NOTE: this function does NOT guarantee the backward computation.
Args:
pixel_features: torch.Tensor
A Tensor of shape (B, C, H, W)
num_spixels: int
A number of superpixels
n_iter: int
A number of iterations
return_hard_label: bool
return hard assignment or not
"""
height, width = pixel_features.shape[-2:]
num_spixels_width = int(math.sqrt(num_spixels * width / height))
num_spixels_height = int(math.sqrt(num_spixels * height / width))
spixel_features, init_label_map = \
calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
abs_indices = get_abs_indices(init_label_map, num_spixels_width)
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
permuted_pixel_features = pixel_features.permute(0, 2, 1)
for _ in range(n_iter):
dist_matrix = PairwiseDistFunction.apply(
pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)
affinity_matrix = (-dist_matrix).softmax(1)
reshaped_affinity_matrix = affinity_matrix.reshape(-1)
mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])
spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \
/ (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16)
spixel_features = spixel_features.permute(0, 2, 1)
hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)
return sparse_abs_affinity, hard_labels, spixel_features
def ssn_iter(pixel_features, num_spixels, n_iter):
"""
computing assignment iterations
detailed process is in Algorithm 1, line 2 - 6
Args:
pixel_features: torch.Tensor
A Tensor of shape (B, C, H, W)
num_spixels: int
A number of superpixels
n_iter: int
A number of iterations
return_hard_label: bool
return hard assignment or not
"""
height, width = pixel_features.shape[-2:]
num_spixels_width = int(math.sqrt(num_spixels * width / height))
num_spixels_height = int(math.sqrt(num_spixels * height / width))
# spixel_features: 10 * 202 * 64
# init_label_map: 10 * 40000
spixel_features, init_label_map = \
calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
# get indices of the 9 neighbors
abs_indices = get_abs_indices(init_label_map, num_spixels_width)
# 10 * 202 * 40000
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
# 10 * 40000 * 202
permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous()
for _ in range(n_iter):
# 10 * 9 * 40000
dist_matrix = PairwiseDistFunction.apply(
pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)
affinity_matrix = (-dist_matrix).softmax(1)
reshaped_affinity_matrix = affinity_matrix.reshape(-1)
mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
# 10 * 64 * 40000
sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])
abs_affinity = sparse_abs_affinity.to_dense().contiguous()
spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \
/ (abs_affinity.sum(2, keepdim=True) + 1e-16)
spixel_features = spixel_features.permute(0, 2, 1).contiguous()
hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)
return abs_affinity, hard_labels, spixel_features
def ssn_iter2(pixel_features, num_spixels, n_iter, init_spixel_features, temp = 1):
"""
computing assignment iterations for second layer
Args:
pixel_features: torch.Tensor
A Tensor of shape (B, C, N)
num_spixels: int
A number of superpixels
init_spixel_features:
A Tensor of shape (B, C, num_spixels)
"""
spixel_features = init_spixel_features.permute(0, 2, 1)
pixel_features = pixel_features.permute(0, 2, 1)
for _ in range(n_iter):
# compute distance to all spixel_features
dist = torch.cdist(pixel_features, spixel_features) # B, N, num_spixels
aff = (-dist * temp).softmax(-1).permute(0, 2, 1) # B, num_spixels, N
# compute new superpixels centers
spixel_features = torch.bmm(aff, pixel_features) / (aff.sum(2, keepdim=True) + 1e-6) # B, num_spixels, C
hard_labels = torch.argmax(aff, dim = 1)
return aff, hard_labels, spixel_features
|