|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import NamedTuple |
|
import torch.nn as nn |
|
import torch |
|
from . import _C |
|
|
|
def cpu_deep_copy_tuple(input_tuple): |
|
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] |
|
return tuple(copied_tensors) |
|
|
|
def rasterize_gaussians( |
|
means3D, |
|
means2D, |
|
sh, |
|
colors_precomp, |
|
opacities, |
|
scales, |
|
rotations, |
|
cov3Ds_precomp, |
|
raster_settings, |
|
): |
|
return _RasterizeGaussians.apply( |
|
means3D, |
|
means2D, |
|
sh, |
|
colors_precomp, |
|
opacities, |
|
scales, |
|
rotations, |
|
cov3Ds_precomp, |
|
raster_settings, |
|
) |
|
|
|
class _RasterizeGaussians(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
means3D, |
|
means2D, |
|
sh, |
|
colors_precomp, |
|
opacities, |
|
scales, |
|
rotations, |
|
cov3Ds_precomp, |
|
raster_settings, |
|
): |
|
|
|
|
|
args = ( |
|
raster_settings.bg, |
|
means3D, |
|
colors_precomp, |
|
opacities, |
|
scales, |
|
rotations, |
|
raster_settings.scale_modifier, |
|
cov3Ds_precomp, |
|
raster_settings.viewmatrix, |
|
raster_settings.projmatrix, |
|
raster_settings.tanfovx, |
|
raster_settings.tanfovy, |
|
raster_settings.image_height, |
|
raster_settings.image_width, |
|
sh, |
|
raster_settings.sh_degree, |
|
raster_settings.campos, |
|
raster_settings.prefiltered, |
|
raster_settings.debug |
|
) |
|
|
|
|
|
if raster_settings.debug: |
|
cpu_args = cpu_deep_copy_tuple(args) |
|
try: |
|
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) |
|
except Exception as ex: |
|
torch.save(cpu_args, "snapshot_fw.dump") |
|
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") |
|
raise ex |
|
else: |
|
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) |
|
|
|
|
|
ctx.raster_settings = raster_settings |
|
ctx.num_rendered = num_rendered |
|
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) |
|
return color, radii |
|
|
|
@staticmethod |
|
def backward(ctx, grad_out_color, _): |
|
|
|
|
|
num_rendered = ctx.num_rendered |
|
raster_settings = ctx.raster_settings |
|
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors |
|
|
|
|
|
args = (raster_settings.bg, |
|
means3D, |
|
radii, |
|
colors_precomp, |
|
scales, |
|
rotations, |
|
raster_settings.scale_modifier, |
|
cov3Ds_precomp, |
|
raster_settings.viewmatrix, |
|
raster_settings.projmatrix, |
|
raster_settings.tanfovx, |
|
raster_settings.tanfovy, |
|
grad_out_color, |
|
sh, |
|
raster_settings.sh_degree, |
|
raster_settings.campos, |
|
geomBuffer, |
|
num_rendered, |
|
binningBuffer, |
|
imgBuffer, |
|
raster_settings.debug) |
|
|
|
|
|
if raster_settings.debug: |
|
cpu_args = cpu_deep_copy_tuple(args) |
|
try: |
|
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) |
|
except Exception as ex: |
|
torch.save(cpu_args, "snapshot_bw.dump") |
|
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") |
|
raise ex |
|
else: |
|
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) |
|
|
|
grads = ( |
|
grad_means3D, |
|
grad_means2D, |
|
grad_sh, |
|
grad_colors_precomp, |
|
grad_opacities, |
|
grad_scales, |
|
grad_rotations, |
|
grad_cov3Ds_precomp, |
|
None, |
|
) |
|
|
|
return grads |
|
|
|
class GaussianRasterizationSettings(NamedTuple): |
|
image_height: int |
|
image_width: int |
|
tanfovx : float |
|
tanfovy : float |
|
bg : torch.Tensor |
|
scale_modifier : float |
|
viewmatrix : torch.Tensor |
|
projmatrix : torch.Tensor |
|
sh_degree : int |
|
campos : torch.Tensor |
|
prefiltered : bool |
|
debug : bool |
|
|
|
class GaussianRasterizer(nn.Module): |
|
def __init__(self, raster_settings): |
|
super().__init__() |
|
self.raster_settings = raster_settings |
|
|
|
def markVisible(self, positions): |
|
|
|
with torch.no_grad(): |
|
raster_settings = self.raster_settings |
|
visible = _C.mark_visible( |
|
positions, |
|
raster_settings.viewmatrix, |
|
raster_settings.projmatrix) |
|
|
|
return visible |
|
|
|
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): |
|
|
|
raster_settings = self.raster_settings |
|
|
|
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): |
|
raise Exception('Please provide excatly one of either SHs or precomputed colors!') |
|
|
|
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): |
|
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') |
|
|
|
if shs is None: |
|
shs = torch.Tensor([]) |
|
if colors_precomp is None: |
|
colors_precomp = torch.Tensor([]) |
|
|
|
if scales is None: |
|
scales = torch.Tensor([]) |
|
if rotations is None: |
|
rotations = torch.Tensor([]) |
|
if cov3D_precomp is None: |
|
cov3D_precomp = torch.Tensor([]) |
|
|
|
|
|
return rasterize_gaussians( |
|
means3D, |
|
means2D, |
|
shs, |
|
colors_precomp, |
|
opacities, |
|
scales, |
|
rotations, |
|
cov3D_precomp, |
|
raster_settings, |
|
) |
|
|
|
|