File size: 5,037 Bytes
d526dbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
r""" Provides functions that manipulate boxes and points """
import math
import torch.nn.functional as F
import torch
class Geometry(object):
@classmethod
def initialize(cls, img_size):
cls.img_size = img_size
cls.spatial_side = int(img_size / 8)
norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
cls.feat_idx = torch.arange(0, cls.spatial_side).float()
@classmethod
def normalize_kps(cls, kps):
kps = kps.clone().detach()
kps[kps != -2] -= (cls.img_size // 2)
kps[kps != -2] /= (cls.img_size // 2)
return kps
@classmethod
def unnormalize_kps(cls, kps):
kps = kps.clone().detach()
kps[kps != -2] *= (cls.img_size // 2)
kps[kps != -2] += (cls.img_size // 2)
return kps
@classmethod
def attentive_indexing(cls, kps, thres=0.1):
r"""kps: normalized keypoints x, y (N, 2)
returns attentive index map(N, spatial_side, spatial_side)
"""
nkps = kps.size(0)
kps = kps.view(nkps, 1, 1, 2)
eps = 1e-5
attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
attmap = (attmap + eps).pow(0.5)
attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
attmap = attmap / attmap.sum(dim=1, keepdim=True)
attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
return attmap
@classmethod
def apply_gaussian_kernel(cls, corr, sigma=17):
bsz, side, side = corr.size()
center = corr.max(dim=2)[1]
center_y = center // cls.spatial_side
center_x = center % cls.spatial_side
y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
filtered_corr = filtered_corr.view(bsz, side, side)
return filtered_corr
@classmethod
def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
r""" Transfer keypoints by weighted average """
if not normalized:
src_kps = Geometry.normalize_kps(src_kps)
confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
pdf = F.softmax(confidence_ts, dim=2)
prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
prd_kps = []
for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
max_pts = src_kp.size()[1]
prd_xy = torch.stack([x, y]).t()
src_kp = src_kp[:, :np].t()
attmap = cls.attentive_indexing(src_kp).view(np, -1)
prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
pads = (torch.zeros((2, max_pts - np)) - 2)
prd_kp = torch.cat([prd_kp, pads], dim=1)
prd_kps.append(prd_kp)
return torch.stack(prd_kps)
@staticmethod
def get_coord1d(coord4d, ksz):
i, j, k, l = coord4d
coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
return coord1d
@staticmethod
def get_distance(coord1, coord2):
delta_y = int(math.pow(coord1[0] - coord2[0], 2))
delta_x = int(math.pow(coord1[1] - coord2[1], 2))
dist = delta_y + delta_x
return dist
@staticmethod
def interpolate4d(tensor4d, size):
bsz, h1, w1, h2, w2 = tensor4d.size()
tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
return tensor4d
@staticmethod
def init_idx4d(ksz):
i0 = torch.arange(0, ksz).repeat(ksz ** 3)
i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
return idx4d
|