Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import open3d as o3d | |
from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency | |
from pytorch3d.transforms import quaternion_apply, quaternion_invert | |
from sugar.sugar_scene.gs_model import GaussianSplattingWrapper, fetchPly | |
from sugar.sugar_scene.sugar_model import SuGaR | |
from sugar.sugar_scene.sugar_optimizer import OptimizationParams, SuGaROptimizer | |
from sugar.sugar_scene.sugar_densifier import SuGaRDensifier | |
from sugar.sugar_utils.loss_utils import ssim, l1_loss, l2_loss | |
from rich.console import Console | |
import time | |
def coarse_training_with_density_regularization(args): | |
CONSOLE = Console(width=120) | |
# ====================Parameters==================== | |
num_device = args.gpu | |
detect_anomaly = False | |
# -----Data parameters----- | |
downscale_resolution_factor = 1 # 2, 4 | |
# -----Model parameters----- | |
use_eval_split = True | |
n_skip_images_for_eval_split = 8 | |
freeze_gaussians = False | |
initialize_from_trained_3dgs = True # True or False | |
if initialize_from_trained_3dgs: | |
prune_at_start = False | |
start_pruning_threshold = 0.5 | |
no_rendering = freeze_gaussians | |
n_points_at_start = None # If None, takes all points in the SfM point cloud | |
learnable_positions = True # True in 3DGS | |
use_same_scale_in_all_directions = False # Should be False | |
sh_levels = 4 | |
# -----Radiance Mesh----- | |
triangle_scale=1. | |
# -----Rendering parameters----- | |
compute_color_in_rasterizer = False # TODO: Try True | |
# -----Optimization parameters----- | |
# Learning rates and scheduling | |
num_iterations = 15_000 # Changed | |
spatial_lr_scale = None | |
position_lr_init=0.00016 | |
position_lr_final=0.0000016 | |
position_lr_delay_mult=0.01 | |
position_lr_max_steps=30_000 | |
feature_lr=0.0025 | |
opacity_lr=0.05 | |
scaling_lr=0.005 | |
rotation_lr=0.001 | |
# Densifier and pruning | |
heavy_densification = False | |
if initialize_from_trained_3dgs: | |
densify_from_iter = 500 + 99999 # 500 # Maybe reduce this, since we have a better initialization? | |
densify_until_iter = 7000 - 7000 # 7000 | |
else: | |
densify_from_iter = 500 # 500 # Maybe reduce this, since we have a better initialization? | |
densify_until_iter = 7000 # 7000 | |
if heavy_densification: | |
densification_interval = 50 # 100 | |
opacity_reset_interval = 3000 # 3000 | |
densify_grad_threshold = 0.0001 # 0.0002 | |
densify_screen_size_threshold = 20 | |
prune_opacity_threshold = 0.005 | |
densification_percent_distinction = 0.01 | |
else: | |
densification_interval = 100 # 100 | |
opacity_reset_interval = 3000 # 3000 | |
densify_grad_threshold = 0.0002 # 0.0002 | |
densify_screen_size_threshold = 20 | |
prune_opacity_threshold = 0.005 | |
densification_percent_distinction = 0.01 | |
# Data processing and batching | |
n_images_to_use_for_training = -1 # If -1, uses all images | |
train_num_images_per_batch = 1 # 1 for full images | |
# Loss functions | |
loss_function = 'l1+dssim' # 'l1' or 'l2' or 'l1+dssim' | |
if loss_function == 'l1+dssim': | |
dssim_factor = 0.2 | |
# Regularization | |
enforce_entropy_regularization = True | |
if enforce_entropy_regularization: | |
start_entropy_regularization_from = 7000 | |
end_entropy_regularization_at = 9000 # TODO: Change | |
entropy_regularization_factor = 0.1 | |
regularize_sdf = True | |
if regularize_sdf: | |
beta_mode = 'average' # 'learnable', 'average' or 'weighted_average' | |
start_sdf_regularization_from = 9000 | |
regularize_sdf_only_for_gaussians_with_high_opacity = False | |
if regularize_sdf_only_for_gaussians_with_high_opacity: | |
sdf_regularization_opacity_threshold = 0.5 | |
use_sdf_estimation_loss = True | |
enforce_samples_to_be_on_surface = False | |
if use_sdf_estimation_loss or enforce_samples_to_be_on_surface: | |
sdf_estimation_mode = 'density' # 'sdf' or 'density' | |
# sdf_estimation_factor = 0.2 # 0.1 or 0.2? | |
samples_on_surface_factor = 0.2 # 0.05 | |
squared_sdf_estimation_loss = False | |
squared_samples_on_surface_loss = False | |
normalize_by_sdf_std = False # False | |
start_sdf_estimation_from = 9000 # 7000 | |
sample_only_in_gaussians_close_to_surface = True | |
close_gaussian_threshold = 2. # 2. | |
use_projection_as_estimation = True | |
if use_projection_as_estimation: | |
sample_only_in_gaussians_close_to_surface = False | |
backpropagate_gradients_through_depth = True # True | |
use_sdf_better_normal_loss = True | |
if use_sdf_better_normal_loss: | |
start_sdf_better_normal_from = 9000 | |
# sdf_better_normal_factor = 0.2 # 0.1 or 0.2? | |
sdf_better_normal_gradient_through_normal_only = True | |
density_factor = 1. / 16. # Should be equal to 1. / regularity_knn | |
if (use_sdf_estimation_loss or enforce_samples_to_be_on_surface) and sdf_estimation_mode == 'density': | |
density_factor = 1. | |
density_threshold = 1. # 0.5 * density_factor | |
n_samples_for_sdf_regularization = 1_000_000 # 300_000 | |
sdf_sampling_scale_factor = 1.5 | |
sdf_sampling_proportional_to_volume = False | |
bind_to_surface_mesh = False | |
if bind_to_surface_mesh: | |
learn_surface_mesh_positions = True | |
learn_surface_mesh_opacity = True | |
learn_surface_mesh_scales = True | |
n_gaussians_per_surface_triangle=6 # 1, 3, 4 or 6 | |
use_surface_mesh_laplacian_smoothing_loss = True | |
if use_surface_mesh_laplacian_smoothing_loss: | |
surface_mesh_laplacian_smoothing_method = "uniform" # "cotcurv", "cot", "uniform" | |
surface_mesh_laplacian_smoothing_factor = 5. # 0.1 | |
use_surface_mesh_normal_consistency_loss = True | |
if use_surface_mesh_normal_consistency_loss: | |
surface_mesh_normal_consistency_factor = 0.1 # 0.1 | |
densify_from_iter = 999_999 | |
densify_until_iter = 0 | |
position_lr_init=0.00016 * 0.01 | |
position_lr_final=0.0000016 * 0.01 | |
scaling_lr=0.005 | |
else: | |
surface_mesh_to_bind_path = None | |
if regularize_sdf: | |
regularize = True | |
regularity_knn = 16 # 8 until now | |
# regularity_knn = 8 | |
regularity_samples = -1 # Retry with 1000, 10000 | |
reset_neighbors_every = 500 # 500 until now | |
regularize_from = 7000 # 0 until now | |
start_reset_neighbors_from = 7000+1 # 0 until now (should be equal to regularize_from + 1?) | |
prune_when_starting_regularization = False | |
else: | |
regularize = False | |
regularity_knn = 0 | |
if bind_to_surface_mesh: | |
regularize = False | |
regularity_knn = 0 | |
# Opacity management | |
prune_low_opacity_gaussians_at = [9000] | |
if bind_to_surface_mesh: | |
prune_low_opacity_gaussians_at = [999_999] | |
prune_hard_opacity_threshold = 0.5 | |
# Warmup | |
do_resolution_warmup = False | |
if do_resolution_warmup: | |
resolution_warmup_every = 500 | |
current_resolution_factor = downscale_resolution_factor * 4. | |
else: | |
current_resolution_factor = downscale_resolution_factor | |
do_sh_warmup = True # Should be True | |
if initialize_from_trained_3dgs: | |
do_sh_warmup = False | |
sh_levels = 4 # nerfmodel.gaussians.active_sh_degree + 1 | |
CONSOLE.print("Changing sh_levels to match the loaded model:", sh_levels) | |
if do_sh_warmup: | |
sh_warmup_every = 1000 | |
current_sh_levels = 1 | |
else: | |
current_sh_levels = sh_levels | |
# -----Log and save----- | |
print_loss_every_n_iterations = 50 | |
save_model_every_n_iterations = 1_000_000 | |
save_milestones = [9000, 12_000, 15_000] | |
# ====================End of parameters==================== | |
if args.output_dir is None: | |
if len(args.scene_path.split("/")[-1]) > 0: | |
args.output_dir = os.path.join("./output/coarse", args.scene_path.split("/")[-1]) | |
else: | |
args.output_dir = os.path.join("./output/coarse", args.scene_path.split("/")[-2]) | |
source_path = args.scene_path | |
gs_checkpoint_path = args.checkpoint_path | |
iteration_to_load = args.iteration_to_load | |
sdf_estimation_factor = args.estimation_factor | |
sdf_better_normal_factor = args.normal_factor | |
sugar_checkpoint_path = f'sugarcoarse_3Dgs{iteration_to_load}_densityestimXX_sdfnormYY/' | |
sugar_checkpoint_path = os.path.join(args.output_dir, sugar_checkpoint_path) | |
sugar_checkpoint_path = sugar_checkpoint_path.replace( | |
'XX', str(sdf_estimation_factor).replace('.', '') | |
).replace( | |
'YY', str(sdf_better_normal_factor).replace('.', '') | |
) | |
use_eval_split = args.eval | |
ply_path = os.path.join(source_path, "sparse/0/points3D.ply") | |
CONSOLE.print("-----Parsed parameters-----") | |
CONSOLE.print("Source path:", source_path) | |
CONSOLE.print(" > Content:", len(os.listdir(source_path))) | |
CONSOLE.print("Gaussian Splatting checkpoint path:", gs_checkpoint_path) | |
CONSOLE.print(" > Content:", len(os.listdir(gs_checkpoint_path))) | |
CONSOLE.print("SUGAR checkpoint path:", sugar_checkpoint_path) | |
CONSOLE.print("Iteration to load:", iteration_to_load) | |
CONSOLE.print("Output directory:", args.output_dir) | |
CONSOLE.print("SDF estimation factor:", sdf_estimation_factor) | |
CONSOLE.print("SDF better normal factor:", sdf_better_normal_factor) | |
CONSOLE.print("Eval split:", use_eval_split) | |
CONSOLE.print("---------------------------") | |
# Setup device | |
torch.cuda.set_device(num_device) | |
CONSOLE.print("Using device:", num_device) | |
device = torch.device(f'cuda:{num_device}') | |
CONSOLE.print(torch.cuda.memory_summary()) | |
torch.autograd.set_detect_anomaly(detect_anomaly) | |
# Creates save directory if it does not exist | |
os.makedirs(sugar_checkpoint_path, exist_ok=True) | |
# ====================Load NeRF model and training data==================== | |
# Load Gaussian Splatting checkpoint | |
CONSOLE.print(f"\nLoading config {gs_checkpoint_path}...") | |
if use_eval_split: | |
CONSOLE.print("Performing train/eval split...") | |
nerfmodel = GaussianSplattingWrapper( | |
source_path=source_path, | |
output_path=gs_checkpoint_path, | |
iteration_to_load=iteration_to_load, | |
load_gt_images=True, | |
eval_split=use_eval_split, | |
eval_split_interval=n_skip_images_for_eval_split, | |
) | |
CONSOLE.print(f'{len(nerfmodel.training_cameras)} training images detected.') | |
CONSOLE.print(f'The model has been trained for {iteration_to_load} steps.') | |
if downscale_resolution_factor != 1: | |
nerfmodel.downscale_output_resolution(downscale_resolution_factor) | |
CONSOLE.print(f'\nCamera resolution scaled to ' | |
f'{nerfmodel.training_cameras.gs_cameras[0].image_height} x ' | |
f'{nerfmodel.training_cameras.gs_cameras[0].image_width}' | |
) | |
# Point cloud | |
if initialize_from_trained_3dgs: | |
with torch.no_grad(): | |
print("Initializing model from trained 3DGS...") | |
with torch.no_grad(): | |
sh_levels = int(np.sqrt(nerfmodel.gaussians.get_features.shape[1])) | |
from sugar.sugar_utils.spherical_harmonics import SH2RGB | |
points = nerfmodel.gaussians.get_xyz.detach().float().cuda() | |
colors = SH2RGB(nerfmodel.gaussians.get_features[:, 0].detach().float().cuda()) | |
if prune_at_start: | |
with torch.no_grad(): | |
start_prune_mask = nerfmodel.gaussians.get_opacity.view(-1) > start_pruning_threshold | |
points = points[start_prune_mask] | |
colors = colors[start_prune_mask] | |
n_points = len(points) | |
else: | |
CONSOLE.print("\nLoading SfM point cloud...") | |
pcd = fetchPly(ply_path) | |
points = torch.tensor(pcd.points, device=nerfmodel.device).float().cuda() | |
colors = torch.tensor(pcd.colors, device=nerfmodel.device).float().cuda() | |
if n_points_at_start is not None: | |
n_points = n_points_at_start | |
pts_idx = torch.randperm(len(points))[:n_points] | |
points, colors = points.to(device)[pts_idx], colors.to(device)[pts_idx] | |
else: | |
n_points = len(points) | |
CONSOLE.print(f"Point cloud generated. Number of points: {len(points)}") | |
# Mesh to bind to if needed TODO | |
if bind_to_surface_mesh: | |
surface_mesh_to_bind_full_path = os.path.join('./results/meshes/', surface_mesh_to_bind_path) | |
CONSOLE.print(f'\nLoading mesh to bind to: {surface_mesh_to_bind_full_path}...') | |
o3d_mesh = o3d.io.read_triangle_mesh(surface_mesh_to_bind_full_path) | |
CONSOLE.print("Mesh to bind to loaded.") | |
else: | |
o3d_mesh = None | |
learn_surface_mesh_positions = False | |
learn_surface_mesh_opacity = False | |
learn_surface_mesh_scales = False | |
n_gaussians_per_surface_triangle=1 | |
if not regularize_sdf: | |
beta_mode = None | |
# ====================Initialize SuGaR model==================== | |
# Construct SuGaR model | |
sugar = SuGaR( | |
nerfmodel=nerfmodel, | |
points=points, #nerfmodel.gaussians.get_xyz.data, | |
colors=colors, #0.5 + _C0 * nerfmodel.gaussians.get_features.data[:, 0, :], | |
initialize=True, | |
sh_levels=sh_levels, | |
learnable_positions=learnable_positions, | |
triangle_scale=triangle_scale, | |
keep_track_of_knn=regularize, | |
knn_to_track=regularity_knn, | |
beta_mode=beta_mode, | |
freeze_gaussians=freeze_gaussians, | |
surface_mesh_to_bind=o3d_mesh, | |
surface_mesh_thickness=None, | |
learn_surface_mesh_positions=learn_surface_mesh_positions, | |
learn_surface_mesh_opacity=learn_surface_mesh_opacity, | |
learn_surface_mesh_scales=learn_surface_mesh_scales, | |
n_gaussians_per_surface_triangle=n_gaussians_per_surface_triangle, | |
) | |
if initialize_from_trained_3dgs: | |
with torch.no_grad(): | |
CONSOLE.print("Initializing 3D gaussians from 3D gaussians...") | |
if prune_at_start: | |
sugar._scales[...] = nerfmodel.gaussians._scaling.detach()[start_prune_mask] | |
sugar._quaternions[...] = nerfmodel.gaussians._rotation.detach()[start_prune_mask] | |
sugar.all_densities[...] = nerfmodel.gaussians._opacity.detach()[start_prune_mask] | |
sugar._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach()[start_prune_mask] | |
sugar._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach()[start_prune_mask] | |
else: | |
sugar._scales[...] = nerfmodel.gaussians._scaling.detach() | |
sugar._quaternions[...] = nerfmodel.gaussians._rotation.detach() | |
sugar.all_densities[...] = nerfmodel.gaussians._opacity.detach() | |
sugar._sh_coordinates_dc[...] = nerfmodel.gaussians._features_dc.detach() | |
sugar._sh_coordinates_rest[...] = nerfmodel.gaussians._features_rest.detach() | |
CONSOLE.print(f'\nSuGaR model has been initialized.') | |
CONSOLE.print(sugar) | |
CONSOLE.print(f'Number of parameters: {sum(p.numel() for p in sugar.parameters() if p.requires_grad)}') | |
CONSOLE.print(f'Checkpoints will be saved in {sugar_checkpoint_path}') | |
CONSOLE.print("\nModel parameters:") | |
for name, param in sugar.named_parameters(): | |
CONSOLE.print(name, param.shape, param.requires_grad) | |
torch.cuda.empty_cache() | |
# Compute scene extent | |
cameras_spatial_extent = sugar.get_cameras_spatial_extent() | |
# ====================Initialize optimizer==================== | |
if spatial_lr_scale is None: | |
spatial_lr_scale = cameras_spatial_extent | |
print("Using camera spatial extent as spatial_lr_scale:", spatial_lr_scale) | |
opt_params = OptimizationParams( | |
iterations=num_iterations, | |
position_lr_init=position_lr_init, | |
position_lr_final=position_lr_final, | |
position_lr_delay_mult=position_lr_delay_mult, | |
position_lr_max_steps=position_lr_max_steps, | |
feature_lr=feature_lr, | |
opacity_lr=opacity_lr, | |
scaling_lr=scaling_lr, | |
rotation_lr=rotation_lr, | |
) | |
optimizer = SuGaROptimizer(sugar, opt_params, spatial_lr_scale=spatial_lr_scale) | |
CONSOLE.print("Optimizer initialized.") | |
CONSOLE.print("Optimization parameters:") | |
CONSOLE.print(opt_params) | |
CONSOLE.print("Optimizable parameters:") | |
for param_group in optimizer.optimizer.param_groups: | |
CONSOLE.print(param_group['name'], param_group['lr']) | |
# ====================Initialize densifier==================== | |
gaussian_densifier = SuGaRDensifier( | |
sugar_model=sugar, | |
sugar_optimizer=optimizer, | |
max_grad=densify_grad_threshold, | |
min_opacity=prune_opacity_threshold, | |
max_screen_size=densify_screen_size_threshold, | |
scene_extent=cameras_spatial_extent, | |
percent_dense=densification_percent_distinction, | |
) | |
CONSOLE.print("Densifier initialized.") | |
# ====================Loss function==================== | |
if loss_function == 'l1': | |
loss_fn = l1_loss | |
elif loss_function == 'l2': | |
loss_fn = l2_loss | |
elif loss_function == 'l1+dssim': | |
def loss_fn(pred_rgb, gt_rgb): | |
return (1.0 - dssim_factor) * l1_loss(pred_rgb, gt_rgb) + dssim_factor * (1.0 - ssim(pred_rgb, gt_rgb)) | |
CONSOLE.print(f'Using loss function: {loss_function}') | |
# ====================Start training==================== | |
sugar.train() | |
epoch = 0 | |
iteration = 0 | |
train_losses = [] | |
t0 = time.time() | |
if initialize_from_trained_3dgs: | |
iteration = 7000 - 1 | |
for batch in range(9_999_999): | |
if iteration >= num_iterations: | |
break | |
# Shuffle images | |
shuffled_idx = torch.randperm(len(nerfmodel.training_cameras)) | |
train_num_images = len(shuffled_idx) | |
# We iterate on images | |
for i in range(0, train_num_images, train_num_images_per_batch): | |
iteration += 1 | |
# Update learning rates | |
optimizer.update_learning_rate(iteration) | |
# Prune low-opacity gaussians for optimizing triangles | |
if ( | |
regularize and prune_when_starting_regularization and iteration == regularize_from + 1 | |
) or ( | |
(iteration-1) in prune_low_opacity_gaussians_at | |
): | |
CONSOLE.print("\nPruning gaussians with low-opacity for further optimization...") | |
prune_mask = (gaussian_densifier.model.strengths < prune_hard_opacity_threshold).squeeze() | |
gaussian_densifier.prune_points(prune_mask) | |
CONSOLE.print(f'Pruning finished: {sugar.n_points} gaussians left.') | |
if regularize and iteration >= start_reset_neighbors_from: | |
sugar.reset_neighbors() | |
start_idx = i | |
end_idx = min(i+train_num_images_per_batch, train_num_images) | |
camera_indices = shuffled_idx[start_idx:end_idx] | |
# Computing rgb predictions | |
if not no_rendering: | |
outputs = sugar.render_image_gaussian_rasterizer( | |
camera_indices=camera_indices.item(), | |
verbose=False, | |
bg_color = None, | |
sh_deg=current_sh_levels-1, | |
sh_rotations=None, | |
compute_color_in_rasterizer=compute_color_in_rasterizer, | |
compute_covariance_in_rasterizer=True, | |
return_2d_radii=True, | |
quaternions=None, | |
use_same_scale_in_all_directions=use_same_scale_in_all_directions, | |
return_opacities=enforce_entropy_regularization, | |
) | |
pred_rgb = outputs['image'].view(-1, | |
sugar.image_height, | |
sugar.image_width, | |
3) | |
radii = outputs['radii'] | |
viewspace_points = outputs['viewspace_points'] | |
if enforce_entropy_regularization: | |
opacities = outputs['opacities'] | |
pred_rgb = pred_rgb.transpose(-1, -2).transpose(-2, -3) # TODO: Change for torch.permute | |
# Gather rgb ground truth | |
gt_image = nerfmodel.get_gt_image(camera_indices=camera_indices) | |
gt_rgb = gt_image.view(-1, sugar.image_height, sugar.image_width, 3) | |
gt_rgb = gt_rgb.transpose(-1, -2).transpose(-2, -3) | |
# Compute loss | |
loss = loss_fn(pred_rgb, gt_rgb) | |
if enforce_entropy_regularization and iteration > start_entropy_regularization_from and iteration < end_entropy_regularization_at: | |
if iteration == start_entropy_regularization_from + 1: | |
CONSOLE.print("\n---INFO---\nStarting entropy regularization.") | |
if iteration == end_entropy_regularization_at - 1: | |
CONSOLE.print("\n---INFO---\nStopping entropy regularization.") | |
visibility_filter = radii > 0 | |
if visibility_filter is not None: | |
vis_opacities = opacities[visibility_filter] | |
else: | |
vis_opacities = opacities | |
loss = loss + entropy_regularization_factor * ( | |
- vis_opacities * torch.log(vis_opacities + 1e-10) | |
- (1 - vis_opacities) * torch.log(1 - vis_opacities + 1e-10) | |
).mean() | |
if regularize: | |
if iteration == regularize_from: | |
CONSOLE.print("Starting regularization...") | |
# sugar.reset_neighbors() | |
if iteration > regularize_from: | |
visibility_filter = radii > 0 | |
if (iteration >= start_reset_neighbors_from) and ((iteration == regularize_from + 1) or (iteration % reset_neighbors_every == 0)): | |
CONSOLE.print("\n---INFO---\nResetting neighbors...") | |
sugar.reset_neighbors() | |
neighbor_idx = sugar.get_neighbors_of_random_points(num_samples=regularity_samples,) | |
if visibility_filter is not None: | |
neighbor_idx = neighbor_idx[visibility_filter] # TODO: Error here | |
if regularize_sdf and iteration > start_sdf_regularization_from: | |
if iteration == start_sdf_regularization_from + 1: | |
CONSOLE.print("\n---INFO---\nStarting SDF regularization.") | |
sampling_mask = visibility_filter | |
if (use_sdf_estimation_loss or enforce_samples_to_be_on_surface) and iteration > start_sdf_estimation_from: | |
if iteration == start_sdf_estimation_from + 1: | |
CONSOLE.print("\n---INFO---\nStarting SDF estimation loss.") | |
fov_camera = nerfmodel.training_cameras.p3d_cameras[camera_indices.item()] | |
if use_projection_as_estimation: | |
pass | |
else: | |
# Render a depth map using gaussian splatting | |
if backpropagate_gradients_through_depth: | |
point_depth = fov_camera.get_world_to_view_transform().transform_points(sugar.points)[..., 2:].expand(-1, 3) | |
max_depth = point_depth.max() | |
depth = sugar.render_image_gaussian_rasterizer( | |
camera_indices=camera_indices.item(), | |
bg_color=max_depth + torch.zeros(3, dtype=torch.float, device=sugar.device), | |
sh_deg=0, | |
compute_color_in_rasterizer=False,#compute_color_in_rasterizer, | |
compute_covariance_in_rasterizer=True, | |
return_2d_radii=False, | |
use_same_scale_in_all_directions=False, | |
point_colors=point_depth, | |
)[..., 0] | |
else: | |
with torch.no_grad(): | |
point_depth = fov_camera.get_world_to_view_transform().transform_points(sugar.points)[..., 2:].expand(-1, 3) | |
max_depth = point_depth.max() | |
depth = sugar.render_image_gaussian_rasterizer( | |
camera_indices=camera_indices.item(), | |
bg_color=max_depth + torch.zeros(3, dtype=torch.float, device=sugar.device), | |
sh_deg=0, | |
compute_color_in_rasterizer=False,#compute_color_in_rasterizer, | |
compute_covariance_in_rasterizer=True, | |
return_2d_radii=False, | |
use_same_scale_in_all_directions=False, | |
point_colors=point_depth, | |
)[..., 0] | |
# If needed, compute which gaussians are close to the surface in the depth map. | |
# Then, we sample points only in these gaussians. | |
# TODO: Compute projections only for gaussians in visibility filter. | |
# TODO: Is the torch.no_grad() a good idea? | |
if sample_only_in_gaussians_close_to_surface: | |
with torch.no_grad(): | |
gaussian_to_camera = torch.nn.functional.normalize(fov_camera.get_camera_center() - sugar.points, dim=-1) | |
gaussian_centers_in_camera_space = fov_camera.get_world_to_view_transform().transform_points(sugar.points) | |
gaussian_centers_z = gaussian_centers_in_camera_space[..., 2] + 0. | |
gaussian_centers_map_z = sugar.get_points_depth_in_depth_map(fov_camera, depth, gaussian_centers_in_camera_space) | |
gaussian_standard_deviations = ( | |
sugar.scaling * quaternion_apply(quaternion_invert(sugar.quaternions), gaussian_to_camera) | |
).norm(dim=-1) | |
gaussians_close_to_surface = (gaussian_centers_map_z - gaussian_centers_z).abs() < close_gaussian_threshold * gaussian_standard_deviations | |
sampling_mask = sampling_mask * gaussians_close_to_surface | |
n_gaussians_in_sampling = sampling_mask.sum() | |
if n_gaussians_in_sampling > 0: | |
sdf_samples, sdf_gaussian_idx = sugar.sample_points_in_gaussians( | |
num_samples=n_samples_for_sdf_regularization, | |
sampling_scale_factor=sdf_sampling_scale_factor, | |
mask=sampling_mask, | |
probabilities_proportional_to_volume=sdf_sampling_proportional_to_volume, | |
) | |
if use_sdf_estimation_loss or use_sdf_better_normal_loss: | |
fields = sugar.get_field_values( | |
sdf_samples, sdf_gaussian_idx, | |
return_sdf=(use_sdf_estimation_loss or enforce_samples_to_be_on_surface) and (sdf_estimation_mode=='sdf') and iteration > start_sdf_estimation_from, | |
density_threshold=density_threshold, density_factor=density_factor, | |
return_sdf_grad=False, sdf_grad_max_value=10., | |
return_closest_gaussian_opacities=use_sdf_better_normal_loss and iteration > start_sdf_better_normal_from, | |
return_beta=(use_sdf_estimation_loss or enforce_samples_to_be_on_surface) and (sdf_estimation_mode=='density') and iteration > start_sdf_estimation_from, | |
) | |
if (use_sdf_estimation_loss or enforce_samples_to_be_on_surface) and iteration > start_sdf_estimation_from: | |
# Compute the depth of the points in the gaussians | |
if use_projection_as_estimation: | |
proj_mask = torch.ones_like(sdf_samples[..., 0], dtype=torch.bool) | |
samples_gaussian_normals = sugar.get_normals(estimate_from_points=False)[sdf_gaussian_idx] | |
sdf_estimation = ((sdf_samples - sugar.points[sdf_gaussian_idx]) * samples_gaussian_normals).sum(dim=-1) # Shape is (n_samples,) | |
else: | |
sdf_samples_in_camera_space = fov_camera.get_world_to_view_transform().transform_points(sdf_samples) | |
sdf_samples_z = sdf_samples_in_camera_space[..., 2] + 0. | |
proj_mask = sdf_samples_z > fov_camera.znear | |
sdf_samples_map_z = sugar.get_points_depth_in_depth_map(fov_camera, depth, sdf_samples_in_camera_space[proj_mask]) | |
sdf_estimation = sdf_samples_map_z - sdf_samples_z[proj_mask] | |
if not sample_only_in_gaussians_close_to_surface: | |
if normalize_by_sdf_std: | |
print("Setting normalize_by_sdf_std to False because sample_only_in_gaussians_close_to_surface is False.") | |
normalize_by_sdf_std = False | |
with torch.no_grad(): | |
if normalize_by_sdf_std: | |
sdf_sample_std = gaussian_standard_deviations[sdf_gaussian_idx][proj_mask] | |
else: | |
sdf_sample_std = sugar.get_cameras_spatial_extent() / 10. | |
if use_sdf_estimation_loss: | |
if sdf_estimation_mode == 'sdf': | |
sdf_values = fields['sdf'][proj_mask] | |
if squared_sdf_estimation_loss: | |
sdf_estimation_loss = ((sdf_values - sdf_estimation.abs()) / sdf_sample_std).pow(2) | |
else: | |
sdf_estimation_loss = (sdf_values - sdf_estimation.abs()).abs() / sdf_sample_std | |
loss = loss + sdf_estimation_factor * sdf_estimation_loss.clamp(max=10.*sugar.get_cameras_spatial_extent()).mean() | |
elif sdf_estimation_mode == 'density': | |
beta = fields['beta'][proj_mask] | |
densities = fields['density'][proj_mask] | |
target_densities = torch.exp(-0.5 * sdf_estimation.pow(2) / beta.pow(2)) | |
if squared_sdf_estimation_loss: | |
sdf_estimation_loss = ((densities - target_densities)).pow(2) | |
else: | |
sdf_estimation_loss = (densities - target_densities).abs() | |
loss = loss + sdf_estimation_factor * sdf_estimation_loss.mean() | |
else: | |
raise ValueError(f"Unknown sdf_estimation_mode: {sdf_estimation_mode}") | |
if enforce_samples_to_be_on_surface: | |
if squared_samples_on_surface_loss: | |
samples_on_surface_loss = (sdf_estimation / sdf_sample_std).pow(2) | |
else: | |
samples_on_surface_loss = sdf_estimation.abs() / sdf_sample_std | |
loss = loss + samples_on_surface_factor * samples_on_surface_loss.clamp(max=10.*sugar.get_cameras_spatial_extent()).mean() | |
if use_sdf_better_normal_loss and (iteration > start_sdf_better_normal_from): | |
if iteration == start_sdf_better_normal_from + 1: | |
CONSOLE.print("\n---INFO---\nStarting SDF better normal loss.") | |
closest_gaussians_idx = sugar.knn_idx[sdf_gaussian_idx] | |
# Compute minimum scaling | |
closest_min_scaling = sugar.scaling.min(dim=-1)[0][closest_gaussians_idx].detach().view(len(sdf_samples), -1) | |
# Compute normals and flip their sign if needed | |
closest_gaussian_normals = sugar.get_normals(estimate_from_points=False)[closest_gaussians_idx] | |
samples_gaussian_normals = sugar.get_normals(estimate_from_points=False)[sdf_gaussian_idx] | |
closest_gaussian_normals = closest_gaussian_normals * torch.sign( | |
(closest_gaussian_normals * samples_gaussian_normals[:, None]).sum(dim=-1, keepdim=True) | |
).detach() | |
# Compute weights for normal regularization, based on the gradient of the sdf | |
closest_gaussian_opacities = fields['closest_gaussian_opacities'].detach() # Shape is (n_samples, n_neighbors) | |
normal_weights = ((sdf_samples[:, None] - sugar.points[closest_gaussians_idx]) * closest_gaussian_normals).sum(dim=-1).abs() # Shape is (n_samples, n_neighbors) | |
if sdf_better_normal_gradient_through_normal_only: | |
normal_weights = normal_weights.detach() | |
normal_weights = closest_gaussian_opacities * normal_weights / closest_min_scaling.clamp(min=1e-6)**2 # Shape is (n_samples, n_neighbors) | |
# The weights should have a sum of 1 because of the eikonal constraint | |
normal_weights_sum = normal_weights.sum(dim=-1).detach() # Shape is (n_samples,) | |
normal_weights = normal_weights / normal_weights_sum.unsqueeze(-1).clamp(min=1e-6) # Shape is (n_samples, n_neighbors) | |
# Compute regularization loss | |
sdf_better_normal_loss = (samples_gaussian_normals - (normal_weights[..., None] * closest_gaussian_normals).sum(dim=-2) | |
).pow(2).sum(dim=-1) # Shape is (n_samples,) | |
loss = loss + sdf_better_normal_factor * sdf_better_normal_loss.mean() | |
else: | |
CONSOLE.log("WARNING: No gaussians available for sampling.") | |
else: | |
loss = 0. | |
# Surface mesh optimization | |
if bind_to_surface_mesh: | |
surface_mesh = sugar.surface_mesh | |
if use_surface_mesh_laplacian_smoothing_loss: | |
loss = loss + surface_mesh_laplacian_smoothing_factor * mesh_laplacian_smoothing( | |
surface_mesh, method=surface_mesh_laplacian_smoothing_method) | |
if use_surface_mesh_normal_consistency_loss: | |
loss = loss + surface_mesh_normal_consistency_factor * mesh_normal_consistency(surface_mesh) | |
# Update parameters | |
loss.backward() | |
# Densification | |
with torch.no_grad(): | |
if (not no_rendering) and (iteration < densify_until_iter): | |
gaussian_densifier.update_densification_stats(viewspace_points, radii, visibility_filter=radii>0) | |
if iteration > densify_from_iter and iteration % densification_interval == 0: | |
size_threshold = gaussian_densifier.max_screen_size if iteration > opacity_reset_interval else None | |
gaussian_densifier.densify_and_prune(densify_grad_threshold, prune_opacity_threshold, | |
cameras_spatial_extent, size_threshold) | |
CONSOLE.print("Gaussians densified and pruned. New number of gaussians:", len(sugar.points)) | |
if regularize and (iteration > regularize_from) and (iteration >= start_reset_neighbors_from): | |
sugar.reset_neighbors() | |
CONSOLE.print("Neighbors reset.") | |
if iteration % opacity_reset_interval == 0: | |
gaussian_densifier.reset_opacity() | |
CONSOLE.print("Opacity reset.") | |
# Optimization step | |
optimizer.step() | |
optimizer.zero_grad(set_to_none = True) | |
# Print loss | |
if iteration==1 or iteration % print_loss_every_n_iterations == 0: | |
CONSOLE.print(f'\n-------------------\nIteration: {iteration}') | |
train_losses.append(loss.detach().item()) | |
CONSOLE.print(f"loss: {loss:>7f} [{iteration:>5d}/{num_iterations:>5d}]", | |
"computed in", (time.time() - t0) / 60., "minutes.") | |
with torch.no_grad(): | |
scales = sugar.scaling.detach() | |
CONSOLE.print("------Stats-----") | |
CONSOLE.print("---Min, Max, Mean, Std") | |
CONSOLE.print("Points:", sugar.points.min().item(), sugar.points.max().item(), sugar.points.mean().item(), sugar.points.std().item(), sep=' ') | |
CONSOLE.print("Scaling factors:", sugar.scaling.min().item(), sugar.scaling.max().item(), sugar.scaling.mean().item(), sugar.scaling.std().item(), sep=' ') | |
CONSOLE.print("Quaternions:", sugar.quaternions.min().item(), sugar.quaternions.max().item(), sugar.quaternions.mean().item(), sugar.quaternions.std().item(), sep=' ') | |
CONSOLE.print("Sh coordinates dc:", sugar._sh_coordinates_dc.min().item(), sugar._sh_coordinates_dc.max().item(), sugar._sh_coordinates_dc.mean().item(), sugar._sh_coordinates_dc.std().item(), sep=' ') | |
CONSOLE.print("Sh coordinates rest:", sugar._sh_coordinates_rest.min().item(), sugar._sh_coordinates_rest.max().item(), sugar._sh_coordinates_rest.mean().item(), sugar._sh_coordinates_rest.std().item(), sep=' ') | |
CONSOLE.print("Opacities:", sugar.strengths.min().item(), sugar.strengths.max().item(), sugar.strengths.mean().item(), sugar.strengths.std().item(), sep=' ') | |
if regularize_sdf and iteration > start_sdf_regularization_from: | |
CONSOLE.print("Number of gaussians used for sampling in SDF regularization:", n_gaussians_in_sampling) | |
t0 = time.time() | |
# Save model | |
if (iteration % save_model_every_n_iterations == 0) or (iteration in save_milestones): | |
CONSOLE.print("Saving model...") | |
model_path = os.path.join(sugar_checkpoint_path, f'{iteration}.pt') | |
sugar.save_model(path=model_path, | |
train_losses=train_losses, | |
epoch=epoch, | |
iteration=iteration, | |
optimizer_state_dict=optimizer.state_dict(), | |
) | |
# if optimize_triangles and iteration >= optimize_triangles_from: | |
# rm.save_model(os.path.join(rc_checkpoint_path, f'rm_{iteration}.pt')) | |
CONSOLE.print("Model saved.") | |
if iteration >= num_iterations: | |
break | |
if do_sh_warmup and (iteration > 0) and (current_sh_levels < sh_levels) and (iteration % sh_warmup_every == 0): | |
current_sh_levels += 1 | |
CONSOLE.print("Increasing number of spherical harmonics levels to", current_sh_levels) | |
if do_resolution_warmup and (iteration > 0) and (current_resolution_factor > 1) and (iteration % resolution_warmup_every == 0): | |
current_resolution_factor /= 2. | |
nerfmodel.downscale_output_resolution(1/2) | |
CONSOLE.print(f'\nCamera resolution scaled to ' | |
f'{nerfmodel.training_cameras.ns_cameras.height[0].item()} x ' | |
f'{nerfmodel.training_cameras.ns_cameras.width[0].item()}' | |
) | |
sugar.adapt_to_cameras(nerfmodel.training_cameras) | |
# TODO: resize GT images | |
epoch += 1 | |
CONSOLE.print(f"Training finished after {num_iterations} iterations with loss={loss.detach().item()}.") | |
CONSOLE.print("Saving final model...") | |
model_path = os.path.join(sugar_checkpoint_path, f'{iteration}.pt') | |
sugar.save_model(path=model_path, | |
train_losses=train_losses, | |
epoch=epoch, | |
iteration=iteration, | |
optimizer_state_dict=optimizer.state_dict(), | |
) | |
CONSOLE.print("Final model saved.") | |
return model_path |