xuelunshen commited on
Commit
1bfbd08
·
1 Parent(s): 8d7cbc7

update: gim code

Browse files
.gitignore CHANGED
@@ -21,3 +21,6 @@ gradio_cached_examples
21
  hloc/matchers/quadtree.py
22
  third_party/QuadTreeAttention
23
  desktop.ini
 
 
 
 
21
  hloc/matchers/quadtree.py
22
  third_party/QuadTreeAttention
23
  desktop.ini
24
+
25
+ */.DS_Store
26
+ .DS_Store
common/utils.py CHANGED
@@ -448,6 +448,7 @@ ransac_zoo = {
448
 
449
  # Matchers collections
450
  matcher_zoo = {
 
451
  "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
452
  "sold2": {"config": match_dense.confs["sold2"], "dense": True},
453
  # 'dedode-sparse': {
 
448
 
449
  # Matchers collections
450
  matcher_zoo = {
451
+ "gim": {"config": match_dense.confs["gim"], "dense": True},
452
  "gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
453
  "sold2": {"config": match_dense.confs["sold2"], "dense": True},
454
  # 'dedode-sparse': {
hloc/match_dense.py CHANGED
@@ -9,6 +9,23 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  confs = {
11
  # Best quality but loads of points. Only use for small scenes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  "loftr": {
13
  "output": "matches-loftr",
14
  "model": {
 
9
 
10
  confs = {
11
  # Best quality but loads of points. Only use for small scenes
12
+ "gim": {
13
+ "output": "matches-gim",
14
+ "model": {
15
+ "name": "gim",
16
+ "weights": "gim_dkm_100h.ckpt",
17
+ "max_keypoints": 2000,
18
+ "match_threshold": 0.2,
19
+ },
20
+ "preprocessing": {
21
+ "grayscale": False,
22
+ "force_resize": True,
23
+ "resize_max": 1024,
24
+ "width": 80,
25
+ "height": 60,
26
+ "dfactor": 8,
27
+ },
28
+ },
29
  "loftr": {
30
  "output": "matches-loftr",
31
  "model": {
hloc/matchers/gim.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import subprocess
4
+
5
+ from pathlib import Path
6
+ from ..utils.base_model import BaseModel
7
+ from .. import logger
8
+
9
+ from .networks.dkm.models.model_zoo.DKMv3 import DKMv3
10
+
11
+ weight_path = Path(__file__).parent / 'networks' / 'dkm'
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+
15
+ class GIM(BaseModel):
16
+ default_conf = {
17
+ "model_name": "gim_dkm_100h.ckpt",
18
+ "match_threshold": 0.2,
19
+ "checkpoint_dir": weight_path,
20
+ }
21
+ required_inputs = [
22
+ "image0",
23
+ "image1",
24
+ ]
25
+ # Models exported using
26
+ # dkm_models = {
27
+ # "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
28
+ # "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
29
+ # }
30
+
31
+ def _init(self, conf):
32
+ model_path = weight_path / conf["model_name"]
33
+
34
+ # Download the model.
35
+ if not model_path.exists():
36
+ model_path.parent.mkdir(exist_ok=True)
37
+ link = self.dkm_models[conf["model_name"]]
38
+ cmd = ["wget", link, "-O", str(model_path)]
39
+ logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
40
+ subprocess.run(cmd, check=True)
41
+ logger.info(f"Loading DKMv3 model...")
42
+ # self.net = DKMv3(path_to_weights=str(model_path), device=device)
43
+
44
+ model = DKMv3(None, 672, 896, upsample_preds=True)
45
+
46
+ checkpoints_path = join('checkpoints', conf['weights'])
47
+ state_dict = torch.load(checkpoints_path, map_location='cpu')
48
+ if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
49
+ for k in list(state_dict.keys()):
50
+ if k.startswith('model.'):
51
+ state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
52
+ if 'encoder.net.fc' in k:
53
+ state_dict.pop(k)
54
+ model.load_state_dict(state_dict)
55
+
56
+ self.net = model
57
+
58
+ def _forward(self, data):
59
+ # img0 = data["image0"].cpu().numpy().squeeze() * 255
60
+ # img1 = data["image1"].cpu().numpy().squeeze() * 255
61
+ # img0 = img0.transpose(1, 2, 0)
62
+ # img1 = img1.transpose(1, 2, 0)
63
+ # img0 = Image.fromarray(img0.astype("uint8"))
64
+ # img1 = Image.fromarray(img1.astype("uint8"))
65
+ # W_A, H_A = img0.size
66
+ # W_B, H_B = img1.size
67
+ #
68
+ # warp, certainty = self.net.match(img0, img1, device=device)
69
+ # matches, certainty = self.net.sample(warp, certainty)
70
+ # kpts1, kpts2 = self.net.to_pixel_coordinates(
71
+ # matches, H_A, W_A, H_B, W_B
72
+ # )
73
+
74
+ image0, image1 = data['image0'], data['image1']
75
+ orig_width = image0.shape[3]
76
+ orig_height = image0.shape[2]
77
+ aspect_ratio = 896 / 672
78
+ new_width = max(orig_width, int(orig_height * aspect_ratio))
79
+ new_height = max(orig_height, int(orig_width / aspect_ratio))
80
+ pad_height = new_height - orig_height
81
+ pad_width = new_width - orig_width
82
+ pad_top = pad_height // 2
83
+ pad_bottom = pad_height - pad_top
84
+ pad_left = pad_width // 2
85
+ pad_right = pad_width - pad_left
86
+ image0 = torch.nn.functional.pad(image0, (pad_left, pad_right, pad_top, pad_bottom))
87
+ image1 = torch.nn.functional.pad(image1, (pad_left, pad_right, pad_top, pad_bottom))
88
+ dense_matches, dense_certainty = self.net.match(image0, image1)
89
+ sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, 2048)
90
+ height0, width0 = image0.shape[-2:]
91
+ height1, width1 = image1.shape[-2:]
92
+ kpts0 = sparse_matches[:, :2]
93
+ kpts1 = sparse_matches[:, 2:]
94
+ kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
95
+ kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
96
+ b_ids, i_ids = torch.where(mconf[None])
97
+ # before padding
98
+ kpts0 -= kpts0.new_tensor((pad_left, pad_top))[None]
99
+ kpts1 -= kpts1.new_tensor((pad_left, pad_top))[None]
100
+ mask = (kpts0[:, 0] > 0) & \
101
+ (kpts0[:, 1] > 0) & \
102
+ (kpts1[:, 0] > 0) & \
103
+ (kpts1[:, 1] > 0)
104
+ mask = mask & \
105
+ (kpts0[:, 0] <= (orig_width - 1)) & \
106
+ (kpts1[:, 0] <= (orig_width - 1)) & \
107
+ (kpts0[:, 1] <= (orig_height - 1)) & \
108
+ (kpts1[:, 1] <= (orig_height - 1))
109
+ pred = {
110
+ 'keypoints0': kpts0[i_ids],
111
+ 'keypoints1': kpts1[i_ids],
112
+ 'confidence': mconf[i_ids],
113
+ 'batch_indexes': b_ids,
114
+ }
115
+ scores, b_ids = pred['confidence'], pred['batch_indexes']
116
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
117
+ pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
118
+ pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]
119
+
120
+ out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
121
+ return out
hloc/matchers/networks/dkm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import (
2
+ DKMv3_outdoor,
3
+ DKMv3_indoor,
4
+ )
hloc/matchers/networks/dkm/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .megadepth import MegadepthBuilder
hloc/matchers/networks/dkm/datasets/megadepth.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset
8
+
9
+ from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
10
+ import torchvision.transforms.functional as tvf
11
+ from dkm.utils.transforms import GeometricSequential
12
+ import kornia.augmentation as K
13
+
14
+
15
+ class MegadepthScene:
16
+ def __init__(
17
+ self,
18
+ data_root,
19
+ scene_info,
20
+ ht=384,
21
+ wt=512,
22
+ min_overlap=0.0,
23
+ shake_t=0,
24
+ rot_prob=0.0,
25
+ normalize=True,
26
+ ) -> None:
27
+ self.data_root = data_root
28
+ self.image_paths = scene_info["image_paths"]
29
+ self.depth_paths = scene_info["depth_paths"]
30
+ self.intrinsics = scene_info["intrinsics"]
31
+ self.poses = scene_info["poses"]
32
+ self.pairs = scene_info["pairs"]
33
+ self.overlaps = scene_info["overlaps"]
34
+ threshold = self.overlaps > min_overlap
35
+ self.pairs = self.pairs[threshold]
36
+ self.overlaps = self.overlaps[threshold]
37
+ if len(self.pairs) > 100000:
38
+ pairinds = np.random.choice(
39
+ np.arange(0, len(self.pairs)), 100000, replace=False
40
+ )
41
+ self.pairs = self.pairs[pairinds]
42
+ self.overlaps = self.overlaps[pairinds]
43
+ # counts, bins = np.histogram(self.overlaps,20)
44
+ # print(counts)
45
+ self.im_transform_ops = get_tuple_transform_ops(
46
+ resize=(ht, wt), normalize=normalize
47
+ )
48
+ self.depth_transform_ops = get_depth_tuple_transform_ops(
49
+ resize=(ht, wt), normalize=False
50
+ )
51
+ self.wt, self.ht = wt, ht
52
+ self.shake_t = shake_t
53
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
54
+
55
+ def load_im(self, im_ref, crop=None):
56
+ im = Image.open(im_ref)
57
+ return im
58
+
59
+ def load_depth(self, depth_ref, crop=None):
60
+ depth = np.array(h5py.File(depth_ref, "r")["depth"])
61
+ return torch.from_numpy(depth)
62
+
63
+ def __len__(self):
64
+ return len(self.pairs)
65
+
66
+ def scale_intrinsic(self, K, wi, hi):
67
+ sx, sy = self.wt / wi, self.ht / hi
68
+ sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
69
+ return sK @ K
70
+
71
+ def rand_shake(self, *things):
72
+ t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
73
+ return [
74
+ tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
75
+ for thing in things
76
+ ], t
77
+
78
+ def __getitem__(self, pair_idx):
79
+ # read intrinsics of original size
80
+ idx1, idx2 = self.pairs[pair_idx]
81
+ K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
82
+ K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
83
+
84
+ # read and compute relative poses
85
+ T1 = self.poses[idx1]
86
+ T2 = self.poses[idx2]
87
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
88
+ :4, :4
89
+ ] # (4, 4)
90
+
91
+ # Load positive pair data
92
+ im1, im2 = self.image_paths[idx1], self.image_paths[idx2]
93
+ depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
94
+ im_src_ref = os.path.join(self.data_root, im1)
95
+ im_pos_ref = os.path.join(self.data_root, im2)
96
+ depth_src_ref = os.path.join(self.data_root, depth1)
97
+ depth_pos_ref = os.path.join(self.data_root, depth2)
98
+ # return torch.randn((1000,1000))
99
+ im_src = self.load_im(im_src_ref)
100
+ im_pos = self.load_im(im_pos_ref)
101
+ depth_src = self.load_depth(depth_src_ref)
102
+ depth_pos = self.load_depth(depth_pos_ref)
103
+
104
+ # Recompute camera intrinsic matrix due to the resize
105
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
106
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
107
+ # Process images
108
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
109
+ depth_src, depth_pos = self.depth_transform_ops(
110
+ (depth_src[None, None], depth_pos[None, None])
111
+ )
112
+ [im_src, im_pos, depth_src, depth_pos], t = self.rand_shake(
113
+ im_src, im_pos, depth_src, depth_pos
114
+ )
115
+ im_src, Hq = self.H_generator(im_src[None])
116
+ depth_src = self.H_generator.apply_transform(depth_src, Hq)
117
+ K1[:2, 2] += t
118
+ K2[:2, 2] += t
119
+ K1 = Hq[0] @ K1
120
+ data_dict = {
121
+ "query": im_src[0],
122
+ "query_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
123
+ "support": im_pos,
124
+ "support_identifier": self.image_paths[idx2]
125
+ .split("/")[-1]
126
+ .split(".jpg")[0],
127
+ "query_depth": depth_src[0, 0],
128
+ "support_depth": depth_pos[0, 0],
129
+ "K1": K1,
130
+ "K2": K2,
131
+ "T_1to2": T_1to2,
132
+ }
133
+ return data_dict
134
+
135
+
136
+ class MegadepthBuilder:
137
+ def __init__(self, data_root="data/megadepth") -> None:
138
+ self.data_root = data_root
139
+ self.scene_info_root = os.path.join(data_root, "prep_scene_info")
140
+ self.all_scenes = os.listdir(self.scene_info_root)
141
+ self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
142
+ self.test_scenes_loftr = ["0015.npy", "0022.npy"]
143
+
144
+ def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
145
+ if split == "train":
146
+ scene_names = set(self.all_scenes) - set(self.test_scenes)
147
+ elif split == "train_loftr":
148
+ scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
149
+ elif split == "test":
150
+ scene_names = self.test_scenes
151
+ elif split == "test_loftr":
152
+ scene_names = self.test_scenes_loftr
153
+ else:
154
+ raise ValueError(f"Split {split} not available")
155
+ scenes = []
156
+ for scene_name in scene_names:
157
+ scene_info = np.load(
158
+ os.path.join(self.scene_info_root, scene_name), allow_pickle=True
159
+ ).item()
160
+ scenes.append(
161
+ MegadepthScene(
162
+ self.data_root, scene_info, min_overlap=min_overlap, **kwargs
163
+ )
164
+ )
165
+ return scenes
166
+
167
+ def weight_scenes(self, concat_dataset, alpha=0.5):
168
+ ns = []
169
+ for d in concat_dataset.datasets:
170
+ ns.append(len(d))
171
+ ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
172
+ return ws
173
+
174
+
175
+ if __name__ == "__main__":
176
+ mega_test = ConcatDataset(MegadepthBuilder().build_scenes(split="train"))
177
+ mega_test[0]
hloc/matchers/networks/dkm/datasets/scannet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+ import cv2
5
+ import h5py
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import (
9
+ Dataset,
10
+ DataLoader,
11
+ ConcatDataset)
12
+
13
+ import torchvision.transforms.functional as tvf
14
+ import kornia.augmentation as K
15
+ import os.path as osp
16
+ import matplotlib.pyplot as plt
17
+ from dkm.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
18
+ from dkm.utils.transforms import GeometricSequential
19
+
20
+ from tqdm import tqdm
21
+
22
+ class ScanNetScene:
23
+ def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.) -> None:
24
+ self.scene_root = osp.join(data_root,"scans","scans_train")
25
+ self.data_names = scene_info['name']
26
+ self.overlaps = scene_info['score']
27
+ # Only sample 10s
28
+ valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
29
+ self.overlaps = self.overlaps[valid]
30
+ self.data_names = self.data_names[valid]
31
+ if len(self.data_names) > 10000:
32
+ pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
33
+ self.data_names = self.data_names[pairinds]
34
+ self.overlaps = self.overlaps[pairinds]
35
+ self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
36
+ self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
37
+ self.wt, self.ht = wt, ht
38
+ self.shake_t = shake_t
39
+ self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
40
+
41
+ def load_im(self, im_ref, crop=None):
42
+ im = Image.open(im_ref)
43
+ return im
44
+
45
+ def load_depth(self, depth_ref, crop=None):
46
+ depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
47
+ depth = depth / 1000
48
+ depth = torch.from_numpy(depth).float() # (h, w)
49
+ return depth
50
+
51
+ def __len__(self):
52
+ return len(self.data_names)
53
+
54
+ def scale_intrinsic(self, K, wi, hi):
55
+ sx, sy = self.wt / wi, self.ht / hi
56
+ sK = torch.tensor([[sx, 0, 0],
57
+ [0, sy, 0],
58
+ [0, 0, 1]])
59
+ return sK@K
60
+
61
+ def read_scannet_pose(self,path):
62
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
63
+
64
+ Returns:
65
+ pose_w2c (np.ndarray): (4, 4)
66
+ """
67
+ cam2world = np.loadtxt(path, delimiter=' ')
68
+ world2cam = np.linalg.inv(cam2world)
69
+ return world2cam
70
+
71
+
72
+ def read_scannet_intrinsic(self,path):
73
+ """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
74
+ """
75
+ intrinsic = np.loadtxt(path, delimiter=' ')
76
+ return intrinsic[:-1, :-1]
77
+
78
+ def __getitem__(self, pair_idx):
79
+ # read intrinsics of original size
80
+ data_name = self.data_names[pair_idx]
81
+ scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
82
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
83
+
84
+ # read the intrinsic of depthmap
85
+ K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
86
+ scene_name,
87
+ 'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
88
+ # read and compute relative poses
89
+ T1 = self.read_scannet_pose(osp.join(self.scene_root,
90
+ scene_name,
91
+ 'pose', f'{stem_name_1}.txt'))
92
+ T2 = self.read_scannet_pose(osp.join(self.scene_root,
93
+ scene_name,
94
+ 'pose', f'{stem_name_2}.txt'))
95
+ T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
96
+
97
+ # Load positive pair data
98
+ im_src_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
99
+ im_pos_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
100
+ depth_src_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
101
+ depth_pos_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
102
+
103
+ im_src = self.load_im(im_src_ref)
104
+ im_pos = self.load_im(im_pos_ref)
105
+ depth_src = self.load_depth(depth_src_ref)
106
+ depth_pos = self.load_depth(depth_pos_ref)
107
+
108
+ # Recompute camera intrinsic matrix due to the resize
109
+ K1 = self.scale_intrinsic(K1, im_src.width, im_src.height)
110
+ K2 = self.scale_intrinsic(K2, im_pos.width, im_pos.height)
111
+ # Process images
112
+ im_src, im_pos = self.im_transform_ops((im_src, im_pos))
113
+ depth_src, depth_pos = self.depth_transform_ops((depth_src[None,None], depth_pos[None,None]))
114
+
115
+ data_dict = {'query': im_src,
116
+ 'support': im_pos,
117
+ 'query_depth': depth_src[0,0],
118
+ 'support_depth': depth_pos[0,0],
119
+ 'K1': K1,
120
+ 'K2': K2,
121
+ 'T_1to2':T_1to2,
122
+ }
123
+ return data_dict
124
+
125
+
126
+ class ScanNetBuilder:
127
+ def __init__(self, data_root = 'data/scannet') -> None:
128
+ self.data_root = data_root
129
+ self.scene_info_root = os.path.join(data_root,'scannet_indices')
130
+ self.all_scenes = os.listdir(self.scene_info_root)
131
+
132
+ def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
133
+ # Note: split doesn't matter here as we always use same scannet_train scenes
134
+ scene_names = self.all_scenes
135
+ scenes = []
136
+ for scene_name in tqdm(scene_names):
137
+ scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
138
+ scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
139
+ return scenes
140
+
141
+ def weight_scenes(self, concat_dataset, alpha=.5):
142
+ ns = []
143
+ for d in concat_dataset.datasets:
144
+ ns.append(len(d))
145
+ ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
146
+ return ws
147
+
148
+
149
+ if __name__ == "__main__":
150
+ mega_test = ConcatDataset(ScanNetBuilder("data/scannet").build_scenes(split='train'))
151
+ mega_test[0]
hloc/matchers/networks/dkm/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model_zoo import (
2
+ DKMv3_outdoor,
3
+ DKMv3_indoor,
4
+ )
hloc/matchers/networks/dkm/models/dkm.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from ..utils.kde import kde
7
+ from ..utils import get_tuple_transform_ops
8
+ from ..utils.local_correlation import local_correlation
9
+
10
+
11
+ class ConvRefiner(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_dim=6,
15
+ hidden_dim=16,
16
+ out_dim=2,
17
+ dw=False,
18
+ kernel_size=5,
19
+ hidden_blocks=3,
20
+ displacement_emb = None,
21
+ displacement_emb_dim = None,
22
+ local_corr_radius = None,
23
+ corr_in_other = None,
24
+ no_support_fm = False,
25
+ ):
26
+ super().__init__()
27
+ self.block1 = self.create_block(
28
+ in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
29
+ )
30
+ self.hidden_blocks = nn.Sequential(
31
+ *[
32
+ self.create_block(
33
+ hidden_dim,
34
+ hidden_dim,
35
+ dw=dw,
36
+ kernel_size=kernel_size,
37
+ )
38
+ for hb in range(hidden_blocks)
39
+ ]
40
+ )
41
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
42
+ if displacement_emb:
43
+ self.has_displacement_emb = True
44
+ self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
45
+ else:
46
+ self.has_displacement_emb = False
47
+ self.local_corr_radius = local_corr_radius
48
+ self.corr_in_other = corr_in_other
49
+ self.no_support_fm = no_support_fm
50
+ def create_block(
51
+ self,
52
+ in_dim,
53
+ out_dim,
54
+ dw=False,
55
+ kernel_size=5,
56
+ ):
57
+ num_groups = 1 if not dw else in_dim
58
+ if dw:
59
+ assert (
60
+ out_dim % in_dim == 0
61
+ ), "outdim must be divisible by indim for depthwise"
62
+ conv1 = nn.Conv2d(
63
+ in_dim,
64
+ out_dim,
65
+ kernel_size=kernel_size,
66
+ stride=1,
67
+ padding=kernel_size // 2,
68
+ groups=num_groups,
69
+ )
70
+ norm = nn.BatchNorm2d(out_dim)
71
+ relu = nn.ReLU(inplace=True)
72
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
73
+ return nn.Sequential(conv1, norm, relu, conv2)
74
+
75
+ def forward(self, x, y, flow):
76
+ """Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
77
+
78
+ Args:
79
+ x ([type]): [description]
80
+ y ([type]): [description]
81
+ flow ([type]): [description]
82
+
83
+ Returns:
84
+ [type]: [description]
85
+ """
86
+ device = x.device
87
+ b,c,hs,ws = x.shape
88
+ with torch.no_grad():
89
+ x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
90
+ if self.has_displacement_emb:
91
+ query_coords = torch.meshgrid(
92
+ (
93
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
94
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
95
+ )
96
+ )
97
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
98
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
99
+ in_displacement = flow-query_coords
100
+ emb_in_displacement = self.disp_emb(in_displacement)
101
+ if self.local_corr_radius:
102
+ #TODO: should corr have gradient?
103
+ if self.corr_in_other:
104
+ # Corr in other means take a kxk grid around the predicted coordinate in other image
105
+ local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow)
106
+ else:
107
+ # Otherwise we use the warp to sample in the first image
108
+ # This is actually different operations, especially for large viewpoint changes
109
+ local_corr = local_correlation(x, x_hat, local_radius=self.local_corr_radius,)
110
+ if self.no_support_fm:
111
+ x_hat = torch.zeros_like(x)
112
+ d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
113
+ else:
114
+ d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
115
+ else:
116
+ if self.no_support_fm:
117
+ x_hat = torch.zeros_like(x)
118
+ d = torch.cat((x, x_hat), dim=1)
119
+ d = self.block1(d)
120
+ d = self.hidden_blocks(d)
121
+ d = self.out_conv(d)
122
+ certainty, displacement = d[:, :-2], d[:, -2:]
123
+ return certainty, displacement
124
+
125
+
126
+ class CosKernel(nn.Module): # similar to softmax kernel
127
+ def __init__(self, T, learn_temperature=False):
128
+ super().__init__()
129
+ self.learn_temperature = learn_temperature
130
+ if self.learn_temperature:
131
+ self.T = nn.Parameter(torch.tensor(T))
132
+ else:
133
+ self.T = T
134
+
135
+ def __call__(self, x, y, eps=1e-6):
136
+ c = torch.einsum("bnd,bmd->bnm", x, y) / (
137
+ x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
138
+ )
139
+ if self.learn_temperature:
140
+ T = self.T.abs() + 0.01
141
+ else:
142
+ T = torch.tensor(self.T, device=c.device)
143
+ K = ((c - 1.0) / T).exp()
144
+ return K
145
+
146
+
147
+ class CAB(nn.Module):
148
+ def __init__(self, in_channels, out_channels):
149
+ super(CAB, self).__init__()
150
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
151
+ self.conv1 = nn.Conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
153
+ )
154
+ self.relu = nn.ReLU()
155
+ self.conv2 = nn.Conv2d(
156
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.sigmod = nn.Sigmoid()
159
+
160
+ def forward(self, x):
161
+ x1, x2 = x # high, low (old, new)
162
+ x = torch.cat([x1, x2], dim=1)
163
+ x = self.global_pooling(x)
164
+ x = self.conv1(x)
165
+ x = self.relu(x)
166
+ x = self.conv2(x)
167
+ x = self.sigmod(x)
168
+ x2 = x * x2
169
+ res = x2 + x1
170
+ return res
171
+
172
+
173
+ class RRB(nn.Module):
174
+ def __init__(self, in_channels, out_channels, kernel_size=3):
175
+ super(RRB, self).__init__()
176
+ self.conv1 = nn.Conv2d(
177
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
178
+ )
179
+ self.conv2 = nn.Conv2d(
180
+ out_channels,
181
+ out_channels,
182
+ kernel_size=kernel_size,
183
+ stride=1,
184
+ padding=kernel_size // 2,
185
+ )
186
+ self.relu = nn.ReLU()
187
+ self.bn = nn.BatchNorm2d(out_channels)
188
+ self.conv3 = nn.Conv2d(
189
+ out_channels,
190
+ out_channels,
191
+ kernel_size=kernel_size,
192
+ stride=1,
193
+ padding=kernel_size // 2,
194
+ )
195
+
196
+ def forward(self, x):
197
+ x = self.conv1(x)
198
+ res = self.conv2(x)
199
+ res = self.bn(res)
200
+ res = self.relu(res)
201
+ res = self.conv3(res)
202
+ return self.relu(x + res)
203
+
204
+
205
+ class DFN(nn.Module):
206
+ def __init__(
207
+ self,
208
+ internal_dim,
209
+ feat_input_modules,
210
+ pred_input_modules,
211
+ rrb_d_dict,
212
+ cab_dict,
213
+ rrb_u_dict,
214
+ use_global_context=False,
215
+ global_dim=None,
216
+ terminal_module=None,
217
+ upsample_mode="bilinear",
218
+ align_corners=False,
219
+ ):
220
+ super().__init__()
221
+ if use_global_context:
222
+ assert (
223
+ global_dim is not None
224
+ ), "Global dim must be provided when using global context"
225
+ self.align_corners = align_corners
226
+ self.internal_dim = internal_dim
227
+ self.feat_input_modules = feat_input_modules
228
+ self.pred_input_modules = pred_input_modules
229
+ self.rrb_d = rrb_d_dict
230
+ self.cab = cab_dict
231
+ self.rrb_u = rrb_u_dict
232
+ self.use_global_context = use_global_context
233
+ if use_global_context:
234
+ self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
235
+ self.global_pooling = nn.AdaptiveAvgPool2d(1)
236
+ self.terminal_module = (
237
+ terminal_module if terminal_module is not None else nn.Identity()
238
+ )
239
+ self.upsample_mode = upsample_mode
240
+ self._scales = [int(key) for key in self.terminal_module.keys()]
241
+
242
+ def scales(self):
243
+ return self._scales.copy()
244
+
245
+ def forward(self, embeddings, feats, context, key):
246
+ feats = self.feat_input_modules[str(key)](feats)
247
+ embeddings = torch.cat([feats, embeddings], dim=1)
248
+ embeddings = self.rrb_d[str(key)](embeddings)
249
+ context = self.cab[str(key)]([context, embeddings])
250
+ context = self.rrb_u[str(key)](context)
251
+ preds = self.terminal_module[str(key)](context)
252
+ pred_coord = preds[:, -2:]
253
+ pred_certainty = preds[:, :-2]
254
+ return pred_coord, pred_certainty, context
255
+
256
+
257
+ class GP(nn.Module):
258
+ def __init__(
259
+ self,
260
+ kernel,
261
+ T=1,
262
+ learn_temperature=False,
263
+ only_attention=False,
264
+ gp_dim=64,
265
+ basis="fourier",
266
+ covar_size=5,
267
+ only_nearest_neighbour=False,
268
+ sigma_noise=0.1,
269
+ no_cov=False,
270
+ predict_features = False,
271
+ ):
272
+ super().__init__()
273
+ self.K = kernel(T=T, learn_temperature=learn_temperature)
274
+ self.sigma_noise = sigma_noise
275
+ self.covar_size = covar_size
276
+ self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
277
+ self.only_attention = only_attention
278
+ self.only_nearest_neighbour = only_nearest_neighbour
279
+ self.basis = basis
280
+ self.no_cov = no_cov
281
+ self.dim = gp_dim
282
+ self.predict_features = predict_features
283
+
284
+ def get_local_cov(self, cov):
285
+ K = self.covar_size
286
+ b, h, w, h, w = cov.shape
287
+ hw = h * w
288
+ cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
289
+ delta = torch.stack(
290
+ torch.meshgrid(
291
+ torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
292
+ ),
293
+ dim=-1,
294
+ )
295
+ positions = torch.stack(
296
+ torch.meshgrid(
297
+ torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
298
+ ),
299
+ dim=-1,
300
+ )
301
+ neighbours = positions[:, :, None, None, :] + delta[None, :, :]
302
+ points = torch.arange(hw)[:, None].expand(hw, K**2)
303
+ local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
304
+ :,
305
+ points.flatten(),
306
+ neighbours[..., 0].flatten(),
307
+ neighbours[..., 1].flatten(),
308
+ ].reshape(b, h, w, K**2)
309
+ return local_cov
310
+
311
+ def reshape(self, x):
312
+ return rearrange(x, "b d h w -> b (h w) d")
313
+
314
+ def project_to_basis(self, x):
315
+ if self.basis == "fourier":
316
+ return torch.cos(8 * math.pi * self.pos_conv(x))
317
+ elif self.basis == "linear":
318
+ return self.pos_conv(x)
319
+ else:
320
+ raise ValueError(
321
+ "No other bases other than fourier and linear currently supported in public release"
322
+ )
323
+
324
+ def get_pos_enc(self, y):
325
+ b, c, h, w = y.shape
326
+ coarse_coords = torch.meshgrid(
327
+ (
328
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
329
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
330
+ )
331
+ )
332
+
333
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
334
+ None
335
+ ].expand(b, h, w, 2)
336
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
337
+ coarse_embedded_coords = self.project_to_basis(coarse_coords)
338
+ return coarse_embedded_coords
339
+
340
+ def forward(self, x, y, **kwargs):
341
+ b, c, h1, w1 = x.shape
342
+ b, c, h2, w2 = y.shape
343
+ f = self.get_pos_enc(y)
344
+ if self.predict_features:
345
+ f = f + y[:,:self.dim] # Stupid way to predict features
346
+ b, d, h2, w2 = f.shape
347
+ #assert x.shape == y.shape
348
+ x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
349
+ K_xx = self.K(x, x)
350
+ K_yy = self.K(y, y)
351
+ K_xy = self.K(x, y)
352
+ K_yx = K_xy.permute(0, 2, 1)
353
+ sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
354
+ # Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
355
+ if len(K_yy[0]) > 2000:
356
+ K_yy_inv = torch.cat([torch.linalg.inv(K_yy[k:k+1] + sigma_noise[k:k+1]) for k in range(b)])
357
+ else:
358
+ K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
359
+
360
+ mu_x = K_xy.matmul(K_yy_inv.matmul(f))
361
+ mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
362
+ if not self.no_cov:
363
+ cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
364
+ cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
365
+ local_cov_x = self.get_local_cov(cov_x)
366
+ local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
367
+ gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
368
+ else:
369
+ gp_feats = mu_x
370
+ return gp_feats
371
+
372
+
373
+ class Encoder(nn.Module):
374
+ def __init__(self, resnet):
375
+ super().__init__()
376
+ self.resnet = resnet
377
+ def forward(self, x):
378
+ x0 = x
379
+ b, c, h, w = x.shape
380
+ x = self.resnet.conv1(x)
381
+ x = self.resnet.bn1(x)
382
+ x1 = self.resnet.relu(x)
383
+
384
+ x = self.resnet.maxpool(x1)
385
+ x2 = self.resnet.layer1(x)
386
+
387
+ x3 = self.resnet.layer2(x2)
388
+
389
+ x4 = self.resnet.layer3(x3)
390
+
391
+ x5 = self.resnet.layer4(x4)
392
+ feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
393
+ return feats
394
+
395
+ def train(self, mode=True):
396
+ super().train(mode)
397
+ for m in self.modules():
398
+ if isinstance(m, nn.BatchNorm2d):
399
+ m.eval()
400
+ pass
401
+
402
+
403
+ class Decoder(nn.Module):
404
+ def __init__(
405
+ self, embedding_decoder, gps, proj, conv_refiner, transformers = None, detach=False, scales="all", pos_embeddings = None,
406
+ ):
407
+ super().__init__()
408
+ self.embedding_decoder = embedding_decoder
409
+ self.gps = gps
410
+ self.proj = proj
411
+ self.conv_refiner = conv_refiner
412
+ self.detach = detach
413
+ if scales == "all":
414
+ self.scales = ["32", "16", "8", "4", "2", "1"]
415
+ else:
416
+ self.scales = scales
417
+
418
+ def upsample_preds(self, flow, certainty, query, support):
419
+ b, hs, ws, d = flow.shape
420
+ b, c, h, w = query.shape
421
+ flow = flow.permute(0, 3, 1, 2)
422
+ certainty = F.interpolate(
423
+ certainty, size=(h, w), align_corners=False, mode="bilinear"
424
+ )
425
+ flow = F.interpolate(
426
+ flow, size=(h, w), align_corners=False, mode="bilinear"
427
+ )
428
+ delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
429
+ flow = torch.stack(
430
+ (
431
+ flow[:, 0] + delta_flow[:, 0] / (4 * w),
432
+ flow[:, 1] + delta_flow[:, 1] / (4 * h),
433
+ ),
434
+ dim=1,
435
+ )
436
+ flow = flow.permute(0, 2, 3, 1)
437
+ certainty = certainty + delta_certainty
438
+ return flow, certainty
439
+
440
+ def get_placeholder_flow(self, b, h, w, device):
441
+ coarse_coords = torch.meshgrid(
442
+ (
443
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
444
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
445
+ )
446
+ )
447
+ coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
448
+ None
449
+ ].expand(b, h, w, 2)
450
+ coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
451
+ return coarse_coords
452
+
453
+
454
+ def forward(self, f1, f2, upsample = False, dense_flow = None, dense_certainty = None):
455
+ coarse_scales = self.embedding_decoder.scales()
456
+ all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
457
+ sizes = {scale: f1[scale].shape[-2:] for scale in f1}
458
+ h, w = sizes[1]
459
+ b = f1[1].shape[0]
460
+ device = f1[1].device
461
+ coarsest_scale = int(all_scales[0])
462
+ old_stuff = torch.zeros(
463
+ b, self.embedding_decoder.internal_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
464
+ )
465
+ dense_corresps = {}
466
+ if not upsample:
467
+ dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
468
+ dense_certainty = 0.0
469
+ else:
470
+ dense_flow = F.interpolate(
471
+ dense_flow,
472
+ size=sizes[coarsest_scale],
473
+ align_corners=False,
474
+ mode="bilinear",
475
+ )
476
+ dense_certainty = F.interpolate(
477
+ dense_certainty,
478
+ size=sizes[coarsest_scale],
479
+ align_corners=False,
480
+ mode="bilinear",
481
+ )
482
+ for new_scale in all_scales:
483
+ ins = int(new_scale)
484
+ f1_s, f2_s = f1[ins], f2[ins]
485
+ if new_scale in self.proj:
486
+ f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
487
+ b, c, hs, ws = f1_s.shape
488
+ if ins in coarse_scales:
489
+ old_stuff = F.interpolate(
490
+ old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
491
+ )
492
+ new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
493
+ dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
494
+ new_stuff, f1_s, old_stuff, new_scale
495
+ )
496
+
497
+ if new_scale in self.conv_refiner:
498
+ delta_certainty, displacement = self.conv_refiner[new_scale](
499
+ f1_s, f2_s, dense_flow
500
+ )
501
+ dense_flow = torch.stack(
502
+ (
503
+ dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
504
+ dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
505
+ ),
506
+ dim=1,
507
+ )
508
+ dense_certainty = (
509
+ dense_certainty + delta_certainty
510
+ ) # predict both certainty and displacement
511
+
512
+ dense_corresps[ins] = {
513
+ "dense_flow": dense_flow,
514
+ "dense_certainty": dense_certainty,
515
+ }
516
+
517
+ if new_scale != "1":
518
+ dense_flow = F.interpolate(
519
+ dense_flow,
520
+ size=sizes[ins // 2],
521
+ align_corners=False,
522
+ mode="bilinear",
523
+ )
524
+
525
+ dense_certainty = F.interpolate(
526
+ dense_certainty,
527
+ size=sizes[ins // 2],
528
+ align_corners=False,
529
+ mode="bilinear",
530
+ )
531
+ if self.detach:
532
+ dense_flow = dense_flow.detach()
533
+ dense_certainty = dense_certainty.detach()
534
+ return dense_corresps
535
+
536
+
537
+ class RegressionMatcher(nn.Module):
538
+ def __init__(
539
+ self,
540
+ encoder,
541
+ decoder,
542
+ h=384,
543
+ w=512,
544
+ use_contrastive_loss = False,
545
+ alpha = 1,
546
+ beta = 0,
547
+ sample_mode = "threshold",
548
+ upsample_preds = False,
549
+ symmetric = False,
550
+ name = None,
551
+ use_soft_mutual_nearest_neighbours = False,
552
+ ):
553
+ super().__init__()
554
+ self.encoder = encoder
555
+ self.decoder = decoder
556
+ self.w_resized = w
557
+ self.h_resized = h
558
+ self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
559
+ self.use_contrastive_loss = use_contrastive_loss
560
+ self.alpha = alpha
561
+ self.beta = beta
562
+ self.sample_mode = sample_mode
563
+ self.upsample_preds = upsample_preds
564
+ self.symmetric = symmetric
565
+ self.name = name
566
+ self.sample_thresh = 0.05
567
+ self.upsample_res = (1152, 1536)
568
+ if use_soft_mutual_nearest_neighbours:
569
+ assert symmetric, "MNS requires symmetric inference"
570
+ self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
571
+
572
+ def extract_backbone_features(self, batch, batched = True, upsample = True):
573
+ #TODO: only extract stride [1,2,4,8] for upsample = True
574
+ x_q = batch["query"]
575
+ x_s = batch["support"]
576
+ if batched:
577
+ X = torch.cat((x_q, x_s))
578
+ feature_pyramid = self.encoder(X)
579
+ else:
580
+ feature_pyramid = self.encoder(x_q), self.encoder(x_s)
581
+ return feature_pyramid
582
+
583
+ def sample(
584
+ self,
585
+ dense_matches,
586
+ dense_certainty,
587
+ num=10000,
588
+ ):
589
+ if "threshold" in self.sample_mode:
590
+ upper_thresh = self.sample_thresh
591
+ dense_certainty = dense_certainty.clone()
592
+ dense_certainty_ = dense_certainty.clone()
593
+ dense_certainty[dense_certainty > upper_thresh] = 1
594
+ elif "pow" in self.sample_mode:
595
+ dense_certainty = dense_certainty**(1/3)
596
+ elif "naive" in self.sample_mode:
597
+ dense_certainty = torch.ones_like(dense_certainty)
598
+ matches, certainty = (
599
+ dense_matches.reshape(-1, 4),
600
+ dense_certainty.reshape(-1),
601
+ )
602
+ certainty_ = dense_certainty_.reshape(-1)
603
+ expansion_factor = 4 if "balanced" in self.sample_mode else 1
604
+ if not certainty.sum(): certainty = certainty + 1e-8
605
+ good_samples = torch.multinomial(certainty,
606
+ num_samples = min(expansion_factor*num, len(certainty)),
607
+ replacement=False)
608
+ good_matches, good_certainty = matches[good_samples], certainty[good_samples]
609
+ good_certainty_ = certainty_[good_samples]
610
+ good_certainty = good_certainty_
611
+ if "balanced" not in self.sample_mode:
612
+ return good_matches, good_certainty
613
+
614
+ density = kde(good_matches, std=0.1)
615
+ p = 1 / (density+1)
616
+ p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
617
+ balanced_samples = torch.multinomial(p,
618
+ num_samples = min(num,len(good_certainty)),
619
+ replacement=False)
620
+ return good_matches[balanced_samples], good_certainty[balanced_samples]
621
+
622
+ def forward(self, batch, batched = True):
623
+ feature_pyramid = self.extract_backbone_features(batch, batched=batched)
624
+ if batched:
625
+ f_q_pyramid = {
626
+ scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
627
+ }
628
+ f_s_pyramid = {
629
+ scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
630
+ }
631
+ else:
632
+ f_q_pyramid, f_s_pyramid = feature_pyramid
633
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
634
+ if self.training and self.use_contrastive_loss:
635
+ return dense_corresps, (f_q_pyramid, f_s_pyramid)
636
+ else:
637
+ return dense_corresps
638
+
639
+ def forward_symmetric(self, batch, upsample = False, batched = True):
640
+ feature_pyramid = self.extract_backbone_features(batch, upsample = upsample, batched = batched)
641
+ f_q_pyramid = feature_pyramid
642
+ f_s_pyramid = {
643
+ scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
644
+ for scale, f_scale in feature_pyramid.items()
645
+ }
646
+ dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid, upsample = upsample, **(batch["corresps"] if "corresps" in batch else {}))
647
+ return dense_corresps
648
+
649
+ def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
650
+ kpts_A, kpts_B = matches[...,:2], matches[...,2:]
651
+ kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
652
+ kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
653
+ return kpts_A, kpts_B
654
+
655
+ def match(
656
+ self,
657
+ im1_path,
658
+ im2_path,
659
+ *args,
660
+ batched=False,
661
+ ):
662
+ assert not (batched and self.upsample_preds), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
663
+ symmetric = self.symmetric
664
+ self.train(False)
665
+ with torch.no_grad():
666
+ if not batched:
667
+ b = 1
668
+ ws = self.w_resized
669
+ hs = self.h_resized
670
+ query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False)
671
+ support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False)
672
+ batch = {"query": query, "support": support}
673
+ else:
674
+ b, c, h, w = im1_path.shape
675
+ b, c, h2, w2 = im2_path.shape
676
+ assert w == w2 and h == h2, "For batched images we assume same size"
677
+ batch = {"query": im1_path, "support": im2_path}
678
+ hs, ws = self.h_resized, self.w_resized
679
+ finest_scale = 1
680
+ # Run matcher
681
+ if symmetric:
682
+ dense_corresps = self.forward_symmetric(batch, batched = True)
683
+ else:
684
+ dense_corresps = self.forward(batch, batched = True)
685
+
686
+ if self.upsample_preds:
687
+ hs, ws = self.upsample_res
688
+ low_res_certainty = F.interpolate(
689
+ dense_corresps[16]["dense_certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
690
+ )
691
+ cert_clamp = 0
692
+ factor = 0.5
693
+ low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
694
+
695
+ if self.upsample_preds:
696
+ query = F.interpolate(im1_path, size=(hs, ws), mode='bilinear', align_corners=False)
697
+ support = F.interpolate(im2_path, size=(hs, ws), mode='bilinear', align_corners=False)
698
+ batch = {"query": query, "support": support, "corresps": dense_corresps[finest_scale]}
699
+ if symmetric:
700
+ dense_corresps = self.forward_symmetric(batch, upsample = True, batched=True)
701
+ else:
702
+ dense_corresps = self.forward(batch, batched = True, upsample=True)
703
+ query_to_support = dense_corresps[finest_scale]["dense_flow"]
704
+ dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
705
+
706
+ # Get certainty interpolation
707
+ dense_certainty = dense_certainty - low_res_certainty
708
+ query_to_support = query_to_support.permute(
709
+ 0, 2, 3, 1
710
+ )
711
+ # Create im1 meshgrid
712
+ query_coords = torch.meshgrid(
713
+ (
714
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=im1_path.device),
715
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=im1_path.device),
716
+ )
717
+ )
718
+ query_coords = torch.stack((query_coords[1], query_coords[0]))
719
+ query_coords = query_coords[None].expand(b, 2, hs, ws)
720
+ dense_certainty = dense_certainty.sigmoid() # logits -> probs
721
+ query_coords = query_coords.permute(0, 2, 3, 1)
722
+ if (query_to_support.abs() > 1).any() and True:
723
+ wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
724
+ dense_certainty[wrong[:,None]] = 0
725
+ # remove black pixels
726
+ black_mask1 = (im1_path[0, 0] < 0.03125) & (im1_path[0, 1] < 0.03125) & (im1_path[0, 2] < 0.03125)
727
+ black_mask2 = (im2_path[0, 0] < 0.03125) & (im2_path[0, 1] < 0.03125) & (im2_path[0, 2] < 0.03125)
728
+ black_mask = torch.stack((black_mask1, black_mask2))[:, None]
729
+ black_mask = F.interpolate(black_mask.float(), size=tuple(dense_certainty.shape[-2:]), mode='nearest').bool()
730
+ dense_certainty[black_mask] = 0
731
+
732
+ query_to_support = torch.clamp(query_to_support, -1, 1)
733
+ if symmetric:
734
+ support_coords = query_coords
735
+ qts, stq = query_to_support.chunk(2)
736
+ q_warp = torch.cat((query_coords, qts), dim=-1)
737
+ s_warp = torch.cat((stq, support_coords), dim=-1)
738
+ warp = torch.cat((q_warp, s_warp),dim=2)
739
+ dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:,0]
740
+ else:
741
+ warp = torch.cat((query_coords, query_to_support), dim=-1)
742
+ if batched:
743
+ return (
744
+ warp,
745
+ dense_certainty
746
+ )
747
+ else:
748
+ return (
749
+ warp[0],
750
+ dense_certainty[0],
751
+ )
hloc/matchers/networks/dkm/models/encoders.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as tvm
5
+
6
+ class ResNet18(nn.Module):
7
+ def __init__(self, pretrained=False) -> None:
8
+ super().__init__()
9
+ self.net = tvm.resnet18(pretrained=pretrained)
10
+ def forward(self, x):
11
+ self = self.net
12
+ x1 = x
13
+ x = self.conv1(x1)
14
+ x = self.bn1(x)
15
+ x2 = self.relu(x)
16
+ x = self.maxpool(x2)
17
+ x4 = self.layer1(x)
18
+ x8 = self.layer2(x4)
19
+ x16 = self.layer3(x8)
20
+ x32 = self.layer4(x16)
21
+ return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}
22
+
23
+ def train(self, mode=True):
24
+ super().train(mode)
25
+ for m in self.modules():
26
+ if isinstance(m, nn.BatchNorm2d):
27
+ m.eval()
28
+ pass
29
+
30
+ class ResNet50(nn.Module):
31
+ def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
32
+ super().__init__()
33
+ if dilation is None:
34
+ dilation = [False,False,False]
35
+ if anti_aliased:
36
+ pass
37
+ else:
38
+ if weights is not None:
39
+ self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
40
+ else:
41
+ self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
42
+
43
+ del self.net.fc
44
+ self.high_res = high_res
45
+ self.freeze_bn = freeze_bn
46
+ def forward(self, x):
47
+ net = self.net
48
+ feats = {1:x}
49
+ x = net.conv1(x)
50
+ x = net.bn1(x)
51
+ x = net.relu(x)
52
+ feats[2] = x
53
+ x = net.maxpool(x)
54
+ x = net.layer1(x)
55
+ feats[4] = x
56
+ x = net.layer2(x)
57
+ feats[8] = x
58
+ x = net.layer3(x)
59
+ feats[16] = x
60
+ x = net.layer4(x)
61
+ feats[32] = x
62
+ return feats
63
+
64
+ def train(self, mode=True):
65
+ super().train(mode)
66
+ if self.freeze_bn:
67
+ for m in self.modules():
68
+ if isinstance(m, nn.BatchNorm2d):
69
+ m.eval()
70
+ pass
71
+
72
+
73
+
74
+
75
+ class ResNet101(nn.Module):
76
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
77
+ super().__init__()
78
+ if weights is not None:
79
+ self.net = tvm.resnet101(weights = weights)
80
+ else:
81
+ self.net = tvm.resnet101(pretrained=pretrained)
82
+ self.high_res = high_res
83
+ self.scale_factor = 1 if not high_res else 1.5
84
+ def forward(self, x):
85
+ net = self.net
86
+ feats = {1:x}
87
+ sf = self.scale_factor
88
+ if self.high_res:
89
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
90
+ x = net.conv1(x)
91
+ x = net.bn1(x)
92
+ x = net.relu(x)
93
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
94
+ x = net.maxpool(x)
95
+ x = net.layer1(x)
96
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
97
+ x = net.layer2(x)
98
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
99
+ x = net.layer3(x)
100
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
101
+ x = net.layer4(x)
102
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
103
+ return feats
104
+
105
+ def train(self, mode=True):
106
+ super().train(mode)
107
+ for m in self.modules():
108
+ if isinstance(m, nn.BatchNorm2d):
109
+ m.eval()
110
+ pass
111
+
112
+
113
+ class WideResNet50(nn.Module):
114
+ def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
115
+ super().__init__()
116
+ if weights is not None:
117
+ self.net = tvm.wide_resnet50_2(weights = weights)
118
+ else:
119
+ self.net = tvm.wide_resnet50_2(pretrained=pretrained)
120
+ self.high_res = high_res
121
+ self.scale_factor = 1 if not high_res else 1.5
122
+ def forward(self, x):
123
+ net = self.net
124
+ feats = {1:x}
125
+ sf = self.scale_factor
126
+ if self.high_res:
127
+ x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
128
+ x = net.conv1(x)
129
+ x = net.bn1(x)
130
+ x = net.relu(x)
131
+ feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
132
+ x = net.maxpool(x)
133
+ x = net.layer1(x)
134
+ feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
135
+ x = net.layer2(x)
136
+ feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
137
+ x = net.layer3(x)
138
+ feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
139
+ x = net.layer4(x)
140
+ feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")
141
+ return feats
142
+
143
+ def train(self, mode=True):
144
+ super().train(mode)
145
+ for m in self.modules():
146
+ if isinstance(m, nn.BatchNorm2d):
147
+ m.eval()
148
+ pass
hloc/matchers/networks/dkm/models/model_zoo/DKMv3.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...models.dkm import *
2
+ from ...models.encoders import *
3
+
4
+
5
+ def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs):
6
+ gp_dim = 256
7
+ dfn_dim = 384
8
+ feat_dim = 256
9
+ coordinate_decoder = DFN(
10
+ internal_dim=dfn_dim,
11
+ feat_input_modules=nn.ModuleDict(
12
+ {
13
+ "32": nn.Conv2d(512, feat_dim, 1, 1),
14
+ "16": nn.Conv2d(512, feat_dim, 1, 1),
15
+ }
16
+ ),
17
+ pred_input_modules=nn.ModuleDict(
18
+ {
19
+ "32": nn.Identity(),
20
+ "16": nn.Identity(),
21
+ }
22
+ ),
23
+ rrb_d_dict=nn.ModuleDict(
24
+ {
25
+ "32": RRB(gp_dim + feat_dim, dfn_dim),
26
+ "16": RRB(gp_dim + feat_dim, dfn_dim),
27
+ }
28
+ ),
29
+ cab_dict=nn.ModuleDict(
30
+ {
31
+ "32": CAB(2 * dfn_dim, dfn_dim),
32
+ "16": CAB(2 * dfn_dim, dfn_dim),
33
+ }
34
+ ),
35
+ rrb_u_dict=nn.ModuleDict(
36
+ {
37
+ "32": RRB(dfn_dim, dfn_dim),
38
+ "16": RRB(dfn_dim, dfn_dim),
39
+ }
40
+ ),
41
+ terminal_module=nn.ModuleDict(
42
+ {
43
+ "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
44
+ "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
45
+ }
46
+ ),
47
+ )
48
+ dw = True
49
+ hidden_blocks = 8
50
+ kernel_size = 5
51
+ displacement_emb = "linear"
52
+ conv_refiner = nn.ModuleDict(
53
+ {
54
+ "16": ConvRefiner(
55
+ 2 * 512+128+(2*7+1)**2,
56
+ 2 * 512+128+(2*7+1)**2,
57
+ 3,
58
+ kernel_size=kernel_size,
59
+ dw=dw,
60
+ hidden_blocks=hidden_blocks,
61
+ displacement_emb=displacement_emb,
62
+ displacement_emb_dim=128,
63
+ local_corr_radius = 7,
64
+ corr_in_other = True,
65
+ ),
66
+ "8": ConvRefiner(
67
+ 2 * 512+64+(2*3+1)**2,
68
+ 2 * 512+64+(2*3+1)**2,
69
+ 3,
70
+ kernel_size=kernel_size,
71
+ dw=dw,
72
+ hidden_blocks=hidden_blocks,
73
+ displacement_emb=displacement_emb,
74
+ displacement_emb_dim=64,
75
+ local_corr_radius = 3,
76
+ corr_in_other = True,
77
+ ),
78
+ "4": ConvRefiner(
79
+ 2 * 256+32+(2*2+1)**2,
80
+ 2 * 256+32+(2*2+1)**2,
81
+ 3,
82
+ kernel_size=kernel_size,
83
+ dw=dw,
84
+ hidden_blocks=hidden_blocks,
85
+ displacement_emb=displacement_emb,
86
+ displacement_emb_dim=32,
87
+ local_corr_radius = 2,
88
+ corr_in_other = True,
89
+ ),
90
+ "2": ConvRefiner(
91
+ 2 * 64+16,
92
+ 128+16,
93
+ 3,
94
+ kernel_size=kernel_size,
95
+ dw=dw,
96
+ hidden_blocks=hidden_blocks,
97
+ displacement_emb=displacement_emb,
98
+ displacement_emb_dim=16,
99
+ ),
100
+ "1": ConvRefiner(
101
+ 2 * 3+6,
102
+ 24,
103
+ 3,
104
+ kernel_size=kernel_size,
105
+ dw=dw,
106
+ hidden_blocks=hidden_blocks,
107
+ displacement_emb=displacement_emb,
108
+ displacement_emb_dim=6,
109
+ ),
110
+ }
111
+ )
112
+ kernel_temperature = 0.2
113
+ learn_temperature = False
114
+ no_cov = True
115
+ kernel = CosKernel
116
+ only_attention = False
117
+ basis = "fourier"
118
+ gp32 = GP(
119
+ kernel,
120
+ T=kernel_temperature,
121
+ learn_temperature=learn_temperature,
122
+ only_attention=only_attention,
123
+ gp_dim=gp_dim,
124
+ basis=basis,
125
+ no_cov=no_cov,
126
+ )
127
+ gp16 = GP(
128
+ kernel,
129
+ T=kernel_temperature,
130
+ learn_temperature=learn_temperature,
131
+ only_attention=only_attention,
132
+ gp_dim=gp_dim,
133
+ basis=basis,
134
+ no_cov=no_cov,
135
+ )
136
+ gps = nn.ModuleDict({"32": gp32, "16": gp16})
137
+ proj = nn.ModuleDict(
138
+ {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
139
+ )
140
+ decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
141
+
142
+ encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False)
143
+ matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs)
144
+ # res = matcher.load_state_dict(weights)
145
+ return matcher
hloc/matchers/networks/dkm/models/model_zoo/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ weight_urls = {
2
+ "DKMv3": {
3
+ "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
4
+ "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
5
+ },
6
+ }
7
+ import torch
8
+ from .DKMv3 import DKMv3
9
+
10
+
11
+ def DKMv3_outdoor(path_to_weights = None, device=None):
12
+ """
13
+ Loads DKMv3 outdoor weights, uses internal resolution of (540, 720) by default
14
+ resolution can be changed by setting model.h_resized, model.w_resized later.
15
+ Additionally upsamples preds to fixed resolution of (864, 1152),
16
+ can be turned off by model.upsample_preds = False
17
+ """
18
+ if device is None:
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ if path_to_weights is not None:
21
+ weights = torch.load(path_to_weights, map_location=device)
22
+ else:
23
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["outdoor"],
24
+ map_location=device)
25
+ return DKMv3(weights, 540, 720, upsample_preds = True, device=device)
26
+
27
+ def DKMv3_indoor(path_to_weights = None, device=None):
28
+ """
29
+ Loads DKMv3 indoor weights, uses internal resolution of (480, 640) by default
30
+ Resolution can be changed by setting model.h_resized, model.w_resized later.
31
+ """
32
+ if device is None:
33
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ if path_to_weights is not None:
35
+ weights = torch.load(path_to_weights, map_location=device)
36
+ else:
37
+ weights = torch.hub.load_state_dict_from_url(weight_urls["DKMv3"]["indoor"],
38
+ map_location=device)
39
+ return DKMv3(weights, 480, 640, upsample_preds = False, device=device)
hloc/matchers/networks/dkm/utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ pose_auc,
3
+ get_pose,
4
+ compute_relative_pose,
5
+ compute_pose_error,
6
+ estimate_pose,
7
+ rotate_intrinsic,
8
+ get_tuple_transform_ops,
9
+ get_depth_tuple_transform_ops,
10
+ warp_kpts,
11
+ numpy_to_pil,
12
+ tensor_to_pil,
13
+ )
hloc/matchers/networks/dkm/utils/kde.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+ def fast_kde(x, std = 0.1, kernel_size = 9, dilation = 3, padding = 9//2, stride = 1):
6
+ raise NotImplementedError("WIP, use at your own risk.")
7
+ # Note: when doing symmetric matching this might not be very exact, since we only check neighbours on the grid
8
+ x = x.permute(0,3,1,2)
9
+ B,C,H,W = x.shape
10
+ K = kernel_size ** 2
11
+ unfolded_x = F.unfold(x,kernel_size=kernel_size, dilation = dilation, padding = padding, stride = stride).reshape(B, C, K, H, W)
12
+ scores = (-(unfolded_x - x[:,:,None]).sum(dim=1)**2/(2*std**2)).exp()
13
+ density = scores.sum(dim=1)
14
+ return density
15
+
16
+
17
+ def kde(x, std = 0.1, device=None):
18
+ if device is None:
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ if isinstance(x, np.ndarray):
21
+ x = torch.from_numpy(x)
22
+ # use a gaussian kernel to estimate density
23
+ x = x.to(device)
24
+ scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
25
+ density = scores.sum(dim=-1)
26
+ return density
hloc/matchers/networks/dkm/utils/local_correlation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def local_correlation(
6
+ feature0,
7
+ feature1,
8
+ local_radius,
9
+ padding_mode="zeros",
10
+ flow = None
11
+ ):
12
+ device = feature0.device
13
+ b, c, h, w = feature0.size()
14
+ if flow is None:
15
+ # If flow is None, assume feature0 and feature1 are aligned
16
+ coords = torch.meshgrid(
17
+ (
18
+ torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
19
+ torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
20
+ ))
21
+ coords = torch.stack((coords[1], coords[0]), dim=-1)[
22
+ None
23
+ ].expand(b, h, w, 2)
24
+ else:
25
+ coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
26
+ r = local_radius
27
+ local_window = torch.meshgrid(
28
+ (
29
+ torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=device),
30
+ torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=device),
31
+ ))
32
+ local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
33
+ None
34
+ ].expand(b, 2*r+1, 2*r+1, 2).reshape(b, (2*r+1)**2, 2)
35
+ coords = (coords[:,:,:,None]+local_window[:,None,None]).reshape(b,h,w*(2*r+1)**2,2)
36
+ window_feature = F.grid_sample(
37
+ feature1, coords, padding_mode=padding_mode, align_corners=False
38
+ )[...,None].reshape(b,c,h,w,(2*r+1)**2)
39
+ corr = torch.einsum("bchw, bchwk -> bkhw", feature0, window_feature)/(c**.5)
40
+ return corr
hloc/matchers/networks/dkm/utils/transforms.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ import torch
4
+ import kornia.augmentation as K
5
+ from kornia.geometry.transform import warp_perspective
6
+
7
+ # Adapted from Kornia
8
+ class GeometricSequential:
9
+ def __init__(self, *transforms, align_corners=True) -> None:
10
+ self.transforms = transforms
11
+ self.align_corners = align_corners
12
+
13
+ def __call__(self, x, mode="bilinear"):
14
+ b, c, h, w = x.shape
15
+ M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
16
+ for t in self.transforms:
17
+ if np.random.rand() < t.p:
18
+ M = M.matmul(
19
+ t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
20
+ )
21
+ return (
22
+ warp_perspective(
23
+ x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
24
+ ),
25
+ M,
26
+ )
27
+
28
+ def apply_transform(self, x, M, mode="bilinear"):
29
+ b, c, h, w = x.shape
30
+ return warp_perspective(
31
+ x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
32
+ )
33
+
34
+
35
+ class RandomPerspective(K.RandomPerspective):
36
+ def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
37
+ distortion_scale = torch.as_tensor(
38
+ self.distortion_scale, device=self._device, dtype=self._dtype
39
+ )
40
+ return self.random_perspective_generator(
41
+ batch_shape[0],
42
+ batch_shape[-2],
43
+ batch_shape[-1],
44
+ distortion_scale,
45
+ self.same_on_batch,
46
+ self.device,
47
+ self.dtype,
48
+ )
49
+
50
+ def random_perspective_generator(
51
+ self,
52
+ batch_size: int,
53
+ height: int,
54
+ width: int,
55
+ distortion_scale: torch.Tensor,
56
+ same_on_batch: bool = False,
57
+ device: torch.device = torch.device("cpu"),
58
+ dtype: torch.dtype = torch.float32,
59
+ ) -> Dict[str, torch.Tensor]:
60
+ r"""Get parameters for ``perspective`` for a random perspective transform.
61
+
62
+ Args:
63
+ batch_size (int): the tensor batch size.
64
+ height (int) : height of the image.
65
+ width (int): width of the image.
66
+ distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
67
+ same_on_batch (bool): apply the same transformation across the batch. Default: False.
68
+ device (torch.device): the device on which the random numbers will be generated. Default: cpu.
69
+ dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
70
+
71
+ Returns:
72
+ params Dict[str, torch.Tensor]: parameters to be passed for transformation.
73
+ - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
74
+ - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
75
+
76
+ Note:
77
+ The generated random numbers are not reproducible across different devices and dtypes.
78
+ """
79
+ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
80
+ raise AssertionError(
81
+ f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
82
+ )
83
+ if not (
84
+ type(height) is int and height > 0 and type(width) is int and width > 0
85
+ ):
86
+ raise AssertionError(
87
+ f"'height' and 'width' must be integers. Got {height}, {width}."
88
+ )
89
+
90
+ start_points: torch.Tensor = torch.tensor(
91
+ [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
92
+ device=distortion_scale.device,
93
+ dtype=distortion_scale.dtype,
94
+ ).expand(batch_size, -1, -1)
95
+
96
+ # generate random offset not larger than half of the image
97
+ fx = distortion_scale * width / 2
98
+ fy = distortion_scale * height / 2
99
+
100
+ factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
101
+ offset = (torch.rand_like(start_points) - 0.5) * 2
102
+ end_points = start_points + factor * offset
103
+
104
+ return dict(start_points=start_points, end_points=end_points)
hloc/matchers/networks/dkm/utils/utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ from torchvision import transforms
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
12
+ # --- GEOMETRY ---
13
+ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
14
+ if len(kpts0) < 5:
15
+ return None
16
+ K0inv = np.linalg.inv(K0[:2,:2])
17
+ K1inv = np.linalg.inv(K1[:2,:2])
18
+
19
+ kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
20
+ kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
21
+
22
+ E, mask = cv2.findEssentialMat(
23
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=cv2.RANSAC
24
+ )
25
+
26
+ ret = None
27
+ if E is not None:
28
+ best_num_inliers = 0
29
+
30
+ for _E in np.split(E, len(E) / 3):
31
+ n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
32
+ if n > best_num_inliers:
33
+ best_num_inliers = n
34
+ ret = (R, t, mask.ravel() > 0)
35
+ return ret
36
+
37
+
38
+ def rotate_intrinsic(K, n):
39
+ base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
40
+ rot = np.linalg.matrix_power(base_rot, n)
41
+ return rot @ K
42
+
43
+
44
+ def rotate_pose_inplane(i_T_w, rot):
45
+ rotation_matrices = [
46
+ np.array(
47
+ [
48
+ [np.cos(r), -np.sin(r), 0.0, 0.0],
49
+ [np.sin(r), np.cos(r), 0.0, 0.0],
50
+ [0.0, 0.0, 1.0, 0.0],
51
+ [0.0, 0.0, 0.0, 1.0],
52
+ ],
53
+ dtype=np.float32,
54
+ )
55
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
56
+ ]
57
+ return np.dot(rotation_matrices[rot], i_T_w)
58
+
59
+
60
+ def scale_intrinsics(K, scales):
61
+ scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
62
+ return np.dot(scales, K)
63
+
64
+
65
+ def to_homogeneous(points):
66
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
67
+
68
+
69
+ def angle_error_mat(R1, R2):
70
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
71
+ cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
72
+ return np.rad2deg(np.abs(np.arccos(cos)))
73
+
74
+
75
+ def angle_error_vec(v1, v2):
76
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
77
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
78
+
79
+
80
+ def compute_pose_error(T_0to1, R, t):
81
+ R_gt = T_0to1[:3, :3]
82
+ t_gt = T_0to1[:3, 3]
83
+ error_t = angle_error_vec(t.squeeze(), t_gt)
84
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
85
+ error_R = angle_error_mat(R, R_gt)
86
+ return error_t, error_R
87
+
88
+
89
+ def pose_auc(errors, thresholds):
90
+ sort_idx = np.argsort(errors)
91
+ errors = np.array(errors.copy())[sort_idx]
92
+ recall = (np.arange(len(errors)) + 1) / len(errors)
93
+ errors = np.r_[0.0, errors]
94
+ recall = np.r_[0.0, recall]
95
+ aucs = []
96
+ for t in thresholds:
97
+ last_index = np.searchsorted(errors, t)
98
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
99
+ e = np.r_[errors[:last_index], t]
100
+ aucs.append(np.trapz(r, x=e) / t)
101
+ return aucs
102
+
103
+
104
+ # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
105
+ def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
106
+ ops = []
107
+ if resize:
108
+ ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
109
+ return TupleCompose(ops)
110
+
111
+
112
+ def get_tuple_transform_ops(resize=None, normalize=True, unscale=False):
113
+ ops = []
114
+ if resize:
115
+ ops.append(TupleResize(resize))
116
+ if normalize:
117
+ ops.append(TupleToTensorScaled())
118
+ # ops.append(
119
+ # TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
120
+ # ) # Imagenet mean/std
121
+ else:
122
+ if unscale:
123
+ ops.append(TupleToTensorUnscaled())
124
+ else:
125
+ ops.append(TupleToTensorScaled())
126
+ return TupleCompose(ops)
127
+
128
+
129
+ class ToTensorScaled(object):
130
+ """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
131
+
132
+ def __call__(self, im):
133
+ if not isinstance(im, torch.Tensor):
134
+ im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
135
+ im /= 255.0
136
+ return torch.from_numpy(im)
137
+ else:
138
+ return im
139
+
140
+ def __repr__(self):
141
+ return "ToTensorScaled(./255)"
142
+
143
+
144
+ class TupleToTensorScaled(object):
145
+ def __init__(self):
146
+ self.to_tensor = ToTensorScaled()
147
+
148
+ def __call__(self, im_tuple):
149
+ return [self.to_tensor(im) for im in im_tuple]
150
+
151
+ def __repr__(self):
152
+ return "TupleToTensorScaled(./255)"
153
+
154
+
155
+ class ToTensorUnscaled(object):
156
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
157
+
158
+ def __call__(self, im):
159
+ return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
160
+
161
+ def __repr__(self):
162
+ return "ToTensorUnscaled()"
163
+
164
+
165
+ class TupleToTensorUnscaled(object):
166
+ """Convert a RGB PIL Image to a CHW ordered Tensor"""
167
+
168
+ def __init__(self):
169
+ self.to_tensor = ToTensorUnscaled()
170
+
171
+ def __call__(self, im_tuple):
172
+ return [self.to_tensor(im) for im in im_tuple]
173
+
174
+ def __repr__(self):
175
+ return "TupleToTensorUnscaled()"
176
+
177
+
178
+ class TupleResize(object):
179
+ def __init__(self, size, mode=InterpolationMode.BICUBIC):
180
+ self.size = size
181
+ self.resize = transforms.Resize(size, mode)
182
+
183
+ def __call__(self, im_tuple):
184
+ return [self.resize(im) for im in im_tuple]
185
+
186
+ def __repr__(self):
187
+ return "TupleResize(size={})".format(self.size)
188
+
189
+
190
+ class TupleNormalize(object):
191
+ def __init__(self, mean, std):
192
+ self.mean = mean
193
+ self.std = std
194
+ self.normalize = transforms.Normalize(mean=mean, std=std)
195
+
196
+ def __call__(self, im_tuple):
197
+ return [self.normalize(im) for im in im_tuple]
198
+
199
+ def __repr__(self):
200
+ return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
201
+
202
+
203
+ class TupleCompose(object):
204
+ def __init__(self, transforms):
205
+ self.transforms = transforms
206
+
207
+ def __call__(self, im_tuple):
208
+ for t in self.transforms:
209
+ im_tuple = t(im_tuple)
210
+ return im_tuple
211
+
212
+ def __repr__(self):
213
+ format_string = self.__class__.__name__ + "("
214
+ for t in self.transforms:
215
+ format_string += "\n"
216
+ format_string += " {0}".format(t)
217
+ format_string += "\n)"
218
+ return format_string
219
+
220
+
221
+ @torch.no_grad()
222
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
223
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
224
+ Also check covisibility and depth consistency.
225
+ Depth is consistent if relative error < 0.2 (hard-coded).
226
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
227
+ Args:
228
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
229
+ depth0 (torch.Tensor): [N, H, W],
230
+ depth1 (torch.Tensor): [N, H, W],
231
+ T_0to1 (torch.Tensor): [N, 3, 4],
232
+ K0 (torch.Tensor): [N, 3, 3],
233
+ K1 (torch.Tensor): [N, 3, 3],
234
+ Returns:
235
+ calculable_mask (torch.Tensor): [N, L]
236
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
237
+ """
238
+ (
239
+ n,
240
+ h,
241
+ w,
242
+ ) = depth0.shape
243
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode="bilinear")[
244
+ :, 0, :, 0
245
+ ]
246
+ kpts0 = torch.stack(
247
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
248
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
249
+ # Sample depth, get calculable_mask on depth != 0
250
+ nonzero_mask = kpts0_depth != 0
251
+
252
+ # Unproject
253
+ kpts0_h = (
254
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
255
+ * kpts0_depth[..., None]
256
+ ) # (N, L, 3)
257
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
258
+ kpts0_cam = kpts0_n
259
+
260
+ # Rigid Transform
261
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
262
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
263
+
264
+ # Project
265
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
266
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
267
+ w_kpts0_h[:, :, [2]] + 1e-4
268
+ ) # (N, L, 2), +1e-4 to avoid zero depth
269
+
270
+ # Covisible Check
271
+ h, w = depth1.shape[1:3]
272
+ covisible_mask = (
273
+ (w_kpts0[:, :, 0] > 0)
274
+ * (w_kpts0[:, :, 0] < w - 1)
275
+ * (w_kpts0[:, :, 1] > 0)
276
+ * (w_kpts0[:, :, 1] < h - 1)
277
+ )
278
+ w_kpts0 = torch.stack(
279
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
280
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
281
+ # w_kpts0[~covisible_mask, :] = -5 # xd
282
+
283
+ w_kpts0_depth = F.grid_sample(
284
+ depth1[:, None], w_kpts0[:, :, None], mode="bilinear"
285
+ )[:, 0, :, 0]
286
+ consistent_mask = (
287
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
288
+ ).abs() < 0.05
289
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
290
+
291
+ return valid_mask, w_kpts0
292
+
293
+
294
+ imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
295
+ imagenet_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
296
+
297
+
298
+ def numpy_to_pil(x: np.ndarray):
299
+ """
300
+ Args:
301
+ x: Assumed to be of shape (h,w,c)
302
+ """
303
+ if isinstance(x, torch.Tensor):
304
+ x = x.detach().cpu().numpy()
305
+ if x.max() <= 1.01:
306
+ x *= 255
307
+ x = x.astype(np.uint8)
308
+ return Image.fromarray(x)
309
+
310
+
311
+ def tensor_to_pil(x, unnormalize=False):
312
+ if unnormalize:
313
+ x = x * imagenet_std[:, None, None] + imagenet_mean[:, None, None]
314
+ x = x.detach().permute(1, 2, 0).cpu().numpy()
315
+ x = np.clip(x, 0.0, 1.0)
316
+ return numpy_to_pil(x)
317
+
318
+
319
+ def to_cuda(batch):
320
+ for key, value in batch.items():
321
+ if isinstance(value, torch.Tensor):
322
+ batch[key] = value.to(device)
323
+ return batch
324
+
325
+
326
+ def to_cpu(batch):
327
+ for key, value in batch.items():
328
+ if isinstance(value, torch.Tensor):
329
+ batch[key] = value.cpu()
330
+ return batch
331
+
332
+
333
+ def get_pose(calib):
334
+ w, h = np.array(calib["imsize"])[0]
335
+ return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
336
+
337
+
338
+ def compute_relative_pose(R1, t1, R2, t2):
339
+ rots = R2 @ (R1.T)
340
+ trans = -rots @ t1 + t2
341
+ return rots, trans