Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import os | |
import copy | |
import cv2 | |
import hashlib | |
import tqdm | |
import scipy | |
import uuid | |
import random | |
from PIL import Image | |
from utils.commons.hparams import hparams | |
from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda | |
from utils.nn.model_utils import not_requires_grad, num_params | |
from utils.nn.schedulers import NoneSchedule | |
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint, restore_weights, restore_opt_state | |
from utils.commons.base_task import BaseTask | |
from tasks.os_avatar.loss_utils.vgg19_loss import VGG19Loss | |
from modules.real3d.facev2v_warp.losses import PerceptualLoss | |
from tasks.os_avatar.dataset_utils.motion2video_dataset import Img2Plane_Dataset | |
import lpips | |
from utils.commons.dataset_utils import data_loader | |
from modules.real3d.img2plane_baseline import OSAvatar_Img2plane | |
from modules.eg3ds.models.triplane import TriPlaneGenerator | |
from modules.eg3ds.models.dual_discriminator import DualDiscriminator | |
from modules.eg3ds.torch_utils.ops import conv2d_gradfix | |
from modules.eg3ds.torch_utils.ops import upfirdn2d | |
from modules.eg3ds.models.dual_discriminator import filtered_resizing | |
class ScheduleForImg2Plane(NoneSchedule): | |
def __init__(self, optimizer, lr, lr_d, warmup_updates=0): | |
self.optimizer = optimizer | |
self.constant_lr = self.lr = lr | |
self.lr_d = lr_d | |
self.warmup_updates = warmup_updates | |
self.step(0) | |
def step(self, num_updates): | |
constant_lr = self.constant_lr | |
if self.warmup_updates > 0 and num_updates <= self.warmup_updates: | |
warmup = min(num_updates / self.warmup_updates, 1.0) | |
self.lr = max(constant_lr * warmup, 1e-7) | |
else: | |
self.lr = constant_lr | |
for optim_i in range(len(self.optimizer)-1): | |
self.optimizer[optim_i].param_groups[0]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) # secc_img2plane | |
self.optimizer[optim_i].param_groups[1]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) if num_updates >= min(2_000, hparams['start_adv_iters']) else 0 # decoder | |
# fix住来自预训练EG3D的超分, 给img2plane 30000步的warmup时间 | |
self.optimizer[optim_i].param_groups[2]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) if num_updates >= hparams['start_adv_iters'] else 0 # SR module | |
self.optimizer[-1].param_groups[0]['lr'] = self.lr_d * 1 # for disc | |
return self.lr | |
class OSAvatarImg2PlaneTask(BaseTask): | |
def __init__(self): | |
super().__init__() | |
if hparams['lpips_mode'] == 'vgg19': | |
self.criterion_lpips = VGG19Loss() | |
elif hparams['lpips_mode'] in ['vgg16', 'vgg']: | |
hparams['lpips_mode'] = 'vgg' | |
self.criterion_lpips = lpips.LPIPS(net=hparams['lpips_mode'],lpips=True) | |
elif hparams['lpips_mode'] == 'vgg19_v2': | |
self.criterion_lpips = PerceptualLoss() | |
else: | |
raise NotImplementedError | |
self.gen_tmp_output = {} | |
if hparams.get("use_kv_dataset", True): | |
self.dataset_cls = Img2Plane_Dataset | |
else: | |
raise NotImplementedError() | |
self.start_adv_iters = hparams["start_adv_iters"] | |
self.resample_filter = upfirdn2d.setup_filter([1,3,3,1]) | |
def train_dataloader(self): | |
train_dataset = self.dataset_cls(prefix='train') | |
self.train_dl = train_dataset.get_dataloader() | |
return self.train_dl | |
def val_dataloader(self): | |
val_dataset = self.dataset_cls(prefix='val') | |
self.val_dl = val_dataset.get_dataloader() | |
return self.val_dl | |
def test_dataloader(self): | |
val_dataset = self.dataset_cls(prefix='val') | |
self.val_dl = val_dataset.get_dataloader() | |
return self.val_dl | |
def build_model(self): | |
self.eg3d_model = TriPlaneGenerator() | |
load_ckpt(self.eg3d_model, hparams['pretrained_eg3d_ckpt'], strict=True) | |
self.model = OSAvatar_Img2plane() | |
load_ckpt(self.model, hparams['pretrained_eg3d_ckpt'], strict=False, verbose=False) | |
self.disc = DualDiscriminator() | |
load_ckpt(self.disc, hparams['pretrained_eg3d_ckpt'], strict=False, model_name='disc') | |
if hparams.get('init_from_ckpt', '') != '': | |
# load_ckpt(self.model.img2plane_backbone, ckpt_dir, model_name='model.img2plane_backbone', strict=True) | |
load_ckpt(self.model, hparams['init_from_ckpt'], model_name='model', strict=False) | |
# load_ckpt(self.model, hparams['init_from_ckpt'], model_name='model', strict=True) | |
load_ckpt(self.disc, hparams['init_from_ckpt'], model_name='disc', strict=True) | |
# restore_weights(self, get_last_checkpoint(hparams.get('init_from_ckpt', ''))[0]) | |
print(f"restored weights from {hparams.get('init_from_ckpt', '')}") | |
# define groups of params, used to assign optimizer. | |
self.img2plane_backbone_params = [p for k, p in self.model.img2plane_backbone.named_parameters() if p.requires_grad] | |
self.decoder_params = [p for p in self.model.decoder.parameters() if p.requires_grad] | |
self.upsample_params = [p for p in self.model.superresolution.parameters() if p.requires_grad] | |
self.disc_params = [p for k, p in self.disc.named_parameters() if p.requires_grad] | |
return self.model | |
def build_optimizer(self, model): | |
self.optimizer_gen = optimizer_gen = torch.optim.Adam( | |
self.img2plane_backbone_params, | |
lr=hparams['lr_g'], | |
betas=(hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) | |
) | |
self.optimizer_gen.add_param_group({ | |
'params': self.decoder_params, | |
'lr': hparams['lr_g'], | |
'betas': (hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) | |
}) | |
self.optimizer_gen.add_param_group({ | |
'params': self.upsample_params, | |
'lr': hparams['lr_g'], | |
'betas': (hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) | |
}) | |
mb_ratio_d = hparams['reg_interval_d'] / (hparams['reg_interval_d'] + 1) | |
self.optimizer_disc = optimizer_disc = torch.optim.Adam( | |
self.disc_params, | |
lr=hparams['lr_d'] * mb_ratio_d, | |
betas=(hparams['optimizer_adam_beta1_d'] ** mb_ratio_d, hparams['optimizer_adam_beta2_d'] ** mb_ratio_d)) | |
optimizers = [optimizer_gen] * 2 + [optimizer_disc] # optim0-1: ref/mv; optim2: disc | |
return optimizers | |
def build_scheduler(self, optimizer): | |
mb_ratio_d = hparams['reg_interval_d'] / (hparams['reg_interval_d'] + 1) | |
return ScheduleForImg2Plane(optimizer, hparams['lr_g'], hparams['lr_d'] * mb_ratio_d, hparams['warmup_updates']) | |
def on_train_start(self): | |
print("==============================") | |
num_params(self.model, model_name="Generator") | |
for n, m in self.model.named_children(): | |
num_params(m, model_name="|-- "+n) | |
print("==============================") | |
num_params(self.disc, model_name="Discriminator") | |
for n, m in self.disc.named_children(): | |
num_params(m, model_name="|-- "+n) | |
print("==============================") | |
def forward_G(self, img, camera, cond=None, ret=None, update_emas=False, cache_backbone=True, use_cached_backbone=False): | |
""" | |
ref_img: [B, 3, W, H] | |
camera: [b, 25], 16 dim c2w, and 9 dim intrinsic | |
""" | |
G = self.model | |
gen_output = G.forward(img=img, camera=camera, cond=cond, ret=ret, update_emas=update_emas, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone) | |
return gen_output | |
def forward_D(self, img, camera, update_emas=False): | |
D = self.disc | |
logits = D.forward(img, camera, update_emas=update_emas) | |
return logits | |
def prepare_batch(self, batch): | |
# get synthesized img as GT data | |
out_batch = {} | |
ws_camera = batch['ffhq_ws_cameras'] | |
camera = batch['ffhq_ref_cameras'] | |
fake_camera = batch['ffhq_mv_cameras'] | |
z = torch.randn([camera.shape[0], hparams['z_dim']],device=camera.device) | |
with torch.no_grad(): | |
ws = self.eg3d_model.mapping(z, ws_camera, update_emas=False, truncation_psi=0.7) | |
ref_img_gen = self.eg3d_model.synthesis(ws, camera, cache_backbone=True, use_cached_backbone=False) | |
mv_img_gen = self.eg3d_model.synthesis(ws, fake_camera, cache_backbone=False, use_cached_backbone=True) | |
ref_img_gen['image_raw'] = filtered_resizing(ref_img_gen['image'], size=hparams['neural_rendering_resolution'], f=self.resample_filter, filter_mode='antialiased') | |
mv_img_gen['image_raw'] = filtered_resizing(mv_img_gen['image'], size=hparams['neural_rendering_resolution'], f=self.resample_filter, filter_mode='antialiased') | |
out_batch.update({ | |
'ffhq_planes': ref_img_gen['plane'], | |
'ffhq_ref_imgs_raw': ref_img_gen['image_raw'], | |
'ffhq_ref_imgs': ref_img_gen['image'], | |
'ffhq_ref_imgs_depth': ref_img_gen['image_depth'], | |
'ffhq_mv_imgs_raw': mv_img_gen['image_raw'], | |
'ffhq_mv_imgs': mv_img_gen['image'], | |
'ffhq_mv_imgs_depth': mv_img_gen['image_depth'], | |
'ffhq_mv_imgs_feature': mv_img_gen['image_feature'], | |
'ffhq_ref_cameras': batch['ffhq_ref_cameras'], | |
'ffhq_mv_cameras': batch['ffhq_mv_cameras'], | |
}) | |
return out_batch | |
def run_G_reference_image(self, batch): | |
losses = {} | |
ret = {} | |
if self.global_step+1 < self.start_adv_iters: | |
# at early stage, don't train on the ref camera too frequently, to prevent bill-board | |
if (self.global_step+1) % 5 != 0: | |
return losses | |
elif self.global_step < self.start_adv_iters + 5000: | |
if (self.global_step+1) % 2 != 0: | |
return losses | |
else: | |
# after 37,500 steps, we train on ref image every step. | |
pass | |
with torch.autograd.profiler.record_function('G_ref_forward'): | |
camera = batch['ffhq_ref_cameras'] | |
img = batch['ffhq_ref_imgs'] | |
img_raw = batch['ffhq_ref_imgs_raw'] | |
img_depth = batch['ffhq_ref_imgs_depth'] | |
gen_img = self.forward_G(img, camera, cond={'ref_cameras': batch['ffhq_ref_cameras']}, ret=ret) | |
if 'losses' in ret: losses.update(ret['losses']) | |
self.gen_tmp_output['recon_ref_imgs'] = gen_img['image'].detach() | |
self.gen_tmp_output['recon_ref_imgs_raw'] = gen_img['image_raw'].detach() | |
losses['G_ref_plane_l1_mean'] = (gen_img['plane'][:,:]).detach().abs().mean() | |
losses['G_ref_plane_l1_std'] = (gen_img['plane'][:,:]).detach().abs().std() | |
if hparams['use_mse']: | |
losses['G_ref_img_mae'] = (gen_img['image'] - img).pow(2).mean() | |
losses['G_ref_img_mae_raw'] = (gen_img['image_raw'] - img_raw).pow(2).mean() | |
else: | |
losses['G_ref_img_mae'] = (gen_img['image'] - img).abs().mean() | |
losses['G_ref_img_mae_raw'] = (gen_img['image_raw'] - img_raw).abs().mean() | |
pred_img_01 = ((gen_img['image'] + 1) / 2).clamp(0, 1) | |
pred_img_01_raw = ((gen_img['image_raw'] + 1) / 2).clamp(0, 1) | |
gt_img_01 = ((img + 1) / 2).clamp(0, 1) | |
gt_img_01_raw = ((img_raw + 1) / 2).clamp(0, 1) | |
losses['G_ref_img_lpips'] = self.criterion_lpips(pred_img_01, gt_img_01).mean() | |
losses['G_ref_img_lpips_raw'] = self.criterion_lpips(pred_img_01_raw, gt_img_01_raw).mean() | |
losses['G_ref_img_mae_depth'] = (gen_img['image_depth'] - img_depth).detach().abs().mean() | |
disc_inp_img = { | |
'image': gen_img['image'], | |
'image_raw': gen_img['image_raw'], | |
} | |
gen_logits = self.forward_D(disc_inp_img, camera) | |
losses['G_ref_adv'] = torch.nn.functional.softplus(-gen_logits).mean() | |
return losses | |
def run_G_multiview_image(self, batch): | |
losses = {} | |
ret = {} | |
with torch.autograd.profiler.record_function('G_mv_forward'): | |
camera = batch['ffhq_mv_cameras'] | |
img = batch['ffhq_mv_imgs'] | |
img_raw = batch['ffhq_mv_imgs_raw'] | |
img_depth = batch['ffhq_mv_imgs_depth'] | |
gen_img = self.forward_G(batch['ffhq_ref_imgs'], camera, cond={'ref_cameras': batch['ffhq_ref_cameras']}, ret=ret) | |
if 'losses' in ret: losses.update(ret['losses']) | |
self.gen_tmp_output['recon_mv_imgs'] = gen_img['image'].detach() | |
self.gen_tmp_output['recon_mv_imgs_raw'] = gen_img['image_raw'].detach() | |
losses['G_mv_plane_l1_mean'] = (gen_img['plane'][:,:]).detach().abs().mean() | |
losses['G_mv_plane_l1_std'] = (gen_img['plane'][:,:]).detach().abs().std() | |
if hparams['use_mse']: | |
losses['G_mv_img_mae'] = (gen_img['image'] - img).pow(2).mean() | |
losses['G_mv_img_mae_raw'] = (gen_img['image_raw'] - img_raw).pow(2).mean() | |
else: | |
losses['G_mv_img_mae'] = (gen_img['image'] - img).abs().mean() | |
losses['G_mv_img_mae_raw'] = (gen_img['image_raw'] - img_raw).abs().mean() | |
pred_img_01 = (gen_img['image'] + 1) / 2 | |
pred_img_01_raw = (gen_img['image_raw'] + 1) / 2 | |
gt_img_01 = (img + 1) / 2 | |
gt_img_01_raw = (img_raw + 1) / 2 | |
losses['G_mv_img_lpips'] = self.criterion_lpips(pred_img_01, gt_img_01).mean() | |
losses['G_mv_img_lpips_raw'] = self.criterion_lpips(pred_img_01_raw, gt_img_01_raw).mean() | |
losses['G_mv_img_mae_depth'] = (gen_img['image_depth'] - img_depth).detach().abs().mean() | |
disc_inp_img = { | |
'image': gen_img['image'], | |
'image_raw': gen_img['image_raw'], | |
} | |
gen_logits = self.forward_D(disc_inp_img, camera) | |
losses['G_mv_adv'] = torch.nn.functional.softplus(-gen_logits).mean() | |
return losses | |
def run_G_reg(self, batch): | |
losses = {} | |
imgs = batch['ffhq_mv_imgs'] | |
if (self.global_step+1) % hparams['reg_interval_g'] == 0: | |
with torch.autograd.profiler.record_function('G_regularize_forward'): | |
initial_coordinates = torch.rand((imgs.shape[0], 1000, 3), device=imgs.device) - 0.5 # [-0.5,0.5] | |
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * 5e-3 | |
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) | |
source_sigma = self.model.sample(coordinates=all_coordinates, directions=torch.randn_like(all_coordinates), img=imgs, cond={'ref_cameras': batch['ffhq_mv_cameras']}, update_emas=False)['sigma'] | |
source_sigma_initial = source_sigma[:, :source_sigma.shape[1]//2] | |
source_sigma_perturbed = source_sigma[:, source_sigma.shape[1]//2:] | |
density_reg_loss = torch.nn.functional.l1_loss(source_sigma_initial, source_sigma_perturbed) | |
# we want the pertubed position has similar density | |
losses['G_regularize_density_l1'] = density_reg_loss | |
return losses | |
def forward_D_main(self, batch): | |
""" | |
we update ema this substep. | |
""" | |
losses = {} | |
with torch.autograd.profiler.record_function('D_minimize_fake_forward'): | |
if 'recon_ref_imgs' in self.gen_tmp_output: | |
camera = batch['ffhq_ref_cameras'] | |
disc_inp_img = { | |
'image': self.gen_tmp_output['recon_ref_imgs'], | |
'image_raw': self.gen_tmp_output['recon_ref_imgs_raw'], | |
} | |
gen_logits = self.forward_D(disc_inp_img, camera, update_emas=True) | |
losses['D_minimize_ref_fake'] = torch.nn.functional.softplus(gen_logits).mean() | |
if 'recon_mv_imgs' in self.gen_tmp_output: | |
camera = batch['ffhq_mv_cameras'] | |
disc_inp_img = { | |
'image': self.gen_tmp_output['recon_mv_imgs'], | |
'image_raw': self.gen_tmp_output['recon_mv_imgs_raw'], | |
} | |
gen_logits = self.forward_D(disc_inp_img, camera, update_emas=True) | |
losses['D_minimize_mv_fake'] = torch.nn.functional.softplus(gen_logits).mean() | |
# Maximize confidence on true samples | |
with torch.autograd.profiler.record_function('D_maximize_true_forward'): | |
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': | |
ref_cameras = batch['ffhq_ref_cameras'] | |
ref_img_tmp_image = batch['ffhq_ref_imgs'].detach().requires_grad_(True) | |
ref_img_tmp_image_raw = batch['ffhq_ref_imgs_raw'].detach().requires_grad_(True) | |
elif hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'ffhq': | |
ref_cameras = batch['ffhq_gt_cameras'] | |
ref_img_tmp_image = batch['ffhq_gt_imgs'].detach().requires_grad_(True) | |
ref_img_tmp_image_raw = batch['ffhq_gt_imgs_raw'].detach().requires_grad_(True) | |
ref_img_tmp = {'image': ref_img_tmp_image, 'image_raw': ref_img_tmp_image_raw} | |
ref_logits = self.forward_D(ref_img_tmp, ref_cameras) | |
losses['D_maximize_ref'] = torch.nn.functional.softplus(-ref_logits).mean() | |
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': | |
mv_cameras = batch['ffhq_mv_cameras'] | |
mv_img_tmp_image = batch['ffhq_mv_imgs'].detach().requires_grad_(True) | |
mv_img_tmp_image_raw = batch['ffhq_mv_imgs_raw'].detach().requires_grad_(True) | |
mv_img_tmp = {'image': mv_img_tmp_image, 'image_raw': mv_img_tmp_image_raw} | |
mv_logits = self.forward_D(mv_img_tmp, mv_cameras) | |
losses['D_maximize_mv'] = torch.nn.functional.softplus(-mv_logits).mean() | |
if (self.global_step+1) % hparams['reg_interval_d'] == 0 and self.training is True: | |
with torch.autograd.profiler.record_function('D_gradient_penalty_on_real_imgs_forward'), conv2d_gradfix.no_weight_gradients(): | |
ref_r1_grads = torch.autograd.grad(outputs=[ref_logits.sum()], inputs=[ref_img_tmp['image'], ref_img_tmp['image_raw']], create_graph=True, only_inputs=True) | |
ref_r1_grads_image = ref_r1_grads[0] | |
ref_r1_grads_image_raw = ref_r1_grads[1] | |
ref_r1_penalty_raw = ref_r1_grads_image_raw.square().sum([1,2,3]).mean() | |
ref_r1_penalty_image = ref_r1_grads_image.square().sum([1,2,3]).mean() | |
losses['D_ref_gradient_penalty'] = (ref_r1_penalty_image + ref_r1_penalty_raw) / 2 | |
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': | |
mv_r1_grads = torch.autograd.grad(outputs=[mv_logits.sum()], inputs=[mv_img_tmp['image'], mv_img_tmp['image_raw']], create_graph=True, only_inputs=True) | |
mv_r1_grads_image = mv_r1_grads[0] | |
mv_r1_grads_image_raw = mv_r1_grads[1] | |
mv_r1_penalty_raw = mv_r1_grads_image_raw.square().sum([1,2,3]).mean() | |
mv_r1_penalty_image = mv_r1_grads_image.square().sum([1,2,3]).mean() | |
losses['D_mv_gradient_penalty'] = (mv_r1_penalty_image + mv_r1_penalty_raw) / 2 | |
self.gen_tmp_output = {} | |
return losses | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
if len(sample) == 0: | |
return None | |
if optimizer_idx == 0: | |
sample = self.prepare_batch(sample) | |
self.cache_sample = sample | |
else: | |
sample = self.cache_sample | |
# self.update_ema() | |
losses = {} | |
if optimizer_idx == 0: | |
losses.update(self.run_G_reference_image(sample)) | |
# after the early stage, we decrease the lambda of error-based loss to prevent oversmoothness | |
loss_weights = { | |
'G_ref_img_mae': 1.0, | |
'G_ref_img_mae_raw': 1.0, | |
'G_ref_img_mae_depth': hparams.get('lambda_mse_depth', 0), | |
'G_ref_img_lpips': 0.1, | |
'G_ref_img_lpips_raw': 0.1, | |
'G_ref_adv': 0.0 if self.global_step < self.start_adv_iters else 0.1, | |
} | |
elif optimizer_idx == 1: | |
losses.update(self.run_G_multiview_image(sample)) | |
losses.update(self.run_G_reg(sample)) | |
loss_weights = { | |
'G_mv_img_mae': 1.0, | |
'G_mv_img_mae_raw': 1.0, | |
'G_mv_img_mae_depth': hparams.get('lambda_mse_depth', 0), | |
'G_mv_img_lpips': 0.1, | |
'G_mv_img_lpips_raw': 0.1, | |
'G_mv_adv': 0.0 if self.global_step < self.start_adv_iters else 0.025, | |
'G_regularize_density_l1': hparams['lambda_density_reg'], | |
} | |
elif optimizer_idx == 2: | |
losses.update(self.forward_D_main(sample)) | |
loss_weights = { | |
'D_maximize_ref': 1.0, | |
'D_minimize_ref_fake': 1.0, | |
'D_ref_gradient_penalty': hparams['lambda_gradient_penalty'] * hparams['reg_interval_d'], | |
'D_maximize_mv': 1.0, | |
'D_minimize_mv_fake': 1.0, | |
'D_mv_gradient_penalty': hparams['lambda_gradient_penalty'] * hparams['reg_interval_d'], | |
} | |
total_loss = sum([loss_weights[k] * v for k, v in losses.items() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
if len(losses) == 0: | |
return None | |
return total_loss, losses | |
##################### | |
# Validation | |
##################### | |
def validation_start(self): | |
if self.global_step % hparams['valid_infer_interval'] == 0: | |
self.gen_dir = os.path.join(hparams['work_dir'], f'validation_results') | |
os.makedirs(self.gen_dir, exist_ok=True) | |
def validation_step(self, sample, batch_idx): | |
outputs = {} | |
losses = {} | |
if len(sample) == 0: | |
return None | |
sample = self.prepare_batch(sample) | |
rank = 0 if len(set(os.environ['CUDA_VISIBLE_DEVICES'].split(","))) == 1 else dist.get_rank() | |
losses.update(self.run_G_reference_image(sample)) | |
losses.update(self.run_G_multiview_image(sample)) | |
losses.update(self.forward_D_main(sample)) | |
outputs['losses'] = losses | |
outputs['total_loss'] = sum(outputs['losses'].values()) | |
outputs = tensors_to_scalars(outputs) | |
if self.global_step % hparams['valid_infer_interval'] == 0 \ | |
and batch_idx < hparams['num_valid_plots'] and rank == 0: | |
imgs_ref = sample['ffhq_ref_imgs'] | |
dummy_cond = { | |
'cond_src': torch.zeros([imgs_ref.shape[0], 3, 512, 512], dtype=sample['ffhq_mv_cameras'].dtype, device=sample['ffhq_mv_cameras'].device), | |
'cond_drv': torch.zeros([imgs_ref.shape[0], 3, 512, 512], dtype=sample['ffhq_mv_cameras'].dtype, device=sample['ffhq_mv_cameras'].device), | |
'ref_cameras': sample['ffhq_ref_cameras'], | |
} | |
gen_img = self.model.forward(imgs_ref, sample['ffhq_mv_cameras'], cond=dummy_cond, noise_mode='const') | |
gen_img_recon = self.model.forward(imgs_ref, sample['ffhq_ref_cameras'], cond=dummy_cond, noise_mode='const') | |
imgs_recon = gen_img_recon['image'].permute(0,2,3,1) | |
imgs_recon_raw = filtered_resizing(gen_img_recon['image_raw'], size=512, f=self.resample_filter, filter_mode='antialiased').permute(0, 2,3,1) | |
imgs_recon_depth = gen_img_recon['image_depth'].permute(0,2,3,1) | |
imgs_pred = gen_img['image'].permute(0,2,3,1) | |
imgs_pred_raw = filtered_resizing(gen_img['image_raw'], size=512, f=self.resample_filter, filter_mode='antialiased').permute(0, 2,3,1) | |
imgs_pred_depth = gen_img['image_depth'].permute(0,2,3,1) | |
imgs_ref = imgs_ref.permute(0,2,3,1) | |
imgs_mv = sample['ffhq_mv_imgs'].permute(0,2,3,1) # [B, H, W, 3] | |
for i in range(len(imgs_pred)): | |
idx_string = format(i+batch_idx * hparams['batch_size'], "05d") | |
base_fn = f"{idx_string}" | |
img_ref_mv_recon_pred = torch.cat([imgs_ref[i], imgs_mv[i], imgs_recon_raw[i], imgs_pred_raw[i], imgs_recon[i], imgs_pred[i]], dim=1) | |
self.save_rgb_to_fname(img_ref_mv_recon_pred, f"{self.gen_dir}/images_rgb_iter{self.global_step}/ref_mv_recon_pred_{base_fn}.png") | |
img_depth_recon_pred = torch.cat([imgs_recon_depth[i], imgs_pred_depth[i]], dim=1) | |
self.save_depth_to_fname(img_depth_recon_pred, f"{self.gen_dir}/images_depth_iter{self.global_step}/recon_pred_{base_fn}.png") | |
return outputs | |
def validation_end(self, outputs): | |
return super().validation_end(outputs) | |
def save_rgb_to_fname(rgb, fname): | |
""" | |
rgb: [H, W, 3] | |
""" | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
img = (rgb * 127.5 + 128).clamp(0, 255) | |
img = convert_to_np(img).astype(np.uint8) | |
Image.fromarray(img, 'RGB').save(fname) | |
def save_depth_to_fname(depth, fname): | |
""" | |
depth: [H, W, 3] | |
""" | |
os.makedirs(os.path.dirname(fname), exist_ok=True) | |
low, high = depth.min(), depth.max() | |
img = (depth - low) * (255 / (high - low)) | |
img = convert_to_np(img) | |
img = np.rint(img).clip(0, 255).astype(np.uint8) | |
Image.fromarray(img[:, :, 0], 'L').save(fname) | |