Spaces:
Paused
Paused
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import time | |
from torch.utils.cpp_extension import load | |
parent_dir = os.path.dirname(os.path.abspath(__file__)) | |
render_utils_cuda = load( | |
name='render_utils_cuda', | |
sources=[ | |
os.path.join(parent_dir, path) | |
for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']], | |
verbose=True) | |
total_variation_cuda = load( | |
name='total_variation_cuda', | |
sources=[ | |
os.path.join(parent_dir, path) | |
for path in ['cuda/total_variation.cpp', 'cuda/total_variation_kernel.cu']], | |
verbose=True) | |
def create_grid(type, **kwargs): | |
if type == 'DenseGrid': | |
return DenseGrid(**kwargs) | |
elif type == 'TensoRFGrid': | |
return TensoRFGrid(**kwargs) | |
else: | |
raise NotImplementedError | |
''' Dense 3D grid | |
''' | |
class DenseGrid(nn.Module): | |
def __init__(self, channels, world_size, xyz_min, xyz_max, **kwargs): | |
super(DenseGrid, self).__init__() | |
self.channels = channels | |
self.world_size = world_size | |
self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
self.grid = nn.Parameter(torch.zeros([1, channels, *world_size])) | |
print(self.xyz_min, self.xyz_max, self.world_size) | |
def forward(self, xyz): | |
''' | |
xyz: global coordinates to query | |
''' | |
shape = xyz.shape[:-1] | |
xyz = xyz.reshape(1,1,1,-1,3) | |
ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 | |
out = F.grid_sample(self.grid, ind_norm, mode='bilinear', align_corners=True) | |
out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels) | |
if self.channels == 1: | |
out = out.squeeze(-1) | |
return out | |
def scale_volume_grid(self, new_world_size): | |
if self.channels == 0: | |
self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size])) | |
else: | |
self.grid = nn.Parameter( | |
F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True)) | |
def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
'''Add gradients by total variation loss in-place''' | |
total_variation_cuda.total_variation_add_grad( | |
self.grid, self.grid.grad, wx, wy, wz, dense_mode) | |
def get_dense_grid(self): | |
return self.grid | |
def __isub__(self, val): | |
self.grid.data -= val | |
return self | |
def extra_repr(self): | |
return f'channels={self.channels}, world_size={self.world_size.tolist()}' | |
# ''' Utilize autograd for 3D mask generation | |
# ''' | |
# class ConstrainedGrad(torch.autograd.Function): | |
# @staticmethod | |
# def forward(ctx, inp): | |
# if inp.requires_grad: | |
# ctx.save_for_backward(inp) | |
# return inp | |
# @staticmethod | |
# @torch.autograd.function.once_differentiable | |
# def backward(ctx, grad_back): | |
# ''' | |
# grad_back should be [0,1] | |
# ''' | |
# val = ctx.saved_tensors[0] | |
# return grad_back * (1-x), None, None | |
# ''' Dense 3D grid for 3D mask | |
# ''' | |
# class MaskDenseGrid(nn.Module): | |
# def __init__(self, channels, world_size, xyz_min, xyz_max, **kwargs): | |
# super(MaskDenseGrid, self).__init__() | |
# self.channels = channels | |
# self.world_size = world_size | |
# self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
# self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
# self.grid = nn.Parameter(torch.zeros([1, channels, *world_size])) | |
# def forward(self, xyz): | |
# ''' | |
# xyz: global coordinates to query | |
# ''' | |
# shape = xyz.shape[:-1] | |
# xyz = xyz.reshape(1,1,1,-1,3) | |
# ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 | |
# # modify the backward gradients | |
# out = F.grid_sample(ConstrainedGrad.apply(self.grid), ind_norm, mode='bilinear', align_corners=True) | |
# out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels) | |
# if self.channels == 1: | |
# out = out.squeeze(-1) | |
# return out | |
# @torch.no_grad() | |
# def scale_volume_grid(self, new_world_size): | |
# if self.channels == 0: | |
# self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size])) | |
# else: | |
# self.grid = nn.Parameter( | |
# F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True)) | |
# self.world_size = new_world_size | |
# @torch.no_grad() | |
# def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
# '''Add gradients by total variation loss in-place''' | |
# total_variation_cuda.total_variation_add_grad( | |
# self.grid, self.grid.grad, wx, wy, wz, dense_mode) | |
# @torch.no_grad() | |
# def get_dense_grid(self): | |
# return self.grid | |
# @torch.no_grad() | |
# def __isub__(self, val): | |
# self.grid.data -= val | |
# return self | |
# def extra_repr(self): | |
# return f'channels={self.channels}, world_size={self.world_size.tolist()}' | |
''' Vector-Matrix decomposited grid | |
See TensoRF: Tensorial Radiance Fields (https://arxiv.org/abs/2203.09517) | |
''' | |
class TensoRFGrid(nn.Module): | |
def __init__(self, channels, world_size, xyz_min, xyz_max, config): | |
super(TensoRFGrid, self).__init__() | |
self.channels = channels | |
self.world_size = world_size | |
self.config = config | |
self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
X, Y, Z = world_size | |
R = config['n_comp'] | |
Rxy = config.get('n_comp_xy', R) | |
self.xy_plane = nn.Parameter(torch.randn([1, Rxy, X, Y]) * 0.1) | |
self.xz_plane = nn.Parameter(torch.randn([1, R, X, Z]) * 0.1) | |
self.yz_plane = nn.Parameter(torch.randn([1, R, Y, Z]) * 0.1) | |
self.x_vec = nn.Parameter(torch.randn([1, R, X, 1]) * 0.1) | |
self.y_vec = nn.Parameter(torch.randn([1, R, Y, 1]) * 0.1) | |
self.z_vec = nn.Parameter(torch.randn([1, Rxy, Z, 1]) * 0.1) | |
if self.channels > 1: | |
self.f_vec = nn.Parameter(torch.ones([R+R+Rxy, channels])) | |
nn.init.kaiming_uniform_(self.f_vec, a=np.sqrt(5)) | |
def forward(self, xyz): | |
''' | |
xyz: global coordinates to query | |
''' | |
shape = xyz.shape[:-1] | |
xyz = xyz.reshape(1,1,-1,3) | |
ind_norm = (xyz - self.xyz_min) / (self.xyz_max - self.xyz_min) * 2 - 1 | |
ind_norm = torch.cat([ind_norm, torch.zeros_like(ind_norm[...,[0]])], dim=-1) | |
if self.channels > 1: | |
out = compute_tensorf_feat( | |
self.xy_plane, self.xz_plane, self.yz_plane, | |
self.x_vec, self.y_vec, self.z_vec, self.f_vec, ind_norm) | |
out = out.reshape(*shape,self.channels) | |
else: | |
out = compute_tensorf_val( | |
self.xy_plane, self.xz_plane, self.yz_plane, | |
self.x_vec, self.y_vec, self.z_vec, ind_norm) | |
out = out.reshape(*shape) | |
return out | |
def scale_volume_grid(self, new_world_size): | |
if self.channels == 0: | |
return | |
X, Y, Z = new_world_size | |
self.xy_plane = nn.Parameter(F.interpolate(self.xy_plane.data, size=[X,Y], mode='bilinear', align_corners=True)) | |
self.xz_plane = nn.Parameter(F.interpolate(self.xz_plane.data, size=[X,Z], mode='bilinear', align_corners=True)) | |
self.yz_plane = nn.Parameter(F.interpolate(self.yz_plane.data, size=[Y,Z], mode='bilinear', align_corners=True)) | |
self.x_vec = nn.Parameter(F.interpolate(self.x_vec.data, size=[X,1], mode='bilinear', align_corners=True)) | |
self.y_vec = nn.Parameter(F.interpolate(self.y_vec.data, size=[Y,1], mode='bilinear', align_corners=True)) | |
self.z_vec = nn.Parameter(F.interpolate(self.z_vec.data, size=[Z,1], mode='bilinear', align_corners=True)) | |
def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
'''Add gradients by total variation loss in-place''' | |
loss = wx * F.smooth_l1_loss(self.xy_plane[:,:,1:], self.xy_plane[:,:,:-1], reduction='sum') +\ | |
wy * F.smooth_l1_loss(self.xy_plane[:,:,:,1:], self.xy_plane[:,:,:,:-1], reduction='sum') +\ | |
wx * F.smooth_l1_loss(self.xz_plane[:,:,1:], self.xz_plane[:,:,:-1], reduction='sum') +\ | |
wz * F.smooth_l1_loss(self.xz_plane[:,:,:,1:], self.xz_plane[:,:,:,:-1], reduction='sum') +\ | |
wy * F.smooth_l1_loss(self.yz_plane[:,:,1:], self.yz_plane[:,:,:-1], reduction='sum') +\ | |
wz * F.smooth_l1_loss(self.yz_plane[:,:,:,1:], self.yz_plane[:,:,:,:-1], reduction='sum') +\ | |
wx * F.smooth_l1_loss(self.x_vec[:,:,1:], self.x_vec[:,:,:-1], reduction='sum') +\ | |
wy * F.smooth_l1_loss(self.y_vec[:,:,1:], self.y_vec[:,:,:-1], reduction='sum') +\ | |
wz * F.smooth_l1_loss(self.z_vec[:,:,1:], self.z_vec[:,:,:-1], reduction='sum') | |
loss /= 6 | |
loss.backward() | |
def get_dense_grid(self): | |
if self.channels > 1: | |
feat = torch.cat([ | |
torch.einsum('rxy,rz->rxyz', self.xy_plane[0], self.z_vec[0,:,:,0]), | |
torch.einsum('rxz,ry->rxyz', self.xz_plane[0], self.y_vec[0,:,:,0]), | |
torch.einsum('ryz,rx->rxyz', self.yz_plane[0], self.x_vec[0,:,:,0]), | |
]) | |
grid = torch.einsum('rxyz,rc->cxyz', feat, self.f_vec)[None] | |
else: | |
grid = torch.einsum('rxy,rz->xyz', self.xy_plane[0], self.z_vec[0,:,:,0]) + \ | |
torch.einsum('rxz,ry->xyz', self.xz_plane[0], self.y_vec[0,:,:,0]) + \ | |
torch.einsum('ryz,rx->xyz', self.yz_plane[0], self.x_vec[0,:,:,0]) | |
grid = grid[None,None] | |
return grid | |
def extra_repr(self): | |
return f'channels={self.channels}, world_size={self.world_size.tolist()}, n_comp={self.config["n_comp"]}' | |
def compute_tensorf_feat(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, f_vec, ind_norm): | |
# Interp feature (feat shape: [n_pts, n_comp]) | |
xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T | |
# Aggregate components | |
feat = torch.cat([ | |
xy_feat * z_feat, | |
xz_feat * y_feat, | |
yz_feat * x_feat, | |
], dim=-1) | |
feat = torch.mm(feat, f_vec) | |
return feat | |
def compute_tensorf_val(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, ind_norm): | |
# Interp feature (feat shape: [n_pts, n_comp]) | |
xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T | |
# Aggregate components | |
feat = (xy_feat * z_feat).sum(-1) + (xz_feat * y_feat).sum(-1) + (yz_feat * x_feat).sum(-1) | |
return feat | |
''' Mask grid | |
It supports query for the known free space and unknown space. | |
''' | |
class MaskGrid(nn.Module): | |
def __init__(self, path=None, mask_cache_thres=None, mask=None, xyz_min=None, xyz_max=None): | |
super(MaskGrid, self).__init__() | |
if path is not None: | |
st = torch.load(path) | |
self.mask_cache_thres = mask_cache_thres | |
density = F.max_pool3d(st['model_state_dict']['density.grid'], kernel_size=3, padding=1, stride=1) | |
alpha = 1 - torch.exp(-F.softplus(density + st['model_state_dict']['act_shift']) * st['model_kwargs']['voxel_size_ratio']) | |
mask = (alpha >= self.mask_cache_thres).squeeze(0).squeeze(0) | |
xyz_min = torch.Tensor(st['model_kwargs']['xyz_min']) | |
xyz_max = torch.Tensor(st['model_kwargs']['xyz_max']) | |
else: | |
mask = mask.bool() | |
xyz_min = torch.Tensor(xyz_min) | |
xyz_max = torch.Tensor(xyz_max) | |
self.register_buffer('mask', mask) | |
xyz_len = xyz_max - xyz_min | |
self.register_buffer('xyz2ijk_scale', (torch.Tensor(list(mask.shape)) - 1) / xyz_len) | |
self.register_buffer('xyz2ijk_shift', -xyz_min * self.xyz2ijk_scale) | |
def forward(self, xyz): | |
'''Skip know freespace | |
@xyz: [..., 3] the xyz in global coordinate. | |
''' | |
shape = xyz.shape[:-1] | |
xyz = xyz.reshape(-1, 3) | |
mask = render_utils_cuda.maskcache_lookup(self.mask, xyz, self.xyz2ijk_scale, self.xyz2ijk_shift) | |
mask = mask.reshape(shape) | |
return mask | |
def extra_repr(self): | |
return f'mask.shape=list(self.mask.shape)' | |
def get_dense_grid_batch_processing(tensorf: TensoRFGrid): | |
''' | |
Expects the tensorf to be already on device and processes it on device batchwise. | |
Not transferring from cpu to avoid repeated transfers from cpu to device | |
Returns the grid which is also on device | |
''' | |
# we will construct it 3d column wise | |
# result_grid = torch.zeros([1, tensorf.channels, *tensorf.world_size], dtype=tensorf.xy_plane.dtype).cpu() | |
start_time = time.time() | |
# result_grid = torch.stack([torch.zeros([1, *tensorf.world_size], dtype=tensorf.x_vec.dtype).cpu() for _ in range(tensorf.channels)], dim=1) | |
# print(tensorf.channels, tensorf.world_size) | |
# result_grid = torch.zeros([1, tensorf.channels, *tensorf.world_size], dtype=tensorf.x_vec.dtype) | |
# debugging | |
result_grid = torch.zeros([1, 64, *tensorf.world_size], dtype=tensorf.x_vec.dtype) | |
print("Time taken for initializing the grid", time.time() - start_time) | |
# created y batches just in case if needed | |
batch_size_x = 35 | |
batch_size_y = 35 | |
batch_size_z = 35 | |
for start_x in range(0, tensorf.world_size[0], batch_size_x): | |
end_x = start_x + batch_size_x | |
for start_y in range(0, tensorf.world_size[1], batch_size_y): | |
end_y = start_y + batch_size_y | |
for start_z in range(0, tensorf.world_size[2], batch_size_z): | |
end_z = start_z + batch_size_z | |
feat = torch.cat([ | |
torch.einsum('rxy,rz->rxyz', tensorf.xy_plane[0, :, start_x:end_x, start_y:end_y], tensorf.z_vec[0,:,start_z:end_z,0]), | |
torch.einsum('rxz,ry->rxyz', tensorf.xz_plane[0, :, start_x:end_x, start_z:end_z], tensorf.y_vec[0,:,start_y:end_y,0]), | |
torch.einsum('ryz,rx->rxyz', tensorf.yz_plane[0, :, start_y:end_y, start_z:end_z], tensorf.x_vec[0,:,start_x:end_x,0]), | |
]) | |
sub_grid = torch.einsum('rxyz,rc->cxyz', feat, tensorf.f_vec)[None] | |
result_grid[:, :, start_x:end_x, start_y:end_y, start_z:end_z] = sub_grid[:,:64,:,:,:] | |
return result_grid | |
def reconstruct_feature_grid(render_viewpoints_kwargs): | |
model = render_viewpoints_kwargs['model'] | |
f_k0 = model.f_k0.cuda() | |
fg = get_dense_grid_batch_processing(f_k0).cuda() | |
fg_kmeans = fg.clone() | |
fg_kmeans = fg_kmeans.squeeze(0).permute(1, 2, 3, 0) # x, y, z, 64 | |
fg_kmeans = fg_kmeans.reshape(-1, 64) | |
fg_kmeans = fg_kmeans.cpu().contiguous() | |
return torch.nn.functional.pad(fg, [1] * 6), fg_kmeans | |
if __name__ == "__main__": | |
with torch.no_grad(): | |
print("Testing whether the outputted grid is the correct or not.") | |
tensorf = TensoRFGrid(64, torch.tensor([100, 100, 100]), 0, 1, {'n_comp': 64}) | |
tensorf = tensorf.cuda() | |
start_time = time.time() | |
grid1 = tensorf.get_dense_grid().cpu() | |
print("Time taken for full gpu implementation", time.time() - start_time) | |
grid2 = get_dense_grid_batch_processing(tensorf) | |
assert grid1.isclose(grid2, atol=1e-7).all() | |
del grid1, grid2, tensorf | |
torch.cuda.empty_cache() | |
tensorf = TensoRFGrid(64, torch.tensor([320, 320, 320]), 0, 1, {'n_comp': 64}) | |
tensorf = tensorf.cuda() | |
start_time = time.time() | |
grid = get_dense_grid_batch_processing(tensorf) | |
print("Time taken to reconstruct the grid", time.time() - start_time) | |
print("Program over.") |