|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
import random |
|
|
|
import numpy as np |
|
from dust3r.inference import inference |
|
from dust3r.model import AsymmetricCroCo3DStereo |
|
from dust3r.utils.geometry import find_reciprocal_matches, geotrf, xy_grid |
|
from dust3r_visloc.datasets import * |
|
from dust3r_visloc.evaluation import aggregate_stats, export_results, get_pose_error |
|
from dust3r_visloc.localization import run_pnp |
|
from tqdm import tqdm |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--dataset", type=str, required=True, help="visloc dataset to eval" |
|
) |
|
parser_weights = parser.add_mutually_exclusive_group(required=True) |
|
parser_weights.add_argument( |
|
"--weights", type=str, help="path to the model weights", default=None |
|
) |
|
parser_weights.add_argument( |
|
"--model_name", |
|
type=str, |
|
help="name of the model weights", |
|
choices=[ |
|
"DUSt3R_ViTLarge_BaseDecoder_512_dpt", |
|
"DUSt3R_ViTLarge_BaseDecoder_512_linear", |
|
"DUSt3R_ViTLarge_BaseDecoder_224_linear", |
|
], |
|
) |
|
parser.add_argument( |
|
"--confidence_threshold", |
|
type=float, |
|
default=3.0, |
|
help="confidence values higher than threshold are invalid", |
|
) |
|
parser.add_argument("--device", type=str, default="cuda", help="pytorch device") |
|
parser.add_argument( |
|
"--pnp_mode", |
|
type=str, |
|
default="cv2", |
|
choices=["cv2", "poselib", "pycolmap"], |
|
help="pnp lib to use", |
|
) |
|
parser_reproj = parser.add_mutually_exclusive_group() |
|
parser_reproj.add_argument( |
|
"--reprojection_error", type=float, default=5.0, help="pnp reprojection error" |
|
) |
|
parser_reproj.add_argument( |
|
"--reprojection_error_diag_ratio", |
|
type=float, |
|
default=None, |
|
help="pnp reprojection error as a ratio of the diagonal of the image", |
|
) |
|
|
|
parser.add_argument( |
|
"--pnp_max_points", |
|
type=int, |
|
default=100_000, |
|
help="pnp maximum number of points kept", |
|
) |
|
parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") |
|
|
|
parser.add_argument("--output_dir", type=str, default=None, help="output path") |
|
parser.add_argument( |
|
"--output_label", type=str, default="", help="prefix for results files" |
|
) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_args_parser() |
|
args = parser.parse_args() |
|
conf_thr = args.confidence_threshold |
|
device = args.device |
|
pnp_mode = args.pnp_mode |
|
reprojection_error = args.reprojection_error |
|
reprojection_error_diag_ratio = args.reprojection_error_diag_ratio |
|
pnp_max_points = args.pnp_max_points |
|
viz_matches = args.viz_matches |
|
|
|
if args.weights is not None: |
|
weights_path = args.weights |
|
else: |
|
weights_path = "naver/" + args.model_name |
|
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) |
|
|
|
dataset = eval(args.dataset) |
|
dataset.set_resolution(model) |
|
|
|
query_names = [] |
|
poses_pred = [] |
|
pose_errors = [] |
|
angular_errors = [] |
|
for idx in tqdm(range(len(dataset))): |
|
views = dataset[(idx)] |
|
query_view = views[0] |
|
map_views = views[1:] |
|
query_names.append(query_view["image_name"]) |
|
|
|
query_pts2d = [] |
|
query_pts3d = [] |
|
for map_view in map_views: |
|
|
|
imgs = [] |
|
for idx, img in enumerate( |
|
[query_view["rgb_rescaled"], map_view["rgb_rescaled"]] |
|
): |
|
imgs.append( |
|
dict( |
|
img=img.unsqueeze(0), |
|
true_shape=np.int32([img.shape[1:]]), |
|
idx=idx, |
|
instance=str(idx), |
|
) |
|
) |
|
output = inference( |
|
[tuple(imgs)], model, device, batch_size=1, verbose=False |
|
) |
|
pred1, pred2 = output["pred1"], output["pred2"] |
|
confidence_masks = [ |
|
pred1["conf"].squeeze(0) >= conf_thr, |
|
(pred2["conf"].squeeze(0) >= conf_thr) & map_view["valid_rescaled"], |
|
] |
|
pts3d = [pred1["pts3d"].squeeze(0), pred2["pts3d_in_other_view"].squeeze(0)] |
|
|
|
|
|
pts2d_list, pts3d_list = [], [] |
|
for i in range(2): |
|
conf_i = confidence_masks[i].cpu().numpy() |
|
true_shape_i = imgs[i]["true_shape"][0] |
|
pts2d_list.append(xy_grid(true_shape_i[1], true_shape_i[0])[conf_i]) |
|
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) |
|
|
|
PQ, PM = pts3d_list[0], pts3d_list[1] |
|
if len(PQ) == 0 or len(PM) == 0: |
|
continue |
|
reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(PQ, PM) |
|
if viz_matches > 0: |
|
print(f"found {num_matches} matches") |
|
matches_im1 = pts2d_list[1][reciprocal_in_PM] |
|
matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] |
|
valid_pts3d = map_view["pts3d_rescaled"][ |
|
matches_im1[:, 1], matches_im1[:, 0] |
|
] |
|
|
|
|
|
matches_im0 = matches_im0.astype(np.float64) |
|
matches_im1 = matches_im1.astype(np.float64) |
|
matches_im0[:, 0] += 0.5 |
|
matches_im0[:, 1] += 0.5 |
|
matches_im1[:, 0] += 0.5 |
|
matches_im1[:, 1] += 0.5 |
|
|
|
matches_im0 = geotrf(query_view["to_orig"], matches_im0, norm=True) |
|
matches_im1 = geotrf(query_view["to_orig"], matches_im1, norm=True) |
|
|
|
matches_im0[:, 0] -= 0.5 |
|
matches_im0[:, 1] -= 0.5 |
|
matches_im1[:, 0] -= 0.5 |
|
matches_im1[:, 1] -= 0.5 |
|
|
|
|
|
if viz_matches > 0: |
|
viz_imgs = [np.array(query_view["rgb"]), np.array(map_view["rgb"])] |
|
from matplotlib import pyplot as pl |
|
|
|
n_viz = viz_matches |
|
match_idx_to_viz = np.round( |
|
np.linspace(0, num_matches - 1, n_viz) |
|
).astype(int) |
|
viz_matches_im0, viz_matches_im1 = ( |
|
matches_im0[match_idx_to_viz], |
|
matches_im1[match_idx_to_viz], |
|
) |
|
|
|
H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] |
|
img0 = np.pad( |
|
viz_imgs[0], |
|
((0, max(H1 - H0, 0)), (0, 0), (0, 0)), |
|
"constant", |
|
constant_values=0, |
|
) |
|
img1 = np.pad( |
|
viz_imgs[1], |
|
((0, max(H0 - H1, 0)), (0, 0), (0, 0)), |
|
"constant", |
|
constant_values=0, |
|
) |
|
img = np.concatenate((img0, img1), axis=1) |
|
pl.figure() |
|
pl.imshow(img) |
|
cmap = pl.get_cmap("jet") |
|
for i in range(n_viz): |
|
(x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T |
|
pl.plot( |
|
[x0, x1 + W0], |
|
[y0, y1], |
|
"-+", |
|
color=cmap(i / (n_viz - 1)), |
|
scalex=False, |
|
scaley=False, |
|
) |
|
pl.show(block=True) |
|
|
|
if len(valid_pts3d) == 0: |
|
pass |
|
else: |
|
query_pts3d.append(valid_pts3d.cpu().numpy()) |
|
query_pts2d.append(matches_im0) |
|
|
|
if len(query_pts2d) == 0: |
|
success = False |
|
pr_querycam_to_world = None |
|
else: |
|
query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) |
|
query_pts3d = np.concatenate(query_pts3d, axis=0) |
|
if len(query_pts2d) > pnp_max_points: |
|
idxs = random.sample(range(len(query_pts2d)), pnp_max_points) |
|
query_pts3d = query_pts3d[idxs] |
|
query_pts2d = query_pts2d[idxs] |
|
|
|
W, H = query_view["rgb"].size |
|
if reprojection_error_diag_ratio is not None: |
|
reprojection_error_img = reprojection_error_diag_ratio * math.sqrt( |
|
W**2 + H**2 |
|
) |
|
else: |
|
reprojection_error_img = reprojection_error |
|
success, pr_querycam_to_world = run_pnp( |
|
query_pts2d, |
|
query_pts3d, |
|
query_view["intrinsics"], |
|
query_view["distortion"], |
|
pnp_mode, |
|
reprojection_error_img, |
|
img_size=[W, H], |
|
) |
|
|
|
if not success: |
|
abs_transl_error = float("inf") |
|
abs_angular_error = float("inf") |
|
else: |
|
abs_transl_error, abs_angular_error = get_pose_error( |
|
pr_querycam_to_world, query_view["cam_to_world"] |
|
) |
|
|
|
pose_errors.append(abs_transl_error) |
|
angular_errors.append(abs_angular_error) |
|
poses_pred.append(pr_querycam_to_world) |
|
|
|
xp_label = f"tol_conf_{conf_thr}" |
|
if args.output_label: |
|
xp_label = args.output_label + "_" + xp_label |
|
if reprojection_error_diag_ratio is not None: |
|
xp_label = xp_label + f"_reproj_diag_{reprojection_error_diag_ratio}" |
|
else: |
|
xp_label = xp_label + f"_reproj_err_{reprojection_error}" |
|
export_results(args.output_dir, xp_label, query_names, poses_pred) |
|
out_string = aggregate_stats(f"{args.dataset}", pose_errors, angular_errors) |
|
print(out_string) |
|
|