Our3D / lib /render_utils.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
8.8 kB
import torch
from tqdm import tqdm, trange
import numpy as np
from .dvgo import get_rays_of_a_view
import os
import imageio
from .utils import to8b, rgb_lpips, rgb_ssim, gen_rand_colors
import matplotlib.pyplot as plt
@torch.no_grad()
def render_viewpoints(model, render_poses, HW, Ks, ndc, render_kwargs,
gt_imgs=None, savedir=None, dump_images=False, cfg=None,
render_factor=0, render_video_flipy=False, render_video_rot90=0,
eval_ssim=False, eval_lpips_alex=False, eval_lpips_vgg=False,
seg_mask=True, render_fct=0.0, seg_type='seg_density'):
'''Render images for the given viewpoints; run evaluation if gt given.'''
assert len(render_poses) == len(HW) and len(HW) == len(Ks)
if render_factor!=0:
HW = np.copy(HW)
Ks = np.copy(Ks)
HW = (HW/render_factor).astype(int)
Ks[:, :2, :3] /= render_factor
rgbs, segs, depths, bgmaps, psnrs, ssims, lpips_alex, lpips_vgg = [], [], [], [], [], [], [], []
for i, c2w in enumerate(tqdm(render_poses, desc='Render {}...'.format(seg_type))):
H, W = HW[i]
K = Ks[i]
c2w = torch.Tensor(c2w)
rays_o, rays_d, viewdirs = get_rays_of_a_view(
H, W, K, c2w, ndc, inverse_y=render_kwargs['inverse_y'],
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
keys = ['rgb_marched', 'depth', 'alphainv_last']
if seg_mask: keys.append('seg_mask_marched')
rays_o = rays_o.flatten(0,-2)
rays_d = rays_d.flatten(0,-2)
viewdirs = viewdirs.flatten(0,-2)
render_result_chunks = [
{k: v for k, v in model(ro, rd, vd, render_fct=render_fct, **render_kwargs).items() if k in keys}
for ro, rd, vd in zip(rays_o.split(8192, 0), rays_d.split(8192, 0), viewdirs.split(8192, 0))
]
render_result = {
k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
for k in render_result_chunks[0].keys()
}
rgb = render_result['rgb_marched'].cpu().numpy()
if seg_mask:
seg_m = render_result['seg_mask_marched'].cpu()
else:
seg_m = None
depth = render_result['depth'].cpu().numpy()
bgmap = render_result['alphainv_last'].cpu().numpy()
rgbs.append(rgb)
if seg_mask:
segs.append(seg_m)
depths.append(depth)
bgmaps.append(bgmap)
if i==0:
print('Testing, rgb shape: ', rgb.shape)
if gt_imgs is not None and render_factor==0:
p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
psnrs.append(p)
if eval_ssim:
ssims.append(rgb_ssim(rgb, gt_imgs[i], max_val=1))
if eval_lpips_alex:
lpips_alex.append(rgb_lpips(rgb, gt_imgs[i], net_name='alex', device=c2w.device))
if eval_lpips_vgg:
lpips_vgg.append(rgb_lpips(rgb, gt_imgs[i], net_name='vgg', device=c2w.device))
if len(psnrs):
print('Testing psnr', np.mean(psnrs), '(avg)')
if eval_ssim: print('Testing ssim', np.mean(ssims), '(avg)')
if eval_lpips_vgg: print('Testing lpips (vgg)', np.mean(lpips_vgg), '(avg)')
if eval_lpips_alex: print('Testing lpips (alex)', np.mean(lpips_alex), '(avg)')
if render_video_flipy:
for i in range(len(rgbs)):
rgbs[i] = np.flip(rgbs[i], axis=0)
depths[i] = np.flip(depths[i], axis=0)
bgmaps[i] = np.flip(bgmaps[i], axis=0)
segs[i] = np.flip(segs[i], axis=0)
if render_video_rot90 != 0:
for i in range(len(rgbs)):
rgbs[i] = np.rot90(rgbs[i], k=render_video_rot90, axes=(0,1))
depths[i] = np.rot90(depths[i], k=render_video_rot90, axes=(0,1))
bgmaps[i] = np.rot90(bgmaps[i], k=render_video_rot90, axes=(0,1))
segs[i] = np.rot90(segs[i], k=render_video_rot90, axes=(0,1))
if savedir is not None and dump_images:
if seg_type == 'seg_density':
img_dir = 'seged_img'
elif seg_type == 'seg_img':
img_dir = 'ori_img'
else:
raise NotImplementedError
img_dir = os.path.join(savedir, img_dir)
os.makedirs(img_dir, exist_ok=True)
for i in trange(len(rgbs), desc='dumping images...'):
rgb8 = to8b(rgbs[i])
filename = os.path.join(img_dir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
rgbs = np.array(rgbs)
depths = np.array(depths)
bgmaps = np.array(bgmaps)
if len(segs): segs = np.stack(segs)
return rgbs, depths, bgmaps, segs
def fetch_render_params(render_type, data_dict):
if render_type == 'train':
render_poses=data_dict['poses'][data_dict['i_train']]
HW=data_dict['HW'][data_dict['i_train']]
Ks=data_dict['Ks'][data_dict['i_train']]
gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_train']]
elif render_type == 'test':
render_poses=data_dict['poses'][data_dict['i_test']]
HW=data_dict['HW'][data_dict['i_test']]
Ks=data_dict['Ks'][data_dict['i_test']]
gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test']]
elif render_type == 'video':
render_poses=data_dict['render_poses']
HW=data_dict['HW'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0)
Ks=data_dict['Ks'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0)
gt_imgs=None
else:
raise NotImplementedError
return render_poses, HW, Ks, gt_imgs
@torch.no_grad()
def render_fn(args, cfg, ckpt_name, flag, e_flag, num_obj, data_dict, render_viewpoints_kwargs, seg_type='seg_density'):
rand_colors = gen_rand_colors(num_obj)
testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_{args.render_opt}_{ckpt_name}')
os.makedirs(testsavedir, exist_ok=True)
print('All results are dumped into', testsavedir)
render_poses, HW, Ks, gt_imgs = fetch_render_params(args.render_opt, data_dict)
rgbs, depths, bgmaps, segs = render_viewpoints(
render_poses=render_poses,
HW=HW, Ks=Ks, gt_imgs=gt_imgs,
cfg=cfg,savedir=testsavedir, dump_images=args.dump_images,
eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
seg_type=seg_type,
**render_viewpoints_kwargs)
imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+flag+e_flag+'_'+seg_type+'.mp4'), to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(os.path.join(testsavedir, 'video.seg'+flag+e_flag+'_'+seg_type+'.mp4'), to8b(segs>0), fps=30, quality=8)
# imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'_'+seg_type+'.mp4'), \
# to8b(1 - depths / np.max(depths)), fps=30, quality=8)
depth_vis = plt.get_cmap('rainbow')(1 - depths / np.max(depths)).squeeze()[..., :3]
imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'_'+seg_type+'.mp4'), to8b(depth_vis), fps=30, quality=8)
if False:
depths_vis = depths * (1-bgmaps) + bgmaps
dmin, dmax = np.percentile(depths_vis[bgmaps < 0.1], q=[5, 95])
depth_vis = plt.get_cmap('rainbow')(1 - np.clip((depths_vis - dmin) / (dmax - dmin), 0, 1)).squeeze()[..., :3]
imageio.mimwrite(os.path.join(testsavedir, 'video.depth'+flag+e_flag+'_'+seg_type+'.mp4'), to8b(depth_vis), fps=30, quality=8)
if seg_type == 'seg_img':
seg_on_rgb = []
if args.dump_images:
masked_img_dir = os.path.join(testsavedir, 'masked_img')
os.makedirs(masked_img_dir, exist_ok=True)
masks_dir = os.path.join(testsavedir, 'masks')
os.makedirs(masks_dir, exist_ok=True)
for i, rgb, seg in zip(range(rgbs.shape[0]), rgbs, segs):
# Winner takes all
max_logit = np.expand_dims(np.max(seg, axis = -1), -1)
tmp_seg = seg
tmp_seg = np.argmax(tmp_seg, axis = -1)
tmp_seg[max_logit[:,:,0] <= 0.1] = num_obj
recolored_rgb = 0.3*rgb + 0.7*(rand_colors[tmp_seg])
seg_on_rgb.append(recolored_rgb)
if args.dump_images:
imageio.imwrite(os.path.join(masked_img_dir, 'rgb_{:07d}.png'.format(i)), to8b(recolored_rgb))
imageio.imwrite(os.path.join(masks_dir, 'mask_{:07d}.png'.format(i)), to8b(seg>0))
imageio.mimwrite(os.path.join(testsavedir, 'video.seg_on_rgb'+e_flag+'_'+seg_type+'.mp4'), to8b(seg_on_rgb), fps=30, quality=8)
return to8b(np.stack(seg_on_rgb))
return to8b(rgbs)