|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
from PIL import Image |
|
import numpy as np |
|
|
|
from core import functional as myF |
|
from tools.common import todevice |
|
from .transforms import instanciate_transforms |
|
from .utils import * |
|
|
|
|
|
class FastPairLoader (DatasetWithRng): |
|
""" On-the-fly generation of related image pairs |
|
crop: random crop applied to both images |
|
scale: random scaling applied to img2 |
|
distort: random ditorsion applied to img2 |
|
|
|
self[idx] returns: (img1, img2), dict(homography=) |
|
(homography: 3x3 array, can be nan) |
|
""" |
|
def __init__(self, dataset, crop=256, transform='', p_flip=0, p_swap=0, scale_jitter=0, seed=None): |
|
super().__init__(seed) |
|
self.dataset = self.with_same_rng(dataset) |
|
self.transform = instanciate_transforms( transform, rng=self.rng ) |
|
self.crop_size = crop |
|
self.p_swap = p_swap |
|
self.p_flip = p_flip |
|
self.scale_jitter = abs(np.log1p(scale_jitter)) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __repr__(self): |
|
fmt_str = f'FastPairLoader({self.dataset},\n' |
|
short_repr = lambda s: repr(s).strip().replace('\n',', ')[14:-1].replace(' ',' ') |
|
fmt_str += ' Transform:\t%s\n' % short_repr(self.transform) |
|
fmt_str +=f' Crop={self.crop_size}, scale_jitter=x{np.exp(self.scale_jitter):g}, p_swap={self.p_swap:g}' |
|
return fmt_str |
|
|
|
def init_worker(self, tid): |
|
super().init_worker(tid) |
|
self.dataset.init_worker(tid) |
|
|
|
def set_epoch(self, epoch): |
|
self.dataset.set_epoch(epoch) |
|
|
|
def __getitem__(self, idx): |
|
self.init_worker(idx) |
|
(img1, img2), gt = self.dataset[idx] |
|
|
|
if self.rng.random() < self.p_swap: |
|
img1, img2 = img2, img1 |
|
if 'homography' in gt: gt['homography'] = invh(gt['homography']) |
|
if 'corres' in gt: gt['corres'] = swap_corres(gt['corres']) |
|
|
|
if self.rng.random() < self.p_flip: |
|
img1, img2, gt = flip_image_pair(img1, img2, gt) |
|
|
|
|
|
img2 = self.transform(dict(img=img2)) |
|
|
|
homography, corres = spatial_relationship( img1, img2, gt ) |
|
|
|
|
|
img1, img2 = map(self._pad_rgb_numpy, (img1, img2['img'])) |
|
|
|
if not 'debug': |
|
from tools.viz import show_correspondences |
|
print(np.median(corres[:,5])) |
|
show_correspondences(img1, img2, corres, bb=bb) |
|
|
|
def windows_from_corres( idx, scale_jitter=1 ): |
|
c = corres[idx] |
|
p1, p2, scale = c[0:2], c[2:4], c[6] |
|
scale *= scale_jitter |
|
|
|
|
|
win1 = window(*p1, self.crop_size, max(1, 1/scale), img1.shape) |
|
win2 = window(*p2, self.crop_size, max(1, scale/1), img2.shape) |
|
return win1, win2 |
|
|
|
best = 0, None |
|
for idx in self.rng.choice(len(corres), size=min(len(corres),5), replace=False): |
|
|
|
win1, win2 = windows_from_corres( idx ) |
|
|
|
|
|
score = score_windows(is_in(corres[:,0:2],win1), is_in(corres[:,2:4],win2)) |
|
if score > best[0]: best = score, idx |
|
|
|
others = {} |
|
if None in best: |
|
img1 = img2 = np.zeros((self.crop_size,self.crop_size,3), dtype=np.uint8) |
|
corres = np.empty((0, 6), dtype=np.float32) |
|
else: |
|
|
|
scale_jitter = np.exp(self.rng.uniform(-self.scale_jitter, self.scale_jitter)) |
|
win1, win2 = windows_from_corres( best[1], scale_jitter ) |
|
|
|
img1, img2 = imresize(img1[win1], self.crop_size), imresize(img2[win2], self.crop_size) |
|
trf1, trf2 = wintrf(win1, img1), wintrf(win2, img2) |
|
|
|
|
|
angle_scores = np.bincount(corres[:,5].astype(int) % 8) |
|
rot90 = int((((angle_scores.argmax() + 4) % 8) - 4) / 2) |
|
if rot90: |
|
img2, trf = myF.rotate_img_90((img2, np.eye(3)), 90*rot90) |
|
trf2 = invh(trf) @ trf2 |
|
|
|
homography = trf2 @ homography @ invh(trf1) |
|
corres = myF.affmul((trf1,trf2), corres) |
|
|
|
f32c = lambda i,**kw: np.require(i, requirements='CWAE', **kw) |
|
return (f32c(img1), f32c(img2)), dict(homography = f32c(homography, dtype=np.float32), corres=corres, **others) |
|
|
|
def _pad_rgb_numpy(self, img): |
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
if min(img.size) < self.crop_size: |
|
w, h = img.size |
|
result = Image.new('RGB', (max(w,self.crop_size), max(h,self.crop_size)), 0) |
|
result.paste(img, (0, 0)) |
|
img = result |
|
return np.asarray(img) |
|
|
|
|
|
|
|
def swap_corres( corres ): |
|
res = corres.copy() |
|
res[:,[0,1,2,3]] = corres[:,[2,3,0,1]] |
|
if corres.shape[1] > 4: |
|
scale, rot = myF.decode_scale_rot(corres[:,5]) |
|
res[:,5] = myF.encode_scale_rot(1/scale, -rot) |
|
return res |
|
|
|
def flip(img): |
|
w, h = img.size |
|
return img.transpose(Image.FLIP_LEFT_RIGHT), np.float32( [[-1,0,w-1],[0,1,0],[0,0,1]] ) |
|
|
|
def flip_image_pair(img1, img2, gt): |
|
img1, F1 = flip(img1) |
|
img2, F2 = flip(img2) |
|
res = {} |
|
for key, value in gt.items(): |
|
if key == 'homography': |
|
res['homography'] = F2 @ value @ F1 |
|
elif key == 'aflow': |
|
assert False, 'flip for aflow: todo' |
|
elif key == 'corres': |
|
new_corres = np.c_[applyh(F1,value[:,0:2]), applyh(F2,value[:,2:4])] |
|
if value.shape[1] == 4: pass |
|
elif value.shape[1] == 6: |
|
scale, rot = myF.decode_scale_rot(value[:,5]) |
|
new_code = myF.encode_scale_rot(scale, -rot) |
|
new_corres = np.c_[new_corres,value[:,4],new_code] |
|
res['corres'] = new_corres |
|
else: |
|
raise ValueError(f"flip_image_pair: bad gt field '{key}'") |
|
return img1, img2, res |
|
|
|
|
|
def spatial_relationship( img1, img2, gt ): |
|
if 'homography' in gt: |
|
homography = gt['homography'] |
|
if 'homography' in img2: |
|
homography = np.float32(img2['homography']) @ homography |
|
corres = corres_from_homography(homography, *img1.size) |
|
|
|
elif 'corres' in gt: |
|
homography = np.full((3,3), np.nan, dtype=np.float32) |
|
corres = gt['corres'] |
|
if 'homography' in img2: |
|
corres[:,2:4] = applyh(img2['homography'], corres[:,2:4]) |
|
else: |
|
img2['homography'] = np.eye(3) |
|
scales = np.sqrt(np.abs(np.linalg.det(jacobianh(img2['homography'], corres[:,0:2]).T))) |
|
|
|
if corres.shape[1] == 4: |
|
scales, rots = scale_rot_from_corres(corres) |
|
corres = np.c_[corres, np.ones_like(scales), myF.encode_scale_rot(scales,rots*180/np.pi), scales] |
|
elif corres.shape[1] == 6: |
|
corres = np.c_[corres, scales * myF.decode_scale_rot(corres[:,5])[0]] |
|
else: |
|
assert ValueError(f'bad shape for corres: {corres.shape}') |
|
|
|
return homography, corres |
|
|
|
|
|
def scale_rot_from_corres( corres, sub=256, nn=16 ): |
|
|
|
sub = np.random.choice(len(corres), size=min(len(corres),sub), replace=False) |
|
sub = corres[sub] |
|
|
|
|
|
from scipy.spatial.distance import cdist |
|
nns = cdist(corres, sub, metric='sqeuclidean').argsort(axis=1)[:,:nn] |
|
|
|
|
|
pts = sub[nns] |
|
|
|
|
|
P1, P2 = pts[:,:,0:2], pts[:,:,2:4] |
|
P1 = np.concatenate((P1,np.ones_like(P1[:,:,:1])),axis=-1) |
|
A = (np.linalg.pinv(P1) @ P2).transpose(0,2,1) |
|
|
|
scale, (angy,angx) = detect_scale_rotation(A.transpose(1,2,0)[:,1::-1]) |
|
rot = np.arctan2(angy, angx) |
|
return scale.clip(min=0.2, max=5), rot |
|
|
|
|
|
def window1(x, size, w): |
|
l = x - int(0.5 + size / 2) |
|
r = l + int(0.5 + size) |
|
if l < 0: l,r = (0, r - l) |
|
if r > w: l,r = (l + w - r, w) |
|
if l < 0: l,r = 0,w |
|
return slice(l,r) |
|
|
|
def window(cx, cy, win_size, scale, img_shape): |
|
return (window1(int(cy), win_size*scale, img_shape[0]), |
|
window1(int(cx), win_size*scale, img_shape[1])) |
|
|
|
def is_in( pts, window ): |
|
x, y = pts.T |
|
sly, slx = window |
|
return (slx.start <= x) & (x < slx.stop) & (sly.start <= y) & (y < sly.stop) |
|
|
|
def score_windows( valid1, valid2 ): |
|
inter = (valid1 & valid2).sum() |
|
iou1 = inter / (valid1.sum() + 1e-8) |
|
iou2 = inter / (valid2.sum() + 1e-8) |
|
return inter * min(iou1, iou2) |
|
|
|
def imresize( img, max_size, resample=Image.ANTIALIAS): |
|
if max(img.shape[:2]) > max_size: |
|
if img.shape[-1] == 2: |
|
img = np.stack([np.float32(Image.fromarray(img[...,i]).resize((max_size,max_size), resample=resample)) for i in range(2)], axis=-1) |
|
else: |
|
img = np.asarray(Image.fromarray(img).resize((max_size,max_size), resample=resample)) |
|
assert img.shape[0] == img.shape[1] == max_size, bb() |
|
return img |
|
|
|
def wintrf( window, final_img ): |
|
wy, wx = window |
|
H, W = final_img.shape[:2] |
|
T = np.float32((((wx.stop-wx.start)/W, 0, wx.start), |
|
(0, (wy.stop-wy.start)/H, wy.start), |
|
(0, 0, 1)) ) |
|
return invh(T) |
|
|
|
|
|
def collate_ordered(batch, _use_shared_memory=True): |
|
pairs, gt = zip(*batch) |
|
imgs1, imgs2 = zip(*pairs) |
|
assert len(imgs1) == len(imgs2) == len(gt) and isinstance(gt[0], dict) |
|
|
|
|
|
supervised = [i for i,b in enumerate(gt) if np.isfinite(b['homography']).all()] |
|
unsupervsd = [i for i,b in enumerate(gt) if np.isnan(b['homography']).any()] |
|
order = supervised + unsupervsd |
|
|
|
def collate( tensors, key=None ): |
|
import torch |
|
batch = todevice([tensors[i] for i in order], 'cpu') |
|
if key == 'corres': return batch |
|
if _use_shared_memory: |
|
numel = sum([x.numel() for x in batch]) |
|
storage = batch[0].storage()._new_shared(numel) |
|
out = batch[0].new(storage) |
|
return torch.stack(batch, dim=0, out=out) |
|
|
|
return (collate(imgs1), collate(imgs2)), {k:collate([b[k] for b in gt],k) for k in gt[0]} |
|
|
|
|
|
if __name__ == '__main__': |
|
from datasets import * |
|
from tools.viz import show_random_pairs |
|
|
|
db = BalancedCatImagePairs( |
|
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'), |
|
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'), |
|
8000, SfM120k_Pairs()) |
|
|
|
db = FastPairLoader(db, |
|
crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise()', |
|
p_swap=0.5, p_flip=0.5, scale_jitter=0, seed=777) |
|
|
|
show_random_pairs(db) |
|
|