File size: 5,037 Bytes
d526dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
r""" Provides functions that manipulate boxes and points """

import math

import torch.nn.functional as F
import torch


class Geometry(object):

    @classmethod
    def initialize(cls, img_size):
        cls.img_size = img_size

        cls.spatial_side = int(img_size / 8)
        norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)

        cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
        cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
        cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)

        cls.feat_idx = torch.arange(0, cls.spatial_side).float()

    @classmethod
    def normalize_kps(cls, kps):
        kps = kps.clone().detach()
        kps[kps != -2] -= (cls.img_size // 2)
        kps[kps != -2] /= (cls.img_size // 2)
        return kps

    @classmethod
    def unnormalize_kps(cls, kps):
        kps = kps.clone().detach()
        kps[kps != -2] *= (cls.img_size // 2)
        kps[kps != -2] += (cls.img_size // 2)
        return kps

    @classmethod
    def attentive_indexing(cls, kps, thres=0.1):
        r"""kps: normalized keypoints x, y (N, 2)

            returns attentive index map(N, spatial_side, spatial_side)

        """
        nkps = kps.size(0)
        kps = kps.view(nkps, 1, 1, 2)

        eps = 1e-5
        attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
        attmap = (attmap + eps).pow(0.5)
        attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
        attmap = attmap / attmap.sum(dim=1, keepdim=True)
        attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)

        return attmap

    @classmethod
    def apply_gaussian_kernel(cls, corr, sigma=17):
        bsz, side, side = corr.size()

        center = corr.max(dim=2)[1]
        center_y = center // cls.spatial_side
        center_x = center % cls.spatial_side

        y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
        x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)

        y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
        x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)

        gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
        filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
        filtered_corr = filtered_corr.view(bsz, side, side)

        return filtered_corr

    @classmethod
    def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
        r""" Transfer keypoints by weighted average """

        if not normalized:
            src_kps = Geometry.normalize_kps(src_kps)
        confidence_ts = cls.apply_gaussian_kernel(confidence_ts)

        pdf = F.softmax(confidence_ts, dim=2)
        prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
        prd_y = (pdf * cls.norm_grid_y).sum(dim=2)

        prd_kps = []
        for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
            max_pts = src_kp.size()[1]
            prd_xy = torch.stack([x, y]).t()

            src_kp = src_kp[:, :np].t()
            attmap = cls.attentive_indexing(src_kp).view(np, -1)
            prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
            pads = (torch.zeros((2, max_pts - np)) - 2)
            prd_kp = torch.cat([prd_kp, pads], dim=1)
            prd_kps.append(prd_kp)

        return torch.stack(prd_kps)

    @staticmethod
    def get_coord1d(coord4d, ksz):
        i, j, k, l = coord4d
        coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
        return coord1d

    @staticmethod
    def get_distance(coord1, coord2):
        delta_y = int(math.pow(coord1[0] - coord2[0], 2))
        delta_x = int(math.pow(coord1[1] - coord2[1], 2))
        dist = delta_y + delta_x
        return dist

    @staticmethod
    def interpolate4d(tensor4d, size):
        bsz, h1, w1, h2, w2 = tensor4d.size()
        tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
        tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
        tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
        tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
        tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])

        return tensor4d
    @staticmethod
    def init_idx4d(ksz):
        i0 = torch.arange(0, ksz).repeat(ksz ** 3)
        i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
        i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
        i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
        idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()

        return idx4d