|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from mini_dust3r.utils.image import ImageDict |
|
|
|
|
|
def make_pairs( |
|
imgs: list[ImageDict], |
|
scene_graph: str = "complete", |
|
prefilter=None, |
|
symmetrize=True, |
|
) -> list[tuple[ImageDict, ImageDict]]: |
|
pairs = [] |
|
if scene_graph == "complete": |
|
for i in range(len(imgs)): |
|
for j in range(i): |
|
pairs.append((imgs[i], imgs[j])) |
|
elif scene_graph.startswith("swin"): |
|
winsize = int(scene_graph.split("-")[1]) if "-" in scene_graph else 3 |
|
pairsid = set() |
|
for i in range(len(imgs)): |
|
for j in range(1, winsize + 1): |
|
idx = (i + j) % len(imgs) |
|
pairsid.add((i, idx) if i < idx else (idx, i)) |
|
for i, j in pairsid: |
|
pairs.append((imgs[i], imgs[j])) |
|
elif scene_graph.startswith("oneref"): |
|
refid = int(scene_graph.split("-")[1]) if "-" in scene_graph else 0 |
|
for j in range(len(imgs)): |
|
if j != refid: |
|
pairs.append((imgs[refid], imgs[j])) |
|
if symmetrize: |
|
pairs += [(img2, img1) for img1, img2 in pairs] |
|
|
|
|
|
if isinstance(prefilter, str) and prefilter.startswith("seq"): |
|
pairs = filter_pairs_seq(pairs, int(prefilter[3:])) |
|
|
|
if isinstance(prefilter, str) and prefilter.startswith("cyc"): |
|
pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) |
|
|
|
return pairs |
|
|
|
|
|
def sel(x, kept): |
|
if isinstance(x, dict): |
|
return {k: sel(v, kept) for k, v in x.items()} |
|
if isinstance(x, (torch.Tensor, np.ndarray)): |
|
return x[kept] |
|
if isinstance(x, (tuple, list)): |
|
return type(x)([x[k] for k in kept]) |
|
|
|
|
|
def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): |
|
|
|
n = max(max(e) for e in edges) + 1 |
|
|
|
kept = [] |
|
for e, (i, j) in enumerate(edges): |
|
dis = abs(i - j) |
|
if cyclic: |
|
dis = min(dis, abs(i + n - j), abs(i - n - j)) |
|
if dis <= seq_dis_thr: |
|
kept.append(e) |
|
return kept |
|
|
|
|
|
def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): |
|
edges = [(img1["idx"], img2["idx"]) for img1, img2 in pairs] |
|
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) |
|
return [pairs[i] for i in kept] |
|
|
|
|
|
def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): |
|
edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])] |
|
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) |
|
print( |
|
f">> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges" |
|
) |
|
return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) |