Spaces:
Running
Running
xuelunshen
commited on
Commit
·
1bfbd08
1
Parent(s):
8d7cbc7
update: gim code
Browse files- .gitignore +3 -0
- common/utils.py +1 -0
- hloc/match_dense.py +17 -0
- hloc/matchers/gim.py +121 -0
- hloc/matchers/networks/dkm/__init__.py +4 -0
- hloc/matchers/networks/dkm/datasets/__init__.py +1 -0
- hloc/matchers/networks/dkm/datasets/megadepth.py +177 -0
- hloc/matchers/networks/dkm/datasets/scannet.py +151 -0
- hloc/matchers/networks/dkm/models/__init__.py +4 -0
- hloc/matchers/networks/dkm/models/dkm.py +751 -0
- hloc/matchers/networks/dkm/models/encoders.py +148 -0
- hloc/matchers/networks/dkm/models/model_zoo/DKMv3.py +145 -0
- hloc/matchers/networks/dkm/models/model_zoo/__init__.py +39 -0
- hloc/matchers/networks/dkm/utils/__init__.py +13 -0
- hloc/matchers/networks/dkm/utils/kde.py +26 -0
- hloc/matchers/networks/dkm/utils/local_correlation.py +40 -0
- hloc/matchers/networks/dkm/utils/transforms.py +104 -0
- hloc/matchers/networks/dkm/utils/utils.py +341 -0
.gitignore
CHANGED
@@ -21,3 +21,6 @@ gradio_cached_examples
|
|
21 |
hloc/matchers/quadtree.py
|
22 |
third_party/QuadTreeAttention
|
23 |
desktop.ini
|
|
|
|
|
|
|
|
21 |
hloc/matchers/quadtree.py
|
22 |
third_party/QuadTreeAttention
|
23 |
desktop.ini
|
24 |
+
|
25 |
+
*/.DS_Store
|
26 |
+
.DS_Store
|
common/utils.py
CHANGED
@@ -448,6 +448,7 @@ ransac_zoo = {
|
|
448 |
|
449 |
# Matchers collections
|
450 |
matcher_zoo = {
|
|
|
451 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
452 |
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
453 |
# 'dedode-sparse': {
|
|
|
448 |
|
449 |
# Matchers collections
|
450 |
matcher_zoo = {
|
451 |
+
"gim": {"config": match_dense.confs["gim"], "dense": True},
|
452 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
453 |
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
454 |
# 'dedode-sparse': {
|
hloc/match_dense.py
CHANGED
@@ -9,6 +9,23 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
9 |
|
10 |
confs = {
|
11 |
# Best quality but loads of points. Only use for small scenes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"loftr": {
|
13 |
"output": "matches-loftr",
|
14 |
"model": {
|
|
|
9 |
|
10 |
confs = {
|
11 |
# Best quality but loads of points. Only use for small scenes
|
12 |
+
"gim": {
|
13 |
+
"output": "matches-gim",
|
14 |
+
"model": {
|
15 |
+
"name": "gim",
|
16 |
+
"weights": "gim_dkm_100h.ckpt",
|
17 |
+
"max_keypoints": 2000,
|
18 |
+
"match_threshold": 0.2,
|
19 |
+
},
|
20 |
+
"preprocessing": {
|
21 |
+
"grayscale": False,
|
22 |
+
"force_resize": True,
|
23 |
+
"resize_max": 1024,
|
24 |
+
"width": 80,
|
25 |
+
"height": 60,
|
26 |
+
"dfactor": 8,
|
27 |
+
},
|
28 |
+
},
|
29 |
"loftr": {
|
30 |
"output": "matches-loftr",
|
31 |
"model": {
|
hloc/matchers/gim.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from ..utils.base_model import BaseModel
|
7 |
+
from .. import logger
|
8 |
+
|
9 |
+
from .networks.dkm.models.model_zoo.DKMv3 import DKMv3
|
10 |
+
|
11 |
+
weight_path = Path(__file__).parent / 'networks' / 'dkm'
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
|
15 |
+
class GIM(BaseModel):
|
16 |
+
default_conf = {
|
17 |
+
"model_name": "gim_dkm_100h.ckpt",
|
18 |
+
"match_threshold": 0.2,
|
19 |
+
"checkpoint_dir": weight_path,
|
20 |
+
}
|
21 |
+
required_inputs = [
|
22 |
+
"image0",
|
23 |
+
"image1",
|
24 |
+
]
|
25 |
+
# Models exported using
|
26 |
+
# dkm_models = {
|
27 |
+
# "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
|
28 |
+
# "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
|
29 |
+
# }
|
30 |
+
|
31 |
+
def _init(self, conf):
|
32 |
+
model_path = weight_path / conf["model_name"]
|
33 |
+
|
34 |
+
# Download the model.
|
35 |
+
if not model_path.exists():
|
36 |
+
model_path.parent.mkdir(exist_ok=True)
|
37 |
+
link = self.dkm_models[conf["model_name"]]
|
38 |
+
cmd = ["wget", link, "-O", str(model_path)]
|
39 |
+
logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
|
40 |
+
subprocess.run(cmd, check=True)
|
41 |
+
logger.info(f"Loading DKMv3 model...")
|
42 |
+
# self.net = DKMv3(path_to_weights=str(model_path), device=device)
|
43 |
+
|
44 |
+
model = DKMv3(None, 672, 896, upsample_preds=True)
|
45 |
+
|
46 |
+
checkpoints_path = join('checkpoints', conf['weights'])
|
47 |
+
state_dict = torch.load(checkpoints_path, map_location='cpu')
|
48 |
+
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
|
49 |
+
for k in list(state_dict.keys()):
|
50 |
+
if k.startswith('model.'):
|
51 |
+
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
|
52 |
+
if 'encoder.net.fc' in k:
|
53 |
+
state_dict.pop(k)
|
54 |
+
model.load_state_dict(state_dict)
|
55 |
+
|
56 |
+
self.net = model
|
57 |
+
|
58 |
+
def _forward(self, data):
|
59 |
+
# img0 = data["image0"].cpu().numpy().squeeze() * 255
|
60 |
+
# img1 = data["image1"].cpu().numpy().squeeze() * 255
|
61 |
+
# img0 = img0.transpose(1, 2, 0)
|
62 |
+
# img1 = img1.transpose(1, 2, 0)
|
63 |
+
# img0 = Image.fromarray(img0.astype("uint8"))
|
64 |
+
# img1 = Image.fromarray(img1.astype("uint8"))
|
65 |
+
# W_A, H_A = img0.size
|
66 |
+
# W_B, H_B = img1.size
|
67 |
+
#
|
68 |
+
# warp, certainty = self.net.match(img0, img1, device=device)
|
69 |
+
# matches, certainty = self.net.sample(warp, certainty)
|
70 |
+
# kpts1, kpts2 = self.net.to_pixel_coordinates(
|
71 |
+
# matches, H_A, W_A, H_B, W_B
|
72 |
+
# )
|
73 |
+
|
74 |
+
image0, image1 = data['image0'], data['image1']
|
75 |
+
orig_width = image0.shape[3]
|
76 |
+
orig_height = image0.shape[2]
|
77 |
+
aspect_ratio = 896 / 672
|
78 |
+
new_width = max(orig_width, int(orig_height * aspect_ratio))
|
79 |
+
new_height = max(orig_height, int(orig_width / aspect_ratio))
|
80 |
+
pad_height = new_height - orig_height
|
81 |
+
pad_width = new_width - orig_width
|
82 |
+
pad_top = pad_height // 2
|
83 |
+
pad_bottom = pad_height - pad_top
|
84 |
+
pad_left = pad_width // 2
|
85 |
+
pad_right = pad_width - pad_left
|
86 |
+
image0 = torch.nn.functional.pad(image0, (pad_left, pad_right, pad_top, pad_bottom))
|
87 |
+
image1 = torch.nn.functional.pad(image1, (pad_left, pad_right, pad_top, pad_bottom))
|
88 |
+
dense_matches, dense_certainty = self.net.match(image0, image1)
|
89 |
+
sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, 2048)
|
90 |
+
height0, width0 = image0.shape[-2:]
|
91 |
+
height1, width1 = image1.shape[-2:]
|
92 |
+
kpts0 = sparse_matches[:, :2]
|
93 |
+
kpts1 = sparse_matches[:, 2:]
|
94 |
+
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
|
95 |
+
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
|
96 |
+
b_ids, i_ids = torch.where(mconf[None])
|
97 |
+
# before padding
|
98 |
+
kpts0 -= kpts0.new_tensor((pad_left, pad_top))[None]
|
99 |
+
kpts1 -= kpts1.new_tensor((pad_left, pad_top))[None]
|
100 |
+
mask = (kpts0[:, 0] > 0) & \
|
101 |
+
(kpts0[:, 1] > 0) & \
|
102 |
+
(kpts1[:, 0] > 0) & \
|
103 |
+
(kpts1[:, 1] > 0)
|
104 |
+
mask = mask & \
|
105 |
+
(kpts0[:, 0] <= (orig_width - 1)) & \
|
106 |
+
(kpts1[:, 0] <= (orig_width - 1)) & \
|
107 |
+
(kpts0[:, 1] <= (orig_height - 1)) & \
|
108 |
+
(kpts1[:, 1] <= (orig_height - 1))
|
109 |
+
pred = {
|
110 |
+
'keypoints0': kpts0[i_ids],
|
111 |
+
'keypoints1': kpts1[i_ids],
|
112 |
+
'confidence': mconf[i_ids],
|
113 |
+
'batch_indexes': b_ids,
|
114 |
+
}
|
115 |
+
scores, b_ids = pred['confidence'], pred['batch_indexes']
|
116 |
+
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
|
117 |
+
pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
|
118 |
+
pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]
|
119 |
+
|
120 |
+
out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
|
121 |
+
return out
|
hloc/matchers/networks/dkm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import (
|
2 |
+
DKMv3_outdoor,
|
3 |
+
DKMv3_indoor,
|
4 |
+
)
|
hloc/matchers/networks/dkm/datasets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .megadepth import MegadepthBuilder
|
hloc/matchers/networks/dkm/datasets/megadepth.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import h5py
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
8 |
+
|
9 |
+
from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
10 |
+
import torchvision.transforms.functional as tvf
|
11 |
+
from dkm.utils.transforms import GeometricSequential
|
12 |
+
import kornia.augmentation as K
|
13 |
+
|
14 |
+
|
15 |
+
class MegadepthScene:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
data_root,
|
19 |
+
scene_info,
|
20 |
+
ht=384,
|
21 |
+
wt=512,
|
22 |
+
min_overlap=0.0,
|
23 |
+
shake_t=0,
|
24 |
+
rot_prob=0.0,
|
25 |
+
normalize=True,
|
26 |
+
) -> None:
|
27 |
+
self.data_root = data_root
|
28 |
+
self.image_paths = scene_info["image_paths"]
|
29 |
+
self.depth_paths = scene_info["depth_paths"]
|
30 |
+
self.intrinsics = scene_info["intrinsics"]
|
31 |
+
self.poses = scene_info["poses"]
|
32 |
+
self.pairs = scene_info["pairs"]
|
33 |
+
self.overlaps = scene_info["overlaps"]
|
34 |
+
threshold = self.overlaps > min_overlap
|
35 |
+
self.pairs = self.pairs[threshold]
|
36 |
+
self.overlaps = self.overlaps[threshold]
|
37 |
+
if len(self.pairs) > 100000:
|
38 |
+
pairinds = np.random.choice(
|
39 |
+
np.arange(0, len(self.pairs)), 100000, replace=False
|
40 |
+
)
|
41 |
+
self.pairs = self.pairs[pairinds]
|
42 |
+
self.overlaps = self.overlaps[pairinds]
|
43 |
+
# counts, bins = np.histogram(self.overlaps,20)
|
44 |
+
# print(counts)
|
45 |
+
self.im_transform_ops = get_tuple_transform_ops(
|
46 |
+
resize=(ht, wt), normalize=normalize
|
47 |
+
)
|
48 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
49 |
+
resize=(ht, wt), normalize=False
|
50 |
+
)
|
51 |
+
self.wt, self.ht = wt, ht
|
52 |
+
self.shake_t = shake_t
|
53 |
+
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
54 |
+
|
55 |
+
def load_im(self, im_ref, crop=None):
|
56 |
+
im = Image.open(im_ref)
|
57 |
+
return im
|
58 |
+
|
59 |
+
def load_depth(self, depth_ref, crop=None):
|
60 |
+
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
61 |
+
return torch.from_numpy(depth)
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.pairs)
|
65 |
+
|
66 |
+
def scale_intrinsic(self, K, wi, hi):
|
67 |
+
sx, sy = self.wt / wi, self.ht / hi
|
68 |
+
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
|
69 |
+
return sK @ K
|
70 |
+
|
71 |
+
def rand_shake(self, *things):
|
72 |
+
t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
|
73 |
+
return [
|
74 |
+
tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
|
75 |
+
for thing in things
|
76 |
+
], t
|
77 |
+
|
78 |
+
def __getitem__(self, pair_idx):
|
79 |
+
# read intrinsics of original size
|
80 |
+
idx1, idx2 = self.pairs[pair_idx]
|
81 |
+
K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
|
82 |
+
K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
|
83 |
+
|
84 |
+
# read and compute relative poses
|
85 |
+
T1 = self.poses[idx1]
|
86 |
+
T2 = self.poses[idx2]
|
87 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
|
88 |
+
:4, :4
|
89 |
+
] # (4, 4)
|
90 |
+
|
91 |
+
# Load positive pair data
|
92 |
+
im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
|
93 |
+
depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
|
94 |
+
im_src_ref = os.path.join(self.data_root, im1)
|
95 |
+
im_pos_ref = os.path.join(self.data_root, im2)
|
96 |
+
depth_src_ref = os.path.join(self.data_root, depth1)
|
97 |
+
depth_pos_ref = os.path.join(self.data_root, depth2)
|
98 |
+
# return torch.randn((1000,1000))
|
99 |
+
im_src = self.load_im(im_src_ref)
|
100 |
+
im_pos = self.load_im(im_pos_ref)
|
101 |
+
depth_src = self.load_depth(depth_src_ref)
|
102 |
+
depth_pos = self.load_depth(depth_pos_ref)
|
103 |
+
|
104 |
+
# Recompute camera intrinsic matrix due to the resize
|
105 |
+
K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
|
106 |
+
K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
|
107 |
+
# Process images
|
108 |
+
im_src, im_pos = self.im_transform_ops((im_src, im_pos))
|
109 |
+
depth_src, depth_pos = self.depth_transform_ops(
|
110 |
+
(depth_src[None, None], depth_pos[None, None])
|
111 |
+
)
|
112 |
+
[im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
|
113 |
+
im_src, im_pos, depth_src, depth_pos
|
114 |
+
)
|
115 |
+
im_src, Hq = self.H_generator(im_src[None])
|
116 |
+
depth_src = self.H_generator.apply_transform(depth_src, Hq)
|
117 |
+
K1[:2, 2] += t
|
118 |
+
K2[:2, 2] += t
|
119 |
+
K1 = Hq[0] @ K1
|
120 |
+
data_dict = {
|
121 |
+
"query": im_src[0],
|
122 |
+
"query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
|
123 |
+
"support": im_pos,
|
124 |
+
"support_identifier": self.image_paths[idx2]
|
125 |
+
.split("/")[-1]
|
126 |
+
.split(".jpg")[0],
|
127 |
+
"query_depth": depth_src[0, 0],
|
128 |
+
"support_depth": depth_pos[0, 0],
|
129 |
+
"K1": K1,
|
130 |
+
"K2": K2,
|
131 |
+
"T_1to2": T_1to2,
|
132 |
+
}
|
133 |
+
return data_dict
|
134 |
+
|
135 |
+
|
136 |
+
class MegadepthBuilder:
|
137 |
+
def __init__(self, data_root="data/megadepth") -> None:
|
138 |
+
self.data_root = data_root
|
139 |
+
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
140 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
141 |
+
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
142 |
+
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
143 |
+
|
144 |
+
def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
|
145 |
+
if split == "train":
|
146 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
147 |
+
elif split == "train_loftr":
|
148 |
+
scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
|
149 |
+
elif split == "test":
|
150 |
+
scene_names = self.test_scenes
|
151 |
+
elif split == "test_loftr":
|
152 |
+
scene_names = self.test_scenes_loftr
|
153 |
+
else:
|
154 |
+
raise ValueError(f"Split {split} not available")
|
155 |
+
scenes = []
|
156 |
+
for scene_name in scene_names:
|
157 |
+
scene_info = np.load(
|
158 |
+
os.path.join(self.scene_info_root, scene_name), allow_pickle=True
|
159 |
+
).item()
|
160 |
+
scenes.append(
|
161 |
+
MegadepthScene(
|
162 |
+
self.data_root, scene_info, min_overlap=min_overlap, **kwargs
|
163 |
+
)
|
164 |
+
)
|
165 |
+
return scenes
|
166 |
+
|
167 |
+
def weight_scenes(self, concat_dataset, alpha=0.5):
|
168 |
+
ns = []
|
169 |
+
for d in concat_dataset.datasets:
|
170 |
+
ns.append(len(d))
|
171 |
+
ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
|
172 |
+
return ws
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
|
177 |
+
mega_test[0]
|
hloc/matchers/networks/dkm/datasets/scannet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import (
|
9 |
+
Dataset,
|
10 |
+
DataLoader,
|
11 |
+
ConcatDataset)
|
12 |
+
|
13 |
+
import torchvision.transforms.functional as tvf
|
14 |
+
import kornia.augmentation as K
|
15 |
+
import os.path as osp
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
18 |
+
from dkm.utils.transforms import GeometricSequential
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
class ScanNetScene:
|
23 |
+
def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
|
24 |
+
self.scene_root = osp.join(data_root,"scans","scans_train")
|
25 |
+
self.data_names = scene_info['name']
|
26 |
+
self.overlaps = scene_info['score']
|
27 |
+
# Only sample 10s
|
28 |
+
valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
|
29 |
+
self.overlaps = self.overlaps[valid]
|
30 |
+
self.data_names = self.data_names[valid]
|
31 |
+
if len(self.data_names) > 10000:
|
32 |
+
pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
|
33 |
+
self.data_names = self.data_names[pairinds]
|
34 |
+
self.overlaps = self.overlaps[pairinds]
|
35 |
+
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
|
36 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
|
37 |
+
self.wt, self.ht = wt, ht
|
38 |
+
self.shake_t = shake_t
|
39 |
+
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
40 |
+
|
41 |
+
def load_im(self, im_ref, crop=None):
|
42 |
+
im = Image.open(im_ref)
|
43 |
+
return im
|
44 |
+
|
45 |
+
def load_depth(self, depth_ref, crop=None):
|
46 |
+
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
|
47 |
+
depth = depth / 1000
|
48 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
49 |
+
return depth
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data_names)
|
53 |
+
|
54 |
+
def scale_intrinsic(self, K, wi, hi):
|
55 |
+
sx, sy = self.wt / wi, self.ht / hi
|
56 |
+
sK = torch.tensor([[sx, 0, 0],
|
57 |
+
[0, sy, 0],
|
58 |
+
[0, 0, 1]])
|
59 |
+
return sK@K
|
60 |
+
|
61 |
+
def read_scannet_pose(self,path):
|
62 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
pose_w2c (np.ndarray): (4, 4)
|
66 |
+
"""
|
67 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
68 |
+
world2cam = np.linalg.inv(cam2world)
|
69 |
+
return world2cam
|
70 |
+
|
71 |
+
|
72 |
+
def read_scannet_intrinsic(self,path):
|
73 |
+
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
|
74 |
+
"""
|
75 |
+
intrinsic = np.loadtxt(path, delimiter=' ')
|
76 |
+
return intrinsic[:-1, :-1]
|
77 |
+
|
78 |
+
def __getitem__(self, pair_idx):
|
79 |
+
# read intrinsics of original size
|
80 |
+
data_name = self.data_names[pair_idx]
|
81 |
+
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
|
82 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
83 |
+
|
84 |
+
# read the intrinsic of depthmap
|
85 |
+
K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
|
86 |
+
scene_name,
|
87 |
+
'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
|
88 |
+
# read and compute relative poses
|
89 |
+
T1 = self.read_scannet_pose(osp.join(self.scene_root,
|
90 |
+
scene_name,
|
91 |
+
'pose', f'{stem_name_1}.txt'))
|
92 |
+
T2 = self.read_scannet_pose(osp.join(self.scene_root,
|
93 |
+
scene_name,
|
94 |
+
'pose', f'{stem_name_2}.txt'))
|
95 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
|
96 |
+
|
97 |
+
# Load positive pair data
|
98 |
+
im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
|
99 |
+
im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
|
100 |
+
depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
|
101 |
+
depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
|
102 |
+
|
103 |
+
im_src = self.load_im(im_src_ref)
|
104 |
+
im_pos = self.load_im(im_pos_ref)
|
105 |
+
depth_src = self.load_depth(depth_src_ref)
|
106 |
+
depth_pos = self.load_depth(depth_pos_ref)
|
107 |
+
|
108 |
+
# Recompute camera intrinsic matrix due to the resize
|
109 |
+
K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
|
110 |
+
K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
|
111 |
+
# Process images
|
112 |
+
im_src, im_pos = self.im_transform_ops((im_src, im_pos))
|
113 |
+
depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
|
114 |
+
|
115 |
+
data_dict = {'query': im_src,
|
116 |
+
'support': im_pos,
|
117 |
+
'query_depth': depth_src[0,0],
|
118 |
+
'support_depth': depth_pos[0,0],
|
119 |
+
'K1': K1,
|
120 |
+
'K2': K2,
|
121 |
+
'T_1to2':T_1to2,
|
122 |
+
}
|
123 |
+
return data_dict
|
124 |
+
|
125 |
+
|
126 |
+
class ScanNetBuilder:
|
127 |
+
def __init__(self, data_root = 'data/scannet') -> None:
|
128 |
+
self.data_root = data_root
|
129 |
+
self.scene_info_root = os.path.join(data_root,'scannet_indices')
|
130 |
+
self.all_scenes = os.listdir(self.scene_info_root)
|
131 |
+
|
132 |
+
def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
|
133 |
+
# Note: split doesn't matter here as we always use same scannet_train scenes
|
134 |
+
scene_names = self.all_scenes
|
135 |
+
scenes = []
|
136 |
+
for scene_name in tqdm(scene_names):
|
137 |
+
scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
|
138 |
+
scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
|
139 |
+
return scenes
|
140 |
+
|
141 |
+
def weight_scenes(self, concat_dataset, alpha=.5):
|
142 |
+
ns = []
|
143 |
+
for d in concat_dataset.datasets:
|
144 |
+
ns.append(len(d))
|
145 |
+
ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
|
146 |
+
return ws
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
|
151 |
+
mega_test[0]
|
hloc/matchers/networks/dkm/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model_zoo import (
|
2 |
+
DKMv3_outdoor,
|
3 |
+
DKMv3_indoor,
|
4 |
+
)
|
hloc/matchers/networks/dkm/models/dkm.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from ..utils.kde import kde
|
7 |
+
from ..utils import get_tuple_transform_ops
|
8 |
+
from ..utils.local_correlation import local_correlation
|
9 |
+
|
10 |
+
|
11 |
+
class ConvRefiner(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_dim=6,
|
15 |
+
hidden_dim=16,
|
16 |
+
out_dim=2,
|
17 |
+
dw=False,
|
18 |
+
kernel_size=5,
|
19 |
+
hidden_blocks=3,
|
20 |
+
displacement_emb = None,
|
21 |
+
displacement_emb_dim = None,
|
22 |
+
local_corr_radius = None,
|
23 |
+
corr_in_other = None,
|
24 |
+
no_support_fm = False,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.block1 = self.create_block(
|
28 |
+
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
|
29 |
+
)
|
30 |
+
self.hidden_blocks = nn.Sequential(
|
31 |
+
*[
|
32 |
+
self.create_block(
|
33 |
+
hidden_dim,
|
34 |
+
hidden_dim,
|
35 |
+
dw=dw,
|
36 |
+
kernel_size=kernel_size,
|
37 |
+
)
|
38 |
+
for hb in range(hidden_blocks)
|
39 |
+
]
|
40 |
+
)
|
41 |
+
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
42 |
+
if displacement_emb:
|
43 |
+
self.has_displacement_emb = True
|
44 |
+
self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
45 |
+
else:
|
46 |
+
self.has_displacement_emb = False
|
47 |
+
self.local_corr_radius = local_corr_radius
|
48 |
+
self.corr_in_other = corr_in_other
|
49 |
+
self.no_support_fm = no_support_fm
|
50 |
+
def create_block(
|
51 |
+
self,
|
52 |
+
in_dim,
|
53 |
+
out_dim,
|
54 |
+
dw=False,
|
55 |
+
kernel_size=5,
|
56 |
+
):
|
57 |
+
num_groups = 1 if not dw else in_dim
|
58 |
+
if dw:
|
59 |
+
assert (
|
60 |
+
out_dim % in_dim == 0
|
61 |
+
), "outdim must be divisible by indim for depthwise"
|
62 |
+
conv1 = nn.Conv2d(
|
63 |
+
in_dim,
|
64 |
+
out_dim,
|
65 |
+
kernel_size=kernel_size,
|
66 |
+
stride=1,
|
67 |
+
padding=kernel_size // 2,
|
68 |
+
groups=num_groups,
|
69 |
+
)
|
70 |
+
norm = nn.BatchNorm2d(out_dim)
|
71 |
+
relu = nn.ReLU(inplace=True)
|
72 |
+
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
73 |
+
return nn.Sequential(conv1, norm, relu, conv2)
|
74 |
+
|
75 |
+
def forward(self, x, y, flow):
|
76 |
+
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
|
77 |
+
|
78 |
+
Args:
|
79 |
+
x ([type]): [description]
|
80 |
+
y ([type]): [description]
|
81 |
+
flow ([type]): [description]
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
[type]: [description]
|
85 |
+
"""
|
86 |
+
device = x.device
|
87 |
+
b,c,hs,ws = x.shape
|
88 |
+
with torch.no_grad():
|
89 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
|
90 |
+
if self.has_displacement_emb:
|
91 |
+
query_coords = torch.meshgrid(
|
92 |
+
(
|
93 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
94 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
95 |
+
)
|
96 |
+
)
|
97 |
+
query_coords = torch.stack((query_coords[1], query_coords[0]))
|
98 |
+
query_coords = query_coords[None].expand(b, 2, hs, ws)
|
99 |
+
in_displacement = flow-query_coords
|
100 |
+
emb_in_displacement = self.disp_emb(in_displacement)
|
101 |
+
if self.local_corr_radius:
|
102 |
+
#TODO: should corr have gradient?
|
103 |
+
if self.corr_in_other:
|
104 |
+
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
105 |
+
local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
|
106 |
+
else:
|
107 |
+
# Otherwise we use the warp to sample in the first image
|
108 |
+
# This is actually different operations, especially for large viewpoint changes
|
109 |
+
local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
|
110 |
+
if self.no_support_fm:
|
111 |
+
x_hat = torch.zeros_like(x)
|
112 |
+
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
113 |
+
else:
|
114 |
+
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
115 |
+
else:
|
116 |
+
if self.no_support_fm:
|
117 |
+
x_hat = torch.zeros_like(x)
|
118 |
+
d = torch.cat((x, x_hat), dim=1)
|
119 |
+
d = self.block1(d)
|
120 |
+
d = self.hidden_blocks(d)
|
121 |
+
d = self.out_conv(d)
|
122 |
+
certainty, displacement = d[:, :-2], d[:, -2:]
|
123 |
+
return certainty, displacement
|
124 |
+
|
125 |
+
|
126 |
+
class CosKernel(nn.Module): # similar to softmax kernel
|
127 |
+
def __init__(self, T, learn_temperature=False):
|
128 |
+
super().__init__()
|
129 |
+
self.learn_temperature = learn_temperature
|
130 |
+
if self.learn_temperature:
|
131 |
+
self.T = nn.Parameter(torch.tensor(T))
|
132 |
+
else:
|
133 |
+
self.T = T
|
134 |
+
|
135 |
+
def __call__(self, x, y, eps=1e-6):
|
136 |
+
c = torch.einsum("bnd,bmd->bnm", x, y) / (
|
137 |
+
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
|
138 |
+
)
|
139 |
+
if self.learn_temperature:
|
140 |
+
T = self.T.abs() + 0.01
|
141 |
+
else:
|
142 |
+
T = torch.tensor(self.T, device=c.device)
|
143 |
+
K = ((c - 1.0) / T).exp()
|
144 |
+
return K
|
145 |
+
|
146 |
+
|
147 |
+
class CAB(nn.Module):
|
148 |
+
def __init__(self, in_channels, out_channels):
|
149 |
+
super(CAB, self).__init__()
|
150 |
+
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
151 |
+
self.conv1 = nn.Conv2d(
|
152 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
153 |
+
)
|
154 |
+
self.relu = nn.ReLU()
|
155 |
+
self.conv2 = nn.Conv2d(
|
156 |
+
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
157 |
+
)
|
158 |
+
self.sigmod = nn.Sigmoid()
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x1, x2 = x # high, low (old, new)
|
162 |
+
x = torch.cat([x1, x2], dim=1)
|
163 |
+
x = self.global_pooling(x)
|
164 |
+
x = self.conv1(x)
|
165 |
+
x = self.relu(x)
|
166 |
+
x = self.conv2(x)
|
167 |
+
x = self.sigmod(x)
|
168 |
+
x2 = x * x2
|
169 |
+
res = x2 + x1
|
170 |
+
return res
|
171 |
+
|
172 |
+
|
173 |
+
class RRB(nn.Module):
|
174 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
175 |
+
super(RRB, self).__init__()
|
176 |
+
self.conv1 = nn.Conv2d(
|
177 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
178 |
+
)
|
179 |
+
self.conv2 = nn.Conv2d(
|
180 |
+
out_channels,
|
181 |
+
out_channels,
|
182 |
+
kernel_size=kernel_size,
|
183 |
+
stride=1,
|
184 |
+
padding=kernel_size // 2,
|
185 |
+
)
|
186 |
+
self.relu = nn.ReLU()
|
187 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
188 |
+
self.conv3 = nn.Conv2d(
|
189 |
+
out_channels,
|
190 |
+
out_channels,
|
191 |
+
kernel_size=kernel_size,
|
192 |
+
stride=1,
|
193 |
+
padding=kernel_size // 2,
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
x = self.conv1(x)
|
198 |
+
res = self.conv2(x)
|
199 |
+
res = self.bn(res)
|
200 |
+
res = self.relu(res)
|
201 |
+
res = self.conv3(res)
|
202 |
+
return self.relu(x + res)
|
203 |
+
|
204 |
+
|
205 |
+
class DFN(nn.Module):
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
internal_dim,
|
209 |
+
feat_input_modules,
|
210 |
+
pred_input_modules,
|
211 |
+
rrb_d_dict,
|
212 |
+
cab_dict,
|
213 |
+
rrb_u_dict,
|
214 |
+
use_global_context=False,
|
215 |
+
global_dim=None,
|
216 |
+
terminal_module=None,
|
217 |
+
upsample_mode="bilinear",
|
218 |
+
align_corners=False,
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
if use_global_context:
|
222 |
+
assert (
|
223 |
+
global_dim is not None
|
224 |
+
), "Global dim must be provided when using global context"
|
225 |
+
self.align_corners = align_corners
|
226 |
+
self.internal_dim = internal_dim
|
227 |
+
self.feat_input_modules = feat_input_modules
|
228 |
+
self.pred_input_modules = pred_input_modules
|
229 |
+
self.rrb_d = rrb_d_dict
|
230 |
+
self.cab = cab_dict
|
231 |
+
self.rrb_u = rrb_u_dict
|
232 |
+
self.use_global_context = use_global_context
|
233 |
+
if use_global_context:
|
234 |
+
self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
|
235 |
+
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
236 |
+
self.terminal_module = (
|
237 |
+
terminal_module if terminal_module is not None else nn.Identity()
|
238 |
+
)
|
239 |
+
self.upsample_mode = upsample_mode
|
240 |
+
self._scales = [int(key) for key in self.terminal_module.keys()]
|
241 |
+
|
242 |
+
def scales(self):
|
243 |
+
return self._scales.copy()
|
244 |
+
|
245 |
+
def forward(self, embeddings, feats, context, key):
|
246 |
+
feats = self.feat_input_modules[str(key)](feats)
|
247 |
+
embeddings = torch.cat([feats, embeddings], dim=1)
|
248 |
+
embeddings = self.rrb_d[str(key)](embeddings)
|
249 |
+
context = self.cab[str(key)]([context, embeddings])
|
250 |
+
context = self.rrb_u[str(key)](context)
|
251 |
+
preds = self.terminal_module[str(key)](context)
|
252 |
+
pred_coord = preds[:, -2:]
|
253 |
+
pred_certainty = preds[:, :-2]
|
254 |
+
return pred_coord, pred_certainty, context
|
255 |
+
|
256 |
+
|
257 |
+
class GP(nn.Module):
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
kernel,
|
261 |
+
T=1,
|
262 |
+
learn_temperature=False,
|
263 |
+
only_attention=False,
|
264 |
+
gp_dim=64,
|
265 |
+
basis="fourier",
|
266 |
+
covar_size=5,
|
267 |
+
only_nearest_neighbour=False,
|
268 |
+
sigma_noise=0.1,
|
269 |
+
no_cov=False,
|
270 |
+
predict_features = False,
|
271 |
+
):
|
272 |
+
super().__init__()
|
273 |
+
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
274 |
+
self.sigma_noise = sigma_noise
|
275 |
+
self.covar_size = covar_size
|
276 |
+
self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
|
277 |
+
self.only_attention = only_attention
|
278 |
+
self.only_nearest_neighbour = only_nearest_neighbour
|
279 |
+
self.basis = basis
|
280 |
+
self.no_cov = no_cov
|
281 |
+
self.dim = gp_dim
|
282 |
+
self.predict_features = predict_features
|
283 |
+
|
284 |
+
def get_local_cov(self, cov):
|
285 |
+
K = self.covar_size
|
286 |
+
b, h, w, h, w = cov.shape
|
287 |
+
hw = h * w
|
288 |
+
cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
|
289 |
+
delta = torch.stack(
|
290 |
+
torch.meshgrid(
|
291 |
+
torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
|
292 |
+
),
|
293 |
+
dim=-1,
|
294 |
+
)
|
295 |
+
positions = torch.stack(
|
296 |
+
torch.meshgrid(
|
297 |
+
torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
|
298 |
+
),
|
299 |
+
dim=-1,
|
300 |
+
)
|
301 |
+
neighbours = positions[:, :, None, None, :] + delta[None, :, :]
|
302 |
+
points = torch.arange(hw)[:, None].expand(hw, K**2)
|
303 |
+
local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
|
304 |
+
:,
|
305 |
+
points.flatten(),
|
306 |
+
neighbours[..., 0].flatten(),
|
307 |
+
neighbours[..., 1].flatten(),
|
308 |
+
].reshape(b, h, w, K**2)
|
309 |
+
return local_cov
|
310 |
+
|
311 |
+
def reshape(self, x):
|
312 |
+
return rearrange(x, "b d h w -> b (h w) d")
|
313 |
+
|
314 |
+
def project_to_basis(self, x):
|
315 |
+
if self.basis == "fourier":
|
316 |
+
return torch.cos(8 * math.pi * self.pos_conv(x))
|
317 |
+
elif self.basis == "linear":
|
318 |
+
return self.pos_conv(x)
|
319 |
+
else:
|
320 |
+
raise ValueError(
|
321 |
+
"No other bases other than fourier and linear currently supported in public release"
|
322 |
+
)
|
323 |
+
|
324 |
+
def get_pos_enc(self, y):
|
325 |
+
b, c, h, w = y.shape
|
326 |
+
coarse_coords = torch.meshgrid(
|
327 |
+
(
|
328 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
|
329 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
|
330 |
+
)
|
331 |
+
)
|
332 |
+
|
333 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
334 |
+
None
|
335 |
+
].expand(b, h, w, 2)
|
336 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
337 |
+
coarse_embedded_coords = self.project_to_basis(coarse_coords)
|
338 |
+
return coarse_embedded_coords
|
339 |
+
|
340 |
+
def forward(self, x, y, **kwargs):
|
341 |
+
b, c, h1, w1 = x.shape
|
342 |
+
b, c, h2, w2 = y.shape
|
343 |
+
f = self.get_pos_enc(y)
|
344 |
+
if self.predict_features:
|
345 |
+
f = f + y[:,:self.dim] # Stupid way to predict features
|
346 |
+
b, d, h2, w2 = f.shape
|
347 |
+
#assert x.shape == y.shape
|
348 |
+
x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
|
349 |
+
K_xx = self.K(x, x)
|
350 |
+
K_yy = self.K(y, y)
|
351 |
+
K_xy = self.K(x, y)
|
352 |
+
K_yx = K_xy.permute(0, 2, 1)
|
353 |
+
sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
|
354 |
+
# Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
|
355 |
+
if len(K_yy[0]) > 2000:
|
356 |
+
K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
|
357 |
+
else:
|
358 |
+
K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
|
359 |
+
|
360 |
+
mu_x = K_xy.matmul(K_yy_inv.matmul(f))
|
361 |
+
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
362 |
+
if not self.no_cov:
|
363 |
+
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
364 |
+
cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
365 |
+
local_cov_x = self.get_local_cov(cov_x)
|
366 |
+
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
367 |
+
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
368 |
+
else:
|
369 |
+
gp_feats = mu_x
|
370 |
+
return gp_feats
|
371 |
+
|
372 |
+
|
373 |
+
class Encoder(nn.Module):
|
374 |
+
def __init__(self, resnet):
|
375 |
+
super().__init__()
|
376 |
+
self.resnet = resnet
|
377 |
+
def forward(self, x):
|
378 |
+
x0 = x
|
379 |
+
b, c, h, w = x.shape
|
380 |
+
x = self.resnet.conv1(x)
|
381 |
+
x = self.resnet.bn1(x)
|
382 |
+
x1 = self.resnet.relu(x)
|
383 |
+
|
384 |
+
x = self.resnet.maxpool(x1)
|
385 |
+
x2 = self.resnet.layer1(x)
|
386 |
+
|
387 |
+
x3 = self.resnet.layer2(x2)
|
388 |
+
|
389 |
+
x4 = self.resnet.layer3(x3)
|
390 |
+
|
391 |
+
x5 = self.resnet.layer4(x4)
|
392 |
+
feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
|
393 |
+
return feats
|
394 |
+
|
395 |
+
def train(self, mode=True):
|
396 |
+
super().train(mode)
|
397 |
+
for m in self.modules():
|
398 |
+
if isinstance(m, nn.BatchNorm2d):
|
399 |
+
m.eval()
|
400 |
+
pass
|
401 |
+
|
402 |
+
|
403 |
+
class Decoder(nn.Module):
|
404 |
+
def __init__(
|
405 |
+
self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
|
406 |
+
):
|
407 |
+
super().__init__()
|
408 |
+
self.embedding_decoder = embedding_decoder
|
409 |
+
self.gps = gps
|
410 |
+
self.proj = proj
|
411 |
+
self.conv_refiner = conv_refiner
|
412 |
+
self.detach = detach
|
413 |
+
if scales == "all":
|
414 |
+
self.scales = ["32", "16", "8", "4", "2", "1"]
|
415 |
+
else:
|
416 |
+
self.scales = scales
|
417 |
+
|
418 |
+
def upsample_preds(self, flow, certainty, query, support):
|
419 |
+
b, hs, ws, d = flow.shape
|
420 |
+
b, c, h, w = query.shape
|
421 |
+
flow = flow.permute(0, 3, 1, 2)
|
422 |
+
certainty = F.interpolate(
|
423 |
+
certainty, size=(h, w), align_corners=False, mode="bilinear"
|
424 |
+
)
|
425 |
+
flow = F.interpolate(
|
426 |
+
flow, size=(h, w), align_corners=False, mode="bilinear"
|
427 |
+
)
|
428 |
+
delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
|
429 |
+
flow = torch.stack(
|
430 |
+
(
|
431 |
+
flow[:, 0] + delta_flow[:, 0] / (4 * w),
|
432 |
+
flow[:, 1] + delta_flow[:, 1] / (4 * h),
|
433 |
+
),
|
434 |
+
dim=1,
|
435 |
+
)
|
436 |
+
flow = flow.permute(0, 2, 3, 1)
|
437 |
+
certainty = certainty + delta_certainty
|
438 |
+
return flow, certainty
|
439 |
+
|
440 |
+
def get_placeholder_flow(self, b, h, w, device):
|
441 |
+
coarse_coords = torch.meshgrid(
|
442 |
+
(
|
443 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
444 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
445 |
+
)
|
446 |
+
)
|
447 |
+
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
|
448 |
+
None
|
449 |
+
].expand(b, h, w, 2)
|
450 |
+
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
451 |
+
return coarse_coords
|
452 |
+
|
453 |
+
|
454 |
+
def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
|
455 |
+
coarse_scales = self.embedding_decoder.scales()
|
456 |
+
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
457 |
+
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
458 |
+
h, w = sizes[1]
|
459 |
+
b = f1[1].shape[0]
|
460 |
+
device = f1[1].device
|
461 |
+
coarsest_scale = int(all_scales[0])
|
462 |
+
old_stuff = torch.zeros(
|
463 |
+
b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
464 |
+
)
|
465 |
+
dense_corresps = {}
|
466 |
+
if not upsample:
|
467 |
+
dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
|
468 |
+
dense_certainty = 0.0
|
469 |
+
else:
|
470 |
+
dense_flow = F.interpolate(
|
471 |
+
dense_flow,
|
472 |
+
size=sizes[coarsest_scale],
|
473 |
+
align_corners=False,
|
474 |
+
mode="bilinear",
|
475 |
+
)
|
476 |
+
dense_certainty = F.interpolate(
|
477 |
+
dense_certainty,
|
478 |
+
size=sizes[coarsest_scale],
|
479 |
+
align_corners=False,
|
480 |
+
mode="bilinear",
|
481 |
+
)
|
482 |
+
for new_scale in all_scales:
|
483 |
+
ins = int(new_scale)
|
484 |
+
f1_s, f2_s = f1[ins], f2[ins]
|
485 |
+
if new_scale in self.proj:
|
486 |
+
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
487 |
+
b, c, hs, ws = f1_s.shape
|
488 |
+
if ins in coarse_scales:
|
489 |
+
old_stuff = F.interpolate(
|
490 |
+
old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
|
491 |
+
)
|
492 |
+
new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
|
493 |
+
dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
|
494 |
+
new_stuff, f1_s, old_stuff, new_scale
|
495 |
+
)
|
496 |
+
|
497 |
+
if new_scale in self.conv_refiner:
|
498 |
+
delta_certainty, displacement = self.conv_refiner[new_scale](
|
499 |
+
f1_s, f2_s, dense_flow
|
500 |
+
)
|
501 |
+
dense_flow = torch.stack(
|
502 |
+
(
|
503 |
+
dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
|
504 |
+
dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
|
505 |
+
),
|
506 |
+
dim=1,
|
507 |
+
)
|
508 |
+
dense_certainty = (
|
509 |
+
dense_certainty + delta_certainty
|
510 |
+
) # predict both certainty and displacement
|
511 |
+
|
512 |
+
dense_corresps[ins] = {
|
513 |
+
"dense_flow": dense_flow,
|
514 |
+
"dense_certainty": dense_certainty,
|
515 |
+
}
|
516 |
+
|
517 |
+
if new_scale != "1":
|
518 |
+
dense_flow = F.interpolate(
|
519 |
+
dense_flow,
|
520 |
+
size=sizes[ins // 2],
|
521 |
+
align_corners=False,
|
522 |
+
mode="bilinear",
|
523 |
+
)
|
524 |
+
|
525 |
+
dense_certainty = F.interpolate(
|
526 |
+
dense_certainty,
|
527 |
+
size=sizes[ins // 2],
|
528 |
+
align_corners=False,
|
529 |
+
mode="bilinear",
|
530 |
+
)
|
531 |
+
if self.detach:
|
532 |
+
dense_flow = dense_flow.detach()
|
533 |
+
dense_certainty = dense_certainty.detach()
|
534 |
+
return dense_corresps
|
535 |
+
|
536 |
+
|
537 |
+
class RegressionMatcher(nn.Module):
|
538 |
+
def __init__(
|
539 |
+
self,
|
540 |
+
encoder,
|
541 |
+
decoder,
|
542 |
+
h=384,
|
543 |
+
w=512,
|
544 |
+
use_contrastive_loss = False,
|
545 |
+
alpha = 1,
|
546 |
+
beta = 0,
|
547 |
+
sample_mode = "threshold",
|
548 |
+
upsample_preds = False,
|
549 |
+
symmetric = False,
|
550 |
+
name = None,
|
551 |
+
use_soft_mutual_nearest_neighbours = False,
|
552 |
+
):
|
553 |
+
super().__init__()
|
554 |
+
self.encoder = encoder
|
555 |
+
self.decoder = decoder
|
556 |
+
self.w_resized = w
|
557 |
+
self.h_resized = h
|
558 |
+
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
559 |
+
self.use_contrastive_loss = use_contrastive_loss
|
560 |
+
self.alpha = alpha
|
561 |
+
self.beta = beta
|
562 |
+
self.sample_mode = sample_mode
|
563 |
+
self.upsample_preds = upsample_preds
|
564 |
+
self.symmetric = symmetric
|
565 |
+
self.name = name
|
566 |
+
self.sample_thresh = 0.05
|
567 |
+
self.upsample_res = (1152, 1536)
|
568 |
+
if use_soft_mutual_nearest_neighbours:
|
569 |
+
assert symmetric, "MNS requires symmetric inference"
|
570 |
+
self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
|
571 |
+
|
572 |
+
def extract_backbone_features(self, batch, batched = True, upsample = True):
|
573 |
+
#TODO: only extract stride [1,2,4,8] for upsample = True
|
574 |
+
x_q = batch["query"]
|
575 |
+
x_s = batch["support"]
|
576 |
+
if batched:
|
577 |
+
X = torch.cat((x_q, x_s))
|
578 |
+
feature_pyramid = self.encoder(X)
|
579 |
+
else:
|
580 |
+
feature_pyramid = self.encoder(x_q), self.encoder(x_s)
|
581 |
+
return feature_pyramid
|
582 |
+
|
583 |
+
def sample(
|
584 |
+
self,
|
585 |
+
dense_matches,
|
586 |
+
dense_certainty,
|
587 |
+
num=10000,
|
588 |
+
):
|
589 |
+
if "threshold" in self.sample_mode:
|
590 |
+
upper_thresh = self.sample_thresh
|
591 |
+
dense_certainty = dense_certainty.clone()
|
592 |
+
dense_certainty_ = dense_certainty.clone()
|
593 |
+
dense_certainty[dense_certainty > upper_thresh] = 1
|
594 |
+
elif "pow" in self.sample_mode:
|
595 |
+
dense_certainty = dense_certainty**(1/3)
|
596 |
+
elif "naive" in self.sample_mode:
|
597 |
+
dense_certainty = torch.ones_like(dense_certainty)
|
598 |
+
matches, certainty = (
|
599 |
+
dense_matches.reshape(-1, 4),
|
600 |
+
dense_certainty.reshape(-1),
|
601 |
+
)
|
602 |
+
certainty_ = dense_certainty_.reshape(-1)
|
603 |
+
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
604 |
+
if not certainty.sum(): certainty = certainty + 1e-8
|
605 |
+
good_samples = torch.multinomial(certainty,
|
606 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
607 |
+
replacement=False)
|
608 |
+
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
609 |
+
good_certainty_ = certainty_[good_samples]
|
610 |
+
good_certainty = good_certainty_
|
611 |
+
if "balanced" not in self.sample_mode:
|
612 |
+
return good_matches, good_certainty
|
613 |
+
|
614 |
+
density = kde(good_matches, std=0.1)
|
615 |
+
p = 1 / (density+1)
|
616 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
617 |
+
balanced_samples = torch.multinomial(p,
|
618 |
+
num_samples = min(num,len(good_certainty)),
|
619 |
+
replacement=False)
|
620 |
+
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
621 |
+
|
622 |
+
def forward(self, batch, batched = True):
|
623 |
+
feature_pyramid = self.extract_backbone_features(batch, batched=batched)
|
624 |
+
if batched:
|
625 |
+
f_q_pyramid = {
|
626 |
+
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
627 |
+
}
|
628 |
+
f_s_pyramid = {
|
629 |
+
scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
|
630 |
+
}
|
631 |
+
else:
|
632 |
+
f_q_pyramid, f_s_pyramid = feature_pyramid
|
633 |
+
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
|
634 |
+
if self.training and self.use_contrastive_loss:
|
635 |
+
return dense_corresps, (f_q_pyramid, f_s_pyramid)
|
636 |
+
else:
|
637 |
+
return dense_corresps
|
638 |
+
|
639 |
+
def forward_symmetric(self, batch, upsample = False, batched = True):
|
640 |
+
feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
|
641 |
+
f_q_pyramid = feature_pyramid
|
642 |
+
f_s_pyramid = {
|
643 |
+
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
|
644 |
+
for scale, f_scale in feature_pyramid.items()
|
645 |
+
}
|
646 |
+
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
|
647 |
+
return dense_corresps
|
648 |
+
|
649 |
+
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
650 |
+
kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
651 |
+
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
652 |
+
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
653 |
+
return kpts_A, kpts_B
|
654 |
+
|
655 |
+
def match(
|
656 |
+
self,
|
657 |
+
im1_path,
|
658 |
+
im2_path,
|
659 |
+
*args,
|
660 |
+
batched=False,
|
661 |
+
):
|
662 |
+
assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
|
663 |
+
symmetric = self.symmetric
|
664 |
+
self.train(False)
|
665 |
+
with torch.no_grad():
|
666 |
+
if not batched:
|
667 |
+
b = 1
|
668 |
+
ws = self.w_resized
|
669 |
+
hs = self.h_resized
|
670 |
+
query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False)
|
671 |
+
support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False)
|
672 |
+
batch = {"query": query, "support": support}
|
673 |
+
else:
|
674 |
+
b, c, h, w = im1_path.shape
|
675 |
+
b, c, h2, w2 = im2_path.shape
|
676 |
+
assert w == w2 and h == h2, "For batched images we assume same size"
|
677 |
+
batch = {"query": im1_path, "support": im2_path}
|
678 |
+
hs, ws = self.h_resized, self.w_resized
|
679 |
+
finest_scale = 1
|
680 |
+
# Run matcher
|
681 |
+
if symmetric:
|
682 |
+
dense_corresps = self.forward_symmetric(batch, batched = True)
|
683 |
+
else:
|
684 |
+
dense_corresps = self.forward(batch, batched = True)
|
685 |
+
|
686 |
+
if self.upsample_preds:
|
687 |
+
hs, ws = self.upsample_res
|
688 |
+
low_res_certainty = F.interpolate(
|
689 |
+
dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
690 |
+
)
|
691 |
+
cert_clamp = 0
|
692 |
+
factor = 0.5
|
693 |
+
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
694 |
+
|
695 |
+
if self.upsample_preds:
|
696 |
+
query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False)
|
697 |
+
support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False)
|
698 |
+
batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
|
699 |
+
if symmetric:
|
700 |
+
dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
|
701 |
+
else:
|
702 |
+
dense_corresps = self.forward(batch, batched = True, upsample=True)
|
703 |
+
query_to_support = dense_corresps[finest_scale]["dense_flow"]
|
704 |
+
dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
|
705 |
+
|
706 |
+
# Get certainty interpolation
|
707 |
+
dense_certainty = dense_certainty - low_res_certainty
|
708 |
+
query_to_support = query_to_support.permute(
|
709 |
+
0, 2, 3, 1
|
710 |
+
)
|
711 |
+
# Create im1 meshgrid
|
712 |
+
query_coords = torch.meshgrid(
|
713 |
+
(
|
714 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=im1_path.device),
|
715 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=im1_path.device),
|
716 |
+
)
|
717 |
+
)
|
718 |
+
query_coords = torch.stack((query_coords[1], query_coords[0]))
|
719 |
+
query_coords = query_coords[None].expand(b, 2, hs, ws)
|
720 |
+
dense_certainty = dense_certainty.sigmoid() # logits -> probs
|
721 |
+
query_coords = query_coords.permute(0, 2, 3, 1)
|
722 |
+
if (query_to_support.abs() > 1).any() and True:
|
723 |
+
wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
|
724 |
+
dense_certainty[wrong[:,None]] = 0
|
725 |
+
# remove black pixels
|
726 |
+
black_mask1 = (im1_path[0, 0] < 0.03125) & (im1_path[0, 1] < 0.03125) & (im1_path[0, 2] < 0.03125)
|
727 |
+
black_mask2 = (im2_path[0, 0] < 0.03125) & (im2_path[0, 1] < 0.03125) & (im2_path[0, 2] < 0.03125)
|
728 |
+
black_mask = torch.stack((black_mask1, black_mask2))[:, None]
|
729 |
+
black_mask = F.interpolate(black_mask.float(), size=tuple(dense_certainty.shape[-2:]), mode='nearest').bool()
|
730 |
+
dense_certainty[black_mask] = 0
|
731 |
+
|
732 |
+
query_to_support = torch.clamp(query_to_support, -1, 1)
|
733 |
+
if symmetric:
|
734 |
+
support_coords = query_coords
|
735 |
+
qts, stq = query_to_support.chunk(2)
|
736 |
+
q_warp = torch.cat((query_coords, qts), dim=-1)
|
737 |
+
s_warp = torch.cat((stq, support_coords), dim=-1)
|
738 |
+
warp = torch.cat((q_warp, s_warp),dim=2)
|
739 |
+
dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
|
740 |
+
else:
|
741 |
+
warp = torch.cat((query_coords, query_to_support), dim=-1)
|
742 |
+
if batched:
|
743 |
+
return (
|
744 |
+
warp,
|
745 |
+
dense_certainty
|
746 |
+
)
|
747 |
+
else:
|
748 |
+
return (
|
749 |
+
warp[0],
|
750 |
+
dense_certainty[0],
|
751 |
+
)
|
hloc/matchers/networks/dkm/models/encoders.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as tvm
|
5 |
+
|
6 |
+
class ResNet18(nn.Module):
|
7 |
+
def __init__(self, pretrained=False) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.net = tvm.resnet18(pretrained=pretrained)
|
10 |
+
def forward(self, x):
|
11 |
+
self = self.net
|
12 |
+
x1 = x
|
13 |
+
x = self.conv1(x1)
|
14 |
+
x = self.bn1(x)
|
15 |
+
x2 = self.relu(x)
|
16 |
+
x = self.maxpool(x2)
|
17 |
+
x4 = self.layer1(x)
|
18 |
+
x8 = self.layer2(x4)
|
19 |
+
x16 = self.layer3(x8)
|
20 |
+
x32 = self.layer4(x16)
|
21 |
+
return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
|
22 |
+
|
23 |
+
def train(self, mode=True):
|
24 |
+
super().train(mode)
|
25 |
+
for m in self.modules():
|
26 |
+
if isinstance(m, nn.BatchNorm2d):
|
27 |
+
m.eval()
|
28 |
+
pass
|
29 |
+
|
30 |
+
class ResNet50(nn.Module):
|
31 |
+
def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
|
32 |
+
super().__init__()
|
33 |
+
if dilation is None:
|
34 |
+
dilation = [False,False,False]
|
35 |
+
if anti_aliased:
|
36 |
+
pass
|
37 |
+
else:
|
38 |
+
if weights is not None:
|
39 |
+
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
|
40 |
+
else:
|
41 |
+
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
|
42 |
+
|
43 |
+
del self.net.fc
|
44 |
+
self.high_res = high_res
|
45 |
+
self.freeze_bn = freeze_bn
|
46 |
+
def forward(self, x):
|
47 |
+
net = self.net
|
48 |
+
feats = {1:x}
|
49 |
+
x = net.conv1(x)
|
50 |
+
x = net.bn1(x)
|
51 |
+
x = net.relu(x)
|
52 |
+
feats[2] = x
|
53 |
+
x = net.maxpool(x)
|
54 |
+
x = net.layer1(x)
|
55 |
+
feats[4] = x
|
56 |
+
x = net.layer2(x)
|
57 |
+
feats[8] = x
|
58 |
+
x = net.layer3(x)
|
59 |
+
feats[16] = x
|
60 |
+
x = net.layer4(x)
|
61 |
+
feats[32] = x
|
62 |
+
return feats
|
63 |
+
|
64 |
+
def train(self, mode=True):
|
65 |
+
super().train(mode)
|
66 |
+
if self.freeze_bn:
|
67 |
+
for m in self.modules():
|
68 |
+
if isinstance(m, nn.BatchNorm2d):
|
69 |
+
m.eval()
|
70 |
+
pass
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
class ResNet101(nn.Module):
|
76 |
+
def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
|
77 |
+
super().__init__()
|
78 |
+
if weights is not None:
|
79 |
+
self.net = tvm.resnet101(weights = weights)
|
80 |
+
else:
|
81 |
+
self.net = tvm.resnet101(pretrained=pretrained)
|
82 |
+
self.high_res = high_res
|
83 |
+
self.scale_factor = 1 if not high_res else 1.5
|
84 |
+
def forward(self, x):
|
85 |
+
net = self.net
|
86 |
+
feats = {1:x}
|
87 |
+
sf = self.scale_factor
|
88 |
+
if self.high_res:
|
89 |
+
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
|
90 |
+
x = net.conv1(x)
|
91 |
+
x = net.bn1(x)
|
92 |
+
x = net.relu(x)
|
93 |
+
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
94 |
+
x = net.maxpool(x)
|
95 |
+
x = net.layer1(x)
|
96 |
+
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
97 |
+
x = net.layer2(x)
|
98 |
+
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
99 |
+
x = net.layer3(x)
|
100 |
+
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
101 |
+
x = net.layer4(x)
|
102 |
+
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
103 |
+
return feats
|
104 |
+
|
105 |
+
def train(self, mode=True):
|
106 |
+
super().train(mode)
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.BatchNorm2d):
|
109 |
+
m.eval()
|
110 |
+
pass
|
111 |
+
|
112 |
+
|
113 |
+
class WideResNet50(nn.Module):
|
114 |
+
def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
|
115 |
+
super().__init__()
|
116 |
+
if weights is not None:
|
117 |
+
self.net = tvm.wide_resnet50_2(weights = weights)
|
118 |
+
else:
|
119 |
+
self.net = tvm.wide_resnet50_2(pretrained=pretrained)
|
120 |
+
self.high_res = high_res
|
121 |
+
self.scale_factor = 1 if not high_res else 1.5
|
122 |
+
def forward(self, x):
|
123 |
+
net = self.net
|
124 |
+
feats = {1:x}
|
125 |
+
sf = self.scale_factor
|
126 |
+
if self.high_res:
|
127 |
+
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
|
128 |
+
x = net.conv1(x)
|
129 |
+
x = net.bn1(x)
|
130 |
+
x = net.relu(x)
|
131 |
+
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
132 |
+
x = net.maxpool(x)
|
133 |
+
x = net.layer1(x)
|
134 |
+
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
135 |
+
x = net.layer2(x)
|
136 |
+
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
137 |
+
x = net.layer3(x)
|
138 |
+
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
139 |
+
x = net.layer4(x)
|
140 |
+
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
|
141 |
+
return feats
|
142 |
+
|
143 |
+
def train(self, mode=True):
|
144 |
+
super().train(mode)
|
145 |
+
for m in self.modules():
|
146 |
+
if isinstance(m, nn.BatchNorm2d):
|
147 |
+
m.eval()
|
148 |
+
pass
|
hloc/matchers/networks/dkm/models/model_zoo/DKMv3.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ...models.dkm import *
|
2 |
+
from ...models.encoders import *
|
3 |
+
|
4 |
+
|
5 |
+
def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs):
|
6 |
+
gp_dim = 256
|
7 |
+
dfn_dim = 384
|
8 |
+
feat_dim = 256
|
9 |
+
coordinate_decoder = DFN(
|
10 |
+
internal_dim=dfn_dim,
|
11 |
+
feat_input_modules=nn.ModuleDict(
|
12 |
+
{
|
13 |
+
"32": nn.Conv2d(512, feat_dim, 1, 1),
|
14 |
+
"16": nn.Conv2d(512, feat_dim, 1, 1),
|
15 |
+
}
|
16 |
+
),
|
17 |
+
pred_input_modules=nn.ModuleDict(
|
18 |
+
{
|
19 |
+
"32": nn.Identity(),
|
20 |
+
"16": nn.Identity(),
|
21 |
+
}
|
22 |
+
),
|
23 |
+
rrb_d_dict=nn.ModuleDict(
|
24 |
+
{
|
25 |
+
"32": RRB(gp_dim + feat_dim, dfn_dim),
|
26 |
+
"16": RRB(gp_dim + feat_dim, dfn_dim),
|
27 |
+
}
|
28 |
+
),
|
29 |
+
cab_dict=nn.ModuleDict(
|
30 |
+
{
|
31 |
+
"32": CAB(2 * dfn_dim, dfn_dim),
|
32 |
+
"16": CAB(2 * dfn_dim, dfn_dim),
|
33 |
+
}
|
34 |
+
),
|
35 |
+
rrb_u_dict=nn.ModuleDict(
|
36 |
+
{
|
37 |
+
"32": RRB(dfn_dim, dfn_dim),
|
38 |
+
"16": RRB(dfn_dim, dfn_dim),
|
39 |
+
}
|
40 |
+
),
|
41 |
+
terminal_module=nn.ModuleDict(
|
42 |
+
{
|
43 |
+
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
44 |
+
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
|
45 |
+
}
|
46 |
+
),
|
47 |
+
)
|
48 |
+
dw = True
|
49 |
+
hidden_blocks = 8
|
50 |
+
kernel_size = 5
|
51 |
+
displacement_emb = "linear"
|
52 |
+
conv_refiner = nn.ModuleDict(
|
53 |
+
{
|
54 |
+
"16": ConvRefiner(
|
55 |
+
2 * 512+128+(2*7+1)**2,
|
56 |
+
2 * 512+128+(2*7+1)**2,
|
57 |
+
3,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
dw=dw,
|
60 |
+
hidden_blocks=hidden_blocks,
|
61 |
+
displacement_emb=displacement_emb,
|
62 |
+
displacement_emb_dim=128,
|
63 |
+
local_corr_radius = 7,
|
64 |
+
corr_in_other = True,
|
65 |
+
),
|
66 |
+
"8": ConvRefiner(
|
67 |
+
2 * 512+64+(2*3+1)**2,
|
68 |
+
2 * 512+64+(2*3+1)**2,
|
69 |
+
3,
|
70 |
+
kernel_size=kernel_size,
|
71 |
+
dw=dw,
|
72 |
+
hidden_blocks=hidden_blocks,
|
73 |
+
displacement_emb=displacement_emb,
|
74 |
+
displacement_emb_dim=64,
|
75 |
+
local_corr_radius = 3,
|
76 |
+
corr_in_other = True,
|
77 |
+
),
|
78 |
+
"4": ConvRefiner(
|
79 |
+
2 * 256+32+(2*2+1)**2,
|
80 |
+
2 * 256+32+(2*2+1)**2,
|
81 |
+
3,
|
82 |
+
kernel_size=kernel_size,
|
83 |
+
dw=dw,
|
84 |
+
hidden_blocks=hidden_blocks,
|
85 |
+
displacement_emb=displacement_emb,
|
86 |
+
displacement_emb_dim=32,
|
87 |
+
local_corr_radius = 2,
|
88 |
+
corr_in_other = True,
|
89 |
+
),
|
90 |
+
"2": ConvRefiner(
|
91 |
+
2 * 64+16,
|
92 |
+
128+16,
|
93 |
+
3,
|
94 |
+
kernel_size=kernel_size,
|
95 |
+
dw=dw,
|
96 |
+
hidden_blocks=hidden_blocks,
|
97 |
+
displacement_emb=displacement_emb,
|
98 |
+
displacement_emb_dim=16,
|
99 |
+
),
|
100 |
+
"1": ConvRefiner(
|
101 |
+
2 * 3+6,
|
102 |
+
24,
|
103 |
+
3,
|
104 |
+
kernel_size=kernel_size,
|
105 |
+
dw=dw,
|
106 |
+
hidden_blocks=hidden_blocks,
|
107 |
+
displacement_emb=displacement_emb,
|
108 |
+
displacement_emb_dim=6,
|
109 |
+
),
|
110 |
+
}
|
111 |
+
)
|
112 |
+
kernel_temperature = 0.2
|
113 |
+
learn_temperature = False
|
114 |
+
no_cov = True
|
115 |
+
kernel = CosKernel
|
116 |
+
only_attention = False
|
117 |
+
basis = "fourier"
|
118 |
+
gp32 = GP(
|
119 |
+
kernel,
|
120 |
+
T=kernel_temperature,
|
121 |
+
learn_temperature=learn_temperature,
|
122 |
+
only_attention=only_attention,
|
123 |
+
gp_dim=gp_dim,
|
124 |
+
basis=basis,
|
125 |
+
no_cov=no_cov,
|
126 |
+
)
|
127 |
+
gp16 = GP(
|
128 |
+
kernel,
|
129 |
+
T=kernel_temperature,
|
130 |
+
learn_temperature=learn_temperature,
|
131 |
+
only_attention=only_attention,
|
132 |
+
gp_dim=gp_dim,
|
133 |
+
basis=basis,
|
134 |
+
no_cov=no_cov,
|
135 |
+
)
|
136 |
+
gps = nn.ModuleDict({"32": gp32, "16": gp16})
|
137 |
+
proj = nn.ModuleDict(
|
138 |
+
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
|
139 |
+
)
|
140 |
+
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
|
141 |
+
|
142 |
+
encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
|
143 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs)
|
144 |
+
# res = matcher.load_state_dict(weights)
|
145 |
+
return matcher
|
hloc/matchers/networks/dkm/models/model_zoo/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
weight_urls = {
|
2 |
+
"DKMv3": {
|
3 |
+
"outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
|
4 |
+
"indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
|
5 |
+
},
|
6 |
+
}
|
7 |
+
import torch
|
8 |
+
from .DKMv3 import DKMv3
|
9 |
+
|
10 |
+
|
11 |
+
def DKMv3_outdoor(path_to_weights = None, device=None):
|
12 |
+
"""
|
13 |
+
Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default
|
14 |
+
resolution can be changed by setting model.h_resized, model.w_resized later.
|
15 |
+
Additionally upsamples preds to fixed resolution of (864, 1152),
|
16 |
+
can be turned off by model.upsample_preds = False
|
17 |
+
"""
|
18 |
+
if device is None:
|
19 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
+
if path_to_weights is not None:
|
21 |
+
weights = torch.load(path_to_weights, map_location=device)
|
22 |
+
else:
|
23 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"],
|
24 |
+
map_location=device)
|
25 |
+
return DKMv3(weights, 540, 720, upsample_preds = True, device=device)
|
26 |
+
|
27 |
+
def DKMv3_indoor(path_to_weights = None, device=None):
|
28 |
+
"""
|
29 |
+
Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default
|
30 |
+
Resolution can be changed by setting model.h_resized, model.w_resized later.
|
31 |
+
"""
|
32 |
+
if device is None:
|
33 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
34 |
+
if path_to_weights is not None:
|
35 |
+
weights = torch.load(path_to_weights, map_location=device)
|
36 |
+
else:
|
37 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"],
|
38 |
+
map_location=device)
|
39 |
+
return DKMv3(weights, 480, 640, upsample_preds = False, device=device)
|
hloc/matchers/networks/dkm/utils/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
pose_auc,
|
3 |
+
get_pose,
|
4 |
+
compute_relative_pose,
|
5 |
+
compute_pose_error,
|
6 |
+
estimate_pose,
|
7 |
+
rotate_intrinsic,
|
8 |
+
get_tuple_transform_ops,
|
9 |
+
get_depth_tuple_transform_ops,
|
10 |
+
warp_kpts,
|
11 |
+
numpy_to_pil,
|
12 |
+
tensor_to_pil,
|
13 |
+
)
|
hloc/matchers/networks/dkm/utils/kde.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1):
|
6 |
+
raise NotImplementedError("WIP, use at your own risk.")
|
7 |
+
# Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid
|
8 |
+
x = x.permute(0,3,1,2)
|
9 |
+
B,C,H,W = x.shape
|
10 |
+
K = kernel_size ** 2
|
11 |
+
unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W)
|
12 |
+
scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp()
|
13 |
+
density = scores.sum(dim=1)
|
14 |
+
return density
|
15 |
+
|
16 |
+
|
17 |
+
def kde(x, std = 0.1, device=None):
|
18 |
+
if device is None:
|
19 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
+
if isinstance(x, np.ndarray):
|
21 |
+
x = torch.from_numpy(x)
|
22 |
+
# use a gaussian kernel to estimate density
|
23 |
+
x = x.to(device)
|
24 |
+
scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
|
25 |
+
density = scores.sum(dim=-1)
|
26 |
+
return density
|
hloc/matchers/networks/dkm/utils/local_correlation.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def local_correlation(
|
6 |
+
feature0,
|
7 |
+
feature1,
|
8 |
+
local_radius,
|
9 |
+
padding_mode="zeros",
|
10 |
+
flow = None
|
11 |
+
):
|
12 |
+
device = feature0.device
|
13 |
+
b, c, h, w = feature0.size()
|
14 |
+
if flow is None:
|
15 |
+
# If flow is None, assume feature0 and feature1 are aligned
|
16 |
+
coords = torch.meshgrid(
|
17 |
+
(
|
18 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
19 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
20 |
+
))
|
21 |
+
coords = torch.stack((coords[1], coords[0]), dim=-1)[
|
22 |
+
None
|
23 |
+
].expand(b, h, w, 2)
|
24 |
+
else:
|
25 |
+
coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
|
26 |
+
r = local_radius
|
27 |
+
local_window = torch.meshgrid(
|
28 |
+
(
|
29 |
+
torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device),
|
30 |
+
torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device),
|
31 |
+
))
|
32 |
+
local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
|
33 |
+
None
|
34 |
+
].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2)
|
35 |
+
coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2)
|
36 |
+
window_feature = F.grid_sample(
|
37 |
+
feature1, coords, padding_mode=padding_mode, align_corners=False
|
38 |
+
)[...,None].reshape(b,c,h,w,(2*r+1)**2)
|
39 |
+
corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5)
|
40 |
+
return corr
|
hloc/matchers/networks/dkm/utils/transforms.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import kornia.augmentation as K
|
5 |
+
from kornia.geometry.transform import warp_perspective
|
6 |
+
|
7 |
+
# Adapted from Kornia
|
8 |
+
class GeometricSequential:
|
9 |
+
def __init__(self, *transforms, align_corners=True) -> None:
|
10 |
+
self.transforms = transforms
|
11 |
+
self.align_corners = align_corners
|
12 |
+
|
13 |
+
def __call__(self, x, mode="bilinear"):
|
14 |
+
b, c, h, w = x.shape
|
15 |
+
M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
|
16 |
+
for t in self.transforms:
|
17 |
+
if np.random.rand() < t.p:
|
18 |
+
M = M.matmul(
|
19 |
+
t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
|
20 |
+
)
|
21 |
+
return (
|
22 |
+
warp_perspective(
|
23 |
+
x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
|
24 |
+
),
|
25 |
+
M,
|
26 |
+
)
|
27 |
+
|
28 |
+
def apply_transform(self, x, M, mode="bilinear"):
|
29 |
+
b, c, h, w = x.shape
|
30 |
+
return warp_perspective(
|
31 |
+
x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class RandomPerspective(K.RandomPerspective):
|
36 |
+
def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
|
37 |
+
distortion_scale = torch.as_tensor(
|
38 |
+
self.distortion_scale, device=self._device, dtype=self._dtype
|
39 |
+
)
|
40 |
+
return self.random_perspective_generator(
|
41 |
+
batch_shape[0],
|
42 |
+
batch_shape[-2],
|
43 |
+
batch_shape[-1],
|
44 |
+
distortion_scale,
|
45 |
+
self.same_on_batch,
|
46 |
+
self.device,
|
47 |
+
self.dtype,
|
48 |
+
)
|
49 |
+
|
50 |
+
def random_perspective_generator(
|
51 |
+
self,
|
52 |
+
batch_size: int,
|
53 |
+
height: int,
|
54 |
+
width: int,
|
55 |
+
distortion_scale: torch.Tensor,
|
56 |
+
same_on_batch: bool = False,
|
57 |
+
device: torch.device = torch.device("cpu"),
|
58 |
+
dtype: torch.dtype = torch.float32,
|
59 |
+
) -> Dict[str, torch.Tensor]:
|
60 |
+
r"""Get parameters for ``perspective`` for a random perspective transform.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
batch_size (int): the tensor batch size.
|
64 |
+
height (int) : height of the image.
|
65 |
+
width (int): width of the image.
|
66 |
+
distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
|
67 |
+
same_on_batch (bool): apply the same transformation across the batch. Default: False.
|
68 |
+
device (torch.device): the device on which the random numbers will be generated. Default: cpu.
|
69 |
+
dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
params Dict[str, torch.Tensor]: parameters to be passed for transformation.
|
73 |
+
- start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
|
74 |
+
- end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
|
75 |
+
|
76 |
+
Note:
|
77 |
+
The generated random numbers are not reproducible across different devices and dtypes.
|
78 |
+
"""
|
79 |
+
if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
|
80 |
+
raise AssertionError(
|
81 |
+
f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
|
82 |
+
)
|
83 |
+
if not (
|
84 |
+
type(height) is int and height > 0 and type(width) is int and width > 0
|
85 |
+
):
|
86 |
+
raise AssertionError(
|
87 |
+
f"'height' and 'width' must be integers. Got {height}, {width}."
|
88 |
+
)
|
89 |
+
|
90 |
+
start_points: torch.Tensor = torch.tensor(
|
91 |
+
[[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
|
92 |
+
device=distortion_scale.device,
|
93 |
+
dtype=distortion_scale.dtype,
|
94 |
+
).expand(batch_size, -1, -1)
|
95 |
+
|
96 |
+
# generate random offset not larger than half of the image
|
97 |
+
fx = distortion_scale * width / 2
|
98 |
+
fy = distortion_scale * height / 2
|
99 |
+
|
100 |
+
factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
|
101 |
+
offset = (torch.rand_like(start_points) - 0.5) * 2
|
102 |
+
end_points = start_points + factor * offset
|
103 |
+
|
104 |
+
return dict(start_points=start_points, end_points=end_points)
|
hloc/matchers/networks/dkm/utils/utils.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from torchvision.transforms.functional import InterpolationMode
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
|
12 |
+
# --- GEOMETRY ---
|
13 |
+
def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
|
14 |
+
if len(kpts0) < 5:
|
15 |
+
return None
|
16 |
+
K0inv = np.linalg.inv(K0[:2,:2])
|
17 |
+
K1inv = np.linalg.inv(K1[:2,:2])
|
18 |
+
|
19 |
+
kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
|
20 |
+
kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
|
21 |
+
|
22 |
+
E, mask = cv2.findEssentialMat(
|
23 |
+
kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
|
24 |
+
)
|
25 |
+
|
26 |
+
ret = None
|
27 |
+
if E is not None:
|
28 |
+
best_num_inliers = 0
|
29 |
+
|
30 |
+
for _E in np.split(E, len(E) / 3):
|
31 |
+
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
|
32 |
+
if n > best_num_inliers:
|
33 |
+
best_num_inliers = n
|
34 |
+
ret = (R, t, mask.ravel() > 0)
|
35 |
+
return ret
|
36 |
+
|
37 |
+
|
38 |
+
def rotate_intrinsic(K, n):
|
39 |
+
base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
|
40 |
+
rot = np.linalg.matrix_power(base_rot, n)
|
41 |
+
return rot @ K
|
42 |
+
|
43 |
+
|
44 |
+
def rotate_pose_inplane(i_T_w, rot):
|
45 |
+
rotation_matrices = [
|
46 |
+
np.array(
|
47 |
+
[
|
48 |
+
[np.cos(r), -np.sin(r), 0.0, 0.0],
|
49 |
+
[np.sin(r), np.cos(r), 0.0, 0.0],
|
50 |
+
[0.0, 0.0, 1.0, 0.0],
|
51 |
+
[0.0, 0.0, 0.0, 1.0],
|
52 |
+
],
|
53 |
+
dtype=np.float32,
|
54 |
+
)
|
55 |
+
for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
|
56 |
+
]
|
57 |
+
return np.dot(rotation_matrices[rot], i_T_w)
|
58 |
+
|
59 |
+
|
60 |
+
def scale_intrinsics(K, scales):
|
61 |
+
scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
|
62 |
+
return np.dot(scales, K)
|
63 |
+
|
64 |
+
|
65 |
+
def to_homogeneous(points):
|
66 |
+
return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
|
67 |
+
|
68 |
+
|
69 |
+
def angle_error_mat(R1, R2):
|
70 |
+
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
|
71 |
+
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
|
72 |
+
return np.rad2deg(np.abs(np.arccos(cos)))
|
73 |
+
|
74 |
+
|
75 |
+
def angle_error_vec(v1, v2):
|
76 |
+
n = np.linalg.norm(v1) * np.linalg.norm(v2)
|
77 |
+
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
|
78 |
+
|
79 |
+
|
80 |
+
def compute_pose_error(T_0to1, R, t):
|
81 |
+
R_gt = T_0to1[:3, :3]
|
82 |
+
t_gt = T_0to1[:3, 3]
|
83 |
+
error_t = angle_error_vec(t.squeeze(), t_gt)
|
84 |
+
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
|
85 |
+
error_R = angle_error_mat(R, R_gt)
|
86 |
+
return error_t, error_R
|
87 |
+
|
88 |
+
|
89 |
+
def pose_auc(errors, thresholds):
|
90 |
+
sort_idx = np.argsort(errors)
|
91 |
+
errors = np.array(errors.copy())[sort_idx]
|
92 |
+
recall = (np.arange(len(errors)) + 1) / len(errors)
|
93 |
+
errors = np.r_[0.0, errors]
|
94 |
+
recall = np.r_[0.0, recall]
|
95 |
+
aucs = []
|
96 |
+
for t in thresholds:
|
97 |
+
last_index = np.searchsorted(errors, t)
|
98 |
+
r = np.r_[recall[:last_index], recall[last_index - 1]]
|
99 |
+
e = np.r_[errors[:last_index], t]
|
100 |
+
aucs.append(np.trapz(r, x=e) / t)
|
101 |
+
return aucs
|
102 |
+
|
103 |
+
|
104 |
+
# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
|
105 |
+
def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
|
106 |
+
ops = []
|
107 |
+
if resize:
|
108 |
+
ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
|
109 |
+
return TupleCompose(ops)
|
110 |
+
|
111 |
+
|
112 |
+
def get_tuple_transform_ops(resize=None, normalize=True, unscale=False):
|
113 |
+
ops = []
|
114 |
+
if resize:
|
115 |
+
ops.append(TupleResize(resize))
|
116 |
+
if normalize:
|
117 |
+
ops.append(TupleToTensorScaled())
|
118 |
+
# ops.append(
|
119 |
+
# TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
120 |
+
# ) # Imagenet mean/std
|
121 |
+
else:
|
122 |
+
if unscale:
|
123 |
+
ops.append(TupleToTensorUnscaled())
|
124 |
+
else:
|
125 |
+
ops.append(TupleToTensorScaled())
|
126 |
+
return TupleCompose(ops)
|
127 |
+
|
128 |
+
|
129 |
+
class ToTensorScaled(object):
|
130 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
|
131 |
+
|
132 |
+
def __call__(self, im):
|
133 |
+
if not isinstance(im, torch.Tensor):
|
134 |
+
im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
|
135 |
+
im /= 255.0
|
136 |
+
return torch.from_numpy(im)
|
137 |
+
else:
|
138 |
+
return im
|
139 |
+
|
140 |
+
def __repr__(self):
|
141 |
+
return "ToTensorScaled(./255)"
|
142 |
+
|
143 |
+
|
144 |
+
class TupleToTensorScaled(object):
|
145 |
+
def __init__(self):
|
146 |
+
self.to_tensor = ToTensorScaled()
|
147 |
+
|
148 |
+
def __call__(self, im_tuple):
|
149 |
+
return [self.to_tensor(im) for im in im_tuple]
|
150 |
+
|
151 |
+
def __repr__(self):
|
152 |
+
return "TupleToTensorScaled(./255)"
|
153 |
+
|
154 |
+
|
155 |
+
class ToTensorUnscaled(object):
|
156 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
|
157 |
+
|
158 |
+
def __call__(self, im):
|
159 |
+
return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
|
160 |
+
|
161 |
+
def __repr__(self):
|
162 |
+
return "ToTensorUnscaled()"
|
163 |
+
|
164 |
+
|
165 |
+
class TupleToTensorUnscaled(object):
|
166 |
+
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
|
167 |
+
|
168 |
+
def __init__(self):
|
169 |
+
self.to_tensor = ToTensorUnscaled()
|
170 |
+
|
171 |
+
def __call__(self, im_tuple):
|
172 |
+
return [self.to_tensor(im) for im in im_tuple]
|
173 |
+
|
174 |
+
def __repr__(self):
|
175 |
+
return "TupleToTensorUnscaled()"
|
176 |
+
|
177 |
+
|
178 |
+
class TupleResize(object):
|
179 |
+
def __init__(self, size, mode=InterpolationMode.BICUBIC):
|
180 |
+
self.size = size
|
181 |
+
self.resize = transforms.Resize(size, mode)
|
182 |
+
|
183 |
+
def __call__(self, im_tuple):
|
184 |
+
return [self.resize(im) for im in im_tuple]
|
185 |
+
|
186 |
+
def __repr__(self):
|
187 |
+
return "TupleResize(size={})".format(self.size)
|
188 |
+
|
189 |
+
|
190 |
+
class TupleNormalize(object):
|
191 |
+
def __init__(self, mean, std):
|
192 |
+
self.mean = mean
|
193 |
+
self.std = std
|
194 |
+
self.normalize = transforms.Normalize(mean=mean, std=std)
|
195 |
+
|
196 |
+
def __call__(self, im_tuple):
|
197 |
+
return [self.normalize(im) for im in im_tuple]
|
198 |
+
|
199 |
+
def __repr__(self):
|
200 |
+
return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
|
201 |
+
|
202 |
+
|
203 |
+
class TupleCompose(object):
|
204 |
+
def __init__(self, transforms):
|
205 |
+
self.transforms = transforms
|
206 |
+
|
207 |
+
def __call__(self, im_tuple):
|
208 |
+
for t in self.transforms:
|
209 |
+
im_tuple = t(im_tuple)
|
210 |
+
return im_tuple
|
211 |
+
|
212 |
+
def __repr__(self):
|
213 |
+
format_string = self.__class__.__name__ + "("
|
214 |
+
for t in self.transforms:
|
215 |
+
format_string += "\n"
|
216 |
+
format_string += " {0}".format(t)
|
217 |
+
format_string += "\n)"
|
218 |
+
return format_string
|
219 |
+
|
220 |
+
|
221 |
+
@torch.no_grad()
|
222 |
+
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
223 |
+
"""Warp kpts0 from I0 to I1 with depth, K and Rt
|
224 |
+
Also check covisibility and depth consistency.
|
225 |
+
Depth is consistent if relative error < 0.2 (hard-coded).
|
226 |
+
# https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
|
227 |
+
Args:
|
228 |
+
kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
|
229 |
+
depth0 (torch.Tensor): [N, H, W],
|
230 |
+
depth1 (torch.Tensor): [N, H, W],
|
231 |
+
T_0to1 (torch.Tensor): [N, 3, 4],
|
232 |
+
K0 (torch.Tensor): [N, 3, 3],
|
233 |
+
K1 (torch.Tensor): [N, 3, 3],
|
234 |
+
Returns:
|
235 |
+
calculable_mask (torch.Tensor): [N, L]
|
236 |
+
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
|
237 |
+
"""
|
238 |
+
(
|
239 |
+
n,
|
240 |
+
h,
|
241 |
+
w,
|
242 |
+
) = depth0.shape
|
243 |
+
kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[
|
244 |
+
:, 0, :, 0
|
245 |
+
]
|
246 |
+
kpts0 = torch.stack(
|
247 |
+
(w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
|
248 |
+
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
|
249 |
+
# Sample depth, get calculable_mask on depth != 0
|
250 |
+
nonzero_mask = kpts0_depth != 0
|
251 |
+
|
252 |
+
# Unproject
|
253 |
+
kpts0_h = (
|
254 |
+
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
|
255 |
+
* kpts0_depth[..., None]
|
256 |
+
) # (N, L, 3)
|
257 |
+
kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
258 |
+
kpts0_cam = kpts0_n
|
259 |
+
|
260 |
+
# Rigid Transform
|
261 |
+
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
262 |
+
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
263 |
+
|
264 |
+
# Project
|
265 |
+
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
266 |
+
w_kpts0 = w_kpts0_h[:, :, :2] / (
|
267 |
+
w_kpts0_h[:, :, [2]] + 1e-4
|
268 |
+
) # (N, L, 2), +1e-4 to avoid zero depth
|
269 |
+
|
270 |
+
# Covisible Check
|
271 |
+
h, w = depth1.shape[1:3]
|
272 |
+
covisible_mask = (
|
273 |
+
(w_kpts0[:, :, 0] > 0)
|
274 |
+
* (w_kpts0[:, :, 0] < w - 1)
|
275 |
+
* (w_kpts0[:, :, 1] > 0)
|
276 |
+
* (w_kpts0[:, :, 1] < h - 1)
|
277 |
+
)
|
278 |
+
w_kpts0 = torch.stack(
|
279 |
+
(2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
|
280 |
+
) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
|
281 |
+
# w_kpts0[~covisible_mask, :] = -5 # xd
|
282 |
+
|
283 |
+
w_kpts0_depth = F.grid_sample(
|
284 |
+
depth1[:, None], w_kpts0[:, :, None], mode="bilinear"
|
285 |
+
)[:, 0, :, 0]
|
286 |
+
consistent_mask = (
|
287 |
+
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
|
288 |
+
).abs() < 0.05
|
289 |
+
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
290 |
+
|
291 |
+
return valid_mask, w_kpts0
|
292 |
+
|
293 |
+
|
294 |
+
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
295 |
+
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
296 |
+
|
297 |
+
|
298 |
+
def numpy_to_pil(x: np.ndarray):
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
x: Assumed to be of shape (h,w,c)
|
302 |
+
"""
|
303 |
+
if isinstance(x, torch.Tensor):
|
304 |
+
x = x.detach().cpu().numpy()
|
305 |
+
if x.max() <= 1.01:
|
306 |
+
x *= 255
|
307 |
+
x = x.astype(np.uint8)
|
308 |
+
return Image.fromarray(x)
|
309 |
+
|
310 |
+
|
311 |
+
def tensor_to_pil(x, unnormalize=False):
|
312 |
+
if unnormalize:
|
313 |
+
x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None]
|
314 |
+
x = x.detach().permute(1, 2, 0).cpu().numpy()
|
315 |
+
x = np.clip(x, 0.0, 1.0)
|
316 |
+
return numpy_to_pil(x)
|
317 |
+
|
318 |
+
|
319 |
+
def to_cuda(batch):
|
320 |
+
for key, value in batch.items():
|
321 |
+
if isinstance(value, torch.Tensor):
|
322 |
+
batch[key] = value.to(device)
|
323 |
+
return batch
|
324 |
+
|
325 |
+
|
326 |
+
def to_cpu(batch):
|
327 |
+
for key, value in batch.items():
|
328 |
+
if isinstance(value, torch.Tensor):
|
329 |
+
batch[key] = value.cpu()
|
330 |
+
return batch
|
331 |
+
|
332 |
+
|
333 |
+
def get_pose(calib):
|
334 |
+
w, h = np.array(calib["imsize"])[0]
|
335 |
+
return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
|
336 |
+
|
337 |
+
|
338 |
+
def compute_relative_pose(R1, t1, R2, t2):
|
339 |
+
rots = R2 @ (R1.T)
|
340 |
+
trans = -rots @ t1 + t2
|
341 |
+
return rots, trans
|