Stanislaw Szymanowicz commited on
Commit
7c548b3
1 Parent(s): 2613050

Add util files

Browse files
gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/graphdeco-inria/gaussian-splatting/tree/main
2
+ # to take in a predicted dictionary with 3D Gaussian parameters.
3
+
4
+ import math
5
+ import torch
6
+ import numpy as np
7
+
8
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
9
+ from utils.graphics_utils import focal2fov
10
+
11
+ def render_predicted(pc : dict,
12
+ world_view_transform,
13
+ full_proj_transform,
14
+ camera_center,
15
+ bg_color : torch.Tensor,
16
+ cfg,
17
+ scaling_modifier = 1.0,
18
+ override_color = None,
19
+ focals_pixels = None):
20
+ """
21
+ Render the scene as specified by pc dictionary.
22
+
23
+ Background tensor (bg_color) must be on GPU!
24
+ """
25
+
26
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
27
+ screenspace_points = torch.zeros_like(pc["xyz"], dtype=pc["xyz"].dtype, requires_grad=True, device="cuda") + 0
28
+ try:
29
+ screenspace_points.retain_grad()
30
+ except:
31
+ pass
32
+
33
+ if focals_pixels == None:
34
+ tanfovx = math.tan(cfg.data.fov * np.pi / 360)
35
+ tanfovy = math.tan(cfg.data.fov * np.pi / 360)
36
+ else:
37
+ tanfovx = math.tan(0.5 * focal2fov(focals_pixels[0].item(), cfg.data.training_resolution))
38
+ tanfovy = math.tan(0.5 * focal2fov(focals_pixels[1].item(), cfg.data.training_resolution))
39
+
40
+ # Set up rasterization configuration
41
+ raster_settings = GaussianRasterizationSettings(
42
+ image_height=int(cfg.data.training_resolution),
43
+ image_width=int(cfg.data.training_resolution),
44
+ tanfovx=tanfovx,
45
+ tanfovy=tanfovy,
46
+ bg=bg_color,
47
+ scale_modifier=scaling_modifier,
48
+ viewmatrix=world_view_transform,
49
+ projmatrix=full_proj_transform,
50
+ sh_degree=cfg.model.max_sh_degree,
51
+ campos=camera_center,
52
+ prefiltered=False,
53
+ debug=False
54
+ )
55
+
56
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
57
+
58
+ means3D = pc["xyz"]
59
+ means2D = screenspace_points
60
+ opacity = pc["opacity"]
61
+
62
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
63
+ # scaling / rotation by the rasterizer.
64
+ scales = None
65
+ rotations = None
66
+ cov3D_precomp = None
67
+
68
+ scales = pc["scaling"]
69
+ rotations = pc["rotation"]
70
+
71
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
72
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
73
+ shs = None
74
+ colors_precomp = None
75
+ if override_color is None:
76
+ if "features_rest" in pc.keys():
77
+ shs = torch.cat([pc["features_dc"], pc["features_rest"]], dim=1).contiguous()
78
+ else:
79
+ shs = pc["features_dc"]
80
+ else:
81
+ colors_precomp = override_color
82
+
83
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
84
+ rendered_image, radii = rasterizer(
85
+ means3D = means3D,
86
+ means2D = means2D,
87
+ shs = shs,
88
+ colors_precomp = colors_precomp,
89
+ opacities = opacity,
90
+ scales = scales,
91
+ rotations = rotations,
92
+ cov3D_precomp = cov3D_precomp)
93
+
94
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
95
+ # They will be excluded from value updates used in the splitting criteria.
96
+ return {"render": rendered_image,
97
+ "viewspace_points": screenspace_points,
98
+ "visibility_filter" : radii > 0,
99
+ "radii": radii}
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
 
2
  tqdm
3
  hydra-core
4
  omegaconf
@@ -7,4 +8,6 @@ einops
7
  imageio
8
  moviepy
9
  markupsafe==2.0.1
10
- gradio
 
 
 
1
  torch
2
+ torchvision
3
  tqdm
4
  hydra-core
5
  omegaconf
 
8
  imageio
9
  moviepy
10
  markupsafe==2.0.1
