|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Generate lerp videos using pretrained network pickle.""" |
|
|
|
import os |
|
import re |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import click |
|
import dnnlib |
|
import imageio |
|
import numpy as np |
|
import scipy.interpolate |
|
import torch |
|
from tqdm import tqdm |
|
import mrcfile |
|
|
|
import legacy |
|
|
|
from camera_utils import LookAtPoseSampler |
|
from torch_utils import misc |
|
|
|
|
|
def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): |
|
batch_size, channels, img_h, img_w = img.shape |
|
if grid_w is None: |
|
grid_w = batch_size // grid_h |
|
assert batch_size == grid_w * grid_h |
|
if float_to_uint8: |
|
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
img = img.reshape(grid_h, grid_w, channels, img_h, img_w) |
|
img = img.permute(2, 0, 3, 1, 4) |
|
img = img.reshape(channels, grid_h * img_h, grid_w * img_w) |
|
if chw_to_hwc: |
|
img = img.permute(1, 2, 0) |
|
if to_numpy: |
|
img = img.cpu().numpy() |
|
return img |
|
|
|
def create_samples(N=256, voxel_origin=[0, 0, 0], cube_length=2.0): |
|
|
|
voxel_origin = np.array(voxel_origin) - cube_length/2 |
|
voxel_size = cube_length / (N - 1) |
|
|
|
overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor()) |
|
samples = torch.zeros(N ** 3, 3) |
|
|
|
|
|
|
|
samples[:, 2] = overall_index % N |
|
samples[:, 1] = (overall_index.float() / N) % N |
|
samples[:, 0] = ((overall_index.float() / N) / N) % N |
|
|
|
|
|
|
|
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] |
|
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] |
|
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] |
|
|
|
num_samples = N ** 3 |
|
|
|
return samples.unsqueeze(0), voxel_origin, voxel_size |
|
|
|
|
|
|
|
def gen_interp_video(G, w_given, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1., truncation_cutoff=14, generator_type='ffhq', image_mode='image', gen_shapes=False, device=torch.device('cuda'), **video_kwargs): |
|
grid_w = grid_dims[0] |
|
grid_h = grid_dims[1] |
|
|
|
if num_keyframes is None: |
|
if len(seeds) % (grid_w*grid_h) != 0: |
|
raise ValueError('Number of input seeds must be divisible by grid W*H') |
|
num_keyframes = len(seeds) // (grid_w*grid_h) |
|
|
|
all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) |
|
for idx in range(num_keyframes*grid_h*grid_w): |
|
all_seeds[idx] = seeds[idx % len(seeds)] |
|
|
|
if shuffle_seed is not None: |
|
rng = np.random.RandomState(seed=shuffle_seed) |
|
rng.shuffle(all_seeds) |
|
|
|
camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device) |
|
zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) |
|
cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device) |
|
focal_length = 4.2647 |
|
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) |
|
c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
c = c.repeat(len(zs), 1) |
|
|
|
if w_given is not None: |
|
ws = w_given |
|
if ws.shape[1] != G.backbone.mapping.num_ws: |
|
ws = ws.repeat([1, G.backbone.mapping.num_ws, 1]) |
|
else: |
|
ws = G.mapping(z=zs, c=c, truncation_psi=psi, truncation_cutoff=truncation_cutoff) |
|
|
|
ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) |
|
|
|
|
|
grid = [] |
|
for yi in range(grid_h): |
|
row = [] |
|
for xi in range(grid_w): |
|
x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) |
|
y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) |
|
interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) |
|
row.append(interp) |
|
grid.append(row) |
|
|
|
|
|
max_batch = 10000000 |
|
voxel_resolution = 512 |
|
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) |
|
|
|
if gen_shapes: |
|
outdir = 'interpolation_{}_{}/'.format(all_seeds[0], all_seeds[1]) |
|
os.makedirs(outdir, exist_ok=True) |
|
all_poses = [] |
|
for frame_idx in tqdm(range(num_keyframes * w_frames)): |
|
imgs = [] |
|
for yi in range(grid_h): |
|
for xi in range(grid_w): |
|
pitch_range = 0.25 |
|
yaw_range = 0.35 |
|
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes * w_frames)), |
|
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / (num_keyframes * w_frames)), |
|
camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device) |
|
all_poses.append(cam2world_pose.squeeze().cpu().numpy()) |
|
focal_length = 4.2647 if generator_type != 'Shapenet' else 1.7074 |
|
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) |
|
c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
|
|
interp = grid[yi][xi] |
|
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) |
|
|
|
entangle = 'camera' |
|
if entangle == 'conditioning': |
|
c_forward = torch.cat([LookAtPoseSampler.sample(3.14/2, |
|
3.14/2, |
|
camera_lookat_point, |
|
radius=G.rendering_kwargs['avg_camera_radius'], device=device).reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) |
|
img = G.synthesis(ws=w_c, c=c_forward, noise_mode='const')[image_mode][0] |
|
elif entangle == 'camera': |
|
img = G.synthesis(ws=w.unsqueeze(0), c=c[0:1], noise_mode='const')[image_mode][0] |
|
|
|
elif entangle == 'both': |
|
w_c = G.mapping(z=zs[0:1], c=c[0:1], truncation_psi=psi, truncation_cutoff=truncation_cutoff) |
|
img = G.synthesis(ws=w_c, c=c[0:1], noise_mode='const')[image_mode][0] |
|
|
|
if image_mode == 'image_depth': |
|
img = -img |
|
img = (img - img.min()) / (img.max() - img.min()) * 2 - 1 |
|
|
|
imgs.append(img) |
|
|
|
if gen_shapes: |
|
|
|
print('Generating shape for frame %d / %d ...' % (frame_idx, num_keyframes * w_frames)) |
|
|
|
samples, voxel_origin, voxel_size = create_samples(N=voxel_resolution, voxel_origin=[0, 0, 0], cube_length=G.rendering_kwargs['box_warp']) |
|
samples = samples.to(device) |
|
sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1), device=device) |
|
transformed_ray_directions_expanded = torch.zeros((samples.shape[0], max_batch, 3), device=device) |
|
transformed_ray_directions_expanded[..., -1] = -1 |
|
|
|
head = 0 |
|
with tqdm(total = samples.shape[1]) as pbar: |
|
with torch.no_grad(): |
|
while head < samples.shape[1]: |
|
torch.manual_seed(0) |
|
sigma = G.sample_mixed(samples[:, head:head+max_batch], transformed_ray_directions_expanded[:, :samples.shape[1]-head], w.unsqueeze(0), truncation_psi=psi, noise_mode='const')['sigma'] |
|
sigmas[:, head:head+max_batch] = sigma |
|
head += max_batch |
|
pbar.update(max_batch) |
|
|
|
sigmas = sigmas.reshape((voxel_resolution, voxel_resolution, voxel_resolution)).cpu().numpy() |
|
sigmas = np.flip(sigmas, 0) |
|
|
|
pad = int(30 * voxel_resolution / 256) |
|
pad_top = int(38 * voxel_resolution / 256) |
|
sigmas[:pad] = 0 |
|
sigmas[-pad:] = 0 |
|
sigmas[:, :pad] = 0 |
|
sigmas[:, -pad_top:] = 0 |
|
sigmas[:, :, :pad] = 0 |
|
sigmas[:, :, -pad:] = 0 |
|
|
|
output_ply = False |
|
if output_ply: |
|
try: |
|
from shape_utils import convert_sdf_samples_to_ply |
|
convert_sdf_samples_to_ply(np.transpose(sigmas, (2, 1, 0)), [0, 0, 0], 1, os.path.join(outdir, f'{frame_idx:04d}_shape.ply'), level=10) |
|
except: |
|
pass |
|
else: |
|
with mrcfile.new_mmap(outdir + f'{frame_idx:04d}_shape.mrc', overwrite=True, shape=sigmas.shape, mrc_mode=2) as mrc: |
|
mrc.data[:] = sigmas |
|
|
|
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) |
|
video_out.close() |
|
all_poses = np.stack(all_poses) |
|
|
|
if gen_shapes: |
|
print(all_poses.shape) |
|
with open(mp4.replace('.mp4', '_trajectory.npy'), 'wb') as f: |
|
np.save(f, all_poses) |
|
|
|
|
|
|
|
def parse_range(s: Union[str, List[int]]) -> List[int]: |
|
'''Parse a comma separated list of numbers or ranges and return a list of ints. |
|
|
|
Example: '1,2,5-10' returns [1, 2, 5, 6, 7] |
|
''' |
|
if isinstance(s, list): return s |
|
ranges = [] |
|
range_re = re.compile(r'^(\d+)-(\d+)$') |
|
for p in s.split(','): |
|
if m := range_re.match(p): |
|
ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) |
|
else: |
|
ranges.append(int(p)) |
|
return ranges |
|
|
|
|
|
|
|
def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: |
|
'''Parse a 'M,N' or 'MxN' integer tuple. |
|
|
|
Example: |
|
'4x2' returns (4,2) |
|
'0,1' returns (0,1) |
|
''' |
|
if isinstance(s, tuple): return s |
|
if m := re.match(r'^(\d+)[x,](\d+)$', s): |
|
return (int(m.group(1)), int(m.group(2))) |
|
raise ValueError(f'cannot parse tuple {s}') |
|
|
|
|
|
|
|
@click.command() |
|
@click.option('--network', help='Network path',multiple=True, required=True) |
|
@click.option('--w_pth', help='latent path') |
|
@click.option('--generator_type', help='Generator type', type=click.Choice(['ffhq', 'cat']), required=False, metavar='STR', default='ffhq', show_default=True) |
|
@click.option('--model_is_state_dict', type=bool, default=False) |
|
@click.option('--seeds', type=parse_range, help='List of random seeds', required=True) |
|
@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) |
|
@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) |
|
@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) |
|
@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) |
|
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) |
|
@click.option('--trunc-cutoff', 'truncation_cutoff', type=int, help='Truncation cutoff', default=14, show_default=True) |
|
@click.option('--outdir', help='Output directory', type=str, default='../test_runs/manip_3D_recon/4_manip_result', metavar='DIR') |
|
@click.option('--image_mode', help='Image mode', type=click.Choice(['image', 'image_depth', 'image_raw']), required=False, metavar='STR', default='image', show_default=True) |
|
@click.option('--sample_mult', 'sampling_multiplier', type=float, help='Multiplier for depth sampling in volume rendering', default=2, show_default=True) |
|
@click.option('--nrr', type=int, help='Neural rendering resolution override', default=None, show_default=True) |
|
@click.option('--shapes', type=bool, help='Gen shapes for shape interpolation', default=False, show_default=True) |
|
|
|
def generate_images( |
|
network: List[str], |
|
w_pth: str, |
|
seeds: List[int], |
|
shuffle_seed: Optional[int], |
|
truncation_psi: float, |
|
truncation_cutoff: int, |
|
grid: Tuple[int,int], |
|
num_keyframes: Optional[int], |
|
w_frames: int, |
|
outdir: str, |
|
generator_type: str, |
|
image_mode: str, |
|
sampling_multiplier: float, |
|
nrr: Optional[int], |
|
shapes: bool, |
|
model_is_state_dict: bool, |
|
): |
|
|
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir, exist_ok=True) |
|
|
|
device = torch.device('cuda') |
|
|
|
if generator_type == 'ffhq': |
|
network_pkl_tmp = 'pretrained/ffhqrebalanced512-128.pkl' |
|
elif generator_type == 'cat': |
|
network_pkl_tmp = 'pretrained/afhqcats512-128.pkl' |
|
else: |
|
NotImplementedError() |
|
|
|
G_list = [] |
|
outputs = [] |
|
for network_path in network: |
|
print('Loading networks from "%s"...' % network_path) |
|
dir_label = network_path.split('/')[-2] + '___' + network_path.split('/')[-1] |
|
output = os.path.join(outdir, dir_label) |
|
outputs.append(output) |
|
if model_is_state_dict: |
|
with dnnlib.util.open_url(network_pkl_tmp) as f: |
|
G = legacy.load_network_pkl(f)['G_ema'].to(device) |
|
ckpt = torch.load(network_path) |
|
G.load_state_dict(ckpt, strict=False) |
|
else: |
|
with dnnlib.util.open_url(network_path) as f: |
|
G = legacy.load_network_pkl(f)['G_ema'].to(device) |
|
|
|
G.rendering_kwargs['depth_resolution'] = int(G.rendering_kwargs['depth_resolution'] * sampling_multiplier) |
|
G.rendering_kwargs['depth_resolution_importance'] = int(G.rendering_kwargs['depth_resolution_importance'] * sampling_multiplier) |
|
|
|
if generator_type == 'cat': |
|
G.rendering_kwargs['avg_camera_pivot'] = [0, 0, -0.06] |
|
elif generator_type == 'ffhq': |
|
G.rendering_kwargs['avg_camera_pivot'] = [0, 0, 0.2] |
|
|
|
if nrr is not None: G.neural_rendering_resolution = nrr |
|
G_list.append(G) |
|
|
|
|
|
if truncation_cutoff == 0: |
|
truncation_psi = 1.0 |
|
if truncation_psi == 1.0: |
|
truncation_cutoff = 14 |
|
|
|
grid_w, grid_h = grid |
|
seeds = seeds[:grid_w * grid_h] |
|
|
|
seed_idx = '' |
|
|
|
for i, seed in enumerate(seeds): |
|
if i < len(seeds) - 1: |
|
seed_idx += f'{seed}_' |
|
else: |
|
seed_idx += f'{seed}' |
|
|
|
|
|
for G, output in zip(G_list, outputs): |
|
if w_pth is not None: |
|
grid = (1, 1) |
|
w_given = torch.load(w_pth).cuda() |
|
w_given_id = os.path.split(w_pth)[-1].split('.')[-2] |
|
output = output + f'__{w_given_id}.mp4' |
|
gen_interp_video(G=G, w_given=w_given, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, |
|
w_frames=w_frames, |
|
seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, generator_type=generator_type, image_mode=image_mode, |
|
gen_shapes=shapes) |
|
|
|
else: |
|
output = output + f'__{seed_idx}.mp4' |
|
gen_interp_video(G=G, w_given=None, mp4=output, bitrate='10M', grid_dims=grid, num_keyframes=num_keyframes, |
|
w_frames=w_frames, |
|
seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, generator_type=generator_type, image_mode=image_mode, |
|
gen_shapes=shapes) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
generate_images() |
|
|
|
|
|
|