Our3D / lib /bbox_utils.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
3.56 kB
import torch
import numpy as np
from . import seg_dvgo as dvgo
import time
from .utils import load_model
__ALL__ = ['_compute_bbox_by_cam_frustrm_bounded',
'_compute_bbox_by_cam_frustrm_unbounded',
'compute_bbox_by_cam_frustrm',
'compute_bbox_by_coarse_geo']
def _compute_bbox_by_cam_frustrm_bounded(cfg, HW, Ks, poses, i_train, near, far):
xyz_min = torch.Tensor([np.inf, np.inf, np.inf])
xyz_max = -xyz_min
for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]):
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H=H, W=W, K=K, c2w=c2w,
ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
if cfg.data.ndc:
pts_nf = torch.stack([rays_o+rays_d*near, rays_o+rays_d*far])
else:
pts_nf = torch.stack([rays_o+viewdirs*near, rays_o+viewdirs*far])
xyz_min = torch.minimum(xyz_min, pts_nf.amin((0,1,2)))
xyz_max = torch.maximum(xyz_max, pts_nf.amax((0,1,2)))
return xyz_min, xyz_max
def _compute_bbox_by_cam_frustrm_unbounded(cfg, HW, Ks, poses, i_train, near_clip):
# Find a tightest cube that cover all camera centers
xyz_min = torch.Tensor([np.inf, np.inf, np.inf])
xyz_max = -xyz_min
for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]):
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H=H, W=W, K=K, c2w=c2w,
ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
pts = rays_o + rays_d * near_clip
xyz_min = torch.minimum(xyz_min, pts.amin((0,1)))
xyz_max = torch.maximum(xyz_max, pts.amax((0,1)))
center = (xyz_min + xyz_max) * 0.5
radius = (center - xyz_min).max() * cfg.data.unbounded_inner_r
xyz_min = center - radius
xyz_max = center + radius
return xyz_min, xyz_max
def compute_bbox_by_cam_frustrm(args, cfg, HW, Ks, poses, i_train, near, far, **kwargs):
print('compute_bbox_by_cam_frustrm: start')
if cfg.data.unbounded_inward:
xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_unbounded(
cfg, HW, Ks, poses, i_train, kwargs.get('near_clip', None))
else:
xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_bounded(
cfg, HW, Ks, poses, i_train, near, far)
print('compute_bbox_by_cam_frustrm: xyz_min', xyz_min)
print('compute_bbox_by_cam_frustrm: xyz_max', xyz_max)
print('compute_bbox_by_cam_frustrm: finish')
return xyz_min, xyz_max
@torch.no_grad()
def compute_bbox_by_coarse_geo(model_class, model_path, thres):
print('compute_bbox_by_coarse_geo: start')
eps_time = time.time()
model = load_model(model_class, model_path)
interp = torch.stack(torch.meshgrid(
torch.linspace(0, 1, model.world_size[0]),
torch.linspace(0, 1, model.world_size[1]),
torch.linspace(0, 1, model.world_size[2]),
), -1)
dense_xyz = model.xyz_min * (1-interp) + model.xyz_max * interp
density = model.density(dense_xyz)
alpha = model.activate_density(density)
mask = (alpha > thres)
active_xyz = dense_xyz[mask]
xyz_min = active_xyz.amin(0)
xyz_max = active_xyz.amax(0)
print('compute_bbox_by_coarse_geo: xyz_min', xyz_min)
print('compute_bbox_by_coarse_geo: xyz_max', xyz_max)
eps_time = time.time() - eps_time
print('compute_bbox_by_coarse_geo: finish (eps time:', eps_time, 'secs)')
return xyz_min, xyz_max