minchul commited on
Commit
e628e01
1 Parent(s): f2be5ee

Upload directory

Browse files
models/vit_kprpe/RPE/KPRPE/relative_keypoints.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ @torch.no_grad()
5
+ def make_rel_keypoints(keyponints, query):
6
+ seq_length = query.shape[1]
7
+ side = int(math.sqrt(seq_length))
8
+ assert side == math.sqrt(seq_length)
9
+
10
+ # make a grid of points from 0 to 1
11
+ coord = torch.linspace(0, 1, side+1, device=query.device, dtype=query.dtype)
12
+ coord = (coord[:-1] + coord[1:]) / 2 # get center of patches
13
+
14
+ x, y = torch.meshgrid(coord, coord, indexing='ij')
15
+ grid = torch.stack([y, x], dim=-1).reshape(-1, 2).unsqueeze(0).unsqueeze(-2) # BxNx1x2
16
+ _keyponints = keyponints.unsqueeze(-3) # Bx1x5x2
17
+ diff = (grid - _keyponints) # BxNx5x2
18
+ diff = diff.flatten(2) # BxNx10
19
+ return diff