PUMP / core /functional.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn.functional as F
def affmul( aff, vecs ):
""" affine multiplication:
computes aff @ vecs.T """
if aff is None: return vecs
if isinstance(aff, (tuple,list)) or aff.ndim==3:
assert len(aff) == 2
assert 4 <= vecs.shape[-1], bb()
vecs = vecs.clone() if isinstance(vecs, torch.Tensor) else vecs.copy()
vecs[...,0:2] = affmul(aff[0], vecs[...,0:2])
vecs[...,2:4] = affmul(aff[1], vecs[...,2:4])
return vecs
else:
assert vecs.shape[-1] == 2, bb()
assert aff.shape == (2,3) or (aff.shape==(3,3) and
aff[2,0] == aff[2,1] == 0 and aff[2,2] == 1), bb()
return (vecs @ aff[:2,:2].T) + aff[:2,2]
def imresize( img, max_size, mode='area' ):
# trf: cur_pix --> old_pix
img, trf = img if isinstance(img,tuple) else (img, torch.eye(3,device=img.device))
shape = img.shape[-2:]
if max_size > 0 and max(shape) > max_size:
new_shape = tuple(i * max_size // max(shape) for i in shape)
img = F.interpolate( img[None].float(), size=new_shape, mode=mode )[0]
img.clamp_(min=0, max=255)
sca = torch.diag(torch.tensor((shape[0]/new_shape[0],shape[1]/new_shape[1],1), device=img.device))
img = img.byte()
trf = trf @ sca # undo sca first
return img, trf
def rotate_img( img, angle, crop=False ):
if angle in (0, 90, 180, 270):
return rotate_img_90(img,angle)
img, trf = img
assert trf.shape == (3,3)
def centered_rotation(rotation, shape, **device):
# rotation matrix
# pt_in_original_image = rot * pt_in_rotated_image
angle = rotation * np.pi / 180
c, s = np.cos(angle), np.sin(angle)
rot = torch.tensor([(c, -s, 0), (s, c, 0), (0, 0, 1)], dtype=torch.float32, **device)
# determine center of rotation before
H, W = shape
c_before = torch.tensor((W,H), **device) / 2
if crop:
c_after = c_before
rot_size = (W,H)
else:
# enlarge image to fit everything
corners = torch.tensor([(0, W, W, 0), (0, 0, H, H)], dtype=torch.float32, **device)
corners = affmul(rot, corners.T).T
rot_size = (corners.max(dim=1).values - corners.min(dim=1).values + 0.5).int()
rot_size = (rot_size // 4) * 4 # legacy
c_after = rot_size / 2
rot[:2,2] = c_before - affmul(rot, c_after) # fix translation
return rot, tuple(rot_size)[::-1]
C, H, W = img.shape
rot, (OH, OW) = centered_rotation(angle, (H,W), device=img.device)
# pt_in_original_image = rot * pt_in_rotated_image
# but pytorch works in [-1,1] coordinates... annoying
# pt_in_original_1_1 = orig_px_to_1_1 * rot * rotated_1_1_to_px * pt_in_rotated_1_1
_1_1_to_px = lambda W,H: torch.tensor(((W/2, 0, W/2), (0, H/2, H/2), (0, 0, 1)), device=img.device)
theta = torch.inverse(_1_1_to_px(W-1,H-1)) @ rot @ _1_1_to_px(OW-1,OH-1)
grid = F.affine_grid(theta[None,:2], (1, C, OH, OW), align_corners=True)
res = F.grid_sample(img[None].float(), grid, align_corners=True).to(dtype=img.dtype)[0]
return res, trf @ rot
def rotate_img_90( img, angle ):
""" Rotate an image by a multiple of 90 degrees using simple transpose and flip ops.
img = tuple( image, existing_trf )
existing_trf: current --> old
"""
angle = angle % 360
assert angle in (0, 90, 180, 270), 'cannot handle rotation other than multiple of 90 degrees'
img, trf = img
assert trf.shape == (3,3)
if isinstance(img, np.ndarray):
assert img.ndim == 3 and 1 <= img.shape[2] <= 3
new, x, y = np.float32, 1, 0
flip = lambda i,d: np.flip(i,axis=d)
elif isinstance(img, torch.Tensor):
assert img.ndim == 3 and 1 <= img.shape[0] <= 3
new, x, y = trf.new, -1, -2
flip = lambda i,d: i.flip(dims=[d])
H, W = img.shape[y], img.shape[x]
if angle == 90:
# point 0,0 --> (0, H-1); W-1,0 --> 0,0
img = flip(img.swapaxes(x,y),y)
trf = trf @ new([[0,-1,W-1],[1,0,0],[0,0,1]]) # inverse transform: new --> current
if angle == 180:
# point 0,0 --> (W-1, H-1)
img = flip(flip(img,x),y)
trf = trf @ new([[-1,0,W-1],[0,-1,H-1],[0,0,1]]) # inverse transform: new --> current
if angle == 270:
# point 0,0 --> (H-1, 0); 0,H-1 --> 0,0
img = flip(img.swapaxes(x,y),x)
trf = trf @ new([[0,1,0],[-1,0,H-1],[0,0,1]]) # inverse transform: new --> current
return img, trf
def encode_scale_rot(scale, rot):
s = np.int32(np.rint(np.log(scale) / (0.5*np.log(2))))
r = np.int32(np.rint(((-rot) % 360) / 45)) % 8
return 8*s + (r%8)
def decode_scale_rot( code ):
s = code // 8
r = (code % 8)
return 2 ** (s/2), -((45 * r + 180) % 360 - 180)
def normalized_corr(patches, img, padding='ncc', extra_patch=False, ret_norms=False):
assert patches.ndim == 4, 'patches shape must be (H*W, C, K, K)'
P, C, K, K = patches.shape
assert img.ndim == 3 and img.shape[0] == C, 'img shape must be (C, W, H)'
eps = torch.finfo(patches.dtype).tiny
# normalize on patches side
norms = patches.view(P,-1).norm(dim=-1)
patches = patches / norms[:,None,None,None].clamp(min=eps)
# convolve normalized patches on unnormalized image
ninth = 0
if padding == 'ninth':
ninth = img[:,-1].mean() # ninth dimension
img = F.pad(img[None], (K//2,K//2)*2, mode='constant', value=ninth)[0]
corr = F.conv2d(img[None], patches, padding=0, bias=None)[0]
# normalize on img's side
ones = patches.new_ones((1, C, K, K))
local_norm = torch.sqrt(F.conv2d(img[None]**2, ones))[0]
corr /= local_norm
# normalize on patches' side (image borders)
if padding == 'ncc':
local_norm = torch.sqrt(F.conv2d(ones, patches**2, padding=2))[0]
local_norm.clamp_(min=eps)
for j in range(-2, 3):
for i in range(-2,3):
if i == j == 2: continue # normal case is already normalized
if i == 2: i = slice(2,-2)
if j == 2: j = slice(2,-2)
corr[:,j,i] /= local_norm[:,j,i]
return (corr, norms) if ret_norms else corr
def true_corr_shape( corr_shape, level ):
H1, W1, H2, W2 = corr_shape[-4:]
if level > 0: # recover true size
H1, W1 = H1-1, W1-1
return corr_shape[:-4] + (H1, W1, H2, W2)
def children(level, H1, W1, H2, W2):
""" level: parent level (> 1) """
gap = 2**(level-2)
# @ level 1: gap=0.5 (parent at x=1 has children at x=[0.5, 1.5])
# @ level 2: gap=1 (parent at x=1 has children at x=[0, 2])
# @ level 3: gap=2 (parent at x=2 has children at x=[0, 4])
# etc.
def ravel_child(x, y):
# x,y is he center of the child patch
inside = (0 <= x <= W1) and (0 <= y <= H1)
if gap < 1:
assert x % 1 == y % 1 == 0.5, bb()
return int((x-0.5) + (y-0.5) * W1) if inside else -1
else:
assert x % 1 == y % 1 == 0, bb()
return int(x + y * (W1+1)) if inside else -1
# 4 children for each parent patch (top-left, top-right, bot-left, bot-right, -1 = None)
parents = []
for h in range(H1+1):
for w in range(W1+1):
# enumerate the 4 children for this patch
children = [ravel_child(w + gap*tx, h + gap*ty) for ty in (-1,1) for tx in (-1,1)]
parents.append(children)
return torch.tensor(parents, dtype=torch.int64)
def sparse_conv(level, corr, weights=None, reverse=False, norm=0.9):
H1, W1, H2, W2 = true_corr_shape(corr.shape, level-1 + reverse)
parents = children(level, H1, W1, H2, W2).to(corr.device)
n_parents = len(parents)
# perform the sparse convolution 'manually'
# since sparse convolutions are not implemented in pytorch currently
corr = corr.view(-1, *corr.shape[-2:])
if not reverse:
res = corr.new_zeros((n_parents+1,)+corr.shape[-2:]) # last one = garbage channel
nrm = corr.new_full((n_parents+1,3,3), 1e-8)
ones = nrm.new_ones((len(corr),1,1))
ex = 1
if weights is not None:
weights = weights.view(len(corr),1,1)
corr *= weights # apply weights to correlation maps without increasing memory footprint
ones *= weights
else:
assert corr._base is not None and corr._base.shape[0] == n_parents+1
corr._base[-1] = 0 # reset garbage layer
ex = 1 if level > 1 else 0
n_children = (H1+ex) * (W1+ex)
res = corr.new_zeros((n_children,)+corr.shape[-2:])
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None)
c = 0
for y in (-1, 1):
for x in (-1, 1):
src_layers = parents[:,c]; c+= 1
# we want to do: res += corr[src_layers] (for all children != -1)
# but we only have 'res.index_add_()' <==> res[tgt_layers] += corr
tgt_layers = inverse_mapping(src_layers, max_elem=len(corr), default=n_parents)[:-1]
if not reverse:
# All of corr's channels MUST be utilized. for level>1, this doesn't hold,
# so we'll send them to a garbage channel ==> res[n_parents]
sel = good_slice( tgt_layers < n_parents )
res[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], corr[sel,sl(y),sl(x)])
nrm[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2))
else:
''' parent=199=11*17+12 @ (x=48, y=44) at level=1
|-- child=171 @ (x=46,y=42) at level0
|-- child=172 @ (x=50,y=42) at level0
|-- child=187 @ (x=46,y=46) at level0
|-- child=188 @ (x=50,y=46) at level0
'''
out = res[:,sl(y),sl(x)]
sel = tgt_layers[:n_children]
torch.maximum(out, corr._base[sel,sl(-y),sl(-x)], out=out)
if not reverse:
if weights is not None: corr /= weights.clamp(min=1e-12) # cancel weights
weights = norm_borders(res, nrm, norm=norm)[:-1]
res = res[:-1] # remove garbage channel
res = res.view(H1+ex, W1+ex, *res.shape[-2:])
return res if reverse else (res, weights)
def norm_borders( res, nrm, norm=0.9 ):
""" apply some border normalization, modulated by `norm`
- if norm=0: no normalization at all
- if norm=1: full normalization
Formula: nrm = k * (nrm/k)**p = k**(1-p) * nrm**p,
with k=nrm[:,1,1] and p=norm
"""
new_weights = nrm[...,1,1].clone()
nrm = (nrm[...,1:2,1:2] ** (1-norm)) * (nrm ** norm)
# assert not torch.isnan(nrm).any()
# normalize results on the borders
res[...,0 ,0 ] /= nrm[...,0 ,0 ]
res[...,0 ,1:-1] /= nrm[...,0 ,1:2]
res[...,0 , -1] /= nrm[...,0 ,2 ]
res[...,1:-1,0 ] /= nrm[...,1:2,0 ]
res[...,1:-1,1:-1] /= nrm[...,1:2,1:2]
res[...,1:-1, -1] /= nrm[...,1:2,2 ]
res[..., -1,0 ] /= nrm[...,2 ,0 ]
res[..., -1,1:-1] /= nrm[...,2 ,1:2]
res[..., -1, -1] /= nrm[...,2 ,2 ]
return new_weights
def inverse_mapping( map, max_elem=None, default=None):
""" given a mapping {i:j} we output {j:i}
(the mapping is a torch array)
"""
assert isinstance(map, torch.Tensor) and map.ndim == 1
if max_elem is None: max_elem = map.max()
if default is None:
index = torch.empty(max_elem+1, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
else:
index = torch.full((max_elem+1,), default, dtype=torch.int64, device=map.device) # same size as corr, last elem == garbage
index[map] = torch.arange(len(map), device=map.device)
return index
def good_slice( nonzero ):
good = nonzero.nonzero().ravel()
return slice(good.min().item(), good.max().item()+1)
def max_unpool(upper, lower, exclude_border=True):
# re-compute max-pool indices
if exclude_border:
# apparently, we cannot unpool on the bottom and right borders in legacy code (local_argmax with ex=1)
_, pos = F.max_pool2d(lower[:,:,:-1,:-1], 3, padding=1, stride=2, return_indices=True, ceil_mode=True)
W1 = lower.shape[-1]
pos = (pos//(W1-1))*W1 + (pos%(W1-1)) # fix the shortening
else:
_, pos = F.max_pool2d(lower, 3, padding=1, stride=2, return_indices=True)
# because there are potential collisions between overlapping 3x3 cells,
# that pytorch does not handle, we unpool in 4 successive non-overlapping steps.
for i in range(2):
for j in range(2):
# stride=0 instead of 1 because pytorch does some size checking, this is a hack
tmp = F.max_unpool2d(upper[:,:,i::2,j::2], pos[:,:,i::2,j::2], kernel_size=3, padding=0, stride=4, output_size=lower.shape[-2:])
if i == j == 0:
res = tmp
else:
torch.maximum(res, tmp, out=res)
# add scores to existing lower correlation map
lower += res
return lower
def mgrid( shape, **kw ):
""" Returns in (x, y) order (contrary to numpy which is (y,x) """
if isinstance(shape, torch.Tensor): shape = shape.shape
res = torch.meshgrid(*[torch.arange(n, dtype=torch.float32, **kw) for n in shape], indexing='ij')
return torch.stack(res[::-1], dim=-1).view(-1,2)
def check_corres( corres, step, rot=None ):
H, W, two = corres.shape
assert two == 2
if isinstance(corres, np.ndarray):
corres = torch.from_numpy(corres)
if rot is not None:
corres = affmul(rot, corres)
gt = mgrid(corres.shape[:2]).view(H,W,2)
assert ((gt - corres // step).abs() <= 2).float().mean() > 0.99, bb()
def best_correspondences(corr):
""" All positions are returned as x1, y1, x2, y2
"""
if isinstance(corr, tuple): return corr # for legacy
H1, W1, H2, W2 = corr.shape
fix1 = lambda arr: 4*arr+2 # center of cells in img1
div = lambda a,b: torch.div(a, b, rounding_mode='trunc') # because of warning in pytorch 1.9+
# best scores in img1
score1, pos1 = corr.view(H1, W1, H2*W2).max(dim=-1)
pos1 = torch.cat((fix1(mgrid(score1, device=pos1.device)), pos1.view(-1,1)%W2, div(pos1.view(-1,1),W2)), dim=-1)
# best scores in img2
score2, pos2 = max_pool3d( corr, kernel_size=4, stride=4 )
pos2, score2 = pos2.view(-1,1), score2.squeeze()
pos2 = torch.cat((fix1(div(pos2,W2*H2)%W1), fix1(div(pos2,(W1*H2*W2))), pos2%W2, div(pos2,W2)%H2), dim=-1).float()
return (pos1, score1), (pos2, score2)
def intersection( set1_, set2_ ):
""" Returns the indices of values in set1 that are duplicated in set2
"""
set1, map1 = set1_.squeeze().unique(return_inverse=True) # map1: i1 -> j1
set2 = set2_.squeeze().unique()
combined = torch.cat((set1, set2))
uniques, inverse, counts = combined.unique(return_counts=True, return_inverse=True)
# j -> u, i -> j, j -> n
# we are interested only in (j -> i) for n > 1:
# assert counts.max() <= 2, 'there were non-unique values in either set1 or set2'+bb()
# intersected_values = uniques[counts > 1]
inverse1 = inverse_mapping(inverse[:len(set1)], max_elem=len(uniques)-1)
intersected_indices1 = inverse1[counts>1]
return inverse_mapping(map1, max_elem=len(set1)-1)[intersected_indices1]
def reciprocal(self, corres1, corres2 ):
pos1, score1 = corres1
pos2, score2 = corres2
(H1, W1), (H2, W2) = score1.shape, map(lambda i: 4*i+1, score2.shape)
to_int = pos1.new_tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32)
inter1 = intersection(pos1@to_int, pos2@to_int)
res = torch.cat((pos1[inter1], score1.view(-1,1)[inter1], 0*score1.view(-1,1)[inter1]), dim=-1)
return res
def max_pool3d( corr, kernel_size=4, stride=4 ):
H1, W1, H2, W2 = corr.shape
ks, st = kernel_size, stride
if corr.numel() >= 2**31 and corr.device != torch.device('cpu'):
# re-implementation due to a bug in pytorch
import core.cuda_deepm as kernels
return kernels.max_pool3d( corr.view(1, H1*W1, H2, W2), kernel_size, stride)
else:
return F.max_pool3d( corr.view(1, 1, H1*W1, H2, W2), kernel_size=(H1*W1,ks,ks), stride=(1,st,st), return_indices=True)
def forward_cuda(self, level, lower, weights=None, pooled=False):
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
assert lower.numel() < 2**31, 'please use cuda-lowmem, pytorch cannot handle big tensors'
pooled = lower if pooled else F.max_pool2d(lower, 3, padding=1, stride=2)
return kernels.forward_agg(level, self.border_inv, pooled, weights)
def forward_cuda_lowmem(self, level, lower, weights=None):
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
return kernels.forward_pool_agg(level, self.border_inv, lower, weights)
def backward_cuda(self, level, pyramid):
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
kernels.backward_agg_unpool(level, pyramid[level], pyramid[level-1], True)
# assert not torch.isnan(pyramid[level-1]).any(), bb()
return pyramid[level-1]
def merge_corres(self, corres, rots, all_corres, code):
" rot : reference --> rotated "
all_step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 # step size in all_corres
dev = all_corres[0][1].device
# stack correspondences
corres = [torch.cat((p.view(*s.shape,4),s[:,:,None],torch.full_like(s[:,:,None],code)),dim=2) for (p,s) in corres]
import core.cuda_deepm as kernels # must be imported after torch_set_gpu()
kernels.merge_corres_one_side( corres[0].to(dev), 0, rots[0].to(dev), all_corres[0][1], all_step )
kernels.merge_corres_one_side( corres[1].to(dev), 2, rots[1].to(dev), all_corres[1][1], all_step )