import os import time import functools import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch_scatter import segment_coo from . import grid from .dvgo import Raw2Alpha, Alphas2Weights from .dmpigo import create_full_step_id from torch.utils.cpp_extension import load parent_dir = os.path.dirname(os.path.abspath(__file__)) ub360_utils_cuda = load( name='ub360_utils_cuda', sources=[ os.path.join(parent_dir, path) for path in ['cuda/ub360_utils.cpp', 'cuda/ub360_utils_kernel.cu']], verbose=True) #TODO ORIGINAL bg_len=0.2 '''Model''' class DirectContractedVoxGO(nn.Module): def __init__(self, xyz_min, xyz_max, num_voxels=0, num_voxels_base=0, num_objects = 1, alpha_init=None, mask_cache_world_size=None, fast_color_thres=0, bg_len=0.2, contracted_norm='inf', density_type='DenseGrid', k0_type='DenseGrid', density_config={}, k0_config={}, rgbnet_dim=0, rgbnet_depth=3, rgbnet_width=128, viewbase_pe=4, **kwargs): super(DirectContractedVoxGO, self).__init__() # xyz_min/max are the boundary that separates fg and bg scene xyz_min = torch.Tensor(xyz_min) xyz_max = torch.Tensor(xyz_max) assert len(((xyz_max - xyz_min) * 100000).long().unique()), 'scene bbox must be a cube in DirectContractedVoxGO' self.register_buffer('scene_center', (xyz_min + xyz_max) * 0.5) self.register_buffer('scene_radius', (xyz_max - xyz_min) * 0.5) self.register_buffer('xyz_min', torch.Tensor([-1,-1,-1]) - bg_len) self.register_buffer('xyz_max', torch.Tensor([1,1,1]) + bg_len) if isinstance(fast_color_thres, dict): self._fast_color_thres = fast_color_thres self.fast_color_thres = fast_color_thres[0] else: self._fast_color_thres = None self.fast_color_thres = fast_color_thres self.bg_len = bg_len self.contracted_norm = contracted_norm # determine based grid resolution self.num_voxels_base = num_voxels_base self.voxel_size_base = ((self.xyz_max - self.xyz_min).prod() / self.num_voxels_base).pow(1/3) # determine init grid resolution self._set_grid_resolution(num_voxels) # determine the density bias shift self.alpha_init = alpha_init self.register_buffer('act_shift', torch.FloatTensor([np.log(1/(1-alpha_init) - 1)])) print('dcvgo: set density bias shift to', self.act_shift) # init density voxel grid self.density_type = density_type self.density_config = density_config self.density = grid.create_grid( density_type, channels=1, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.density_config) self.mode = 'coarse' self.num_objects = num_objects self.seg_mask_grid = grid.create_grid( density_type, channels=self.num_objects, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.density_config) self.mask_view_counts = torch.zeros_like(self.seg_mask_grid.grid, requires_grad=False) self.dual_seg_mask_grid = grid.create_grid( density_type, channels=self.num_objects, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.density_config) # init color representation self.rgbnet_kwargs = { 'rgbnet_dim': rgbnet_dim, 'rgbnet_depth': rgbnet_depth, 'rgbnet_width': rgbnet_width, 'viewbase_pe': viewbase_pe, } self.k0_type = k0_type self.k0_config = k0_config if rgbnet_dim <= 0: # color voxel grid (coarse stage) self.k0_dim = 3 self.k0 = grid.create_grid( k0_type, channels=self.k0_dim, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.k0_config) self.rgbnet = None else: # feature voxel grid + shallow MLP (fine stage) self.k0_dim = rgbnet_dim self.k0 = grid.create_grid( k0_type, channels=self.k0_dim, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.k0_config) self.register_buffer('viewfreq', torch.FloatTensor([(2**i) for i in range(viewbase_pe)])) dim0 = (3+3*viewbase_pe*2) dim0 += self.k0_dim self.rgbnet = nn.Sequential( nn.Linear(dim0, rgbnet_width), nn.ReLU(inplace=True), *[ nn.Sequential(nn.Linear(rgbnet_width, rgbnet_width), nn.ReLU(inplace=True)) for _ in range(rgbnet_depth-2) ], nn.Linear(rgbnet_width, 3), ) nn.init.constant_(self.rgbnet[-1].bias, 0) print('dcvgo: feature voxel grid', self.k0) print('dcvgo: mlp', self.rgbnet) # Using the coarse geometry if provided (used to determine known free space and unknown space) # Re-implement as occupancy grid (2021/1/31) if mask_cache_world_size is None: mask_cache_world_size = self.world_size mask = torch.ones(list(mask_cache_world_size), dtype=torch.bool) self.mask_cache = grid.MaskGrid( path=None, mask=mask, xyz_min=self.xyz_min, xyz_max=self.xyz_max) def _set_grid_resolution(self, num_voxels): # Determine grid resolution self.num_voxels = num_voxels self.voxel_size = ((self.xyz_max - self.xyz_min).prod() / num_voxels).pow(1/3) self.world_size = ((self.xyz_max - self.xyz_min) / self.voxel_size).long() self.world_len = self.world_size[0].item() self.voxel_size_ratio = self.voxel_size / self.voxel_size_base print('dcvgo: voxel_size ', self.voxel_size) print('dcvgo: world_size ', self.world_size) print('dcvgo: voxel_size_base ', self.voxel_size_base) print('dcvgo: voxel_size_ratio', self.voxel_size_ratio) def get_kwargs(self): return { 'xyz_min': self.xyz_min.cpu().numpy(), 'xyz_max': self.xyz_max.cpu().numpy(), 'num_voxels': self.num_voxels, 'num_voxels_base': self.num_voxels_base, 'alpha_init': self.alpha_init, 'voxel_size_ratio': self.voxel_size_ratio, 'mask_cache_world_size': list(self.mask_cache.mask.shape), 'fast_color_thres': self.fast_color_thres, 'contracted_norm': self.contracted_norm, 'density_type': self.density_type, 'k0_type': self.k0_type, 'density_config': self.density_config, 'k0_config': self.k0_config, **self.rgbnet_kwargs, } @torch.no_grad() def change_num_objects(self, num_obj): self.num_objects = num_obj device = self.seg_mask_grid.grid.device self.seg_mask_grid = grid.create_grid( 'DenseGrid', channels=self.num_objects, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.density_config) self.dual_seg_mask_grid = grid.create_grid( 'DenseGrid', channels=self.num_objects, world_size=self.world_size, xyz_min=self.xyz_min, xyz_max=self.xyz_max, config=self.density_config) self.seg_mask_grid.to(device) self.dual_seg_mask_grid.to(device) print("Reset the seg_mask_grid with num_objects =", num_obj) @torch.no_grad() def segmentation_to_density(self): assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid" mask_grid = torch.zeros_like(self.seg_mask_grid.grid) mask_grid[self.seg_mask_grid.grid > 0] = 1 self.density.grid *= mask_grid self.density.grid[self.density.grid == 0] = -1e7 @torch.no_grad() def segmentation_only(self): assert self.seg_mask_grid.grid.shape[1] == 1 and "multi-object seg label cannot be applied directly to the density grid" pass @torch.no_grad() def change_to_fine_mode(self): self.mode = 'fine' @torch.no_grad() def scale_volume_grid(self, num_voxels): print('dcvgo: scale_volume_grid start') ori_world_size = self.world_size self._set_grid_resolution(num_voxels) print('dcvgo: scale_volume_grid scale world_size from', ori_world_size.tolist(), 'to', self.world_size.tolist()) self.density.scale_volume_grid(self.world_size) self.seg_mask_grid.scale_volume_grid(self.world_size) self.dual_seg_mask_grid.scale_volume_grid(self.world_size) self.k0.scale_volume_grid(self.world_size) if np.prod(self.world_size.tolist()) <= 256**3: self_grid_xyz = torch.stack(torch.meshgrid( torch.linspace(self.xyz_min[0], self.xyz_max[0], self.world_size[0]), torch.linspace(self.xyz_min[1], self.xyz_max[1], self.world_size[1]), torch.linspace(self.xyz_min[2], self.xyz_max[2], self.world_size[2]), ), -1) self_alpha = F.max_pool3d(self.activate_density(self.density.get_dense_grid()), kernel_size=3, padding=1, stride=1)[0,0] self.mask_cache = grid.MaskGrid( path=None, mask=self.mask_cache(self_grid_xyz) & (self_alpha>self.fast_color_thres), xyz_min=self.xyz_min, xyz_max=self.xyz_max) print('dcvgo: scale_volume_grid finish') @torch.no_grad() def update_occupancy_cache(self): ori_p = self.mask_cache.mask.float().mean().item() cache_grid_xyz = torch.stack(torch.meshgrid( torch.linspace(self.xyz_min[0], self.xyz_max[0], self.mask_cache.mask.shape[0]), torch.linspace(self.xyz_min[1], self.xyz_max[1], self.mask_cache.mask.shape[1]), torch.linspace(self.xyz_min[2], self.xyz_max[2], self.mask_cache.mask.shape[2]), ), -1) cache_grid_density = self.density(cache_grid_xyz)[None,None] cache_grid_alpha = self.activate_density(cache_grid_density) cache_grid_alpha = F.max_pool3d(cache_grid_alpha, kernel_size=3, padding=1, stride=1)[0,0] self.mask_cache.mask &= (cache_grid_alpha > self.fast_color_thres) new_p = self.mask_cache.mask.float().mean().item() print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}') def update_occupancy_cache_lt_nviews(self, rays_o_tr, rays_d_tr, imsz, render_kwargs, maskout_lt_nviews): print('dcvgo: update mask_cache lt_nviews start') eps_time = time.time() count = torch.zeros_like(self.density.get_dense_grid()).long() device = count.device for rays_o_, rays_d_ in zip(rays_o_tr.split(imsz), rays_d_tr.split(imsz)): ones = grid.DenseGrid(1, self.world_size, self.xyz_min, self.xyz_max) for rays_o, rays_d in zip(rays_o_.split(8192), rays_d_.split(8192)): ray_pts, inner_mask, t = self.sample_ray( ori_rays_o=rays_o.to(device), ori_rays_d=rays_d.to(device), **render_kwargs) ones(ray_pts).sum().backward() count.data += (ones.grid.grad > 1) ori_p = self.mask_cache.mask.float().mean().item() self.mask_cache.mask &= (count >= maskout_lt_nviews)[0,0] new_p = self.mask_cache.mask.float().mean().item() print(f'dcvgo: update mask_cache {ori_p:.4f} => {new_p:.4f}') eps_time = time.time() - eps_time print(f'dcvgo: update mask_cache lt_nviews finish (eps time:', eps_time, 'sec)') def density_total_variation_add_grad(self, weight, dense_mode): w = weight * self.world_size.max() / 128 self.density.total_variation_add_grad(w, w, w, dense_mode) def k0_total_variation_add_grad(self, weight, dense_mode): w = weight * self.world_size.max() / 128 self.k0.total_variation_add_grad(w, w, w, dense_mode) def activate_density(self, density, interval=None): interval = interval if interval is not None else self.voxel_size_ratio shape = density.shape return Raw2Alpha.apply(density.flatten(), self.act_shift, interval).reshape(shape) def sample_ray(self, ori_rays_o, ori_rays_d, stepsize, is_train=False, **render_kwargs): '''Sample query points on rays. All the output points are sorted from near to far. Input: rays_o, rayd_d: both in [N, 3] indicating ray configurations. stepsize: the number of voxels of each sample step. Output: ray_pts: [M, 3] storing all the sampled points. ray_id: [M] the index of the ray of each point. step_id: [M] the i'th step on a ray of each point. ''' rays_o = (ori_rays_o - self.scene_center) / self.scene_radius rays_d = ori_rays_d / ori_rays_d.norm(dim=-1, keepdim=True) N_inner = int(2 / (2+2*self.bg_len) * self.world_len / stepsize) + 1 N_outer = N_inner b_inner = torch.linspace(0, 2, N_inner+1) b_outer = 2 / torch.linspace(1, 1/128, N_outer+1) t = torch.cat([ (b_inner[1:] + b_inner[:-1]) * 0.5, (b_outer[1:] + b_outer[:-1]) * 0.5, ]) ray_pts = rays_o[:,None,:] + rays_d[:,None,:] * t[None,:,None] if self.contracted_norm == 'inf': norm = ray_pts.abs().amax(dim=-1, keepdim=True) elif self.contracted_norm == 'l2': norm = ray_pts.norm(dim=-1, keepdim=True) else: raise NotImplementedError inner_mask = (norm<=1) ray_pts = torch.where( inner_mask, ray_pts, ray_pts / norm * ((1+self.bg_len) - self.bg_len/norm) ) return ray_pts, inner_mask.squeeze(-1), t @torch.no_grad() def forward(self, rays_o, rays_d, viewdirs, global_step=None, is_train=False, render_fct=0.0, **render_kwargs): '''Volume rendering @rays_o: [N, 3] the starting point of the N shooting rays. @rays_d: [N, 3] the shooting direction of the N rays. @viewdirs: [N, 3] viewing direction to compute positional embedding for MLP. ''' assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format' if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres: print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}') self.fast_color_thres = self._fast_color_thres[global_step] ret_dict = {} N = len(rays_o) # sample points on rays ray_pts, inner_mask, t = self.sample_ray( ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=global_step is not None, **render_kwargs) n_max = len(t) interval = render_kwargs['stepsize'] * self.voxel_size_ratio ray_id, step_id = create_full_step_id(ray_pts.shape[:2]) # cumsum ray_pts to get distance from ray_o to any ray_pt in a ray ray_distance = torch.zeros_like(ray_pts) ray_distance[:, 1:] = torch.abs(ray_pts[:, 1:] - ray_pts[:, :-1]) ray_distance = torch.cumsum(ray_distance, dim=1) # skip oversampled points outside scene bbox mask = inner_mask.clone() dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95 dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1) mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres) ray_pts = ray_pts[mask] ray_distance = ray_distance[mask] inner_mask = inner_mask[mask] t = t[None].repeat(N,1)[mask] ray_id = ray_id[mask.flatten()] step_id = step_id[mask.flatten()] # skip known free space mask = self.mask_cache(ray_pts) ray_pts = ray_pts[mask] ray_distance = ray_distance[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] # print(self.fast_color_thres, "self.fast_color_thres") render_fct = max(render_fct, self.fast_color_thres) # query for alpha w/ post-activation density = self.density(ray_pts) alpha = self.activate_density(density, interval) if render_fct > 0: mask = (alpha > render_fct) ray_pts = ray_pts[mask] ray_distance = ray_distance[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] density = density[mask] alpha = alpha[mask] # compute accumulated transmittance weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N) if render_fct > 0: mask = (weights > render_fct) ray_pts = ray_pts[mask] ray_distance = ray_distance[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] density = density[mask] alpha = alpha[mask] weights = weights[mask] # query for segmentation mask # only optimize the mask volume if self.seg_mask_grid.grid.requires_grad: with torch.enable_grad(): mask_pred = self.seg_mask_grid(ray_pts) if self.mode == 'fine': dual_mask_pred = self.dual_seg_mask_grid(ray_pts) else: mask_pred = self.seg_mask_grid(ray_pts) if self.mode == 'fine': dual_mask_pred = self.dual_seg_mask_grid(ray_pts) # query for color k0 = self.k0(ray_pts) if self.rgbnet is None: # no view-depend effect rgb = torch.sigmoid(k0) else: # view-dependent color emission viewdirs_emb = (viewdirs.unsqueeze(-1) * self.viewfreq).flatten(-2) viewdirs_emb = torch.cat([viewdirs, viewdirs_emb.sin(), viewdirs_emb.cos()], -1) viewdirs_emb = viewdirs_emb.flatten(0,-2)[ray_id] rgb_feat = torch.cat([k0, viewdirs_emb], -1) rgb_logit = self.rgbnet(rgb_feat) rgb = torch.sigmoid(rgb_logit) # Ray marching rgb_marched = segment_coo( src=(weights.unsqueeze(-1) * rgb), index=ray_id, out=torch.zeros([N, 3]), reduce='sum') dual_seg_mask_marched = None if self.num_objects == 1: if self.seg_mask_grid.grid.requires_grad: with torch.enable_grad(): seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1).detach().clone() * mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') if self.mode == 'fine': dual_seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') else: seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') if self.mode == 'fine': dual_seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') else: if self.seg_mask_grid.grid.requires_grad: with torch.enable_grad(): seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1).detach().clone() * mask_pred), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') if self.mode == 'fine': dual_seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1).detach().clone() * dual_mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') else: seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * mask_pred), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') if self.mode == 'fine': dual_seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * dual_mask_pred.unsqueeze(-1)), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') if render_kwargs.get('rand_bkgd', False) and is_train: rgb_marched += (alphainv_last.unsqueeze(-1) * torch.rand_like(rgb_marched)) else: rgb_marched += (alphainv_last.unsqueeze(-1) * render_kwargs['bg']) wsum_mid = segment_coo( src=weights[inner_mask], index=ray_id[inner_mask], out=torch.zeros([N]), reduce='sum') s = 1 - 1/(1+t) # [0, inf] => [0, 1] ray_distance = ray_distance.norm(dim=-1) ret_dict.update({ 'alphainv_last': alphainv_last, 'weights': weights, 'wsum_mid': wsum_mid, 'rgb_marched': rgb_marched, 'raw_density': density, 'raw_alpha': alpha, 'raw_rgb': rgb, 'ray_id': ray_id, 'step_id': step_id, 'n_max': n_max, 't': t, 's': s, 'seg_mask_marched': seg_mask_marched, 'dual_seg_mask_marched': dual_seg_mask_marched, 'ray_distance': ray_distance }) if render_kwargs.get('render_depth', False): with torch.no_grad(): depth = segment_coo( src=(weights * s), index=ray_id, out=torch.zeros([N]), reduce='sum') distance = segment_coo( src=(weights * ray_distance), index=ray_id, out=torch.zeros([N]), reduce='sum') ret_dict.update({'depth': depth}) ret_dict.update({'distance': distance}) return ret_dict @torch.no_grad() def forward_mask(self, rays_o, rays_d, render_fct=0.0,**render_kwargs): '''Volume rendering @rays_o: [N, 3] the starting point of the N shooting rays. @rays_d: [N, 3] the shooting direction of the N rays. ''' assert len(rays_o.shape)==2 and rays_o.shape[-1]==3, 'Only suuport point queries in [N, 3] format' # if isinstance(self._fast_color_thres, dict) and global_step in self._fast_color_thres: # print(f'dcvgo: update fast_color_thres {self.fast_color_thres} => {self._fast_color_thres[global_step]}') # self.fast_color_thres = self._fast_color_thres[global_step] ret_dict = {} N = len(rays_o) # sample points on rays ray_pts, inner_mask, t = self.sample_ray( ori_rays_o=rays_o, ori_rays_d=rays_d, is_train=False, **render_kwargs) n_max = len(t) interval = render_kwargs['stepsize'] * self.voxel_size_ratio ray_id, step_id = create_full_step_id(ray_pts.shape[:2]) # skip oversampled points outside scene bbox mask = inner_mask.clone() dist_thres = (2+2*self.bg_len) / self.world_len * render_kwargs['stepsize'] * 0.95 dist = (ray_pts[:,1:] - ray_pts[:,:-1]).norm(dim=-1) mask[:, 1:] |= ub360_utils_cuda.cumdist_thres(dist, dist_thres) ray_pts = ray_pts[mask] inner_mask = inner_mask[mask] t = t[None].repeat(N,1)[mask] ray_id = ray_id[mask.flatten()] step_id = step_id[mask.flatten()] # skip known free space mask = self.mask_cache(ray_pts) ray_pts = ray_pts[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] render_fct = max(render_fct, self.fast_color_thres) # query for alpha w/ post-activation density = self.density(ray_pts) alpha = self.activate_density(density, interval) if render_fct > 0: mask = (alpha > render_fct) ray_pts = ray_pts[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] density = density[mask] alpha = alpha[mask] # compute accumulated transmittance weights, alphainv_last = Alphas2Weights.apply(alpha, ray_id, N) if render_fct > 0: mask = (weights > render_fct) ray_pts = ray_pts[mask] inner_mask = inner_mask[mask] t = t[mask] ray_id = ray_id[mask] step_id = step_id[mask] density = density[mask] alpha = alpha[mask] weights = weights[mask] # query for segmentation mask # only optimize the mask volume if self.seg_mask_grid.grid.requires_grad: with torch.enable_grad(): mask_pred = self.seg_mask_grid(ray_pts) else: mask_pred = self.seg_mask_grid(ray_pts) if self.seg_mask_grid.grid.requires_grad: with torch.enable_grad(): seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * mask_pred), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') else: seg_mask_marched = segment_coo( src=(weights.unsqueeze(-1) * mask_pred), index=ray_id, out=torch.zeros([N, self.num_objects]), reduce='sum') ret_dict.update({ 'seg_mask_marched': seg_mask_marched, }) return ret_dict class DistortionLoss(torch.autograd.Function): @staticmethod def forward(ctx, w, s, n_max, ray_id): n_rays = ray_id.max()+1 interval = 1/n_max w_prefix, w_total, ws_prefix, ws_total = ub360_utils_cuda.segment_cumsum(w, s, ray_id) loss_uni = (1/3) * interval * w.pow(2) loss_bi = 2 * w * (s * w_prefix - ws_prefix) ctx.save_for_backward(w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id) ctx.interval = interval return (loss_bi.sum() + loss_uni.sum()) / n_rays @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_back): w, s, w_prefix, w_total, ws_prefix, ws_total, ray_id = ctx.saved_tensors interval = ctx.interval grad_uni = (1/3) * interval * 2 * w w_suffix = w_total[ray_id] - (w_prefix + w) ws_suffix = ws_total[ray_id] - (ws_prefix + w*s) grad_bi = 2 * (s * (w_prefix - w_suffix) + (ws_suffix - ws_prefix)) grad = grad_back * (grad_bi + grad_uni) return grad, None, None, None distortion_loss = DistortionLoss.apply