File size: 7,562 Bytes
f12ab4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from torch_utils import persistence
from training.networks_stylegan2 import Generator as StyleGAN2Backbone
from training.volumetric_rendering.renderer import ImportanceRenderer
from training.volumetric_rendering.ray_sampler import RaySampler
import dnnlib
@persistence.persistent_class
class TriPlaneGenerator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
sr_num_fp16_res = 0,
mapping_kwargs = {}, # Arguments for MappingNetwork.
rendering_kwargs = {},
sr_kwargs = {},
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim=z_dim
self.c_dim=c_dim
self.w_dim=w_dim
self.img_resolution=img_resolution
self.img_channels=img_channels
self.renderer = ImportanceRenderer()
self.ray_sampler = RaySampler()
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': 32})
self.neural_rendering_resolution = 64
self.rendering_kwargs = rendering_kwargs
self._last_planes = None
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
def synthesis(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
intrinsics = c[:, 16:25].view(-1, 3, 3)
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
# Create a batch of rays for volume rendering
ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
# Create triplanes by running StyleGAN backbone
N, M, _ = ray_origins.shape
if use_cached_backbone and self._last_planes is not None:
planes = self._last_planes
else:
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
if cache_backbone:
self._last_planes = planes
# Reshape output into three 32-channel planes
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
# Perform volume rendering
feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
# Reshape into 'raw' neural-rendered image
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run superresolution to get final image
rgb_image = feature_image[:, :3]
sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
# Render a batch of generated images.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
from training.networks_stylegan2 import FullyConnectedLayer
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']),
torch.nn.Softplus(),
FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])
)
def forward(self, sampled_features, ray_directions):
# Aggregate features
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N*M, C)
x = self.net(x)
x = x.view(N, M, -1)
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}
|