|
from enum import Enum |
|
|
|
from .optimizer import PointCloudOptimizer |
|
from .modular_optimizer import ModularPointCloudOptimizer |
|
from .pair_viewer import PairViewer |
|
from mini_dust3r.inference import Dust3rResult |
|
from typing import Literal |
|
|
|
class GlobalAlignerMode(Enum): |
|
PointCloudOptimizer = "PointCloudOptimizer" |
|
ModularPointCloudOptimizer = "ModularPointCloudOptimizer" |
|
PairViewer = "PairViewer" |
|
|
|
def global_aligner( |
|
dust3r_output: Dust3rResult, |
|
device: Literal["cpu", "cuda", "mps"], |
|
mode: GlobalAlignerMode = GlobalAlignerMode.PointCloudOptimizer, |
|
**optim_kw, |
|
): |
|
|
|
view1, view2, pred1, pred2 = [ |
|
dust3r_output[k] for k in "view1 view2 pred1 pred2".split() |
|
] |
|
|
|
if mode == GlobalAlignerMode.PointCloudOptimizer: |
|
net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) |
|
elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: |
|
net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to( |
|
device |
|
) |
|
elif mode == GlobalAlignerMode.PairViewer: |
|
net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) |
|
else: |
|
raise NotImplementedError(f"Unknown mode {mode}") |
|
|
|
return net |