yansong1616 commited on
Commit
56cd6b7
·
verified ·
1 Parent(s): 633d2c0

Upload 90 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dust3r/__init__.py +2 -0
  2. dust3r/__pycache__/__init__.cpython-310.pyc +0 -0
  3. dust3r/__pycache__/__init__.cpython-38.pyc +0 -0
  4. dust3r/__pycache__/__init__.cpython-39.pyc +0 -0
  5. dust3r/__pycache__/image_pairs.cpython-310.pyc +0 -0
  6. dust3r/__pycache__/image_pairs.cpython-38.pyc +0 -0
  7. dust3r/__pycache__/inference.cpython-310.pyc +0 -0
  8. dust3r/__pycache__/inference.cpython-38.pyc +0 -0
  9. dust3r/__pycache__/inference.cpython-39.pyc +0 -0
  10. dust3r/__pycache__/model.cpython-310.pyc +0 -0
  11. dust3r/__pycache__/model.cpython-38.pyc +0 -0
  12. dust3r/__pycache__/model.cpython-39.pyc +0 -0
  13. dust3r/__pycache__/optim_factory.cpython-310.pyc +0 -0
  14. dust3r/__pycache__/optim_factory.cpython-38.pyc +0 -0
  15. dust3r/__pycache__/patch_embed.cpython-310.pyc +0 -0
  16. dust3r/__pycache__/patch_embed.cpython-38.pyc +0 -0
  17. dust3r/__pycache__/post_process.cpython-310.pyc +0 -0
  18. dust3r/__pycache__/render_to_3d.cpython-310.pyc +0 -0
  19. dust3r/__pycache__/viz.cpython-310.pyc +0 -0
  20. dust3r/__pycache__/viz.cpython-38.pyc +0 -0
  21. dust3r/cloud_opt/__init__.py +29 -0
  22. dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc +0 -0
  23. dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc +0 -0
  24. dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc +0 -0
  25. dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc +0 -0
  26. dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc +0 -0
  27. dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc +0 -0
  28. dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc +0 -0
  29. dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc +0 -0
  30. dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc +0 -0
  31. dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc +0 -0
  32. dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc +0 -0
  33. dust3r/cloud_opt/base_opt.py +380 -0
  34. dust3r/cloud_opt/commons.py +91 -0
  35. dust3r/cloud_opt/init_im_poses.py +316 -0
  36. dust3r/cloud_opt/optimizer.py +249 -0
  37. dust3r/cloud_opt/pair_viewer.py +125 -0
  38. dust3r/datasets/__init__.py +42 -0
  39. dust3r/datasets/base/__init__.py +2 -0
  40. dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
  41. dust3r/datasets/base/batched_sampler.py +74 -0
  42. dust3r/datasets/base/easy_dataset.py +157 -0
  43. dust3r/datasets/co3d.py +146 -0
  44. dust3r/datasets/utils/__init__.py +2 -0
  45. dust3r/datasets/utils/cropping.py +119 -0
  46. dust3r/datasets/utils/transforms.py +11 -0
  47. dust3r/heads/__init__.py +19 -0
  48. dust3r/heads/__pycache__/__init__.cpython-310.pyc +0 -0
  49. dust3r/heads/__pycache__/__init__.cpython-38.pyc +0 -0
  50. dust3r/heads/__pycache__/__init__.cpython-39.pyc +0 -0
dust3r/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
dust3r/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (143 Bytes). View file
 
dust3r/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (143 Bytes). View file
 
dust3r/__pycache__/image_pairs.cpython-310.pyc ADDED
Binary file (3.19 kB). View file
 
dust3r/__pycache__/image_pairs.cpython-38.pyc ADDED
Binary file (3.25 kB). View file
 
dust3r/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.2 kB). View file
 
dust3r/__pycache__/inference.cpython-38.pyc ADDED
Binary file (5.21 kB). View file
 
dust3r/__pycache__/inference.cpython-39.pyc ADDED
Binary file (5.2 kB). View file
 
dust3r/__pycache__/model.cpython-310.pyc ADDED
Binary file (5.99 kB). View file
 
dust3r/__pycache__/model.cpython-38.pyc ADDED
Binary file (5.96 kB). View file
 
dust3r/__pycache__/model.cpython-39.pyc ADDED
Binary file (5.97 kB). View file
 
dust3r/__pycache__/optim_factory.cpython-310.pyc ADDED
Binary file (371 Bytes). View file
 
dust3r/__pycache__/optim_factory.cpython-38.pyc ADDED
Binary file (367 Bytes). View file
 
dust3r/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
dust3r/__pycache__/patch_embed.cpython-38.pyc ADDED
Binary file (2.76 kB). View file
 
dust3r/__pycache__/post_process.cpython-310.pyc ADDED
Binary file (1.65 kB). View file
 
dust3r/__pycache__/render_to_3d.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
dust3r/__pycache__/viz.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
dust3r/__pycache__/viz.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
dust3r/cloud_opt/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # global alignment optimization wrapper function
6
+ # --------------------------------------------------------
7
+ from enum import Enum
8
+
9
+ from .optimizer import PointCloudOptimizer
10
+ from .pair_viewer import PairViewer
11
+
12
+
13
+ class GlobalAlignerMode(Enum):
14
+ PointCloudOptimizer = "PointCloudOptimizer"
15
+ PairViewer = "PairViewer"
16
+
17
+
18
+ def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
19
+ # extract all inputs
20
+ view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
21
+ # build the optimizer
22
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
23
+ net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
24
+ elif mode == GlobalAlignerMode.PairViewer:
25
+ net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
26
+ else:
27
+ raise NotImplementedError(f'Unknown mode {mode}')
28
+
29
+ return net
dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.04 kB). View file
 
dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc ADDED
Binary file (3.36 kB). View file
 
dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc ADDED
Binary file (3.41 kB). View file
 
dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc ADDED
Binary file (8.42 kB). View file
 
dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc ADDED
Binary file (8.45 kB). View file
 
dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
dust3r/cloud_opt/base_opt.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Base class for the global alignement procedure
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import roma
13
+ from copy import deepcopy
14
+ import tqdm
15
+
16
+ from dust3r.utils.geometry import inv, geotrf
17
+ from dust3r.utils.device import to_numpy
18
+ from dust3r.utils.image import rgb
19
+ from dust3r.viz import SceneViz, segment_sky, auto_cam_size
20
+ from dust3r.optim_factory import adjust_learning_rate_by_lr
21
+
22
+ from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
23
+ cosine_schedule, linear_schedule, get_conf_trf)
24
+ import dust3r.cloud_opt.init_im_poses as init_fun
25
+
26
+
27
+ class BasePCOptimizer (nn.Module):
28
+ """ Optimize a global scene, given a list of pairwise observations.
29
+ Graph node: images
30
+ Graph edges: observations = (pred1, pred2)
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ if len(args) == 1 and len(kwargs) == 0:
35
+ other = deepcopy(args[0])
36
+ attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
37
+ min_conf_thr conf_thr conf_i conf_j im_conf
38
+ base_scale norm_pw_scale POSE_DIM pw_poses
39
+ pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split()
40
+ self.__dict__.update({k: other[k] for k in attrs})
41
+ else:
42
+ self._init_from_views(*args, **kwargs)
43
+
44
+ def _init_from_views(self, view1, view2, pred1, pred2,
45
+ dist='l1',
46
+ conf='log',
47
+ min_conf_thr=3,
48
+ base_scale=0.5,
49
+ allow_pw_adaptors=False,
50
+ pw_break=20,
51
+ rand_pose=torch.randn,
52
+ iterationsCount=None,
53
+ ):
54
+ super().__init__()
55
+ if not isinstance(view1['idx'], list):
56
+ view1['idx'] = view1['idx'].tolist()
57
+ if not isinstance(view2['idx'], list):
58
+ view2['idx'] = view2['idx'].tolist()
59
+ self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
60
+ self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
61
+ self.dist = ALL_DISTS[dist]
62
+
63
+
64
+ self.n_imgs = self._check_edges()
65
+
66
+ # input data
67
+ pred1_pts = pred1['pts3d']
68
+ pred2_pts = pred2['pts3d_in_other_view']
69
+ self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
70
+ self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
71
+ self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
72
+
73
+ # work in log-scale with conf
74
+ pred1_conf = pred1['conf']
75
+ pred2_conf = pred2['conf']
76
+ self.min_conf_thr = min_conf_thr
77
+ self.conf_trf = get_conf_trf(conf)
78
+
79
+ self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
80
+ self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
81
+ self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
82
+
83
+ # pairwise pose parameters
84
+ self.base_scale = base_scale
85
+ self.norm_pw_scale = True
86
+ self.pw_break = pw_break
87
+ self.POSE_DIM = 7
88
+ self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
89
+ self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
90
+ self.pw_adaptors.requires_grad_(allow_pw_adaptors)
91
+ self.has_im_poses = False
92
+ self.rand_pose = rand_pose
93
+
94
+ # possibly store images for show_pointcloud
95
+ self.imgs = None
96
+ if 'img' in view1 and 'img' in view2:
97
+ imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
98
+ for v in range(len(self.edges)):
99
+ idx = view1['idx'][v]
100
+ imgs[idx] = view1['img'][v]
101
+ idx = view2['idx'][v]
102
+ imgs[idx] = view2['img'][v]
103
+ self.imgs = rgb(imgs)
104
+
105
+ @property
106
+ def n_edges(self):
107
+ return len(self.edges)
108
+
109
+ @property
110
+ def str_edges(self):
111
+ return [edge_str(i, j) for i, j in self.edges]
112
+
113
+ @property
114
+ def imsizes(self):
115
+ return [(w, h) for h, w in self.imshapes]
116
+
117
+ @property
118
+ def device(self):
119
+ return next(iter(self.parameters())).device
120
+
121
+ def state_dict(self, trainable=True):
122
+ all_params = super().state_dict()
123
+ return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
124
+
125
+ def load_state_dict(self, data):
126
+ return super().load_state_dict(self.state_dict(trainable=False) | data)
127
+
128
+ def _check_edges(self):
129
+ indices = sorted({i for edge in self.edges for i in edge})
130
+ assert indices == list(range(len(indices))), 'bad pair indices: missing values '
131
+ return len(indices)
132
+
133
+ @torch.no_grad()
134
+ def _compute_img_conf(self, pred1_conf, pred2_conf):
135
+ im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
136
+ for e, (i, j) in enumerate(self.edges):
137
+ im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
138
+ im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
139
+ return im_conf
140
+
141
+ def get_adaptors(self): # 公式(5)中的σ_e
142
+ adapt = self.pw_adaptors
143
+ adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
144
+ if self.norm_pw_scale: # normalize so that the product == 1
145
+ adapt = adapt - adapt.mean(dim=1, keepdim=True) # 归一化
146
+ return (adapt / self.pw_break).exp() # TODO gys:公式(5)中的σ_e是什么?
147
+
148
+ def _get_poses(self, poses): # self.im_poses 或者 self.pw_poses
149
+ # normalize rotation
150
+ Q = poses[:, :4]
151
+ T = signed_expm1(poses[:, 4:7])
152
+ RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
153
+ return RT
154
+
155
+ def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
156
+ # all poses == cam-to-world
157
+ pose = poses[idx]
158
+ if not (pose.requires_grad or force):
159
+ return pose
160
+
161
+ if R.shape == (4, 4):
162
+ assert T is None
163
+ T = R[:3, 3]
164
+ R = R[:3, :3]
165
+
166
+ if R is not None:
167
+ pose.data[0:4] = roma.rotmat_to_unitquat(R)
168
+ if T is not None:
169
+ pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
170
+
171
+ if scale is not None:
172
+ assert poses.shape[-1] in (8, 13)
173
+ pose.data[-1] = np.log(float(scale))
174
+ return pose
175
+
176
+ def get_pw_norm_scale_factor(self):
177
+ if self.norm_pw_scale:
178
+ # normalize scales so that things cannot go south
179
+ # we want that exp(scale) ~= self.base_scale
180
+ return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
181
+ else:
182
+ return 1 # don't norm scale for known poses
183
+
184
+ def get_pw_scale(self):
185
+ scale = self.pw_poses[:, -1].exp() # (n_edges,)
186
+ scale = scale * self.get_pw_norm_scale_factor()
187
+ return scale
188
+
189
+ def get_pw_poses(self): # cam to world
190
+ RT = self._get_poses(self.pw_poses)
191
+ scaled_RT = RT.clone()
192
+ scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
193
+ return scaled_RT
194
+
195
+ def get_masks(self):
196
+ return [(conf > self.min_conf_thr) for conf in self.im_conf]
197
+
198
+ def depth_to_pts3d(self):
199
+ raise NotImplementedError()
200
+
201
+ def get_pts3d(self, raw=False):
202
+ res = self.depth_to_pts3d()
203
+ if not raw:
204
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
205
+ return res
206
+
207
+ def _set_focal(self, idx, focal, force=False):
208
+ raise NotImplementedError()
209
+
210
+ def get_focals(self):
211
+ raise NotImplementedError()
212
+
213
+ def get_known_focal_mask(self):
214
+ raise NotImplementedError()
215
+
216
+ def get_principal_points(self):
217
+ raise NotImplementedError()
218
+
219
+ def get_conf(self, mode=None):
220
+ trf = self.conf_trf if mode is None else get_conf_trf(mode)
221
+ return [trf(c) for c in self.im_conf]
222
+
223
+ def get_im_poses(self):
224
+ raise NotImplementedError()
225
+
226
+ def _set_depthmap(self, idx, depth, force=False):
227
+ raise NotImplementedError()
228
+
229
+ def get_depthmaps(self, raw=False):
230
+ raise NotImplementedError()
231
+
232
+ @torch.no_grad()
233
+ def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
234
+ """ Method:
235
+ 1) express all 3d points in each camera coordinate frame
236
+ 2) if they're in front of a depthmap --> then lower their confidence
237
+ """
238
+ assert 0 <= tol < 1
239
+ cams = inv(self.get_im_poses())
240
+ K = self.get_intrinsics()
241
+ depthmaps = self.get_depthmaps()
242
+ res = deepcopy(self)
243
+
244
+ for i, pts3d in enumerate(self.depth_to_pts3d()):
245
+ for j in range(self.n_imgs):
246
+ if i == j:
247
+ continue
248
+
249
+ # project 3dpts in other view
250
+ Hi, Wi = self.imshapes[i]
251
+ Hj, Wj = self.imshapes[j]
252
+ proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
253
+ proj_depth = proj[:, :, 2]
254
+ u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
255
+
256
+ # check which points are actually in the visible cone
257
+ msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
258
+ msk_j = v[msk_i], u[msk_i]
259
+
260
+ # find bad points = those in front but less confident
261
+ bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
262
+ ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
263
+
264
+ bad_msk_i = msk_i.clone()
265
+ bad_msk_i[msk_i] = bad_points
266
+ res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
267
+
268
+ return res
269
+
270
+ def forward(self, ret_details=False):
271
+ pw_poses = self.get_pw_poses() # cam-to-world
272
+ pw_adapt = self.get_adaptors()
273
+ proj_pts3d = self.get_pts3d()
274
+ # pre-compute pixel weights
275
+ weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
276
+ weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
277
+
278
+ loss = 0
279
+ if ret_details:
280
+ details = -torch.ones((self.n_imgs, self.n_imgs))
281
+
282
+ for e, (i, j) in enumerate(self.edges):
283
+ i_j = edge_str(i, j)
284
+ # distance in image i and j
285
+ aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
286
+ aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
287
+ li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
288
+ lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
289
+ loss = loss + li + lj
290
+
291
+ if ret_details:
292
+ details[i, j] = li + lj
293
+ loss /= self.n_edges # average over all pairs
294
+
295
+ if ret_details:
296
+ return loss, details
297
+ return loss
298
+
299
+ def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
300
+ if init is None:
301
+ pass
302
+ elif init == 'msp' or init == 'mst':
303
+ # ==============3.3.Downstream Applications:主要是为3.4. Global Alignment中的公式(5)初始化内外参矩阵和待估计的世界坐标系的坐标============
304
+ init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
305
+ elif init == 'known_poses':
306
+ init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP)
307
+ else:
308
+ raise ValueError(f'bad value for {init=}')
309
+
310
+ global_alignment_loop(self, **kw) # 3.4. Global Alignment:梯度下降公式(5)
311
+
312
+ @torch.no_grad()
313
+ def mask_sky(self):
314
+ res = deepcopy(self)
315
+ for i in range(self.n_imgs):
316
+ sky = segment_sky(self.imgs[i])
317
+ res.im_conf[i][sky] = 0
318
+ return res
319
+
320
+ def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
321
+ viz = SceneViz()
322
+ if self.imgs is None:
323
+ colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
324
+ colors = list(map(tuple, colors.tolist()))
325
+ for n in range(self.n_imgs):
326
+ viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
327
+ else:
328
+ viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
329
+ colors = np.random.randint(256, size=(self.n_imgs, 3))
330
+
331
+ # camera poses
332
+ im_poses = to_numpy(self.get_im_poses())
333
+ if cam_size is None:
334
+ cam_size = auto_cam_size(im_poses)
335
+ viz.add_cameras(im_poses, self.get_focals(), colors=colors,
336
+ images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
337
+ if show_pw_cams:
338
+ pw_poses = self.get_pw_poses()
339
+ viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
340
+
341
+ if show_pw_pts3d:
342
+ pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
343
+ viz.add_pointcloud(pts, (128, 0, 128))
344
+
345
+ viz.show(**kw)
346
+ return viz
347
+
348
+
349
+ def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False):
350
+ params = [p for p in net.parameters() if p.requires_grad]
351
+ if not params:
352
+ return net
353
+
354
+ if verbose:
355
+ print([name for name, value in net.named_parameters() if value.requires_grad])
356
+
357
+ lr_base = lr
358
+ optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
359
+
360
+ with tqdm.tqdm(total=niter) as bar:
361
+ while bar.n < bar.total:
362
+ t = bar.n / bar.total
363
+
364
+ if schedule == 'cosine':
365
+ lr = cosine_schedule(t, lr_base, lr_min)
366
+ elif schedule == 'linear':
367
+ lr = linear_schedule(t, lr_base, lr_min)
368
+ else:
369
+ raise ValueError(f'bad lr {schedule=}')
370
+ adjust_learning_rate_by_lr(optimizer, lr)
371
+
372
+ optimizer.zero_grad()
373
+ loss = net() # 论文中:Global optimization
374
+ loss.backward()
375
+ optimizer.step()
376
+ loss = float(loss)
377
+ bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
378
+ if bar.n % 30 == 0:
379
+ print(' ')
380
+ bar.update()
dust3r/cloud_opt/commons.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utility functions for global alignment
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+
12
+ def edge_str(i, j):
13
+ return f'{i}_{j}'
14
+
15
+
16
+ def i_j_ij(ij):
17
+ return edge_str(*ij), ij
18
+
19
+
20
+ def edge_conf(conf_i, conf_j, edge):
21
+ return float(conf_i[edge].mean() * conf_j[edge].mean())
22
+ # edge对应的两张图片经dust3r输出的置信度,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
23
+
24
+
25
+ def compute_edge_scores(edges, conf_i, conf_j):# edge对应的两张图片经dust3r会输出两个置信度矩阵,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
26
+ return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
27
+
28
+
29
+ def NoGradParamDict(x):
30
+ assert isinstance(x, dict)
31
+ return nn.ParameterDict(x).requires_grad_(False)
32
+
33
+
34
+ def get_imshapes(edges, pred_i, pred_j):
35
+ n_imgs = max(max(e) for e in edges) + 1
36
+ imshapes = [None] * n_imgs
37
+ for e, (i, j) in enumerate(edges):
38
+ shape_i = tuple(pred_i[e].shape[0:2])
39
+ shape_j = tuple(pred_j[e].shape[0:2])
40
+ if imshapes[i]:
41
+ assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
42
+ if imshapes[j]:
43
+ assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
44
+ imshapes[i] = shape_i
45
+ imshapes[j] = shape_j
46
+ return imshapes
47
+
48
+
49
+ def get_conf_trf(mode):
50
+ if mode == 'log':
51
+ def conf_trf(x): return x.log()
52
+ elif mode == 'sqrt':
53
+ def conf_trf(x): return x.sqrt()
54
+ elif mode == 'm1':
55
+ def conf_trf(x): return x-1
56
+ elif mode in ('id', 'none'):
57
+ def conf_trf(x): return x
58
+ else:
59
+ raise ValueError(f'bad mode for {mode=}')
60
+ return conf_trf
61
+
62
+
63
+ def l2_dist(a, b, weight):
64
+ return ((a - b).square().sum(dim=-1) * weight)
65
+
66
+
67
+ def l1_dist(a, b, weight):
68
+ return ((a - b).norm(dim=-1) * weight) # torch.norm()是求范式的损失,默认是第二范式
69
+
70
+
71
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
72
+
73
+
74
+ def signed_log1p(x):
75
+ sign = torch.sign(x)
76
+ return sign * torch.log1p(torch.abs(x))
77
+
78
+
79
+ def signed_expm1(x):
80
+ sign = torch.sign(x)
81
+ return sign * torch.expm1(torch.abs(x))
82
+
83
+
84
+ def cosine_schedule(t, lr_start, lr_end):
85
+ assert 0 <= t <= 1
86
+ return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
87
+
88
+
89
+ def linear_schedule(t, lr_start, lr_end):
90
+ assert 0 <= t <= 1
91
+ return lr_start + (lr_end - lr_start) * t
dust3r/cloud_opt/init_im_poses.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Initialization functions for global alignment
6
+ # --------------------------------------------------------
7
+ from functools import cache
8
+
9
+ import numpy as np
10
+ import scipy.sparse as sp
11
+ import torch
12
+ import cv2
13
+ import roma
14
+ from tqdm import tqdm
15
+
16
+ from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
17
+ from dust3r.post_process import estimate_focal_knowing_depth
18
+ from dust3r.viz import to_numpy
19
+
20
+ from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
21
+
22
+
23
+ @torch.no_grad()
24
+ def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
25
+ device = self.device
26
+
27
+ # indices of known poses
28
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
29
+ assert nkp == self.n_imgs, 'not all poses are known'
30
+
31
+ # get all focals
32
+ nkf, _, im_focals = get_known_focals(self)
33
+ assert nkf == self.n_imgs
34
+ im_pp = self.get_principal_points()
35
+
36
+ best_depthmaps = {}
37
+ # init all pairwise poses
38
+ for e, (i, j) in enumerate(tqdm(self.edges)):
39
+ i_j = edge_str(i, j)
40
+
41
+ # find relative pose for this pair
42
+ P1 = torch.eye(4, device=device)
43
+ msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
44
+ _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
45
+ pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
46
+
47
+ # align the two predicted camera with the two gt cameras
48
+ s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
49
+ # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
50
+ # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
51
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
52
+
53
+ # remember if this is a good depthmap
54
+ score = float(self.conf_i[i_j].mean())
55
+ if score > best_depthmaps.get(i, (0,))[0]:
56
+ best_depthmaps[i] = score, i_j, s
57
+
58
+ # init all image poses
59
+ for n in range(self.n_imgs):
60
+ assert known_poses_msk[n]
61
+ _, i_j, scale = best_depthmaps[n]
62
+ depth = self.pred_i[i_j][:, :, 2]
63
+ self._set_depthmap(n, depth * scale)
64
+
65
+
66
+ @torch.no_grad()
67
+ def init_minimum_spanning_tree(self, **kw):
68
+ """ Init all camera poses (image-wise and pairwise poses) given
69
+ an initial set of pairwise estimations.
70
+ """
71
+ device = self.device
72
+ pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
73
+ self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
74
+ device, has_im_poses=self.has_im_poses, **kw)
75
+
76
+ return init_from_pts3d(self, pts3d, im_focals, im_poses) # 初始化
77
+
78
+
79
+ def init_from_pts3d(self, pts3d, im_focals, im_poses):
80
+ # init poses
81
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
82
+ if nkp == 1: # 0
83
+ raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
84
+ elif nkp > 1:
85
+ # global rigid SE3 alignment
86
+ s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
87
+ trf = sRT_to_4x4(s, R, T, device=known_poses.device)
88
+
89
+ # rotate everything
90
+ im_poses = trf @ im_poses
91
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
92
+ for img_pts3d in pts3d:
93
+ img_pts3d[:] = geotrf(trf, img_pts3d)
94
+
95
+ # pw_poses:遍历所有的edge,计算每个edge对应的(即输入dust3r的第一张图片的)相机坐标系转成“世界坐标系”的转换矩阵即P_e
96
+ for e, (i, j) in enumerate(self.edges):
97
+ i_j = edge_str(i, j)
98
+ # compute transform that goes from cam to world
99
+ # pred_i:dust3r输出的第一张图片对应的3D点云
100
+ s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) # 估计每个edge对应的相机坐标系转成世界坐标系的外参矩阵
101
+ self._set_pose(self.pw_poses, e, R, T, scale=s) # pw_poses *****************
102
+
103
+ # TODO gys:s_factor是什么? take into account the scale normalization
104
+ s_factor = self.get_pw_norm_scale_factor()
105
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factorS
106
+ for img_pts3d in pts3d:
107
+ img_pts3d *= s_factor
108
+
109
+ # init all image poses
110
+ if self.has_im_poses:
111
+ for i in range(self.n_imgs):
112
+ cam2world = im_poses[i]
113
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2] # 将世界坐标系的点pts3d[i]转成相机坐标系
114
+ self._set_depthmap(i, depth)
115
+ self._set_pose(self.im_poses, i, cam2world) # im_poses ********************
116
+ if im_focals[i] is not None:
117
+ self._set_focal(i, im_focals[i])
118
+
119
+ print(' init loss =', float(self()))
120
+
121
+
122
+ def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
123
+ device, has_im_poses=True, niter_PnP=10):
124
+ n_imgs = len(imshapes)
125
+ sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) # 计算置信度,返回一个矩阵,表示两两图片表示的edge的置信度
126
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() # 将上面的矩阵转换成最小生成树,因为sparse_graph加了负号,所以这里筛选出来的其实是最大的置信度
127
+ # 上面找最小生成树的目的是:为每个图片尽量选一个置信度最大的edge,因为每两两图片之间都存在一个edge
128
+ # temp variable to store 3d points
129
+ pts3d = [None] * len(imshapes) # 长度为5的空list(输入图片的数量是5)
130
+
131
+ todo = sorted(zip(-msp.data, msp.row, msp.col)) # 根据最小生成树选出:平均置信度最大的4个edge(输入图片的数量是5),这4个edge一定包含5张输入图像 ,因为是生成树 # sorted edges
132
+ im_poses = [None] * n_imgs
133
+ im_focals = [None] * n_imgs
134
+
135
+ # init with strongest edge
136
+ score, i, j = todo.pop() # 这里的socre是compute_edge_scores函数计算出的置信度
137
+ print(f' init edge ({i}*,{j}*) {score=}')
138
+ i_j = edge_str(i, j)
139
+ pts3d[i] = pred_i[i_j].clone() # 置信度最大的edge对应的两张图片的三维点云(对与所有图片,每两张图片经dust3r都会输出两个三维点云)
140
+ pts3d[j] = pred_j[i_j].clone()
141
+ done = {i, j}
142
+ if has_im_poses: #============选择置信度最高edge中的第一张图片的相机坐标系为世界坐标系==============
143
+ im_poses[i] = torch.eye(4, device=device) # 4*4的单位矩阵,因为该图片的相机坐标系就是世界坐标系,所以外参矩阵为单位矩阵
144
+ im_focals[i] = estimate_focal(pred_i[i_j]) # 3.3 估计内参矩阵
145
+
146
+ # set initial pointcloud based on pairwise graph
147
+ msp_edges = [(i, j)]
148
+ while todo:
149
+ # each time, predict the next one
150
+ score, i, j = todo.pop() # pop把list最后一个元素弹出
151
+
152
+ if im_focals[i] is None: # 图片i对应的相机内参已经计算过了
153
+ im_focals[i] = estimate_focal(pred_i[i_j])
154
+
155
+ if i in done:
156
+ print(f' init edge ({i},{j}*) {score=}')
157
+ assert j not in done
158
+ # align pred[i] with pts3d[i], and then set j accordingly
159
+ i_j = edge_str(i, j)
160
+ s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) # 3.3 外参估计,s是sigma;直接调用roma工具包实现的
161
+ trf = sRT_to_4x4(s, R, T, device) # 存放到4*4的矩阵中,第四行是[0,0,0,1],对应齐次坐标的转换
162
+ pts3d[j] = geotrf(trf, pred_j[i_j]) # pred_j[i_j]表示dust3r的输出:图片j在i的相机坐标系下的三维点云
163
+ done.add(j)
164
+ msp_edges.append((i, j))
165
+
166
+ if has_im_poses and im_poses[i] is None:
167
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
168
+
169
+ elif j in done:
170
+ print(f' init edge ({i}*,{j}) {score=}')
171
+ assert i not in done
172
+ i_j = edge_str(i, j)
173
+ s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) # 从pred_j[i_j]转换到 pts3d[j]的外参矩阵
174
+ trf = sRT_to_4x4(s, R, T, device)
175
+ pts3d[i] = geotrf(trf, pred_i[i_j]) # 应用估计出的外参矩阵将相机坐标系的点转成世界坐标系
176
+ done.add(i)
177
+ msp_edges.append((i, j))
178
+
179
+ if has_im_poses and im_poses[i] is None:
180
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
181
+ else:
182
+ # let's try again later
183
+ todo.insert(0, (score, i, j))
184
+
185
+ if has_im_poses:
186
+ # complete all missing informations
187
+ pair_scores = list(sparse_graph.values()) # already negative scores: less is best
188
+ edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
189
+ for i, j in edges_from_best_to_worse.tolist():
190
+ if im_focals[i] is None:
191
+ im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
192
+
193
+ for i in range(n_imgs):
194
+ if im_poses[i] is None:
195
+ msk = im_conf[i] > min_conf_thr # 使用PnP算法估计外参矩阵
196
+ res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
197
+ if res:
198
+ im_focals[i], im_poses[i] = res
199
+ if im_poses[i] is None:
200
+ im_poses[i] = torch.eye(4, device=device)
201
+ im_poses = torch.stack(im_poses)
202
+ else:
203
+ im_poses = im_focals = None
204
+
205
+ return pts3d, msp_edges, im_focals, im_poses # pts3d表示:每个输入的图片在自己的相机坐标系下的三维点经im_poses转换成世界坐标系的点
206
+
207
+
208
+ def dict_to_sparse_graph(dic):
209
+ n_imgs = max(max(e) for e in dic) + 1 # 取出照片数量
210
+ for e in dic:
211
+ a1 = max(e)
212
+ a2 = 2
213
+ res = sp.dok_array((n_imgs, n_imgs))
214
+ for edge, value in dic.items():
215
+ res[edge] = value
216
+ return res # 将edge中存放的置信度转移到一个n_imgs * n_imgs大小的列表中
217
+
218
+
219
+ def rigid_points_registration(pts1, pts2, conf):
220
+ R, T, s = roma.rigid_points_registration( # 调用roma的工具类函数
221
+ pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
222
+ return s, R, T # return un-scaled (R, T)
223
+
224
+
225
+ def sRT_to_4x4(scale, R, T, device):
226
+ trf = torch.eye(4, device=device) # 单位矩阵
227
+ trf[:3, :3] = R * scale
228
+ trf[:3, 3] = T.ravel() # doesn't need scaling
229
+ return trf # 外参矩阵 3*4
230
+
231
+
232
+ def estimate_focal(pts3d_i, pp=None):
233
+ if pp is None:
234
+ H, W, THREE = pts3d_i.shape
235
+ assert THREE == 3
236
+ pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
237
+ focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(
238
+ 0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel()
239
+ return float(focal)
240
+
241
+
242
+ @cache
243
+ def pixel_grid(H, W):
244
+ return np.mgrid[:W, :H].T.astype(np.float32)
245
+
246
+
247
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
248
+ # extract camera poses and focals with RANSAC-PnP
249
+ if msk.sum() < 4:
250
+ return None # we need at least 4 points for PnP
251
+ pts3d, msk = map(to_numpy, (pts3d, msk))
252
+
253
+ H, W, THREE = pts3d.shape
254
+ assert THREE == 3
255
+ pixels = pixel_grid(H, W)
256
+
257
+ if focal is None:
258
+ S = max(W, H)
259
+ tentative_focals = np.geomspace(S/2, S*3, 21)
260
+ else:
261
+ tentative_focals = [focal]
262
+
263
+ if pp is None:
264
+ pp = (W/2, H/2)
265
+ else:
266
+ pp = to_numpy(pp)
267
+
268
+ best = 0,
269
+ for focal in tentative_focals:
270
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
271
+
272
+ success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
273
+ iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
274
+ if not success:
275
+ continue
276
+
277
+ score = len(inliers)
278
+ if success and score > best[0]:
279
+ best = score, R, T, focal
280
+
281
+ if not best[0]:
282
+ return None
283
+
284
+ _, R, T, best_focal = best
285
+ R = cv2.Rodrigues(R)[0] # world to cam
286
+ R, T = map(torch.from_numpy, (R, T))
287
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
288
+
289
+
290
+ def get_known_poses(self):
291
+ if self.has_im_poses:
292
+ known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
293
+ known_poses = self.get_im_poses()
294
+ return known_poses_msk.sum(), known_poses_msk, known_poses
295
+ else:
296
+ return 0, None, None
297
+
298
+
299
+ def get_known_focals(self):
300
+ if self.has_im_poses:
301
+ known_focal_msk = self.get_known_focal_mask()
302
+ known_focals = self.get_focals()
303
+ return known_focal_msk.sum(), known_focal_msk, known_focals
304
+ else:
305
+ return 0, None, None
306
+
307
+
308
+ def align_multiple_poses(src_poses, target_poses):
309
+ N = len(src_poses)
310
+ assert src_poses.shape == target_poses.shape == (N, 4, 4)
311
+
312
+ def center_and_z(poses):
313
+ eps = get_med_dist_between_poses(poses) / 100
314
+ return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
315
+ R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
316
+ return s, R, T
dust3r/cloud_opt/optimizer.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Main class for the implementation of the global alignment
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from dust3r.cloud_opt.base_opt import BasePCOptimizer
12
+ from dust3r.utils.geometry import xy_grid, geotrf
13
+ from dust3r.utils.device import to_cpu, to_numpy
14
+
15
+
16
+ class PointCloudOptimizer(BasePCOptimizer):
17
+ """ Optimize a global scene, given a list of pairwise observations.
18
+ Graph node: images
19
+ Graph edges: observations = (pred1, pred2)
20
+ """
21
+
22
+ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.has_im_poses = True # by definition of this class
26
+ self.focal_break = focal_break
27
+
28
+ # adding thing to optimize
29
+ self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
30
+ self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
31
+ self.im_focals = nn.ParameterList(torch.FloatTensor(
32
+ [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
33
+ self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
34
+ self.im_pp.requires_grad_(optimize_pp)
35
+
36
+ self.imshape = self.imshapes[0]
37
+ im_areas = [h*w for h, w in self.imshapes]
38
+
39
+
40
+ self.max_area = max(im_areas)
41
+
42
+
43
+ # adding thing to optimize
44
+ self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
45
+ self.im_poses = ParameterStack(self.im_poses, is_param=True)
46
+ self.im_focals = ParameterStack(self.im_focals, is_param=True)
47
+ self.im_pp = ParameterStack(self.im_pp, is_param=True)
48
+ self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
49
+ self.register_buffer('_grid', ParameterStack(
50
+ [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
51
+
52
+ # pre-compute pixel weights
53
+ self.register_buffer('_weight_i', ParameterStack(
54
+ [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
55
+ self.register_buffer('_weight_j', ParameterStack(
56
+ [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
57
+
58
+ # precompute
59
+ self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
60
+ self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
61
+ self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
62
+ self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
63
+ self.total_area_i = sum([im_areas[i] for i, j in self.edges])
64
+ self.total_area_j = sum([im_areas[j] for i, j in self.edges])
65
+
66
+
67
+ def _check_all_imgs_are_selected(self, msk):
68
+ assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
69
+
70
+ def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
71
+ self._check_all_imgs_are_selected(pose_msk)
72
+
73
+ if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
74
+ known_poses = [known_poses]
75
+ for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
76
+ print(f' (setting pose #{idx} = {pose[:3,3]})')
77
+ self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
78
+
79
+ # normalize scale if there's less than 1 known pose
80
+ n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
81
+ self.norm_pw_scale = (n_known_poses <= 1)
82
+
83
+ self.im_poses.requires_grad_(False)
84
+ self.norm_pw_scale = False
85
+
86
+ def preset_focal(self, known_focals, msk=None):
87
+ self._check_all_imgs_are_selected(msk)
88
+
89
+ for idx, focal in zip(self._get_msk_indices(msk), known_focals):
90
+ print(f' (setting focal #{idx} = {focal})')
91
+ self._no_grad(self._set_focal(idx, focal))
92
+
93
+ self.im_focals.requires_grad_(False)
94
+
95
+ def preset_principal_point(self, known_pp, msk=None):
96
+ self._check_all_imgs_are_selected(msk)
97
+
98
+ for idx, pp in zip(self._get_msk_indices(msk), known_pp):
99
+ print(f' (setting principal point #{idx} = {pp})')
100
+ self._no_grad(self._set_principal_point(idx, pp))
101
+
102
+ self.im_pp.requires_grad_(False)
103
+
104
+ def _get_msk_indices(self, msk):
105
+ if msk is None:
106
+ return range(self.n_imgs)
107
+ elif isinstance(msk, int):
108
+ return [msk]
109
+ elif isinstance(msk, (tuple, list)):
110
+ return self._get_msk_indices(np.array(msk))
111
+ elif msk.dtype in (bool, torch.bool, np.bool_):
112
+ assert len(msk) == self.n_imgs
113
+ return np.cumsum([0] + msk.tolist())
114
+ elif np.issubdtype(msk.dtype, np.integer):
115
+ return msk
116
+ else:
117
+ raise ValueError(f'bad {msk=}')
118
+
119
+ def _no_grad(self, tensor):
120
+ assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
121
+
122
+ def _set_focal(self, idx, focal, force=False):
123
+ param = self.im_focals[idx]
124
+ if param.requires_grad or force: # can only init a parameter not already initialized
125
+ param.data[:] = self.focal_break * np.log(focal)
126
+ return param
127
+
128
+ def get_focals(self): # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
129
+ log_focals = torch.stack(list(self.im_focals), dim=0)
130
+ return (log_focals / self.focal_break).exp()
131
+
132
+ def get_known_focal_mask(self):
133
+ return torch.tensor([not (p.requires_grad) for p in self.im_focals])
134
+
135
+ def _set_principal_point(self, idx, pp, force=False):
136
+ param = self.im_pp[idx]
137
+ H, W = self.imshapes[idx]
138
+ if param.requires_grad or force: # can only init a parameter not already initialized
139
+ param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
140
+ return param
141
+
142
+ def get_principal_points(self):
143
+ return self._pp + 10 * self.im_pp # 将图像坐标系和像素坐标系的中心点偏移量
144
+
145
+ def get_intrinsics(self):
146
+ K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
147
+ focals = self.get_focals().flatten()
148
+ K[:, 0, 0] = K[:, 1, 1] = focals
149
+ K[:, :2, 2] = self.get_principal_points()
150
+ K[:, 2, 2] = 1
151
+ return K
152
+
153
+ def get_im_poses(self): # cam to world 外参数矩阵的逆
154
+ cam2world = self._get_poses(self.im_poses)
155
+ return cam2world
156
+
157
+ def _set_depthmap(self, idx, depth, force=False):
158
+ depth = _ravel_hw(depth, self.max_area)
159
+
160
+ param = self.im_depthmaps[idx]
161
+ if param.requires_grad or force: # can only init a parameter not already initialized
162
+ param.data[:] = depth.log().nan_to_num(neginf=0)
163
+ return param
164
+
165
+ def get_depthmaps(self, raw=False): #论文中公式(1)上面的的深度信息D
166
+ res = self.im_depthmaps.exp()
167
+ if not raw:
168
+ res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
169
+ return res
170
+
171
+ def depth_to_pts3d(self): # 这里根据深度信息D计算真实的世界坐标系下的点,即论文中公式(1)上面的公式
172
+ # Get depths and projection params if not provided
173
+ focals = self.get_focals() # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
174
+ pp = self.get_principal_points() # 图像坐标系和像素坐标系之间的偏移,即照片宽高的一半
175
+ im_poses = self.get_im_poses() # 外参数矩阵
176
+ depth = self.get_depthmaps(raw=True)#论文中公式(1)上面的深度信息D
177
+
178
+ # get pointmaps in camera frame self._grid:输入的所有图像(图像坐标系)
179
+ rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) # 将输入图像的坐标点转成相机坐标系下的点
180
+ # project to world frame
181
+ return geotrf(im_poses, rel_ptmaps) # 再由相机坐标系转成世界坐标系
182
+
183
+ def get_pts3d(self, raw=False): # 计算真实的世界坐标系下的三维点坐标,根据公式(1)上面的深度D计算公式计算
184
+ res = self.depth_to_pts3d()
185
+ if not raw:
186
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
187
+ return res
188
+ # 这里的forward返回的就是公式(5)计算的损失值
189
+ def forward(self): # 论文中: Global optimization
190
+ pw_poses = self.get_pw_poses() # pw_poses cam-to-world 公式(5)的P_e: 外参矩阵的逆,由相机坐标系转成世界坐标系,requires_grad=True
191
+ pw_adapt = self.get_adaptors().unsqueeze(1) # 公式(5)中的比例系数 sigma,requires_grad=False
192
+ proj_pts3d = self.get_pts3d(raw=True) # im_poses 公式(5)的待优化的真实的世界坐标系下的三维点requires_grad=True
193
+
194
+ # rotate pairwise prediction according to pw_poses 根据公式(5)的外参矩阵部分转成世界坐标系requires_grad=True
195
+ aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) # _stacked_pred_i/j表示dest3r预测的三维点云, requires_grad=False
196
+ aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
197
+
198
+ # compute the loss: 转换成世界坐标系后的两张图像分别与待估计世界坐标系下的点(proj_pts3d)计算损失
199
+ li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
200
+ lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
201
+
202
+ return li + lj
203
+
204
+
205
+ def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
206
+ pp = pp.unsqueeze(1)
207
+ focal = focal.unsqueeze(1)
208
+ assert focal.shape == (len(depth), 1, 1)
209
+ assert pp.shape == (len(depth), 1, 2)
210
+ assert pixel_grid.shape == depth.shape + (2,)
211
+ depth = depth.unsqueeze(-1)
212
+ return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) # 公式(1)上面的计算公式,根据内参矩阵和深度D,将图像坐标系的点转成相机坐标系下的三维点
213
+
214
+
215
+ def ParameterStack(params, keys=None, is_param=None, fill=0):
216
+ if keys is not None:
217
+ params = [params[k] for k in keys]
218
+
219
+ if fill > 0:
220
+ params = [_ravel_hw(p, fill) for p in params]
221
+
222
+ requires_grad = params[0].requires_grad
223
+ assert all(p.requires_grad == requires_grad for p in params)
224
+
225
+ params = torch.stack(list(params)).float().detach()
226
+ if is_param or requires_grad:
227
+ params = nn.Parameter(params)
228
+ params.requires_grad_(requires_grad)
229
+ return params
230
+
231
+
232
+ def _ravel_hw(tensor, fill=0):
233
+ # ravel H,W
234
+ tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
235
+
236
+ if len(tensor) < fill:
237
+ tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
238
+ return tensor
239
+
240
+
241
+ def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
242
+ focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
243
+ return minf*focal_base, maxf*focal_base
244
+
245
+
246
+ def apply_mask(img, msk):
247
+ img = img.copy()
248
+ img[msk] = 0
249
+ return img
dust3r/cloud_opt/pair_viewer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dummy optimizer for visualizing pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import cv2
11
+
12
+ from dust3r.cloud_opt.base_opt import BasePCOptimizer
13
+ from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
14
+ from dust3r.cloud_opt.commons import edge_str
15
+ from dust3r.post_process import estimate_focal_knowing_depth
16
+
17
+
18
+ class PairViewer (BasePCOptimizer):
19
+ """
20
+ This a Dummy Optimizer.
21
+ To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ assert self.is_symmetrized and self.n_edges == 2
27
+ self.has_im_poses = True
28
+
29
+ # compute all parameters directly from raw input
30
+ self.focals = []
31
+ self.pp = []
32
+ rel_poses = []
33
+ confs = []
34
+ for i in range(self.n_imgs):
35
+ conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
36
+ print(f' - {conf=:.3} for edge {i}-{1-i}')
37
+ confs.append(conf)
38
+
39
+ H, W = self.imshapes[i]
40
+ pts3d = self.pred_i[edge_str(i, 1-i)]
41
+ pp = torch.tensor((W/2, H/2))
42
+ focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
43
+ self.focals.append(focal)
44
+ self.pp.append(pp)
45
+
46
+ # estimate the pose of pts1 in image 2
47
+ pixels = np.mgrid[:W, :H].T.astype(np.float32)
48
+ pts3d = self.pred_j[edge_str(1-i, i)].numpy()
49
+ assert pts3d.shape[:2] == (H, W)
50
+ msk = self.get_masks()[i].numpy()
51
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
52
+
53
+ try:
54
+ res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
55
+ iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
56
+ success, R, T, inliers = res
57
+ assert success
58
+
59
+ R = cv2.Rodrigues(R)[0] # world to cam
60
+ pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
61
+ except:
62
+ pose = np.eye(4)
63
+ rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
64
+
65
+ # let's use the pair with the most confidence
66
+ if confs[0] > confs[1]:
67
+ # ptcloud is expressed in camera1
68
+ self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
69
+ self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
70
+ else:
71
+ # ptcloud is expressed in camera2
72
+ self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
73
+ self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
74
+
75
+ self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
76
+ self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
77
+ self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
78
+ self.depth = nn.ParameterList(self.depth)
79
+ for p in self.parameters():
80
+ p.requires_grad = False
81
+
82
+ def _set_depthmap(self, idx, depth, force=False):
83
+ print('_set_depthmap is ignored in PairViewer')
84
+ return
85
+
86
+ def get_depthmaps(self, raw=False):
87
+ depth = [d.to(self.device) for d in self.depth]
88
+ return depth
89
+
90
+ def _set_focal(self, idx, focal, force=False):
91
+ self.focals[idx] = focal
92
+
93
+ def get_focals(self):
94
+ return self.focals
95
+
96
+ def get_known_focal_mask(self):
97
+ return torch.tensor([not (p.requires_grad) for p in self.focals])
98
+
99
+ def get_principal_points(self):
100
+ return self.pp
101
+
102
+ def get_intrinsics(self):
103
+ focals = self.get_focals()
104
+ pps = self.get_principal_points()
105
+ K = torch.zeros((len(focals), 3, 3), device=self.device)
106
+ for i in range(len(focals)):
107
+ K[i, 0, 0] = K[i, 1, 1] = focals[i]
108
+ K[i, :2, 2] = pps[i]
109
+ K[i, 2, 2] = 1
110
+ return K
111
+
112
+ def get_im_poses(self):
113
+ return self.im_poses
114
+
115
+ def depth_to_pts3d(self):
116
+ pts3d = []
117
+ for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
118
+ pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
119
+ intrinsics.cpu().numpy(),
120
+ im_pose.cpu().numpy())
121
+ pts3d.append(torch.from_numpy(pts).to(device=self.device))
122
+ return pts3d
123
+
124
+ def forward(self):
125
+ return float('nan')
dust3r/datasets/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ from .utils.transforms import *
4
+ from .base.batched_sampler import BatchedRandomSampler # noqa: F401
5
+ from .co3d import Co3d # noqa: F401
6
+
7
+
8
+ def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
9
+ import torch
10
+ from croco.utils.misc import get_world_size, get_rank
11
+
12
+ # pytorch dataset
13
+ if isinstance(dataset, str):
14
+ dataset = eval(dataset)
15
+
16
+ world_size = get_world_size()
17
+ rank = get_rank()
18
+
19
+ try:
20
+ sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
21
+ rank=rank, drop_last=drop_last)
22
+ except (AttributeError, NotImplementedError):
23
+ # not avail for this dataset
24
+ if torch.distributed.is_initialized():
25
+ sampler = torch.utils.data.DistributedSampler(
26
+ dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
27
+ )
28
+ elif shuffle:
29
+ sampler = torch.utils.data.RandomSampler(dataset)
30
+ else:
31
+ sampler = torch.utils.data.SequentialSampler(dataset)
32
+
33
+ data_loader = torch.utils.data.DataLoader(
34
+ dataset,
35
+ sampler=sampler,
36
+ batch_size=batch_size,
37
+ num_workers=num_workers,
38
+ pin_memory=pin_mem,
39
+ drop_last=drop_last,
40
+ )
41
+
42
+ return data_loader
dust3r/datasets/base/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/datasets/base/base_stereo_view_dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL
8
+ import numpy as np
9
+ import torch
10
+
11
+ from dust3r.datasets.base.easy_dataset import EasyDataset
12
+ from dust3r.datasets.utils.transforms import ImgNorm
13
+ from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
14
+ import dust3r.datasets.utils.cropping as cropping
15
+
16
+
17
+ class BaseStereoViewDataset (EasyDataset):
18
+ """ Define all basic options.
19
+
20
+ Usage:
21
+ class MyDataset (BaseStereoViewDataset):
22
+ def _get_views(self, idx, rng):
23
+ # overload here
24
+ views = []
25
+ views.append(dict(img=, ...))
26
+ return views
27
+ """
28
+
29
+ def __init__(self, *, # only keyword arguments
30
+ split=None,
31
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
32
+ transform=ImgNorm,
33
+ aug_crop=False,
34
+ seed=None):
35
+ self.num_views = 2
36
+ self.split = split
37
+ self._set_resolutions(resolution)
38
+
39
+ self.transform = transform
40
+ if isinstance(transform, str):
41
+ transform = eval(transform)
42
+
43
+ self.aug_crop = aug_crop
44
+ self.seed = seed
45
+
46
+ def __len__(self):
47
+ return len(self.scenes)
48
+
49
+ def get_stats(self):
50
+ return f"{len(self)} pairs"
51
+
52
+ def __repr__(self):
53
+ resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
54
+ return f"""{type(self).__name__}({self.get_stats()},
55
+ {self.split=},
56
+ {self.seed=},
57
+ resolutions={resolutions_str},
58
+ {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
59
+
60
+ def _get_views(self, idx, resolution, rng):
61
+ raise NotImplementedError()
62
+
63
+ def __getitem__(self, idx):
64
+ if isinstance(idx, tuple):
65
+ # the idx is specifying the aspect-ratio
66
+ idx, ar_idx = idx
67
+ else:
68
+ assert len(self._resolutions) == 1
69
+ ar_idx = 0
70
+
71
+ # set-up the rng
72
+ if self.seed: # reseed for each __getitem__
73
+ self._rng = np.random.default_rng(seed=self.seed + idx)
74
+ elif not hasattr(self, '_rng'):
75
+ seed = torch.initial_seed() # this is different for each dataloader process
76
+ self._rng = np.random.default_rng(seed=seed)
77
+
78
+ # over-loaded code
79
+ resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
80
+ views = self._get_views(idx, resolution, self._rng)
81
+ assert len(views) == self.num_views
82
+
83
+ # check data-types
84
+ for v, view in enumerate(views):
85
+ assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
86
+ view['idx'] = (idx, ar_idx, v)
87
+
88
+ # encode the image
89
+ width, height = view['img'].size
90
+ view['true_shape'] = np.int32((height, width))
91
+ view['img'] = self.transform(view['img'])
92
+
93
+ assert 'camera_intrinsics' in view
94
+ if 'camera_pose' not in view:
95
+ view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
96
+ else:
97
+ assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
98
+ assert 'pts3d' not in view
99
+ assert 'valid_mask' not in view
100
+ assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
101
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
102
+
103
+ view['pts3d'] = pts3d
104
+ view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
105
+
106
+ # check all datatypes
107
+ for key, val in view.items():
108
+ res, err_msg = is_good_type(key, val)
109
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
110
+ K = view['camera_intrinsics']
111
+
112
+ # last thing done!
113
+ for view in views:
114
+ # transpose to make sure all views are the same size
115
+ transpose_to_landscape(view)
116
+ # this allows to check whether the RNG is is the same state each time
117
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
118
+ return views
119
+
120
+ def _set_resolutions(self, resolutions):
121
+ assert resolutions is not None, 'undefined resolution'
122
+
123
+ if not isinstance(resolutions, list):
124
+ resolutions = [resolutions]
125
+
126
+ self._resolutions = []
127
+ for resolution in resolutions:
128
+ if isinstance(resolution, int):
129
+ width = height = resolution
130
+ else:
131
+ width, height = resolution
132
+ assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
133
+ assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
134
+ assert width >= height
135
+ self._resolutions.append((width, height))
136
+
137
+ def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
138
+ """ This function:
139
+ - first downsizes the image with LANCZOS inteprolation,
140
+ which is better than bilinear interpolation in
141
+ """
142
+ if not isinstance(image, PIL.Image.Image):
143
+ image = PIL.Image.fromarray(image)
144
+
145
+ # downscale with lanczos interpolation so that image.size == resolution
146
+ # cropping centered on the principal point
147
+ W, H = image.size
148
+ cx, cy = intrinsics[:2, 2].round().astype(int)
149
+ min_margin_x = min(cx, W-cx)
150
+ min_margin_y = min(cy, H-cy)
151
+ assert min_margin_x > W/5, f'Bad principal point in view={info}'
152
+ assert min_margin_y > H/5, f'Bad principal point in view={info}'
153
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
154
+ l, t = cx - min_margin_x, cy - min_margin_y
155
+ r, b = cx + min_margin_x, cy + min_margin_y
156
+ crop_bbox = (l, t, r, b)
157
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
158
+
159
+ # transpose the resolution if necessary
160
+ W, H = image.size # new size
161
+ assert resolution[0] >= resolution[1]
162
+ if H > 1.1*W:
163
+ # image is portrait mode
164
+ resolution = resolution[::-1]
165
+ elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
166
+ # image is square, so we chose (portrait, landscape) randomly
167
+ if rng.integers(2):
168
+ resolution = resolution[::-1]
169
+
170
+ # high-quality Lanczos down-scaling
171
+ target_resolution = np.array(resolution)
172
+ if self.aug_crop > 1:
173
+ target_resolution += rng.integers(0, self.aug_crop)
174
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
175
+
176
+ # actual cropping (if necessary) with bilinear interpolation
177
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
178
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
179
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
180
+
181
+ return image, depthmap, intrinsics2
182
+
183
+
184
+ def is_good_type(key, v):
185
+ """ returns (is_good, err_msg)
186
+ """
187
+ if isinstance(v, (str, int, tuple)):
188
+ return True, None
189
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
190
+ return False, f"bad {v.dtype=}"
191
+ return True, None
192
+
193
+
194
+ def view_name(view, batch_index=None):
195
+ def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
196
+ db = sel(view['dataset'])
197
+ label = sel(view['label'])
198
+ instance = sel(view['instance'])
199
+ return f"{db}/{label}/{instance}"
200
+
201
+
202
+ def transpose_to_landscape(view):
203
+ height, width = view['true_shape']
204
+
205
+ if width < height:
206
+ # rectify portrait to landscape
207
+ assert view['img'].shape == (3, height, width)
208
+ view['img'] = view['img'].swapaxes(1, 2)
209
+
210
+ assert view['valid_mask'].shape == (height, width)
211
+ view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
212
+
213
+ assert view['depthmap'].shape == (height, width)
214
+ view['depthmap'] = view['depthmap'].swapaxes(0, 1)
215
+
216
+ assert view['pts3d'].shape == (height, width, 3)
217
+ view['pts3d'] = view['pts3d'].swapaxes(0, 1)
218
+
219
+ # transpose x and y pixels
220
+ view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
dust3r/datasets/base/batched_sampler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Random sampling under a constraint
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class BatchedRandomSampler:
12
+ """ Random sampling under a constraint: each sample in the batch has the same feature,
13
+ which is chosen randomly from a known pool of 'features' for each batch.
14
+
15
+ For instance, the 'feature' could be the image aspect-ratio.
16
+
17
+ The index returned is a tuple (sample_idx, feat_idx).
18
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
19
+ """
20
+
21
+ def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
22
+ self.batch_size = batch_size
23
+ self.pool_size = pool_size
24
+
25
+ self.len_dataset = N = len(dataset)
26
+ self.total_size = round_by(N, batch_size*world_size) if drop_last else N
27
+ assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
28
+
29
+ # distributed sampler
30
+ self.world_size = world_size
31
+ self.rank = rank
32
+ self.epoch = None
33
+
34
+ def __len__(self):
35
+ return self.total_size // self.world_size
36
+
37
+ def set_epoch(self, epoch):
38
+ self.epoch = epoch
39
+
40
+ def __iter__(self):
41
+ # prepare RNG
42
+ if self.epoch is None:
43
+ assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
44
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
45
+ else:
46
+ seed = self.epoch + 777
47
+ rng = np.random.default_rng(seed=seed)
48
+
49
+ # random indices (will restart from 0 if not drop_last)
50
+ sample_idxs = np.arange(self.total_size)
51
+ rng.shuffle(sample_idxs)
52
+
53
+ # random feat_idxs (same across each batch)
54
+ n_batches = (self.total_size+self.batch_size-1) // self.batch_size
55
+ feat_idxs = rng.integers(self.pool_size, size=n_batches)
56
+ feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
57
+ feat_idxs = feat_idxs.ravel()[:self.total_size]
58
+
59
+ # put them together
60
+ idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
61
+
62
+ # Distributed sampler: we select a subset of batches
63
+ # make sure the slice for each node is aligned with batch_size
64
+ size_per_proc = self.batch_size * ((self.total_size + self.world_size *
65
+ self.batch_size-1) // (self.world_size * self.batch_size))
66
+ idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
67
+
68
+ yield from (tuple(idx) for idx in idxs)
69
+
70
+
71
+ def round_by(total, multiple, up=False):
72
+ if up:
73
+ total = total + multiple-1
74
+ return (total//multiple) * multiple
dust3r/datasets/base/easy_dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # A dataset base class that you can easily resize and combine.
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
9
+
10
+
11
+ class EasyDataset:
12
+ """ a dataset that you can easily resize and combine.
13
+ Examples:
14
+ ---------
15
+ 2 * dataset ==> duplicate each element 2x
16
+
17
+ 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
18
+
19
+ dataset1 + dataset2 ==> concatenate datasets
20
+ """
21
+
22
+ def __add__(self, other):
23
+ return CatDataset([self, other])
24
+
25
+ def __rmul__(self, factor):
26
+ return MulDataset(factor, self)
27
+
28
+ def __rmatmul__(self, factor):
29
+ return ResizedDataset(factor, self)
30
+
31
+ def set_epoch(self, epoch):
32
+ pass # nothing to do by default
33
+
34
+ def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
35
+ if not (shuffle):
36
+ raise NotImplementedError() # cannot deal yet
37
+ num_of_aspect_ratios = len(self._resolutions)
38
+ return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
39
+
40
+
41
+ class MulDataset (EasyDataset):
42
+ """ Artifically augmenting the size of a dataset.
43
+ """
44
+ multiplicator: int
45
+
46
+ def __init__(self, multiplicator, dataset):
47
+ assert isinstance(multiplicator, int) and multiplicator > 0
48
+ self.multiplicator = multiplicator
49
+ self.dataset = dataset
50
+
51
+ def __len__(self):
52
+ return self.multiplicator * len(self.dataset)
53
+
54
+ def __repr__(self):
55
+ return f'{self.multiplicator}*{repr(self.dataset)}'
56
+
57
+ def __getitem__(self, idx):
58
+ if isinstance(idx, tuple):
59
+ idx, other = idx
60
+ return self.dataset[idx // self.multiplicator, other]
61
+ else:
62
+ return self.dataset[idx // self.multiplicator]
63
+
64
+ @property
65
+ def _resolutions(self):
66
+ return self.dataset._resolutions
67
+
68
+
69
+ class ResizedDataset (EasyDataset):
70
+ """ Artifically changing the size of a dataset.
71
+ """
72
+ new_size: int
73
+
74
+ def __init__(self, new_size, dataset):
75
+ assert isinstance(new_size, int) and new_size > 0
76
+ self.new_size = new_size
77
+ self.dataset = dataset
78
+
79
+ def __len__(self):
80
+ return self.new_size
81
+
82
+ def __repr__(self):
83
+ size_str = str(self.new_size)
84
+ for i in range((len(size_str)-1) // 3):
85
+ sep = -4*i-3
86
+ size_str = size_str[:sep] + '_' + size_str[sep:]
87
+ return f'{size_str} @ {repr(self.dataset)}'
88
+
89
+ def set_epoch(self, epoch):
90
+ # this random shuffle only depends on the epoch
91
+ rng = np.random.default_rng(seed=epoch+777)
92
+
93
+ # shuffle all indices
94
+ perm = rng.permutation(len(self.dataset))
95
+
96
+ # rotary extension until target size is met
97
+ shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
98
+ self._idxs_mapping = shuffled_idxs[:self.new_size]
99
+
100
+ assert len(self._idxs_mapping) == self.new_size
101
+
102
+ def __getitem__(self, idx):
103
+ assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
104
+ if isinstance(idx, tuple):
105
+ idx, other = idx
106
+ return self.dataset[self._idxs_mapping[idx], other]
107
+ else:
108
+ return self.dataset[self._idxs_mapping[idx]]
109
+
110
+ @property
111
+ def _resolutions(self):
112
+ return self.dataset._resolutions
113
+
114
+
115
+ class CatDataset (EasyDataset):
116
+ """ Concatenation of several datasets
117
+ """
118
+
119
+ def __init__(self, datasets):
120
+ for dataset in datasets:
121
+ assert isinstance(dataset, EasyDataset)
122
+ self.datasets = datasets
123
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
124
+
125
+ def __len__(self):
126
+ return self._cum_sizes[-1]
127
+
128
+ def __repr__(self):
129
+ # remove uselessly long transform
130
+ return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
131
+
132
+ def set_epoch(self, epoch):
133
+ for dataset in self.datasets:
134
+ dataset.set_epoch(epoch)
135
+
136
+ def __getitem__(self, idx):
137
+ other = None
138
+ if isinstance(idx, tuple):
139
+ idx, other = idx
140
+
141
+ if not (0 <= idx < len(self)):
142
+ raise IndexError()
143
+
144
+ db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
145
+ dataset = self.datasets[db_idx]
146
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
147
+
148
+ if other is not None:
149
+ new_idx = (new_idx, other)
150
+ return dataset[new_idx]
151
+
152
+ @property
153
+ def _resolutions(self):
154
+ resolutions = self.datasets[0]._resolutions
155
+ for dataset in self.datasets[1:]:
156
+ assert tuple(dataset._resolutions) == tuple(resolutions)
157
+ return resolutions
dust3r/datasets/co3d.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dataloader for preprocessed Co3d_v2
6
+ # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
7
+ # See datasets_preprocess/preprocess_co3d.py
8
+ # --------------------------------------------------------
9
+ import os.path as osp
10
+ import json
11
+ import itertools
12
+ from collections import deque
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
18
+ from dust3r.utils.image import imread_cv2
19
+
20
+
21
+ class Co3d(BaseStereoViewDataset):
22
+ def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
23
+ self.ROOT = ROOT
24
+ super().__init__(*args, **kwargs)
25
+ assert mask_bg in (True, False, 'rand')
26
+ self.mask_bg = mask_bg
27
+
28
+ # load all scenes
29
+ with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
30
+ self.scenes = json.load(f)
31
+ self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
32
+ self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
33
+ for k2, v2 in v.items()}
34
+ self.scene_list = list(self.scenes.keys())
35
+
36
+ # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
37
+ # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
38
+ self.combinations = [(i, j)
39
+ for i, j in itertools.combinations(range(100), 2)
40
+ if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
41
+
42
+ self.invalidate = {scene: {} for scene in self.scene_list}
43
+
44
+ def __len__(self):
45
+ return len(self.scene_list) * len(self.combinations)
46
+
47
+ def _get_views(self, idx, resolution, rng):
48
+ # choose a scene
49
+ obj, instance = self.scene_list[idx // len(self.combinations)]
50
+ image_pool = self.scenes[obj, instance]
51
+ im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
52
+
53
+ # add a bit of randomness
54
+ last = len(image_pool)-1
55
+
56
+ if resolution not in self.invalidate[obj, instance]: # flag invalid images
57
+ self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
58
+
59
+ # decide now if we mask the bg
60
+ mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
61
+
62
+ views = []
63
+ imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
64
+ imgs_idxs = deque(imgs_idxs)
65
+ while len(imgs_idxs) > 0: # some images (few) have zero depth
66
+ im_idx = imgs_idxs.pop()
67
+
68
+ if self.invalidate[obj, instance][resolution][im_idx]:
69
+ # search for a valid image
70
+ random_direction = 2 * rng.choice(2) - 1
71
+ for offset in range(1, len(image_pool)):
72
+ tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
73
+ if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
74
+ im_idx = tentative_im_idx
75
+ break
76
+
77
+ view_idx = image_pool[im_idx]
78
+
79
+ impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
80
+
81
+ # load camera params
82
+ input_metadata = np.load(impath.replace('jpg', 'npz'))
83
+ camera_pose = input_metadata['camera_pose'].astype(np.float32)
84
+ intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
85
+
86
+ # load image and depth
87
+ rgb_image = imread_cv2(impath)
88
+ depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
89
+ depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
90
+
91
+ if mask_bg:
92
+ # load object mask
93
+ maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
94
+ maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
95
+ maskmap = (maskmap / 255.0) > 0.1
96
+
97
+ # update the depthmap with mask
98
+ depthmap *= maskmap
99
+
100
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
101
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
102
+
103
+ num_valid = (depthmap > 0.0).sum()
104
+ if num_valid == 0:
105
+ # problem, invalidate image and retry
106
+ self.invalidate[obj, instance][resolution][im_idx] = True
107
+ imgs_idxs.append(im_idx)
108
+ continue
109
+
110
+ views.append(dict(
111
+ img=rgb_image,
112
+ depthmap=depthmap,
113
+ camera_pose=camera_pose,
114
+ camera_intrinsics=intrinsics,
115
+ dataset='Co3d_v2',
116
+ label=osp.join(obj, instance),
117
+ instance=osp.split(impath)[1],
118
+ ))
119
+ return views
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from dust3r.datasets.base.base_stereo_view_dataset import view_name
124
+ from dust3r.viz import SceneViz, auto_cam_size
125
+ from dust3r.utils.image import rgb
126
+
127
+ dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
128
+
129
+ for idx in np.random.permutation(len(dataset)):
130
+ views = dataset[idx]
131
+ assert len(views) == 2
132
+ print(view_name(views[0]), view_name(views[1]))
133
+ viz = SceneViz()
134
+ poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
135
+ cam_size = max(auto_cam_size(poses), 0.001)
136
+ for view_idx in [0, 1]:
137
+ pts3d = views[view_idx]['pts3d']
138
+ valid_mask = views[view_idx]['valid_mask']
139
+ colors = rgb(views[view_idx]['img'])
140
+ viz.add_pointcloud(pts3d, colors, valid_mask)
141
+ viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
142
+ focal=views[view_idx]['camera_intrinsics'][0, 0],
143
+ color=(idx*255, (1 - idx)*255, 0),
144
+ image=colors,
145
+ cam_size=cam_size)
146
+ viz.show()
dust3r/datasets/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/datasets/utils/cropping.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # croppping utilities
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import os
9
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10
+ import cv2 # noqa
11
+ import numpy as np # noqa
12
+ from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
13
+ try:
14
+ lanczos = PIL.Image.Resampling.LANCZOS
15
+ except AttributeError:
16
+ lanczos = PIL.Image.LANCZOS
17
+
18
+
19
+ class ImageList:
20
+ """ Convenience class to aply the same operation to a whole set of images.
21
+ """
22
+
23
+ def __init__(self, images):
24
+ if not isinstance(images, (tuple, list, set)):
25
+ images = [images]
26
+ self.images = []
27
+ for image in images:
28
+ if not isinstance(image, PIL.Image.Image):
29
+ image = PIL.Image.fromarray(image)
30
+ self.images.append(image)
31
+
32
+ def __len__(self):
33
+ return len(self.images)
34
+
35
+ def to_pil(self):
36
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
37
+
38
+ @property
39
+ def size(self):
40
+ sizes = [im.size for im in self.images]
41
+ assert all(sizes[0] == s for s in sizes)
42
+ return sizes[0]
43
+
44
+ def resize(self, *args, **kwargs):
45
+ return ImageList(self._dispatch('resize', *args, **kwargs))
46
+
47
+ def crop(self, *args, **kwargs):
48
+ return ImageList(self._dispatch('crop', *args, **kwargs))
49
+
50
+ def _dispatch(self, func, *args, **kwargs):
51
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
52
+
53
+
54
+ def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
55
+ """ Jointly rescale a (image, depthmap)
56
+ so that (out_width, out_height) >= output_res
57
+ """
58
+ image = ImageList(image)
59
+ input_resolution = np.array(image.size) # (W,H)
60
+ output_resolution = np.array(output_resolution)
61
+ if depthmap is not None:
62
+ # can also use this with masks instead of depthmaps
63
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
64
+ assert output_resolution.shape == (2,)
65
+ # define output resolution
66
+ scale_final = max(output_resolution / image.size) + 1e-8
67
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
68
+
69
+ # first rescale the image so that it contains the crop
70
+ image = image.resize(output_resolution, resample=lanczos)
71
+ if depthmap is not None:
72
+ depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
73
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
74
+
75
+ # no offset here; simple rescaling
76
+ camera_intrinsics = camera_matrix_of_crop(
77
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
78
+
79
+ return image.to_pil(), depthmap, camera_intrinsics
80
+
81
+
82
+ def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
83
+ # Margins to offset the origin
84
+ margins = np.asarray(input_resolution) * scaling - output_resolution
85
+ assert np.all(margins >= 0.0)
86
+ if offset is None:
87
+ offset = offset_factor * margins
88
+
89
+ # Generate new camera parameters
90
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
91
+ output_camera_matrix_colmap[:2, :] *= scaling
92
+ output_camera_matrix_colmap[:2, 2] -= offset
93
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
94
+
95
+ return output_camera_matrix
96
+
97
+
98
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
99
+ """
100
+ Return a crop of the input view.
101
+ """
102
+ image = ImageList(image)
103
+ l, t, r, b = crop_bbox
104
+
105
+ image = image.crop((l, t, r, b))
106
+ depthmap = depthmap[t:b, l:r]
107
+
108
+ camera_intrinsics = camera_intrinsics.copy()
109
+ camera_intrinsics[0, 2] -= l
110
+ camera_intrinsics[1, 2] -= t
111
+
112
+ return image.to_pil(), depthmap, camera_intrinsics
113
+
114
+
115
+ def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
116
+ out_width, out_height = output_resolution
117
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
118
+ crop_bbox = (l, t, l+out_width, t+out_height)
119
+ return crop_bbox
dust3r/datasets/utils/transforms.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUST3R default transforms
6
+ # --------------------------------------------------------
7
+ import torchvision.transforms as tvf
8
+ from dust3r.utils.image import ImgNorm
9
+
10
+ # define the standard image transforms
11
+ ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
dust3r/heads/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # head factory
6
+ # --------------------------------------------------------
7
+ from .linear_head import LinearPts3d
8
+ from .dpt_head import create_dpt_head
9
+
10
+
11
+ def head_factory(head_type, output_mode, net, has_conf=False):
12
+ """" build a prediction head for the decoder
13
+ """
14
+ if head_type == 'linear' and output_mode == 'pts3d':
15
+ return LinearPts3d(net, has_conf)
16
+ elif head_type == 'dpt' and output_mode == 'pts3d':
17
+ return create_dpt_head(net, has_conf=has_conf)
18
+ else:
19
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
dust3r/heads/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (618 Bytes). View file
 
dust3r/heads/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (621 Bytes). View file
 
dust3r/heads/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (621 Bytes). View file