3DTopia-XL / dva /mvp /models /volumetric.py
FrozenBurning
single view to 3D init release
81ecb2b
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
""" Volumetric autoencoder (image -> encoding -> volume -> image) """
import inspect
import time
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import models.utils
from extensions.utils.utils import compute_raydirs
@torch.jit.script
def compute_raydirs_ref(pixelcoords : torch.Tensor, viewrot : torch.Tensor, focal : torch.Tensor, princpt : torch.Tensor):
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
raydir = F.normalize(raydir, dim=-1)
return raydir
@torch.jit.script
def compute_rmbounds(viewpos : torch.Tensor, raydir : torch.Tensor, volradius : float):
viewpos = viewpos / volradius
# compute raymarching starting points
with torch.no_grad():
t1 = (-1. - viewpos[:, None, None, :]) / raydir
t2 = ( 1. - viewpos[:, None, None, :]) / raydir
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
torch.max(torch.min(t1[..., 1], t2[..., 1]),
torch.min(t1[..., 2], t2[..., 2])))
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
torch.min(torch.max(t1[..., 1], t2[..., 1]),
torch.max(t1[..., 2], t2[..., 2])))
intersections = tmin < tmax
t = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.)
tmin = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.)
tmax = torch.where(intersections, tmax, torch.zeros_like(tmin))
raypos = viewpos[:, None, None, :] + raydir * 0.
tminmax = torch.stack([tmin, tmax], dim=-1)
return raypos, tminmax
class Autoencoder(nn.Module):
def __init__(self, dataset, encoder, decoder, raymarcher, colorcal,
volradius, bgmodel=None, encoderinputs=[], topology=None,
imagemean=0., imagestd=1., vertmask=None, cudaraydirs=True):
super(Autoencoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.raymarcher = raymarcher
self.colorcal = colorcal
self.volradius = volradius
self.bgmodel = bgmodel
self.encoderinputs = encoderinputs
if hasattr(dataset, 'vertmean'):
self.register_buffer("vertmean", torch.from_numpy(dataset.vertmean), persistent=False)
self.vertstd = dataset.vertstd
if hasattr(dataset, 'texmean'):
self.register_buffer("texmean", torch.from_numpy(dataset.texmean), persistent=False)
self.texstd = dataset.texstd
self.imagemean = imagemean
self.imagestd = imagestd
self.cudaraydirs = cudaraydirs
if vertmask is not None:
self.register_buffer("vertmask", torch.from_numpy(vertmask), persistent=False)
self.irgbmsestart = -1
def forward(self,
camrot : torch.Tensor,
campos : torch.Tensor,
focal : torch.Tensor,
princpt : torch.Tensor,
camindex : Optional[torch.Tensor] = None,
pixelcoords : Optional[torch.Tensor]=None,
modelmatrix : Optional[torch.Tensor]=None,
modelmatrixinv : Optional[torch.Tensor]=None,
modelmatrix_next : Optional[torch.Tensor]=None,
modelmatrixinv_next : Optional[torch.Tensor]=None,
validinput : Optional[torch.Tensor]=None,
avgtex : Optional[torch.Tensor]=None,
avgtex_next : Optional[torch.Tensor]=None,
verts : Optional[torch.Tensor]=None,
verts_next : Optional[torch.Tensor]=None,
fixedcamimage : Optional[torch.Tensor]=None,
encoding : Optional[torch.Tensor]=None,
image : Optional[torch.Tensor]=None,
imagemask : Optional[torch.Tensor]=None,
imagevalid : Optional[torch.Tensor]=None,
bg : Optional[torch.Tensor]=None,
renderoptions : dict ={},
trainiter : int=-1,
evaliter : Optional[torch.Tensor]=None,
outputlist : list=[],
losslist : list=[],
**kwargs):
"""
Parameters
----------
camrot : torch.Tensor [B, 3, 3]
Rotation matrix of target view camera
campos : torch.Tensor [B, 3]
Position of target view camera
focal : torch.Tensor [B, 2]
Focal length of target view camera
princpt : torch.Tensor [B, 2]
Princple point of target view camera
camindex : torch.Tensor[int32], optional [B]
Camera index within the list of all cameras
pixelcoords : torch.Tensor, optional [B, H', W', 2]
Pixel coordinates to render of the target view camera
modelmatrix : torch.Tensor, optional [B, 3, 3]
Relative transform from the 'neutral' pose of object
validinput : torch.Tensor, optional [B]
Whether the current batch element is valid (used for missing images)
avgtex : torch.Tensor, optional [B, 3, 1024, 1024]
Texture map averaged from all viewpoints
verts : torch.Tensor, optional [B, 7306, 3]
Mesh vertex positions
fixedcamimage : torch.Tensor, optional [B, 3, 512, 334]
Camera images from a one or more cameras that are always the same
(i.e., unrelated to target)
encoding : torch.Tensor, optional [B, 256]
Direct encodings (overrides encoder)
image : torch.Tensor, optional [B, 3, H, W]
Target image
imagemask : torch.Tensor, optional [B, 1, H, W]
Target image mask for computing reconstruction loss
imagevalid : torch.Tensor, optional [B]
bg : torch.Tensor, optional [B, 3, H, W]
renderoptions : dict
Rendering/raymarching options (e.g., stepsize, whether to output debug images, etc.)
trainiter : int
Training iteration number
outputlist : list
Values to return (e.g., image reconstruction, debug output)
losslist : list
Losses to output (e.g., image reconstruction loss, priors)
Returns
-------
result : dict
Contains outputs specified in outputlist (e.g., image rgb
reconstruction "irgbrec")
losses : dict
Losses to optimize
"""
resultout = {}
resultlosses = {}
aestart = time.time()
# encode/get encoding
# verts [6, 7306, 3]
# avgtex [6, 3, 256, 256]
if encoding is None:
if "enctime" in outputlist:
torch.cuda.synchronize()
encstart = time.time()
encout, enclosses = self.encoder(
*[dict(verts=verts, avgtex=avgtex, fixedcamimage=fixedcamimage)[k] for k in self.encoderinputs],
losslist=losslist)
if "enctime" in outputlist:
torch.cuda.synchronize()
encend = time.time()
resultout["enctime"] = encend - encstart
# encoding [6, 256]
encoding = encout["encoding"]
resultlosses.update(enclosses)
# compute relative viewing position
if modelmatrixinv is not None:
viewrot = torch.bmm(camrot, modelmatrixinv[:, :3, :3])
viewpos = torch.bmm((campos[:, :] - modelmatrixinv[:, :3, 3])[:, None, :], modelmatrixinv[:, :3, :3])[:, 0, :]
else:
viewrot = camrot
viewpos = campos
# decode volumetric representation
if "dectime" in outputlist:
torch.cuda.synchronize()
decstart = time.time()
if isinstance(self.decoder, torch.jit.ScriptModule):
# torchscript requires statically typed dict
renderoptionstyped : Dict[str, str] = {k: str(v) for k, v in renderoptions.items()}
else:
renderoptionstyped = renderoptions
decout, declosses = self.decoder(
encoding,
viewpos,
renderoptions=renderoptionstyped,
trainiter=trainiter,
evaliter=evaliter,
losslist=losslist)
if "dectime" in outputlist:
torch.cuda.synchronize()
decend = time.time()
resultout["dectime"] = decend - decstart
resultlosses.update(declosses)
# compute vertex loss
if "vertmse" in losslist:
weight = validinput[:, None, None].expand_as(verts)
if hasattr(self, "vertmask"):
weight = weight * self.vertmask[None, :, None]
vertsrecstd = (decout["verts"] - self.vertmean) / self.vertstd
vertsqerr = weight * (verts - vertsrecstd) ** 2
vertmse = torch.sum(vertsqerr.view(vertsqerr.size(0), -1), dim=-1)
vertmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)
resultlosses["vertmse"] = (vertmse, vertmse_weight)
# compute texture loss
if "trgbmse" in losslist or "trgbsqerr" in outputlist:
weight = (validinput[:, None, None, None] * texmask[:, None, :, :].float()).expand_as(tex).contiguous()
# re-standardize
texrecstd = (decout["tex"] - self.texmean.to("cuda")) / self.texstd
texstd = (tex - self.texmean.to("cuda")) / self.texstd
texsqerr = weight * (texstd - texrecstd) ** 2
if "trgbsqerr" in outputlist:
resultout["trgbsqerr"] = texsqerr
# texture rgb mean-squared-error
if "trgbmse" in losslist:
texmse = torch.sum(texsqerr.view(texsqerr.size(0), -1), dim=-1)
texmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)
resultlosses["trgbmse"] = (texmse, texmse_weight)
# subsample depth, imagerec, imagerecmask
if image is not None and pixelcoords.size()[1:3] != image.size()[2:4]:
imagesize = torch.tensor(image.size()[3:1:-1], dtype=torch.float32, device=pixelcoords.device)
else:
imagesize = torch.tensor(pixelcoords.size()[2:0:-1], dtype=torch.float32, device=pixelcoords.device)
samplecoords = pixelcoords * 2. / (imagesize[None, None, None, :] - 1.) - 1.
# compute ray directions
if self.cudaraydirs:
raypos, raydir, tminmax = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, self.volradius)
else:
raydir = compute_raydirs_ref(pixelcoords, viewrot, focal, princpt)
raypos, tminmax = compute_rmbounds(viewpos, raydir, self.volradius)
if "dtstd" in renderoptions:
renderoptions["dt"] = renderoptions["dt"] * \
torch.exp(torch.randn(1) * renderoptions.get("dtstd")).item()
if renderoptions.get("unbiastminmax", False):
stepsize = renderoptions["dt"] / self.volradius
tminmax = torch.floor(tminmax / stepsize) * stepsize
if renderoptions.get("tminmaxblocks", False):
bx, by = renderoptions.get("blocksize", (8, 16))
H, W = tminmax.size(1), tminmax.size(2)
tminmax = tminmax.view(tminmax.size(0), H // by, by, W // bx, bx, 2)
tminmax = tminmax.amin(dim=[2, 4], keepdim=True)
tminmax = tminmax.repeat(1, 1, by, 1, bx, 1)
tminmax = tminmax.view(tminmax.size(0), H, W, 2)
# raymarch
if "rmtime" in outputlist:
torch.cuda.synchronize()
rmstart = time.time()
# rayrgba [6, 4, 384, 384]
rayrgba, rmlosses = self.raymarcher(raypos, raydir, tminmax,
decout=decout, renderoptions=renderoptions,
trainiter=trainiter, evaliter=evaliter, losslist=losslist)
resultlosses.update(rmlosses)
if "rmtime" in outputlist:
torch.cuda.synchronize()
rmend = time.time()
resultout["rmtime"] = rmend - rmstart
if isinstance(rayrgba, tuple):
rayrgb, rayalpha = rayrgba
else:
rayrgb, rayalpha = rayrgba[:, :3, :, :].contiguous(), rayrgba[:, 3:4, :, :].contiguous()
# beta distribution prior on final opacity
if "alphapr" in losslist:
alphaprior = torch.mean(
torch.log(0.1 + rayalpha.view(rayalpha.size(0), -1)) +
torch.log(0.1 + 1. - rayalpha.view(rayalpha.size(0), -1)) - -2.20727, dim=-1)
resultlosses["alphapr"] = alphaprior
# color correction
if camindex is not None and not renderoptions.get("nocolcorrect", False):
rayrgb = self.colorcal(rayrgb, camindex)
# background decoder
if self.bgmodel is not None and not renderoptions.get("nobg", False):
if "bgtime" in outputlist:
torch.cuda.synchronize()
bgstart = time.time()
raypos, raydir, tminmax = compute_raydirs(campos, camrot, focal, princpt, pixelcoords, self.volradius)
rayposbeg = raypos + raydir * tminmax[..., 0:1]
rayposend = raypos + raydir * tminmax[..., 1:2]
bg = self.bgmodel(bg, camindex, campos, rayposend, raydir, samplecoords, trainiter=trainiter)
# alpha matting
if bg is not None:
rayrgb = rayrgb + (1. - rayalpha) * bg
if "bg" in outputlist:
resultout["bg"] = bg
if "bgtime" in outputlist:
torch.cuda.synchronize()
bgend = time.time()
resultout["bgtime"] = bgend - bgstart
if "irgbrec" in outputlist:
resultout["irgbrec"] = rayrgb
if "irgbarec" in outputlist:
resultout["irgbarec"] = torch.cat([rayrgb, rayalpha], dim=1)
if "irgbflip" in outputlist:
resultout["irgbflip"] = torch.cat([rayrgb[i:i+1] if i % 4 < 2 else image[i:i+1]
for i in range(image.size(0))], dim=0)
# image rgb loss
if image is not None and trainiter > self.irgbmsestart:
# subsample image
if pixelcoords.size()[1:3] != image.size()[2:4]:
image = F.grid_sample(image, samplecoords, align_corners=True)
if imagemask is not None:
imagemask = F.grid_sample(imagemask, samplecoords, align_corners=True)
# compute reconstruction loss weighting
weight = torch.ones_like(image) * validinput[:, None, None, None]
if imagevalid is not None:
weight = weight * imagevalid[:, None, None, None]
if imagemask is not None:
weight = weight * imagemask
if "irgbsqerr" in outputlist:
irgbsqerr_nonorm = (weight * (image - rayrgb) ** 2).contiguous()
resultout["irgbsqerr"] = torch.sqrt(irgbsqerr_nonorm.mean(dim=1, keepdim=True))
# standardize
rayrgb = (rayrgb - self.imagemean) / self.imagestd
image = (image - self.imagemean) / self.imagestd
irgbsqerr = (weight * (image - rayrgb) ** 2).contiguous()
if "irgbmse" in losslist:
irgbmse = torch.sum(irgbsqerr.view(irgbsqerr.size(0), -1), dim=-1)
irgbmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)
resultlosses["irgbmse"] = (irgbmse, irgbmse_weight)
aeend = time.time()
if "aetime" in outputlist:
resultout["aetime"] = aeend - aestart
return resultout, resultlosses