Spaces:
Configuration error
Configuration error
File size: 6,655 Bytes
1ba539f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import torch
from lib.config import cfg
from .nerf_net_utils import *
from .. import embedder
class Renderer:
def __init__(self, net):
self.net = net
def get_sampling_points(self, ray_o, ray_d, near, far):
# calculate the steps for each ray
t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near)
z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals
if cfg.perturb > 0. and self.net.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape).to(upper)
z_vals = lower + (upper - lower) * t_rand
pts = ray_o[:, :, None] + ray_d[:, :, None] * z_vals[..., None]
return pts, z_vals
def pts_to_can_pts(self, pts, batch):
"""transform pts from the world coordinate to the smpl coordinate"""
Th = batch['Th'][:, None]
pts = pts - Th
R = batch['R']
sh = pts.shape
pts = torch.matmul(pts.view(sh[0], -1, sh[3]), R)
pts = pts.view(*sh)
return pts
def transform_sampling_points(self, pts, batch):
if not self.net.training:
return pts
center = batch['center'][:, None, None]
pts = pts - center
rot = batch['rot']
pts_ = pts[..., [0, 2]].clone()
sh = pts_.shape
pts_ = torch.matmul(pts_.view(sh[0], -1, sh[3]), rot.permute(0, 2, 1))
pts[..., [0, 2]] = pts_.view(*sh)
pts = pts + center
trans = batch['trans'][:, None, None]
pts = pts + trans
return pts
def prepare_sp_input(self, batch):
# feature, coordinate, shape, batch size
sp_input = {}
# coordinate: [N, 4], batch_idx, x, y, z
sh = batch['tcoord'].shape
idx = [torch.full([sh[1]], i) for i in range(sh[0])]
idx = torch.cat(idx).to(batch['tcoord'])
coord = batch['tcoord'].view(-1, sh[-1])
sp_input['coord'] = torch.cat([idx[:, None], coord], dim=1)
out_sh, _ = torch.max(batch['tout_sh'], dim=0)
sp_input['out_sh'] = out_sh.tolist()
sp_input['batch_size'] = sh[0]
sp_input['i'] = batch['i']
return sp_input
def get_ptot_grid_coords(self, pts, out_sh, bounds):
# pts: [batch_size, x, y, z, 3], x, y, z
min_xyz = bounds[:, 0]
pts = pts - min_xyz[:, None, None, None]
pts = pts / torch.tensor(cfg.voxel_size).to(pts)
# convert the voxel coordinate to [-1, 1]
out_sh = torch.tensor(out_sh).to(pts)
pts = pts / out_sh * 2 - 1
# convert xyz to zyx, since the occupancy is indexed by xyz
grid_coords = pts[..., [2, 1, 0]]
return grid_coords
def get_grid_coords(self, pts, ptot_pts, bounds):
out_sh = torch.tensor(ptot_pts.shape[1:-1]).to(pts)
# pts: [batch_size, N, 3], x, y, z
min_xyz = bounds[:, 0]
pts = pts - min_xyz[:, None]
pts = pts / torch.tensor(cfg.ptot_vsize).to(pts)
# convert the voxel coordinate to [-1, 1]
pts = pts / out_sh * 2 - 1
# convert xyz to zyx, since the occupancy is indexed by xyz
grid_coords = pts[..., [2, 1, 0]]
return grid_coords
# def batchify_rays(self, rays_flat, chunk=1024 * 32, net_c=None):
def batchify_rays(self,
sp_input,
tgrid_coords,
pgrid_coords,
viewdir,
light_pts,
chunk=1024 * 32,
net_c=None):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = []
for i in range(0, tgrid_coords.shape[1], chunk):
# ret = self.render_rays(rays_flat[i:i + chunk], net_c)
ret = self.net(sp_input, tgrid_coords[:, i:i + chunk],
pgrid_coords[:, i:i + chunk],
viewdir[:, i:i + chunk], light_pts[:, i:i + chunk])
# for k in ret:
# if k not in all_ret:
# all_ret[k] = []
# all_ret[k].append(ret[k])
all_ret.append(ret)
# all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
all_ret = torch.cat(all_ret, 1)
return all_ret
def render(self, batch):
ray_o = batch['ray_o']
ray_d = batch['ray_d']
near = batch['near']
far = batch['far']
sh = ray_o.shape
pts, z_vals = self.get_sampling_points(ray_o, ray_d, near, far)
# light intensity varies with 3D location
light_pts = embedder.xyz_embedder(pts)
ppts = self.pts_to_can_pts(pts, batch)
ray_d0 = batch['ray_d']
viewdir = ray_d0 / torch.norm(ray_d0, dim=2, keepdim=True)
viewdir = embedder.view_embedder(viewdir)
viewdir = viewdir[:, :, None].repeat(1, 1, pts.size(2), 1).contiguous()
sp_input = self.prepare_sp_input(batch)
# reshape to [batch_size, n, 3]
light_pts = light_pts.view(sh[0], -1, embedder.xyz_dim)
viewdir = viewdir.view(sh[0], -1, embedder.view_dim)
ppts = ppts.view(sh[0], -1, 3)
# create grid coords for sampling feature volume at t pose
ptot_pts = batch['ptot_pts']
tgrid_coords = self.get_ptot_grid_coords(ptot_pts, sp_input['out_sh'],
batch['tbounds'])
# create grid coords for sampling feature volume at i-th frame
pgrid_coords = self.get_grid_coords(ppts, ptot_pts, batch['pbounds'])
if ray_o.size(1) <= 2048:
raw = self.net(sp_input, tgrid_coords, pgrid_coords, viewdir,
light_pts)
else:
raw = self.batchify_rays(sp_input, tgrid_coords, pgrid_coords,
viewdir, light_pts, 1024 * 32, None)
# reshape to [num_rays, num_samples along ray, 4]
raw = raw.reshape(-1, z_vals.size(2), 4)
z_vals = z_vals.view(-1, z_vals.size(2))
ray_d = ray_d.view(-1, 3)
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
raw, z_vals, ray_d, cfg.raw_noise_std, cfg.white_bkgd)
rgb_map = rgb_map.view(*sh[:-1], -1)
acc_map = acc_map.view(*sh[:-1])
depth_map = depth_map.view(*sh[:-1])
ret = {'rgb_map': rgb_map, 'acc_map': acc_map, 'depth_map': depth_map}
return ret
|