import numpy as np import torch import torch.nn as nn from pytorch3d.transforms import quaternion_to_matrix, matrix_to_quaternion from sugar.sugar_scene.sugar_model import SuGaR from sugar.sugar_scene.sugar_optimizer import SuGaROptimizer from sugar.sugar_utils.general_utils import inverse_sigmoid class SuGaRDensifier(): """Wrapper of the densification functions used for Gaussian Splatting and SuGaR optimization. Largely inspired by the original implementation of the 3D Gaussian Splatting paper: https://github.com/graphdeco-inria/gaussian-splatting """ def __init__( self, sugar_model:SuGaR, sugar_optimizer:SuGaROptimizer, max_grad=0.0002, min_opacity:float=0.005, max_screen_size:int=20, scene_extent:float=None, percent_dense:float=0.01, ) -> None: pass self.model = sugar_model self.optimizer = sugar_optimizer.optimizer self.points_gradient_accum = torch.zeros((self.model.points.shape[0], 1), device=self.model.device) self.denom = torch.zeros((self.model.points.shape[0], 1), device=self.model.device) self.max_radii2D = torch.zeros((self.model.points.shape[0]), device=self.model.device) self.max_grad = max_grad self.min_opacity = min_opacity self.max_screen_size = max_screen_size if scene_extent is None: self.scene_extent = sugar_model.get_cameras_spatial_extent() else: self.scene_extent = scene_extent self.percent_dense = percent_dense self.params_to_densify = [] if not self.model.freeze_gaussians: self.params_to_densify.extend(["points", "all_densities", "scales", "quaternions"]) self.params_to_densify.extend(["sh_coordinates_dc", "sh_coordinates_rest"]) def _prune_optimizer(self, mask): optimizable_tensors = {} for group in self.optimizer.param_groups: name = group["name"] if name in self.params_to_densify: stored_state = self.optimizer.state.get(group['params'][0], None) if stored_state is not None: stored_state["exp_avg"] = stored_state["exp_avg"][mask] stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] else: group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def prune_points(self, mask): valid_points_mask = ~mask optimizable_tensors = self._prune_optimizer(valid_points_mask) if "points" in self.params_to_densify: self.model._points = optimizable_tensors["points"] if "all_densities" in self.params_to_densify: self.model.all_densities = optimizable_tensors["all_densities"] if "scales" in self.params_to_densify: self.model._scales = optimizable_tensors["scales"] if "quaternions" in self.params_to_densify: self.model._quaternions = optimizable_tensors["quaternions"] if "sh_coordinates_dc" in self.params_to_densify: self.model._sh_coordinates_dc = optimizable_tensors["sh_coordinates_dc"] if "sh_coordinates_rest" in self.params_to_densify: self.model._sh_coordinates_rest = optimizable_tensors["sh_coordinates_rest"] self.points_gradient_accum = self.points_gradient_accum[valid_points_mask] self.denom = self.denom[valid_points_mask] self.max_radii2D = self.max_radii2D[valid_points_mask] def cat_tensors_to_optimizer(self, tensors_dict): optimizable_tensors = {} for group in self.optimizer.param_groups: name = group["name"] if name in self.params_to_densify: assert len(group["params"]) == 1 extension_tensor = tensors_dict[group["name"]] stored_state = self.optimizer.state.get(group['params'][0], None) dim_to_cat = 0 if stored_state is not None: stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=dim_to_cat) stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=dim_to_cat) del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=dim_to_cat).requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] else: group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=dim_to_cat).requires_grad_(True)) optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def replace_tensor_to_optimizer(self, tensor, name): optimizable_tensors = {} for group in self.optimizer.param_groups: if group["name"] == name: stored_state = self.optimizer.state.get(group['params'][0], None) stored_state["exp_avg"] = torch.zeros_like(tensor) stored_state["exp_avg_sq"] = torch.zeros_like(tensor) del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def densification_postfix(self, new_points, new_densities, new_scales, new_quaternions, new_sh_coordinates_dc=None, new_sh_coordinates_rest=None, ): tensors_dict = { "points": new_points, "all_densities": new_densities, "scales": new_scales, "quaternions": new_quaternions } tensors_dict["sh_coordinates_dc"] = new_sh_coordinates_dc tensors_dict["sh_coordinates_rest"] = new_sh_coordinates_rest optimizable_tensors = self.cat_tensors_to_optimizer(tensors_dict) self.model._points = optimizable_tensors["points"] self.model.all_densities = optimizable_tensors["all_densities"] self.model._scales = optimizable_tensors["scales"] self.model._quaternions = optimizable_tensors["quaternions"] self.model._sh_coordinates_dc = optimizable_tensors["sh_coordinates_dc"] self.model._sh_coordinates_rest = optimizable_tensors["sh_coordinates_rest"] self.points_gradient_accum = torch.zeros((self.model.points.shape[0], 1), device=self.model.device) self.denom = torch.zeros((self.model.points.shape[0], 1), device=self.model.device) self.max_radii2D = torch.zeros((self.model.points.shape[0]), device=self.model.device) def update_densification_stats(self, viewspace_point_tensor, radii, visibility_filter): # Updates maximum observed 2D radii of all gaussians self.max_radii2D[visibility_filter] = torch.max(self.max_radii2D[visibility_filter], radii[visibility_filter]) # Accumulate gradient magnitudes of all points self.points_gradient_accum[visibility_filter] += torch.norm(viewspace_point_tensor.grad[visibility_filter, :2], dim=-1, keepdim=True) # Counts number of updates for each point self.denom[visibility_filter] += 1 def densify_and_clone(self, grads, max_grad, extent): max_grad = self.max_grad if max_grad is None else max_grad extent = self.scene_extent if extent is None else extent # Extract points that satisfy the gradient condition selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= max_grad, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, torch.max(self.model.scaling, dim=1).values <= self.percent_dense * extent) new_points = self.model._points[selected_pts_mask] new_densities = self.model.all_densities[selected_pts_mask] new_scales = self.model._scales[selected_pts_mask] new_quaternions = self.model._quaternions[selected_pts_mask] new_sh_coordinates_dc = self.model._sh_coordinates_dc[selected_pts_mask] new_sh_coordinates_rest = self.model._sh_coordinates_rest[selected_pts_mask] self.densification_postfix( new_points=new_points, new_densities=new_densities, new_scales=new_scales, new_quaternions=new_quaternions, new_sh_coordinates_dc=new_sh_coordinates_dc, new_sh_coordinates_rest=new_sh_coordinates_rest, ) def densify_and_split(self, grads, max_grad, extent, N=2): max_grad = self.max_grad if max_grad is None else max_grad extent = self.scene_extent if extent is None else extent n_init_points = self.model._points.shape[0] # Extract points that satisfy the gradient condition padded_grad = torch.zeros((n_init_points), device="cuda") padded_grad[:grads.shape[0]] = grads.squeeze() selected_pts_mask = torch.where(padded_grad >= max_grad, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, torch.max(self.model.scaling, dim=1).values > self.percent_dense*extent) stds = self.model.scaling[selected_pts_mask].repeat(N,1) means = torch.zeros((stds.size(0), 3),device="cuda") samples = torch.normal(mean=means, std=stds) rots = quaternion_to_matrix(self.model.quaternions[selected_pts_mask]).repeat(N, 1, 1) new_points = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.model.points[selected_pts_mask].repeat(N, 1) new_scales = self.model.scale_inverse_activation(self.model.scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) new_quaternions = self.model._quaternions[selected_pts_mask].repeat(N,1) new_densities = self.model.all_densities[selected_pts_mask].repeat(N,1) new_sh_coordinates_dc = self.model._sh_coordinates_dc[selected_pts_mask].repeat(N,1,1) new_sh_coordinates_rest = self.model._sh_coordinates_rest[selected_pts_mask].repeat(N,1,1) self.densification_postfix( new_points=new_points, new_densities=new_densities, new_scales=new_scales, new_quaternions=new_quaternions, new_sh_coordinates_dc=new_sh_coordinates_dc, new_sh_coordinates_rest=new_sh_coordinates_rest, ) prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=self.model.device, dtype=bool))) self.prune_points(prune_filter) def densify_and_prune(self, max_grad:float=None, min_opacity:float=None, extent:float=None, max_screen_size:int=None): max_grad = self.max_grad if max_grad is None else max_grad min_opacity = self.min_opacity if min_opacity is None else min_opacity extent = self.scene_extent if extent is None else extent grads = self.points_gradient_accum / self.denom grads[grads.isnan()] = 0.0 self.densify_and_clone(grads, max_grad, extent) self.densify_and_split(grads, max_grad, extent) prune_mask = (self.model.strengths < min_opacity).squeeze() if max_screen_size: big_points_vs = self.max_radii2D > max_screen_size big_points_ws = self.model.scaling.max(dim=1).values > 0.1 * extent prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) self.prune_points(prune_mask) torch.cuda.empty_cache() def reset_opacity(self): opacities_new = inverse_sigmoid(torch.min(self.model.strengths, torch.ones_like(self.model.all_densities.view(-1, 1))*0.01)) optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "all_densities") self.all_densities = optimizable_tensors["all_densities"]