|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def affmul( aff, vecs ): |
|
""" affine multiplication: |
|
computes aff @ vecs.T """ |
|
if aff is None: return vecs |
|
if isinstance(aff, (tuple,list)) or aff.ndim==3: |
|
assert len(aff) == 2 |
|
assert 4 <= vecs.shape[-1], bb() |
|
vecs = vecs.clone() if isinstance(vecs, torch.Tensor) else vecs.copy() |
|
vecs[...,0:2] = affmul(aff[0], vecs[...,0:2]) |
|
vecs[...,2:4] = affmul(aff[1], vecs[...,2:4]) |
|
return vecs |
|
else: |
|
assert vecs.shape[-1] == 2, bb() |
|
assert aff.shape == (2,3) or (aff.shape==(3,3) and |
|
aff[2,0] == aff[2,1] == 0 and aff[2,2] == 1), bb() |
|
return (vecs @ aff[:2,:2].T) + aff[:2,2] |
|
|
|
|
|
def imresize( img, max_size, mode='area' ): |
|
|
|
img, trf = img if isinstance(img,tuple) else (img, torch.eye(3,device=img.device)) |
|
|
|
shape = img.shape[-2:] |
|
if max_size > 0 and max(shape) > max_size: |
|
new_shape = tuple(i * max_size // max(shape) for i in shape) |
|
img = F.interpolate( img[None].float(), size=new_shape, mode=mode )[0] |
|
img.clamp_(min=0, max=255) |
|
sca = torch.diag(torch.tensor((shape[0]/new_shape[0],shape[1]/new_shape[1],1), device=img.device)) |
|
img = img.byte() |
|
trf = trf @ sca |
|
|
|
return img, trf |
|
|
|
|
|
def rotate_img( img, angle, crop=False ): |
|
if angle in (0, 90, 180, 270): |
|
return rotate_img_90(img,angle) |
|
|
|
img, trf = img |
|
assert trf.shape == (3,3) |
|
|
|
def centered_rotation(rotation, shape, **device): |
|
|
|
|
|
angle = rotation * np.pi / 180 |
|
c, s = np.cos(angle), np.sin(angle) |
|
rot = torch.tensor([(c, -s, 0), (s, c, 0), (0, 0, 1)], dtype=torch.float32, **device) |
|
|
|
|
|
H, W = shape |
|
c_before = torch.tensor((W,H), **device) / 2 |
|
if crop: |
|
c_after = c_before |
|
rot_size = (W,H) |
|
else: |
|
|
|
corners = torch.tensor([(0, W, W, 0), (0, 0, H, H)], dtype=torch.float32, **device) |
|
corners = affmul(rot, corners.T).T |
|
rot_size = (corners.max(dim=1).values - corners.min(dim=1).values + 0.5).int() |
|
rot_size = (rot_size // 4) * 4 |
|
c_after = rot_size / 2 |
|
|
|
rot[:2,2] = c_before - affmul(rot, c_after) |
|
return rot, tuple(rot_size)[::-1] |
|
|
|
C, H, W = img.shape |
|
rot, (OH, OW) = centered_rotation(angle, (H,W), device=img.device) |
|
|
|
|
|
|
|
|
|
_1_1_to_px = lambda W,H: torch.tensor(((W/2, 0, W/2), (0, H/2, H/2), (0, 0, 1)), device=img.device) |
|
theta = torch.inverse(_1_1_to_px(W-1,H-1)) @ rot @ _1_1_to_px(OW-1,OH-1) |
|
|
|
grid = F.affine_grid(theta[None,:2], (1, C, OH, OW), align_corners=True) |
|
res = F.grid_sample(img[None].float(), grid, align_corners=True).to(dtype=img.dtype)[0] |
|
return res, trf @ rot |
|
|
|
|
|
|
|
def rotate_img_90( img, angle ): |
|
""" Rotate an image by a multiple of 90 degrees using simple transpose and flip ops. |
|
img = tuple( image, existing_trf ) |
|
existing_trf: current --> old |
|
""" |
|
angle = angle % 360 |
|
assert angle in (0, 90, 180, 270), 'cannot handle rotation other than multiple of 90 degrees' |
|
img, trf = img |
|
assert trf.shape == (3,3) |
|
|
|
if isinstance(img, np.ndarray): |
|
assert img.ndim == 3 and 1 <= img.shape[2] <= 3 |
|
new, x, y = np.float32, 1, 0 |
|
flip = lambda i,d: np.flip(i,axis=d) |
|
elif isinstance(img, torch.Tensor): |
|
assert img.ndim == 3 and 1 <= img.shape[0] <= 3 |
|
new, x, y = trf.new, -1, -2 |
|
flip = lambda i,d: i.flip(dims=[d]) |
|
H, W = img.shape[y], img.shape[x] |
|
|
|
if angle == 90: |
|
|
|
img = flip(img.swapaxes(x,y),y) |
|
trf = trf @ new([[0,-1,W-1],[1,0,0],[0,0,1]]) |
|
if angle == 180: |
|
|
|
img = flip(flip(img,x),y) |
|
trf = trf @ new([[-1,0,W-1],[0,-1,H-1],[0,0,1]]) |
|
if angle == 270: |
|
|
|
img = flip(img.swapaxes(x,y),x) |
|
trf = trf @ new([[0,1,0],[-1,0,H-1],[0,0,1]]) |
|
return img, trf |
|
|
|
|
|
def encode_scale_rot(scale, rot): |
|
s = np.int32(np.rint(np.log(scale) / (0.5*np.log(2)))) |
|
r = np.int32(np.rint(((-rot) % 360) / 45)) % 8 |
|
return 8*s + (r%8) |
|
|
|
def decode_scale_rot( code ): |
|
s = code // 8 |
|
r = (code % 8) |
|
return 2 ** (s/2), -((45 * r + 180) % 360 - 180) |
|
|
|
|
|
def normalized_corr(patches, img, padding='ncc', extra_patch=False, ret_norms=False): |
|
assert patches.ndim == 4, 'patches shape must be (H*W, C, K, K)' |
|
P, C, K, K = patches.shape |
|
assert img.ndim == 3 and img.shape[0] == C, 'img shape must be (C, W, H)' |
|
eps = torch.finfo(patches.dtype).tiny |
|
|
|
|
|
norms = patches.view(P,-1).norm(dim=-1) |
|
patches = patches / norms[:,None,None,None].clamp(min=eps) |
|
|
|
|
|
ninth = 0 |
|
if padding == 'ninth': |
|
ninth = img[:,-1].mean() |
|
img = F.pad(img[None], (K//2,K//2)*2, mode='constant', value=ninth)[0] |
|
|
|
corr = F.conv2d(img[None], patches, padding=0, bias=None)[0] |
|
|
|
|
|
ones = patches.new_ones((1, C, K, K)) |
|
local_norm = torch.sqrt(F.conv2d(img[None]**2, ones))[0] |
|
corr /= local_norm |
|
|
|
|
|
if padding == 'ncc': |
|
local_norm = torch.sqrt(F.conv2d(ones, patches**2, padding=2))[0] |
|
local_norm.clamp_(min=eps) |
|
for j in range(-2, 3): |
|
for i in range(-2,3): |
|
if i == j == 2: continue |
|
if i == 2: i = slice(2,-2) |
|
if j == 2: j = slice(2,-2) |
|
corr[:,j,i] /= local_norm[:,j,i] |
|
|
|
return (corr, norms) if ret_norms else corr |
|
|
|
|
|
def true_corr_shape( corr_shape, level ): |
|
H1, W1, H2, W2 = corr_shape[-4:] |
|
if level > 0: |
|
H1, W1 = H1-1, W1-1 |
|
return corr_shape[:-4] + (H1, W1, H2, W2) |
|
|
|
def children(level, H1, W1, H2, W2): |
|
""" level: parent level (> 1) """ |
|
gap = 2**(level-2) |
|
|
|
|
|
|
|
|
|
|
|
def ravel_child(x, y): |
|
|
|
inside = (0 <= x <= W1) and (0 <= y <= H1) |
|
if gap < 1: |
|
assert x % 1 == y % 1 == 0.5, bb() |
|
return int((x-0.5) + (y-0.5) * W1) if inside else -1 |
|
else: |
|
assert x % 1 == y % 1 == 0, bb() |
|
return int(x + y * (W1+1)) if inside else -1 |
|
|
|
|
|
parents = [] |
|
for h in range(H1+1): |
|
for w in range(W1+1): |
|
|
|
children = [ravel_child(w + gap*tx, h + gap*ty) for ty in (-1,1) for tx in (-1,1)] |
|
parents.append(children) |
|
|
|
return torch.tensor(parents, dtype=torch.int64) |
|
|
|
|
|
def sparse_conv(level, corr, weights=None, reverse=False, norm=0.9): |
|
H1, W1, H2, W2 = true_corr_shape(corr.shape, level-1 + reverse) |
|
parents = children(level, H1, W1, H2, W2).to(corr.device) |
|
n_parents = len(parents) |
|
|
|
|
|
|
|
corr = corr.view(-1, *corr.shape[-2:]) |
|
if not reverse: |
|
res = corr.new_zeros((n_parents+1,)+corr.shape[-2:]) |
|
nrm = corr.new_full((n_parents+1,3,3), 1e-8) |
|
ones = nrm.new_ones((len(corr),1,1)) |
|
ex = 1 |
|
if weights is not None: |
|
weights = weights.view(len(corr),1,1) |
|
corr *= weights |
|
ones *= weights |
|
else: |
|
assert corr._base is not None and corr._base.shape[0] == n_parents+1 |
|
corr._base[-1] = 0 |
|
ex = 1 if level > 1 else 0 |
|
n_children = (H1+ex) * (W1+ex) |
|
res = corr.new_zeros((n_children,)+corr.shape[-2:]) |
|
|
|
sl = lambda v: slice(0,-1 or None) if v < 0 else slice(1,None) |
|
c = 0 |
|
for y in (-1, 1): |
|
for x in (-1, 1): |
|
src_layers = parents[:,c]; c+= 1 |
|
|
|
|
|
tgt_layers = inverse_mapping(src_layers, max_elem=len(corr), default=n_parents)[:-1] |
|
|
|
if not reverse: |
|
|
|
|
|
sel = good_slice( tgt_layers < n_parents ) |
|
|
|
res[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], corr[sel,sl(y),sl(x)]) |
|
nrm[:,sl(-y),sl(-x)].index_add_(0, tgt_layers[sel], ones[sel].expand(-1,2,2)) |
|
else: |
|
''' parent=199=11*17+12 @ (x=48, y=44) at level=1 |
|
|-- child=171 @ (x=46,y=42) at level0 |
|
|-- child=172 @ (x=50,y=42) at level0 |
|
|-- child=187 @ (x=46,y=46) at level0 |
|
|-- child=188 @ (x=50,y=46) at level0 |
|
''' |
|
out = res[:,sl(y),sl(x)] |
|
sel = tgt_layers[:n_children] |
|
torch.maximum(out, corr._base[sel,sl(-y),sl(-x)], out=out) |
|
|
|
if not reverse: |
|
if weights is not None: corr /= weights.clamp(min=1e-12) |
|
weights = norm_borders(res, nrm, norm=norm)[:-1] |
|
res = res[:-1] |
|
res = res.view(H1+ex, W1+ex, *res.shape[-2:]) |
|
return res if reverse else (res, weights) |
|
|
|
def norm_borders( res, nrm, norm=0.9 ): |
|
""" apply some border normalization, modulated by `norm` |
|
- if norm=0: no normalization at all |
|
- if norm=1: full normalization |
|
Formula: nrm = k * (nrm/k)**p = k**(1-p) * nrm**p, |
|
with k=nrm[:,1,1] and p=norm |
|
""" |
|
new_weights = nrm[...,1,1].clone() |
|
nrm = (nrm[...,1:2,1:2] ** (1-norm)) * (nrm ** norm) |
|
|
|
|
|
|
|
res[...,0 ,0 ] /= nrm[...,0 ,0 ] |
|
res[...,0 ,1:-1] /= nrm[...,0 ,1:2] |
|
res[...,0 , -1] /= nrm[...,0 ,2 ] |
|
res[...,1:-1,0 ] /= nrm[...,1:2,0 ] |
|
res[...,1:-1,1:-1] /= nrm[...,1:2,1:2] |
|
res[...,1:-1, -1] /= nrm[...,1:2,2 ] |
|
res[..., -1,0 ] /= nrm[...,2 ,0 ] |
|
res[..., -1,1:-1] /= nrm[...,2 ,1:2] |
|
res[..., -1, -1] /= nrm[...,2 ,2 ] |
|
return new_weights |
|
|
|
|
|
def inverse_mapping( map, max_elem=None, default=None): |
|
""" given a mapping {i:j} we output {j:i} |
|
(the mapping is a torch array) |
|
""" |
|
assert isinstance(map, torch.Tensor) and map.ndim == 1 |
|
if max_elem is None: max_elem = map.max() |
|
if default is None: |
|
index = torch.empty(max_elem+1, dtype=torch.int64, device=map.device) |
|
else: |
|
index = torch.full((max_elem+1,), default, dtype=torch.int64, device=map.device) |
|
index[map] = torch.arange(len(map), device=map.device) |
|
return index |
|
|
|
|
|
def good_slice( nonzero ): |
|
good = nonzero.nonzero().ravel() |
|
return slice(good.min().item(), good.max().item()+1) |
|
|
|
|
|
def max_unpool(upper, lower, exclude_border=True): |
|
|
|
if exclude_border: |
|
|
|
_, pos = F.max_pool2d(lower[:,:,:-1,:-1], 3, padding=1, stride=2, return_indices=True, ceil_mode=True) |
|
W1 = lower.shape[-1] |
|
pos = (pos//(W1-1))*W1 + (pos%(W1-1)) |
|
else: |
|
_, pos = F.max_pool2d(lower, 3, padding=1, stride=2, return_indices=True) |
|
|
|
|
|
|
|
for i in range(2): |
|
for j in range(2): |
|
|
|
tmp = F.max_unpool2d(upper[:,:,i::2,j::2], pos[:,:,i::2,j::2], kernel_size=3, padding=0, stride=4, output_size=lower.shape[-2:]) |
|
if i == j == 0: |
|
res = tmp |
|
else: |
|
torch.maximum(res, tmp, out=res) |
|
|
|
|
|
lower += res |
|
return lower |
|
|
|
|
|
def mgrid( shape, **kw ): |
|
""" Returns in (x, y) order (contrary to numpy which is (y,x) """ |
|
if isinstance(shape, torch.Tensor): shape = shape.shape |
|
res = torch.meshgrid(*[torch.arange(n, dtype=torch.float32, **kw) for n in shape], indexing='ij') |
|
return torch.stack(res[::-1], dim=-1).view(-1,2) |
|
|
|
|
|
def check_corres( corres, step, rot=None ): |
|
H, W, two = corres.shape |
|
assert two == 2 |
|
if isinstance(corres, np.ndarray): |
|
corres = torch.from_numpy(corres) |
|
if rot is not None: |
|
corres = affmul(rot, corres) |
|
gt = mgrid(corres.shape[:2]).view(H,W,2) |
|
assert ((gt - corres // step).abs() <= 2).float().mean() > 0.99, bb() |
|
|
|
|
|
def best_correspondences(corr): |
|
""" All positions are returned as x1, y1, x2, y2 |
|
""" |
|
if isinstance(corr, tuple): return corr |
|
H1, W1, H2, W2 = corr.shape |
|
fix1 = lambda arr: 4*arr+2 |
|
div = lambda a,b: torch.div(a, b, rounding_mode='trunc') |
|
|
|
|
|
score1, pos1 = corr.view(H1, W1, H2*W2).max(dim=-1) |
|
pos1 = torch.cat((fix1(mgrid(score1, device=pos1.device)), pos1.view(-1,1)%W2, div(pos1.view(-1,1),W2)), dim=-1) |
|
|
|
|
|
score2, pos2 = max_pool3d( corr, kernel_size=4, stride=4 ) |
|
pos2, score2 = pos2.view(-1,1), score2.squeeze() |
|
pos2 = torch.cat((fix1(div(pos2,W2*H2)%W1), fix1(div(pos2,(W1*H2*W2))), pos2%W2, div(pos2,W2)%H2), dim=-1).float() |
|
|
|
return (pos1, score1), (pos2, score2) |
|
|
|
|
|
def intersection( set1_, set2_ ): |
|
""" Returns the indices of values in set1 that are duplicated in set2 |
|
""" |
|
set1, map1 = set1_.squeeze().unique(return_inverse=True) |
|
set2 = set2_.squeeze().unique() |
|
combined = torch.cat((set1, set2)) |
|
|
|
uniques, inverse, counts = combined.unique(return_counts=True, return_inverse=True) |
|
|
|
|
|
|
|
|
|
inverse1 = inverse_mapping(inverse[:len(set1)], max_elem=len(uniques)-1) |
|
intersected_indices1 = inverse1[counts>1] |
|
return inverse_mapping(map1, max_elem=len(set1)-1)[intersected_indices1] |
|
|
|
|
|
def reciprocal(self, corres1, corres2 ): |
|
pos1, score1 = corres1 |
|
pos2, score2 = corres2 |
|
(H1, W1), (H2, W2) = score1.shape, map(lambda i: 4*i+1, score2.shape) |
|
|
|
to_int = pos1.new_tensor((W1*H2*W2, H2*W2, W2, 1), dtype=torch.float32) |
|
inter1 = intersection(pos1@to_int, pos2@to_int) |
|
res = torch.cat((pos1[inter1], score1.view(-1,1)[inter1], 0*score1.view(-1,1)[inter1]), dim=-1) |
|
return res |
|
|
|
|
|
def max_pool3d( corr, kernel_size=4, stride=4 ): |
|
H1, W1, H2, W2 = corr.shape |
|
ks, st = kernel_size, stride |
|
if corr.numel() >= 2**31 and corr.device != torch.device('cpu'): |
|
|
|
import core.cuda_deepm as kernels |
|
return kernels.max_pool3d( corr.view(1, H1*W1, H2, W2), kernel_size, stride) |
|
else: |
|
return F.max_pool3d( corr.view(1, 1, H1*W1, H2, W2), kernel_size=(H1*W1,ks,ks), stride=(1,st,st), return_indices=True) |
|
|
|
|
|
def forward_cuda(self, level, lower, weights=None, pooled=False): |
|
import core.cuda_deepm as kernels |
|
assert lower.numel() < 2**31, 'please use cuda-lowmem, pytorch cannot handle big tensors' |
|
pooled = lower if pooled else F.max_pool2d(lower, 3, padding=1, stride=2) |
|
return kernels.forward_agg(level, self.border_inv, pooled, weights) |
|
|
|
def forward_cuda_lowmem(self, level, lower, weights=None): |
|
import core.cuda_deepm as kernels |
|
return kernels.forward_pool_agg(level, self.border_inv, lower, weights) |
|
|
|
def backward_cuda(self, level, pyramid): |
|
import core.cuda_deepm as kernels |
|
kernels.backward_agg_unpool(level, pyramid[level], pyramid[level-1], True) |
|
|
|
return pyramid[level-1] |
|
|
|
def merge_corres(self, corres, rots, all_corres, code): |
|
" rot : reference --> rotated " |
|
all_step = self.matcher.pixel_desc.get_atomic_patch_size() // 2 |
|
dev = all_corres[0][1].device |
|
|
|
|
|
corres = [torch.cat((p.view(*s.shape,4),s[:,:,None],torch.full_like(s[:,:,None],code)),dim=2) for (p,s) in corres] |
|
|
|
import core.cuda_deepm as kernels |
|
kernels.merge_corres_one_side( corres[0].to(dev), 0, rots[0].to(dev), all_corres[0][1], all_step ) |
|
kernels.merge_corres_one_side( corres[1].to(dev), 2, rots[1].to(dev), all_corres[1][1], all_step ) |
|
|
|
|