import torch import numpy as np from . import seg_dvgo as dvgo import time from .utils import load_model __ALL__ = ['_compute_bbox_by_cam_frustrm_bounded', '_compute_bbox_by_cam_frustrm_unbounded', 'compute_bbox_by_cam_frustrm', 'compute_bbox_by_coarse_geo'] def _compute_bbox_by_cam_frustrm_bounded(cfg, HW, Ks, poses, i_train, near, far): xyz_min = torch.Tensor([np.inf, np.inf, np.inf]) xyz_max = -xyz_min for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]): rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view( H=H, W=W, K=K, c2w=c2w, ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y, flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y) if cfg.data.ndc: pts_nf = torch.stack([rays_o+rays_d*near, rays_o+rays_d*far]) else: pts_nf = torch.stack([rays_o+viewdirs*near, rays_o+viewdirs*far]) xyz_min = torch.minimum(xyz_min, pts_nf.amin((0,1,2))) xyz_max = torch.maximum(xyz_max, pts_nf.amax((0,1,2))) return xyz_min, xyz_max def _compute_bbox_by_cam_frustrm_unbounded(cfg, HW, Ks, poses, i_train, near_clip): # Find a tightest cube that cover all camera centers xyz_min = torch.Tensor([np.inf, np.inf, np.inf]) xyz_max = -xyz_min for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]): rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view( H=H, W=W, K=K, c2w=c2w, ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y, flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y) pts = rays_o + rays_d * near_clip xyz_min = torch.minimum(xyz_min, pts.amin((0,1))) xyz_max = torch.maximum(xyz_max, pts.amax((0,1))) center = (xyz_min + xyz_max) * 0.5 radius = (center - xyz_min).max() * cfg.data.unbounded_inner_r xyz_min = center - radius xyz_max = center + radius return xyz_min, xyz_max def compute_bbox_by_cam_frustrm(args, cfg, HW, Ks, poses, i_train, near, far, **kwargs): print('compute_bbox_by_cam_frustrm: start') if cfg.data.unbounded_inward: xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_unbounded( cfg, HW, Ks, poses, i_train, kwargs.get('near_clip', None)) else: xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_bounded( cfg, HW, Ks, poses, i_train, near, far) print('compute_bbox_by_cam_frustrm: xyz_min', xyz_min) print('compute_bbox_by_cam_frustrm: xyz_max', xyz_max) print('compute_bbox_by_cam_frustrm: finish') return xyz_min, xyz_max @torch.no_grad() def compute_bbox_by_coarse_geo(model_class, model_path, thres): print('compute_bbox_by_coarse_geo: start') eps_time = time.time() model = load_model(model_class, model_path) interp = torch.stack(torch.meshgrid( torch.linspace(0, 1, model.world_size[0]), torch.linspace(0, 1, model.world_size[1]), torch.linspace(0, 1, model.world_size[2]), ), -1) dense_xyz = model.xyz_min * (1-interp) + model.xyz_max * interp density = model.density(dense_xyz) alpha = model.activate_density(density) mask = (alpha > thres) active_xyz = dense_xyz[mask] xyz_min = active_xyz.amin(0) xyz_max = active_xyz.amax(0) print('compute_bbox_by_coarse_geo: xyz_min', xyz_min) print('compute_bbox_by_coarse_geo: xyz_max', xyz_max) eps_time = time.time() - eps_time print('compute_bbox_by_coarse_geo: finish (eps time:', eps_time, 'secs)') return xyz_min, xyz_max