File size: 6,655 Bytes
1ba539f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
from lib.config import cfg
from .nerf_net_utils import *
from .. import embedder


class Renderer:
    def __init__(self, net):
        self.net = net

    def get_sampling_points(self, ray_o, ray_d, near, far):
        # calculate the steps for each ray
        t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near)
        z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals

        if cfg.perturb > 0. and self.net.training:
            # get intervals between samples
            mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
            upper = torch.cat([mids, z_vals[..., -1:]], -1)
            lower = torch.cat([z_vals[..., :1], mids], -1)
            # stratified samples in those intervals
            t_rand = torch.rand(z_vals.shape).to(upper)
            z_vals = lower + (upper - lower) * t_rand

        pts = ray_o[:, :, None] + ray_d[:, :, None] * z_vals[..., None]

        return pts, z_vals

    def pts_to_can_pts(self, pts, batch):
        """transform pts from the world coordinate to the smpl coordinate"""
        Th = batch['Th'][:, None]
        pts = pts - Th
        R = batch['R']
        sh = pts.shape
        pts = torch.matmul(pts.view(sh[0], -1, sh[3]), R)
        pts = pts.view(*sh)
        return pts

    def transform_sampling_points(self, pts, batch):
        if not self.net.training:
            return pts
        center = batch['center'][:, None, None]
        pts = pts - center
        rot = batch['rot']
        pts_ = pts[..., [0, 2]].clone()
        sh = pts_.shape
        pts_ = torch.matmul(pts_.view(sh[0], -1, sh[3]), rot.permute(0, 2, 1))
        pts[..., [0, 2]] = pts_.view(*sh)
        pts = pts + center
        trans = batch['trans'][:, None, None]
        pts = pts + trans
        return pts

    def prepare_sp_input(self, batch):
        # feature, coordinate, shape, batch size
        sp_input = {}

        # coordinate: [N, 4], batch_idx, x, y, z
        sh = batch['tcoord'].shape
        idx = [torch.full([sh[1]], i) for i in range(sh[0])]
        idx = torch.cat(idx).to(batch['tcoord'])
        coord = batch['tcoord'].view(-1, sh[-1])
        sp_input['coord'] = torch.cat([idx[:, None], coord], dim=1)

        out_sh, _ = torch.max(batch['tout_sh'], dim=0)
        sp_input['out_sh'] = out_sh.tolist()
        sp_input['batch_size'] = sh[0]

        sp_input['i'] = batch['i']

        return sp_input

    def get_ptot_grid_coords(self, pts, out_sh, bounds):
        # pts: [batch_size, x, y, z, 3], x, y, z
        min_xyz = bounds[:, 0]
        pts = pts - min_xyz[:, None, None, None]
        pts = pts / torch.tensor(cfg.voxel_size).to(pts)
        # convert the voxel coordinate to [-1, 1]
        out_sh = torch.tensor(out_sh).to(pts)
        pts = pts / out_sh * 2 - 1
        # convert xyz to zyx, since the occupancy is indexed by xyz
        grid_coords = pts[..., [2, 1, 0]]
        return grid_coords

    def get_grid_coords(self, pts, ptot_pts, bounds):
        out_sh = torch.tensor(ptot_pts.shape[1:-1]).to(pts)
        # pts: [batch_size, N, 3], x, y, z
        min_xyz = bounds[:, 0]
        pts = pts - min_xyz[:, None]
        pts = pts / torch.tensor(cfg.ptot_vsize).to(pts)
        # convert the voxel coordinate to [-1, 1]
        pts = pts / out_sh * 2 - 1
        # convert xyz to zyx, since the occupancy is indexed by xyz
        grid_coords = pts[..., [2, 1, 0]]
        return grid_coords

    # def batchify_rays(self, rays_flat, chunk=1024 * 32, net_c=None):
    def batchify_rays(self,
                      sp_input,
                      tgrid_coords,
                      pgrid_coords,
                      viewdir,
                      light_pts,
                      chunk=1024 * 32,
                      net_c=None):
        """Render rays in smaller minibatches to avoid OOM.
        """
        all_ret = []
        for i in range(0, tgrid_coords.shape[1], chunk):
            # ret = self.render_rays(rays_flat[i:i + chunk], net_c)
            ret = self.net(sp_input, tgrid_coords[:, i:i + chunk],
                           pgrid_coords[:, i:i + chunk],
                           viewdir[:, i:i + chunk], light_pts[:, i:i + chunk])
            # for k in ret:
            #     if k not in all_ret:
            #         all_ret[k] = []
            #     all_ret[k].append(ret[k])
            all_ret.append(ret)
        # all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
        all_ret = torch.cat(all_ret, 1)
        return all_ret

    def render(self, batch):
        ray_o = batch['ray_o']
        ray_d = batch['ray_d']
        near = batch['near']
        far = batch['far']
        sh = ray_o.shape

        pts, z_vals = self.get_sampling_points(ray_o, ray_d, near, far)
        # light intensity varies with 3D location
        light_pts = embedder.xyz_embedder(pts)
        ppts = self.pts_to_can_pts(pts, batch)

        ray_d0 = batch['ray_d']
        viewdir = ray_d0 / torch.norm(ray_d0, dim=2, keepdim=True)
        viewdir = embedder.view_embedder(viewdir)
        viewdir = viewdir[:, :, None].repeat(1, 1, pts.size(2), 1).contiguous()

        sp_input = self.prepare_sp_input(batch)

        # reshape to [batch_size, n, 3]
        light_pts = light_pts.view(sh[0], -1, embedder.xyz_dim)
        viewdir = viewdir.view(sh[0], -1, embedder.view_dim)
        ppts = ppts.view(sh[0], -1, 3)

        # create grid coords for sampling feature volume at t pose
        ptot_pts = batch['ptot_pts']
        tgrid_coords = self.get_ptot_grid_coords(ptot_pts, sp_input['out_sh'],
                                                 batch['tbounds'])

        # create grid coords for sampling feature volume at i-th frame
        pgrid_coords = self.get_grid_coords(ppts, ptot_pts, batch['pbounds'])

        if ray_o.size(1) <= 2048:
            raw = self.net(sp_input, tgrid_coords, pgrid_coords, viewdir,
                           light_pts)
        else:
            raw = self.batchify_rays(sp_input, tgrid_coords, pgrid_coords,
                                     viewdir, light_pts, 1024 * 32, None)

        # reshape to [num_rays, num_samples along ray, 4]
        raw = raw.reshape(-1, z_vals.size(2), 4)
        z_vals = z_vals.view(-1, z_vals.size(2))
        ray_d = ray_d.view(-1, 3)
        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
            raw, z_vals, ray_d, cfg.raw_noise_std, cfg.white_bkgd)
        rgb_map = rgb_map.view(*sh[:-1], -1)
        acc_map = acc_map.view(*sh[:-1])
        depth_map = depth_map.view(*sh[:-1])

        ret = {'rgb_map': rgb_map, 'acc_map': acc_map, 'depth_map': depth_map}

        return ret