11
+ gradio
12
+ rembg
13
+ git+https://github.com/graphdeco-inria/diff-gaussian-rasterization
scene/__init__.py ADDED
File without changes
scene/gaussian_predictor.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import numpy as np
5
+
6
+ from torch.nn.functional import silu
7
+
8
+ from einops import rearrange
9
+
10
+ from utils.general_utils import quaternion_raw_multiply
11
+ from utils.graphics_utils import fov2focal
12
+
13
+ # U-Net implementation from EDM
14
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
15
+ #
16
+ # This work is licensed under a Creative Commons
17
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
18
+ # You should have received a copy of the license along with this
19
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
20
+
21
+ """Model architectures and preconditioning schemes used in the paper
22
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Unified routine for initializing weights and biases.
26
+
27
+ def weight_init(shape, mode, fan_in, fan_out):
28
+ if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
29
+ if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
30
+ if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
31
+ if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
32
+ raise ValueError(f'Invalid init mode "{mode}"')
33
+
34
+ #----------------------------------------------------------------------------
35
+ # Fully-connected layer.
36
+
37
+ class Linear(torch.nn.Module):
38
+ def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
39
+ super().__init__()
40
+ self.in_features = in_features
41
+ self.out_features = out_features
42
+ init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
43
+ self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
44
+ self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
45
+
46
+ def forward(self, x):
47
+ x = x @ self.weight.to(x.dtype).t()
48
+ if self.bias is not None:
49
+ x = x.add_(self.bias.to(x.dtype))
50
+ return x
51
+
52
+ #----------------------------------------------------------------------------
53
+ # Convolutional layer with optional up/downsampling.
54
+
55
+ class Conv2d(torch.nn.Module):
56
+ def __init__(self,
57
+ in_channels, out_channels, kernel, bias=True, up=False, down=False,
58
+ resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0,
59
+ ):
60
+ assert not (up and down)
61
+ super().__init__()
62
+ self.in_channels = in_channels
63
+ self.out_channels = out_channels
64
+ self.up = up
65
+ self.down = down
66
+ self.fused_resample = fused_resample
67
+ init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel)
68
+ self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None
69
+ self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None
70
+ f = torch.as_tensor(resample_filter, dtype=torch.float32)
71
+ f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
72
+ self.register_buffer('resample_filter', f if up or down else None)
73
+
74
+ def forward(self, x, N_views_xa=1):
75
+ w = self.weight.to(x.dtype) if self.weight is not None else None
76
+ b = self.bias.to(x.dtype) if self.bias is not None else None
77
+ f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None
78
+ w_pad = w.shape[-1] // 2 if w is not None else 0
79
+ f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
80
+
81
+ if self.fused_resample and self.up and w is not None:
82
+ x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0))
83
+ x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
84
+ elif self.fused_resample and self.down and w is not None:
85
+ x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad)
86
+ x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2)
87
+ else:
88
+ if self.up:
89
+ x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
90
+ if self.down:
91
+ x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
92
+ if w is not None:
93
+ x = torch.nn.functional.conv2d(x, w, padding=w_pad)
94
+ if b is not None:
95
+ x = x.add_(b.reshape(1, -1, 1, 1))
96
+ return x
97
+
98
+ #----------------------------------------------------------------------------
99
+ # Group normalization.
100
+
101
+ class GroupNorm(torch.nn.Module):
102
+ def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
103
+ super().__init__()
104
+ self.num_groups = min(num_groups, num_channels // min_channels_per_group)
105
+ self.eps = eps
106
+ self.weight = torch.nn.Parameter(torch.ones(num_channels))
107
+ self.bias = torch.nn.Parameter(torch.zeros(num_channels))
108
+
109
+ def forward(self, x, N_views_xa=1):
110
+ x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps)
111
+ return x.to(memory_format=torch.channels_last)
112
+
113
+ #----------------------------------------------------------------------------
114
+ # Attention weight computation, i.e., softmax(Q^T * K).
115
+ # Performs all computation using FP32, but uses the original datatype for
116
+ # inputs/outputs/gradients to conserve memory.
117
+
118
+ class AttentionOp(torch.autograd.Function):
119
+ @staticmethod
120
+ def forward(ctx, q, k):
121
+ w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
122
+ ctx.save_for_backward(q, k, w)
123
+ return w
124
+
125
+ @staticmethod
126
+ def backward(ctx, dw):
127
+ q, k, w = ctx.saved_tensors
128
+ db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
129
+ dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
130
+ dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
131
+ return dq, dk
132
+
133
+ #----------------------------------------------------------------------------
134
+ # Timestep embedding used in the DDPM++ and ADM architectures.
135
+
136
+ class PositionalEmbedding(torch.nn.Module):
137
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
138
+ super().__init__()
139
+ self.num_channels = num_channels
140
+ self.max_positions = max_positions
141
+ self.endpoint = endpoint
142
+
143
+ def forward(self, x):
144
+ b, c = x.shape
145
+ x = rearrange(x, 'b c -> (b c)')
146
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
147
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
148
+ freqs = (1 / self.max_positions) ** freqs
149
+ x = x.ger(freqs.to(x.dtype))
150
+ x = torch.cat([x.cos(), x.sin()], dim=1)
151
+ x = rearrange(x, '(b c) emb_ch -> b (c emb_ch)', b=b)
152
+ return x
153
+
154
+ #----------------------------------------------------------------------------
155
+ # Timestep embedding used in the NCSN++ architecture.
156
+
157
+ class FourierEmbedding(torch.nn.Module):
158
+ def __init__(self, num_channels, scale=16):
159
+ super().__init__()
160
+ self.register_buffer('freqs', torch.randn(num_channels // 2) * scale)
161
+
162
+ def forward(self, x):
163
+ b, c = x.shape
164
+ x = rearrange(x, 'b c -> (b c)')
165
+ x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
166
+ x = torch.cat([x.cos(), x.sin()], dim=1)
167
+ x = rearrange(x, '(b c) emb_ch -> b (c emb_ch)', b=b)
168
+ return x
169
+
170
+ class CrossAttentionBlock(torch.nn.Module):
171
+ def __init__(self, num_channels, num_heads = 1, eps=1e-5):
172
+ super().__init__()
173
+
174
+ self.num_heads = 1
175
+ init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
176
+ init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
177
+
178
+ self.norm = GroupNorm(num_channels=num_channels, eps=eps)
179
+
180
+ self.q_proj = Conv2d(in_channels=num_channels, out_channels=num_channels, kernel=1, **init_attn)
181
+ self.kv_proj = Conv2d(in_channels=num_channels, out_channels=num_channels*2, kernel=1, **init_attn)
182
+
183
+ self.out_proj = Conv2d(in_channels=num_channels, out_channels=num_channels, kernel=3, **init_zero)
184
+
185
+ def forward(self, q, kv):
186
+ q_proj = self.q_proj(self.norm(q)).reshape(q.shape[0] * self.num_heads, q.shape[1] // self.num_heads, -1)
187
+ k_proj, v_proj = self.kv_proj(self.norm(kv)).reshape(kv.shape[0] * self.num_heads,
188
+ kv.shape[1] // self.num_heads, 2, -1).unbind(2)
189
+ w = AttentionOp.apply(q_proj, k_proj)
190
+ a = torch.einsum('nqk,nck->ncq', w, v_proj)
191
+ x = self.out_proj(a.reshape(*q.shape)).add_(q)
192
+
193
+ return x
194
+
195
+ #----------------------------------------------------------------------------
196
+ # Unified U-Net block with optional up/downsampling and self-attention.
197
+ # Represents the union of all features employed by the DDPM++, NCSN++, and
198
+ # ADM architectures.
199
+
200
+ class UNetBlock(torch.nn.Module):
201
+ def __init__(self,
202
+ in_channels, out_channels, emb_channels, up=False, down=False, attention=False,
203
+ num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5,
204
+ resample_filter=[1,1], resample_proj=False, adaptive_scale=True,
205
+ init=dict(), init_zero=dict(init_weight=0), init_attn=None,
206
+ ):
207
+ super().__init__()
208
+ self.in_channels = in_channels
209
+ self.out_channels = out_channels
210
+ if emb_channels is not None:
211
+ self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init)
212
+ self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head
213
+ self.dropout = dropout
214
+ self.skip_scale = skip_scale
215
+ self.adaptive_scale = adaptive_scale
216
+
217
+ self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
218
+ self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init)
219
+ self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
220
+ self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero)
221
+
222
+ self.skip = None
223
+ if out_channels != in_channels or up or down:
224
+ kernel = 1 if resample_proj or out_channels!= in_channels else 0
225
+ self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init)
226
+
227
+ if self.num_heads:
228
+ self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
229
+ self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init))
230
+ self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero)
231
+
232
+ def forward(self, x, emb=None, N_views_xa=1):
233
+ orig = x
234
+ x = self.conv0(silu(self.norm0(x)))
235
+
236
+ if emb is not None:
237
+ params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
238
+ if self.adaptive_scale:
239
+ scale, shift = params.chunk(chunks=2, dim=1)
240
+ x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
241
+ else:
242
+ x = silu(self.norm1(x.add_(params)))
243
+
244
+ x = silu(self.norm1(x))
245
+
246
+ x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
247
+ x = x.add_(self.skip(orig) if self.skip is not None else orig)
248
+ x = x * self.skip_scale
249
+
250
+ if self.num_heads:
251
+ if N_views_xa != 1:
252
+ B, C, H, W = x.shape
253
+ # (B, C, H, W) -> (B/N, N, C, H, W) -> (B/N, N, H, W, C)
254
+ x = x.reshape(B // N_views_xa, N_views_xa, *x.shape[1:]).permute(0, 1, 3, 4, 2)
255
+ # (B/N, N, H, W, C) -> (B/N, N*H, W, C) -> (B/N, C, N*H, W)
256
+ x = x.reshape(B // N_views_xa, N_views_xa * x.shape[2], *x.shape[3:]).permute(0, 3, 1, 2)
257
+ q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
258
+ w = AttentionOp.apply(q, k)
259
+ a = torch.einsum('nqk,nck->ncq', w, v)
260
+ x = self.proj(a.reshape(*x.shape)).add_(x)
261
+ x = x * self.skip_scale
262
+ if N_views_xa != 1:
263
+ # (B/N, C, N*H, W) -> (B/N, N*H, W, C)
264
+ x = x.permute(0, 2, 3, 1)
265
+ # (B/N, N*H, W, C) -> (B/N, N, H, W, C) -> (B/N, N, C, H, W)
266
+ x = x.reshape(B // N_views_xa, N_views_xa, H, W, C).permute(0, 1, 4, 2, 3)
267
+ # (B/N, N, C, H, W) -> # (B, C, H, W)
268
+ x = x.reshape(B, C, H, W)
269
+ return x
270
+
271
+ #----------------------------------------------------------------------------
272
+ # Reimplementation of the DDPM++ and NCSN++ architectures from the paper
273
+ # "Score-Based Generative Modeling through Stochastic Differential
274
+ # Equations". Equivalent to the original implementation by Song et al.,
275
+ # available at https://github.com/yang-song/score_sde_pytorch
276
+ # taken from EDM repository https://github.com/NVlabs/edm/blob/main/training/networks.py#L372
277
+
278
+ class SongUNet(nn.Module):
279
+ def __init__(self,
280
+ img_resolution, # Image resolution at input/output.
281
+ in_channels, # Number of color channels at input.
282
+ out_channels, # Number of color channels at output.
283
+ emb_dim_in = 0, # Input embedding dim.
284
+ augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
285
+
286
+ model_channels = 128, # Base multiplier for the number of channels.
287
+ channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels.
288
+ channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
289
+ num_blocks = 4, # Number of residual blocks per resolution.
290
+ attn_resolutions = [16], # List of resolutions with self-attention.
291
+ dropout = 0.10, # Dropout probability of intermediate activations.
292
+ label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
293
+
294
+ embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
295
+ channel_mult_noise = 0, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
296
+ encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
297
+ decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
298
+ resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
299
+ ):
300
+ assert embedding_type in ['fourier', 'positional']
301
+ assert encoder_type in ['standard', 'skip', 'residual']
302
+ assert decoder_type in ['standard', 'skip']
303
+
304
+ super().__init__()
305
+ self.label_dropout = label_dropout
306
+ self.emb_dim_in = emb_dim_in
307
+ if emb_dim_in > 0:
308
+ emb_channels = model_channels * channel_mult_emb
309
+ else:
310
+ emb_channels = None
311
+ noise_channels = model_channels * channel_mult_noise
312
+ init = dict(init_mode='xavier_uniform')
313
+ init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
314
+ init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
315
+ block_kwargs = dict(
316
+ emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6,
317
+ resample_filter=resample_filter, resample_proj=True, adaptive_scale=False,
318
+ init=init, init_zero=init_zero, init_attn=init_attn,
319
+ )
320
+
321
+ # Mapping.
322
+ # self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None
323
+ # self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None
324
+ # self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
325
+ # self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
326
+ if emb_dim_in > 0:
327
+ self.map_layer0 = Linear(in_features=emb_dim_in, out_features=emb_channels, **init)
328
+ self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
329
+
330
+ if noise_channels > 0:
331
+ self.noise_map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
332
+ self.noise_map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
333
+
334
+ # Encoder.
335
+ self.enc = torch.nn.ModuleDict()
336
+ cout = in_channels
337
+ caux = in_channels
338
+ for level, mult in enumerate(channel_mult):
339
+ res = img_resolution >> level
340
+ if level == 0:
341
+ cin = cout
342
+ cout = model_channels
343
+ self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
344
+ else:
345
+ self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
346
+ if encoder_type == 'skip':
347
+ self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter)
348
+ self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init)
349
+ if encoder_type == 'residual':
350
+ self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init)
351
+ caux = cout
352
+ for idx in range(num_blocks):
353
+ cin = cout
354
+ cout = model_channels * mult
355
+ attn = (res in attn_resolutions)
356
+ self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
357
+ skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name]
358
+
359
+ # Decoder.
360
+ self.dec = torch.nn.ModuleDict()
361
+ for level, mult in reversed(list(enumerate(channel_mult))):
362
+ res = img_resolution >> level
363
+ if level == len(channel_mult) - 1:
364
+ self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
365
+ self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
366
+ else:
367
+ self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
368
+ for idx in range(num_blocks + 1):
369
+ cin = cout + skips.pop()
370
+ cout = model_channels * mult
371
+ attn = (idx == num_blocks and res in attn_resolutions)
372
+ self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
373
+ if decoder_type == 'skip' or level == 0:
374
+ if decoder_type == 'skip' and level < len(channel_mult) - 1:
375
+ self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter)
376
+ self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6)
377
+ self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, init_weight=0.2, **init)# init_zero)
378
+
379
+ def forward(self, x, film_camera_emb=None, N_views_xa=1):
380
+
381
+ emb = None
382
+
383
+ if film_camera_emb is not None:
384
+ if self.emb_dim_in != 1:
385
+ film_camera_emb = film_camera_emb.reshape(
386
+ film_camera_emb.shape[0], 2, -1).flip(1).reshape(*film_camera_emb.shape) # swap sin/cos
387
+ film_camera_emb = silu(self.map_layer0(film_camera_emb))
388
+ film_camera_emb = silu(self.map_layer1(film_camera_emb))
389
+ emb = film_camera_emb
390
+
391
+ # Encoder.
392
+ skips = []
393
+ aux = x
394
+ for name, block in self.enc.items():
395
+ if 'aux_down' in name:
396
+ aux = block(aux, N_views_xa)
397
+ elif 'aux_skip' in name:
398
+ x = skips[-1] = x + block(aux, N_views_xa)
399
+ elif 'aux_residual' in name:
400
+ x = skips[-1] = aux = (x + block(aux, N_views_xa)) / np.sqrt(2)
401
+ else:
402
+ x = block(x, emb=emb, N_views_xa=N_views_xa) if isinstance(block, UNetBlock) \
403
+ else block(x, N_views_xa=N_views_xa)
404
+ skips.append(x)
405
+
406
+ # Decoder.
407
+ aux = None
408
+ tmp = None
409
+ for name, block in self.dec.items():
410
+ if 'aux_up' in name:
411
+ aux = block(aux, N_views_xa)
412
+ elif 'aux_norm' in name:
413
+ tmp = block(x, N_views_xa)
414
+ elif 'aux_conv' in name:
415
+ tmp = block(silu(tmp), N_views_xa)
416
+ aux = tmp if aux is None else tmp + aux
417
+ else:
418
+ if x.shape[1] != block.in_channels:
419
+ # skip connection is pixel-aligned which is good for
420
+ # foreground features
421
+ # but it's not good for gradient flow and background features
422
+ x = torch.cat([x, skips.pop()], dim=1)
423
+ x = block(x, emb=emb, N_views_xa=N_views_xa)
424
+ return aux
425
+
426
+ class SingleImageSongUNetPredictor(nn.Module):
427
+ def __init__(self, cfg, out_channels, bias, scale):
428
+ super(SingleImageSongUNetPredictor, self).__init__()
429
+ self.out_channels = out_channels
430
+ self.cfg = cfg
431
+ if cfg.cam_embd.embedding is None:
432
+ in_channels = 3
433
+ emb_dim_in = 0
434
+ else:
435
+ in_channels = 3
436
+ emb_dim_in = 6 * cfg.cam_embd.dimension
437
+
438
+ self.encoder = SongUNet(cfg.data.training_resolution,
439
+ in_channels,
440
+ sum(out_channels),
441
+ model_channels=cfg.model.base_dim,
442
+ num_blocks=cfg.model.num_blocks,
443
+ emb_dim_in=emb_dim_in,
444
+ channel_mult_noise=0,
445
+ attn_resolutions=cfg.model.attention_resolutions)
446
+ self.out = nn.Conv2d(in_channels=sum(out_channels),
447
+ out_channels=sum(out_channels),
448
+ kernel_size=1)
449
+
450
+ start_channels = 0
451
+ for out_channel, b, s in zip(out_channels, bias, scale):
452
+ nn.init.xavier_uniform_(
453
+ self.out.weight[start_channels:start_channels+out_channel,
454
+ :, :, :], s)
455
+ nn.init.constant_(
456
+ self.out.bias[start_channels:start_channels+out_channel], b)
457
+ start_channels += out_channel
458
+
459
+ def forward(self, x, film_camera_emb=None, N_views_xa=1):
460
+ x = self.encoder(x,
461
+ film_camera_emb=film_camera_emb,
462
+ N_views_xa=N_views_xa)
463
+
464
+ return self.out(x)
465
+
466
+ def networkCallBack(cfg, name, out_channels, **kwargs):
467
+ assert name == "SingleUNet"
468
+ return SingleImageSongUNetPredictor(cfg, out_channels, **kwargs)
469
+
470
+ class GaussianSplatPredictor(nn.Module):
471
+ def __init__(self, cfg):
472
+ super(GaussianSplatPredictor, self).__init__()
473
+ self.cfg = cfg
474
+ assert cfg.model.network_with_offset or cfg.model.network_without_offset, \
475
+ "Need at least one network"
476
+
477
+ if cfg.model.network_with_offset:
478
+ split_dimensions, scale_inits, bias_inits = self.get_splits_and_inits(True, cfg)
479
+ self.network_with_offset = networkCallBack(cfg,
480
+ cfg.model.name,
481
+ split_dimensions,
482
+ scale = scale_inits,
483
+ bias = bias_inits)
484
+ assert not cfg.model.network_without_offset, "Can only have one network"
485
+ if cfg.model.network_without_offset:
486
+ split_dimensions, scale_inits, bias_inits = self.get_splits_and_inits(False, cfg)
487
+ self.network_wo_offset = networkCallBack(cfg,
488
+ cfg.model.name,
489
+ split_dimensions,
490
+ scale = scale_inits,
491
+ bias = bias_inits)
492
+ assert not cfg.model.network_with_offset, "Can only have one network"
493
+
494
+ self.init_ray_dirs()
495
+
496
+ # Activation functions for different parameters
497
+ self.depth_act = nn.Sigmoid()
498
+ self.scaling_activation = torch.exp
499
+ self.opacity_activation = torch.sigmoid
500
+ self.rotation_activation = torch.nn.functional.normalize
501
+
502
+ if self.cfg.model.max_sh_degree > 0:
503
+ self.init_sh_transform_matrices()
504
+
505
+ if self.cfg.cam_embd.embedding is not None:
506
+ if self.cfg.cam_embd.encode_embedding is None:
507
+ self.cam_embedding_map = nn.Identity()
508
+ elif self.cfg.cam_embd.encode_embedding == "positional":
509
+ self.cam_embedding_map = PositionalEmbedding(self.cfg.cam_embd.dimension)
510
+
511
+ def init_sh_transform_matrices(self):
512
+ v_to_sh_transform = torch.tensor([[ 0, 0,-1],
513
+ [-1, 0, 0],
514
+ [ 0, 1, 0]], dtype=torch.float32)
515
+ sh_to_v_transform = v_to_sh_transform.transpose(0, 1)
516
+ self.register_buffer('sh_to_v_transform', sh_to_v_transform.unsqueeze(0))
517
+ self.register_buffer('v_to_sh_transform', v_to_sh_transform.unsqueeze(0))
518
+
519
+ def init_ray_dirs(self):
520
+ x = torch.linspace(-self.cfg.data.training_resolution // 2 + 0.5,
521
+ self.cfg.data.training_resolution // 2 - 0.5,
522
+ self.cfg.data.training_resolution)
523
+ y = torch.linspace( self.cfg.data.training_resolution // 2 - 0.5,
524
+ -self.cfg.data.training_resolution // 2 + 0.5,
525
+ self.cfg.data.training_resolution)
526
+ if self.cfg.model.inverted_x:
527
+ x = -x
528
+ if self.cfg.model.inverted_y:
529
+ y = -y
530
+ grid_x, grid_y = torch.meshgrid(x, y, indexing='xy')
531
+ ones = torch.ones_like(grid_x, dtype=grid_x.dtype)
532
+ ray_dirs = torch.stack([grid_x, grid_y, ones]).unsqueeze(0)
533
+
534
+ # for cars and chairs the focal length is fixed across dataset
535
+ # so we can preprocess it
536
+ # for co3d this is done on the fly
537
+ if self.cfg.data.category not in ["hydrants", "teddybears"]:
538
+ ray_dirs[:, :2, ...] /= fov2focal(self.cfg.data.fov * np.pi / 180,
539
+ self.cfg.data.training_resolution)
540
+ self.register_buffer('ray_dirs', ray_dirs)
541
+
542
+ def get_splits_and_inits(self, with_offset, cfg):
543
+ # Gets channel split dimensions and last layer initialisation
544
+ split_dimensions = []
545
+ scale_inits = []
546
+ bias_inits = []
547
+
548
+ if with_offset:
549
+ split_dimensions = split_dimensions + [1, 3, 1, 3, 4, 3]
550
+ scale_inits = scale_inits + [cfg.model.depth_scale,
551
+ cfg.model.xyz_scale,
552
+ cfg.model.opacity_scale,
553
+ cfg.model.scale_scale,
554
+ 1.0,
555
+ 5.0]
556
+ bias_inits = [cfg.model.depth_bias,
557
+ cfg.model.xyz_bias,
558
+ cfg.model.opacity_bias,
559
+ np.log(cfg.model.scale_bias),
560
+ 0.0,
561
+ 0.0]
562
+ else:
563
+ split_dimensions = split_dimensions + [1, 1, 3, 4, 3]
564
+ scale_inits = scale_inits + [cfg.model.depth_scale,
565
+ cfg.model.opacity_scale,
566
+ cfg.model.scale_scale,
567
+ 1.0,
568
+ 5.0]
569
+ bias_inits = bias_inits + [cfg.model.depth_bias,
570
+ cfg.model.opacity_bias,
571
+ np.log(cfg.model.scale_bias),
572
+ 0.0,
573
+ 0.0]
574
+
575
+ if cfg.model.max_sh_degree != 0:
576
+ sh_num = (self.cfg.model.max_sh_degree + 1) ** 2 - 1
577
+ sh_num_rgb = sh_num * 3
578
+ split_dimensions.append(sh_num_rgb)
579
+ scale_inits.append(0.0)
580
+ bias_inits.append(0.0)
581
+
582
+ if with_offset:
583
+ self.split_dimensions_with_offset = split_dimensions
584
+ else:
585
+ self.split_dimensions_without_offset = split_dimensions
586
+
587
+ return split_dimensions, scale_inits, bias_inits
588
+
589
+ def flatten_vector(self, x):
590
+ # Gets rid of the image dimensions and flattens to a point list
591
+ # B x C x H x W -> B x C x N -> B x N x C
592
+ return x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)
593
+
594
+ def make_contiguous(self, tensor_dict):
595
+ return {k: v.contiguous() for k, v in tensor_dict.items()}
596
+
597
+ def multi_view_union(self, tensor_dict, B, N_view):
598
+ for t_name, t in tensor_dict.items():
599
+ t = t.reshape(B, N_view, *t.shape[1:])
600
+ tensor_dict[t_name] = t.reshape(B, N_view * t.shape[2], *t.shape[3:])
601
+ return tensor_dict
602
+
603
+ def get_camera_embeddings(self, cameras):
604
+ # get embedding
605
+ # pass through encoding
606
+ b, n_view = cameras.shape[:2]
607
+ if self.cfg.cam_embd.embedding == "index":
608
+ cam_embedding = torch.arange(n_view,
609
+ dtype=cameras.dtype,
610
+ device=cameras.device,
611
+ ).unsqueeze(0).expand(b, n_view).unsqueeze(2)
612
+ if self.cfg.cam_embd.embedding == "pose":
613
+ # concatenate origin and z-vector. cameras are in row-major order
614
+ cam_embedding = torch.cat([cameras[:, :, 3, :3], cameras[:, :, 2, :3]], dim=2)
615
+
616
+ cam_embedding = rearrange(cam_embedding, 'b n_view c -> (b n_view) c')
617
+ cam_embedding = self.cam_embedding_map(cam_embedding)
618
+ cam_embedding = rearrange(cam_embedding, '(b n_view) c -> b n_view c', b=b, n_view=n_view)
619
+
620
+ return cam_embedding
621
+
622
+ def transform_SHs(self, shs, source_cameras_to_world):
623
+ # shs: B x N x SH_num x 3
624
+ # source_cameras_to_world: B 4 4
625
+ assert shs.shape[2] == 3, "Can only process shs order 1"
626
+ shs = rearrange(shs, 'b n sh_num rgb -> b (n rgb) sh_num')
627
+ transforms = torch.bmm(
628
+ self.sh_to_v_transform.expand(source_cameras_to_world.shape[0], 3, 3),
629
+ # transpose is because source_cameras_to_world is
630
+ # in row major order
631
+ source_cameras_to_world[:, :3, :3])
632
+ transforms = torch.bmm(transforms,
633
+ self.v_to_sh_transform.expand(source_cameras_to_world.shape[0], 3, 3))
634
+
635
+ shs_transformed = torch.bmm(shs, transforms)
636
+ shs_transformed = rearrange(shs_transformed, 'b (n rgb) sh_num -> b n sh_num rgb', rgb=3)
637
+
638
+ return shs_transformed
639
+
640
+ def transform_rotations(self, rotations, source_cv2wT_quat):
641
+ """
642
+ Applies a transform that rotates the predicted rotations from
643
+ camera space to world space.
644
+ Args:
645
+ rotations: predicted in-camera rotation quaternions (B x N x 4)
646
+ source_cameras_to_world: transformation quaternions from
647
+ camera-to-world matrices transposed(B x 4)
648
+ Retures:
649
+ rotations with appropriately applied transform to world space
650
+ """
651
+
652
+ Mq = source_cv2wT_quat.unsqueeze(1).expand(*rotations.shape)
653
+
654
+ rotations = quaternion_raw_multiply(Mq, rotations)
655
+
656
+ return rotations
657
+
658
+ def get_pos_from_network_output(self, depth_network, offset, focals_pixels, const_offset=None):
659
+
660
+ # expands ray dirs along the batch dimension
661
+ # adjust ray directions according to fov if not done already
662
+ ray_dirs_xy = self.ray_dirs.expand(depth_network.shape[0], 3, *self.ray_dirs.shape[2:])
663
+ if self.cfg.data.category in ["hydrants", "teddybears"]:
664
+ assert torch.all(focals_pixels > 0)
665
+ ray_dirs_xy = ray_dirs_xy.clone()
666
+ ray_dirs_xy[:, :2, ...] = ray_dirs_xy[:, :2, ...] / focals_pixels.unsqueeze(2).unsqueeze(3)
667
+
668
+ # depth and offsets are shaped as (b 3 h w)
669
+ if const_offset is not None:
670
+ depth = self.depth_act(depth_network) * (self.cfg.data.zfar - self.cfg.data.znear) + self.cfg.data.znear + const_offset
671
+ else:
672
+ depth = self.depth_act(depth_network) * (self.cfg.data.zfar - self.cfg.data.znear) + self.cfg.data.znear
673
+
674
+ pos = ray_dirs_xy * depth + offset
675
+
676
+ return pos
677
+
678
+ def forward(self, x,
679
+ source_cameras_view_to_world,
680
+ source_cv2wT_quat=None,
681
+ focals_pixels=None,
682
+ activate_output=True):
683
+
684
+ B = x.shape[0]
685
+ N_views = x.shape[1]
686
+ # UNet attention will reshape outputs so that there is cross-view attention
687
+ if self.cfg.model.cross_view_attention:
688
+ N_views_xa = N_views
689
+ else:
690
+ N_views_xa = 1
691
+
692
+ if self.cfg.cam_embd.embedding is not None:
693
+ cam_embedding = self.get_camera_embeddings(source_cameras_view_to_world)
694
+ assert self.cfg.cam_embd.method == "film"
695
+ film_camera_emb = cam_embedding.reshape(B*N_views, cam_embedding.shape[2])
696
+ else:
697
+ film_camera_emb = None
698
+
699
+ if self.cfg.data.category in ["hydrants", "teddybears"]:
700
+ assert focals_pixels is not None
701
+ focals_pixels = focals_pixels.reshape(B*N_views, *focals_pixels.shape[2:])
702
+ else:
703
+ assert focals_pixels is None, "Unexpected argument for non-co3d dataset"
704
+
705
+ x = x.reshape(B*N_views, *x.shape[2:])
706
+ if self.cfg.data.origin_distances:
707
+ const_offset = x[:, 3:, ...]
708
+ x = x[:, :3, ...]
709
+ else:
710
+ const_offset = None
711
+
712
+ source_cameras_view_to_world = source_cameras_view_to_world.reshape(B*N_views, *source_cameras_view_to_world.shape[2:])
713
+ x = x.contiguous(memory_format=torch.channels_last)
714
+
715
+ if self.cfg.model.network_with_offset:
716
+
717
+ split_network_outputs = self.network_with_offset(x,
718
+ film_camera_emb=film_camera_emb,
719
+ N_views_xa=N_views_xa
720
+ )
721
+
722
+ split_network_outputs = split_network_outputs.split(self.split_dimensions_with_offset, dim=1)
723
+ depth, offset, opacity, scaling, rotation, features_dc = split_network_outputs[:6]
724
+ if self.cfg.model.max_sh_degree > 0:
725
+ features_rest = split_network_outputs[6]
726
+
727
+ pos = self.get_pos_from_network_output(depth, offset, focals_pixels, const_offset=const_offset)
728
+
729
+ else:
730
+ split_network_outputs = self.network_wo_offset(x,
731
+ film_camera_emb=film_camera_emb,
732
+ N_views_xa=N_views_xa
733
+ ).split(self.split_dimensions_without_offset, dim=1)
734
+
735
+ depth, opacity, scaling, rotation, features_dc = split_network_outputs[:5]
736
+ if self.cfg.model.max_sh_degree > 0:
737
+ features_rest = split_network_outputs[5]
738
+
739
+ pos = self.get_pos_from_network_output(depth, 0.0, focals_pixels, const_offset=const_offset)
740
+
741
+ if self.cfg.model.isotropic:
742
+ scaling_out = torch.cat([scaling[:, :1, ...], scaling[:, :1, ...], scaling[:, :1, ...]], dim=1)
743
+ else:
744
+ scaling_out = scaling
745
+
746
+ # Pos prediction is in camera space - compute the positions in the world space
747
+ pos = self.flatten_vector(pos)
748
+ pos = torch.cat([pos,
749
+ torch.ones((pos.shape[0], pos.shape[1], 1), device="cuda", dtype=torch.float32)
750
+ ], dim=2)
751
+ pos = torch.bmm(pos, source_cameras_view_to_world)
752
+ pos = pos[:, :, :3] / (pos[:, :, 3:] + 1e-10)
753
+
754
+ out_dict = {
755
+ "xyz": pos,
756
+ "rotation": self.flatten_vector(self.rotation_activation(rotation)),
757
+ "features_dc": self.flatten_vector(features_dc).unsqueeze(2)
758
+ }
759
+
760
+ if activate_output:
761
+ out_dict["opacity"] = self.flatten_vector(self.opacity_activation(opacity))
762
+ out_dict["scaling"] = self.flatten_vector(self.scaling_activation(scaling_out))
763
+ else:
764
+ out_dict["opacity"] = self.flatten_vector(opacity)
765
+ out_dict["scaling"] = self.flatten_vector(scaling_out)
766
+
767
+ assert source_cv2wT_quat is not None
768
+ source_cv2wT_quat = source_cv2wT_quat.reshape(B*N_views, *source_cv2wT_quat.shape[2:])
769
+ out_dict["rotation"] = self.transform_rotations(out_dict["rotation"],
770
+ source_cv2wT_quat=source_cv2wT_quat)
771
+
772
+ if self.cfg.model.max_sh_degree > 0:
773
+ features_rest = self.flatten_vector(features_rest)
774
+ # Channel dimension holds SH_num * RGB(3) -> renderer expects split across RGB
775
+ # Split channel dimension B x N x C -> B x N x SH_num x 3
776
+ out_dict["features_rest"] = features_rest.reshape(*features_rest.shape[:2], -1, 3)
777
+ assert self.cfg.model.max_sh_degree == 1 # "Only accepting degree 1"
778
+ out_dict["features_rest"] = self.transform_SHs(out_dict["features_rest"],
779
+ source_cameras_view_to_world)
780
+ else:
781
+ out_dict["features_rest"] = torch.zeros((out_dict["features_dc"].shape[0],
782
+ out_dict["features_dc"].shape[1],
783
+ (self.cfg.model.max_sh_degree + 1) ** 2 - 1,
784
+ 3), dtype=out_dict["features_dc"].dtype, device="cuda")
785
+
786
+ out_dict = self.multi_view_union(out_dict, B, N_views)
787
+ out_dict = self.make_contiguous(out_dict)
788
+
789
+ return out_dict
utils/app_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from typing import Any
3
+ import rembg
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ from plyfile import PlyData, PlyElement
7
+ import os
8
+ import torch
9
+ from .camera_utils import get_loop_cameras
10
+ from .graphics_utils import getProjectionMatrix
11
+ from .general_utils import matrix_to_quaternion
12
+
13
+ def remove_background(image, rembg_session):
14
+ do_remove = True
15
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
16
+ do_remove = False
17
+ if do_remove:
18
+ image = rembg.remove(image, session=rembg_session)
19
+ return image
20
+
21
+ def set_white_background(image):
22
+ image = np.array(image).astype(np.float32) / 255.0
23
+ mask = image[:, :, 3:4]
24
+ image = image[:, :, :3] * mask + (1 - mask)
25
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
26
+ return image
27
+
28
+ def resize_foreground(image, ratio):
29
+ image = np.array(image)
30
+ assert image.shape[-1] == 4
31
+ alpha = np.where(image[..., 3] > 0)
32
+ # modify so that cropping doesn't change the world center
33
+ y1, y2, x1, x2 = (
34
+ alpha[0].min(),
35
+ alpha[0].max(),
36
+ alpha[1].min(),
37
+ alpha[1].max(),
38
+ )
39
+
40
+ # crop the foreground
41
+ fg = image[y1: y2,
42
+ x1: x2]
43
+ # pad to square
44
+ size = max(fg.shape[0], fg.shape[1])
45
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
46
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
47
+ new_image = np.pad(
48
+ fg,
49
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
50
+ mode="constant",
51
+ constant_values=((255, 255), (255, 255), (0, 0)),
52
+ )
53
+
54
+ # compute padding according to the ratio
55
+ new_size = int(new_image.shape[0] / ratio)
56
+ # pad to size, double side
57
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
58
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
59
+ new_image = np.pad(
60
+ new_image,
61
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
62
+ mode="constant",
63
+ constant_values=((255, 255), (255, 255), (0, 0)),
64
+ )
65
+
66
+ new_image = Image.fromarray(new_image)
67
+
68
+ return new_image
69
+
70
+ def resize_to_128(img):
71
+ img = transforms.functional.resize(img, 128,
72
+ interpolation=transforms.InterpolationMode.LANCZOS)
73
+ return img
74
+
75
+ def to_tensor(img):
76
+ img = torch.tensor(img).permute(2, 0, 1) / 255.0
77
+ return img
78
+
79
+ def get_source_camera_v2w_rmo_and_quats(num_imgs_in_loop=200):
80
+ source_camera = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop)[0]
81
+ source_camera = torch.from_numpy(source_camera).transpose(0, 1).unsqueeze(0)
82
+
83
+ qs = []
84
+ for c_idx in range(source_camera.shape[0]):
85
+ qs.append(matrix_to_quaternion(source_camera[c_idx, :3, :3].transpose(0, 1)))
86
+
87
+ return source_camera.unsqueeze(0), torch.stack(qs, dim=0).unsqueeze(0)
88
+
89
+ def get_target_cameras(num_imgs_in_loop=200):
90
+ """
91
+ Returns camera parameters for rendering a loop around the object:
92
+ world_to_view_transforms,
93
+ full_proj_transforms,
94
+ camera_centers
95
+ """
96
+
97
+ projection_matrix = getProjectionMatrix(
98
+ znear=0.8, zfar=3.2,
99
+ fovX=49.134342641202636 * 2 * np.pi / 360,
100
+ fovY=49.134342641202636 * 2 * np.pi / 360).transpose(0,1)
101
+
102
+ target_cameras = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop,
103
+ max_elevation=np.pi/4,
104
+ elevation_freq=1.5)
105
+ world_view_transforms = []
106
+ view_world_transforms = []
107
+ camera_centers = []
108
+
109
+ for loop_camera_c2w_cmo in target_cameras:
110
+ view_world_transform = torch.from_numpy(loop_camera_c2w_cmo).transpose(0, 1)
111
+ world_view_transform = torch.from_numpy(loop_camera_c2w_cmo).inverse().transpose(0, 1)
112
+ camera_center = view_world_transform[3, :3].clone()
113
+
114
+ world_view_transforms.append(world_view_transform)
115
+ view_world_transforms.append(view_world_transform)
116
+ camera_centers.append(camera_center)
117
+
118
+ world_view_transforms = torch.stack(world_view_transforms)
119
+ view_world_transforms = torch.stack(view_world_transforms)
120
+ camera_centers = torch.stack(camera_centers)
121
+
122
+ full_proj_transforms = world_view_transforms.bmm(projection_matrix.unsqueeze(0).expand(
123
+ world_view_transforms.shape[0], 4, 4))
124
+
125
+ return world_view_transforms, full_proj_transforms, camera_centers
126
+
127
+ def construct_list_of_attributes():
128
+ # taken from gaussian splatting repo.
129
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
130
+ # All channels except the 3 DC
131
+ # 3 channels for DC
132
+ for i in range(3):
133
+ l.append('f_dc_{}'.format(i))
134
+ # 9 channels for SH order 1
135
+ for i in range(9):
136
+ l.append('f_rest_{}'.format(i))
137
+ l.append('opacity')
138
+ for i in range(3):
139
+ l.append('scale_{}'.format(i))
140
+ for i in range(4):
141
+ l.append('rot_{}'.format(i))
142
+ return l
143
+
144
+ def export_to_obj(reconstruction, ply_out_path):
145
+ """
146
+ Args:
147
+ reconstruction: dict with xyz, opacity, features dc, etc with leading batch size
148
+ ply_out_path: file path where to save the output
149
+ """
150
+ os.makedirs(os.path.dirname(ply_out_path), exist_ok=True)
151
+
152
+ for k, v in reconstruction.items():
153
+ # check dimensions
154
+ if k not in ["features_dc", "features_rest"]:
155
+ assert len(v.shape) == 3, "Unexpected size for {}".format(k)
156
+ else:
157
+ assert len(v.shape) == 4, "Unexpected size for {}".format(k)
158
+ assert v.shape[0] == 1, "Expected batch size to be 0"
159
+ reconstruction[k] = v[0]
160
+
161
+ xyz = reconstruction["xyz"].detach().cpu().numpy()
162
+ normals = np.zeros_like(xyz)
163
+ f_dc = reconstruction["features_dc"].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
164
+ f_rest = reconstruction["features_rest"].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
165
+ opacities = reconstruction["opacity"].detach().cpu().numpy()
166
+ scale = reconstruction["scaling"].detach().cpu().numpy()
167
+ rotation = reconstruction["rotation"].detach().cpu().numpy()
168
+
169
+ dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()]
170
+
171
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
172
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
173
+ elements[:] = list(map(tuple, attributes))
174
+ el = PlyElement.describe(elements, 'vertex')
175
+ PlyData([el]).write(ply_out_path)
utils/camera_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def get_loop_cameras(num_imgs_in_loop, radius=2.0,
4
+ max_elevation=np.pi/6, elevation_freq=0.5,
5
+ azimuth_freq=2.0):
6
+
7
+ all_cameras_c2w_cmo = []
8
+
9
+ for i in range(num_imgs_in_loop):
10
+ azimuth_angle = np.pi * 2 * azimuth_freq * i / num_imgs_in_loop
11
+ elevation_angle = max_elevation * np.sin(
12
+ np.pi * i * 2 * elevation_freq / num_imgs_in_loop)
13
+ x = np.cos(azimuth_angle) * radius * np.cos(elevation_angle)
14
+ y = np.sin(azimuth_angle) * radius * np.cos(elevation_angle)
15
+ z = np.sin(elevation_angle) * radius
16
+
17
+ camera_T_c2w = np.array([x, y, z], dtype=np.float32)
18
+
19
+ # in COLMAP / OpenCV convention: z away from camera, y down, x right
20
+ camera_z = - camera_T_c2w / radius
21
+ up = np.array([0, 0, -1], dtype=np.float32)
22
+ camera_x = np.cross(up, camera_z)
23
+ camera_x = camera_x / np.linalg.norm(camera_x)
24
+ camera_y = np.cross(camera_z, camera_x)
25
+
26
+ camera_c2w_cmo = np.hstack([camera_x[:, None],
27
+ camera_y[:, None],
28
+ camera_z[:, None],
29
+ camera_T_c2w[:, None]])
30
+ camera_c2w_cmo = np.vstack([camera_c2w_cmo, np.array([0, 0, 0, 1], dtype=np.float32)[None, :]])
31
+
32
+ all_cameras_c2w_cmo.append(camera_c2w_cmo)
33
+
34
+ return all_cameras_c2w_cmo
utils/general_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
4
+ """
5
+ From Pytorch3d
6
+ Multiply two quaternions.
7
+ Usual torch rules for broadcasting apply.
8
+
9
+ Args:
10
+ a: Quaternions as tensor of shape (..., 4), real part first.
11
+ b: Quaternions as tensor of shape (..., 4), real part first.
12
+
13
+ Returns:
14
+ The product of a and b, a tensor of quaternions shape (..., 4).
15
+ """
16
+ aw, ax, ay, az = torch.unbind(a, -1)
17
+ bw, bx, by, bz = torch.unbind(b, -1)
18
+ ow = aw * bw - ax * bx - ay * by - az * bz
19
+ ox = aw * bx + ax * bw + ay * bz - az * by
20
+ oy = aw * by - ax * bz + ay * bw + az * bx
21
+ oz = aw * bz + ax * by - ay * bx + az * bw
22
+ return torch.stack((ow, ox, oy, oz), -1)
23
+
24
+ # Written by Stan Szymanowicz 2023
25
+ def matrix_to_quaternion(M: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Matrix-to-quaternion conversion method. Equation taken from
28
+ https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/index.htm
29
+ Args:
30
+ M: rotation matrices, (3 x 3)
31
+ Returns:
32
+ q: quaternion of shape (4)
33
+ """
34
+ tr = 1 + M[ 0, 0] + M[ 1, 1] + M[ 2, 2]
35
+
36
+ if tr > 0:
37
+ r = torch.sqrt(tr) / 2.0
38
+ x = ( M[ 2, 1] - M[ 1, 2] ) / ( 4 * r )
39
+ y = ( M[ 0, 2] - M[ 2, 0] ) / ( 4 * r )
40
+ z = ( M[ 1, 0] - M[ 0, 1] ) / ( 4 * r )
41
+ elif ( M[ 0, 0] > M[ 1, 1]) and (M[ 0, 0] > M[ 2, 2]):
42
+ S = torch.sqrt(1.0 + M[ 0, 0] - M[ 1, 1] - M[ 2, 2]) * 2 # S=4*qx
43
+ r = (M[ 2, 1] - M[ 1, 2]) / S
44
+ x = 0.25 * S
45
+ y = (M[ 0, 1] + M[ 1, 0]) / S
46
+ z = (M[ 0, 2] + M[ 2, 0]) / S
47
+ elif M[ 1, 1] > M[ 2, 2]:
48
+ S = torch.sqrt(1.0 + M[ 1, 1] - M[ 0, 0] - M[ 2, 2]) * 2 # S=4*qy
49
+ r = (M[ 0, 2] - M[ 2, 0]) / S
50
+ x = (M[ 0, 1] + M[ 1, 0]) / S
51
+ y = 0.25 * S
52
+ z = (M[ 1, 2] + M[ 2, 1]) / S
53
+ else:
54
+ S = torch.sqrt(1.0 + M[ 2, 2] - M[ 0, 0] - M[ 1, 1]) * 2 # S=4*qz
55
+ r = (M[ 1, 0] - M[ 0, 1]) / S
56
+ x = (M[ 0, 2] + M[ 2, 0]) / S
57
+ y = (M[ 1, 2] + M[ 2, 1]) / S
58
+ z = 0.25 * S
59
+
60
+ return torch.stack([r, x, y, z], dim=-1)
utils/graphics_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
5
+ tanHalfFovY = math.tan((fovY / 2))
6
+ tanHalfFovX = math.tan((fovX / 2))
7
+
8
+ top = tanHalfFovY * znear
9
+ bottom = -top
10
+ right = tanHalfFovX * znear
11
+ left = -right
12
+
13
+ P = torch.zeros(4, 4)
14
+
15
+ z_sign = 1.0
16
+
17
+ P[0, 0] = 2.0 * znear / (right - left)
18
+ P[1, 1] = 2.0 * znear / (top - bottom)
19
+ P[0, 2] = (right + left) / (right - left)
20
+ P[1, 2] = (top + bottom) / (top - bottom)
21
+ P[3, 2] = z_sign
22
+ P[2, 2] = z_sign * zfar / (zfar - znear)
23
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
24
+ return P
25
+
26
+ def fov2focal(fov, pixels):
27
+ return pixels / (2 * math.tan(fov / 2))
28
+
29
+ def focal2fov(focal, pixels):
30
+ return 2*math.atan(pixels/(2*focal))