import glob import os import os.path as osp import fire import numpy as np import torch import torch.nn.functional as F from PIL import Image from tqdm import tqdm from seva.data_io import get_parser from seva.eval import ( IS_TORCH_NIGHTLY, compute_relative_inds, create_transforms_simple, infer_prior_inds, infer_prior_stats, run_one_scene, ) from seva.geometry import ( generate_interpolated_path, generate_spiral_path, get_arc_horizontal_w2cs, get_default_intrinsics, get_lookat, get_preset_pose_fov, ) from seva.model import SGMWrapper from seva.modules.autoencoder import AutoEncoder from seva.modules.conditioner import CLIPConditioner from seva.sampling import DDPMDiscretization, DiscreteDenoiser from seva.utils import load_model device = "cuda:0" # Constants. WORK_DIR = "work_dirs/demo" if IS_TORCH_NIGHTLY: COMPILE = True os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" else: COMPILE = False MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) AE = AutoEncoder(chunk_size=1).to(device) CONDITIONER = CLIPConditioner().to(device) DISCRETIZATION = DDPMDiscretization() DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) VERSION_DICT = { "H": 576, "W": 576, "T": 21, "C": 4, "f": 8, "options": {}, } if COMPILE: MODEL = torch.compile(MODEL, dynamic=False) CONDITIONER = torch.compile(CONDITIONER, dynamic=False) AE = torch.compile(AE, dynamic=False) def parse_task( task, scene, num_inputs, T, version_dict, ): options = version_dict["options"] anchor_indices = None anchor_c2ws = None anchor_Ks = None if task == "img2trajvid_s-prob": if num_inputs is not None: assert ( num_inputs == 1 ), "Task `img2trajvid_s-prob` only support 1-view conditioning..." else: num_inputs = 1 num_targets = options.get("num_targets", T - 1) num_anchors = infer_prior_stats( T, num_inputs, num_total_frames=num_targets, version_dict=version_dict, ) input_indices = [0] anchor_indices = np.linspace(1, num_targets, num_anchors).tolist() all_imgs_path = [scene] + [None] * num_targets c2ws, fovs = get_preset_pose_fov( option=options.get("traj_prior", "orbit"), num_frames=num_targets + 1, start_w2c=torch.eye(4), look_at=torch.Tensor([0, 0, 10]), ) with Image.open(scene) as img: W, H = img.size aspect_ratio = W / H Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) # unormalized Ks[:, :2] *= ( torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1) ) # normalized Ks = Ks.numpy() anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]] anchor_Ks = Ks[[round(ind) for ind in anchor_indices]] else: parser = get_parser( parser_type="reconfusion", data_dir=scene, normalize=False, ) all_imgs_path = parser.image_paths c2ws = parser.camtoworlds camera_ids = parser.camera_ids Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0) if num_inputs is None: assert len(parser.splits_per_num_input_frames.keys()) == 1 num_inputs = list(parser.splits_per_num_input_frames.keys())[0] split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore elif isinstance(num_inputs, str): split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore num_inputs = int(num_inputs.split("-")[0]) # for example 1_from32 else: split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore num_targets = len(split_dict["test_ids"]) if task == "img2img": # Note in this setting, we should refrain from using all the other camera # info except ones from sampled_indices, and most importantly, the order. num_anchors = infer_prior_stats( T, num_inputs, num_total_frames=num_targets, version_dict=version_dict, ) sampled_indices = np.sort( np.array(split_dict["train_ids"] + split_dict["test_ids"]) ) # we always sort all indices first traj_prior = options.get("traj_prior", None) if traj_prior == "spiral": assert parser.bounds is not None anchor_c2ws = generate_spiral_path( c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]), parser.bounds[sampled_indices], n_frames=num_anchors + 1, n_rots=2, zrate=0.5, endpoint=False, )[1:] @ np.diagflat([1, -1, -1, 1]) elif traj_prior == "interpolated": assert num_inputs > 1 anchor_c2ws = generate_interpolated_path( c2ws[split_dict["train_ids"], :3], round((num_anchors + 1) / (num_inputs - 1)), endpoint=False, )[1 : num_anchors + 1] elif traj_prior == "orbit": c2ws_th = torch.as_tensor(c2ws) lookat = get_lookat( c2ws_th[sampled_indices, :3, 3], c2ws_th[sampled_indices, :3, 2], ) anchor_c2ws = torch.linalg.inv( get_arc_horizontal_w2cs( torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]), lookat, -F.normalize( c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0), dim=-1, ), num_frames=num_anchors + 1, endpoint=False, ) ).numpy()[1:, :3] else: anchor_c2ws = None # anchor_Ks is default to be the first from target_Ks all_imgs_path = [all_imgs_path[i] for i in sampled_indices] c2ws = c2ws[sampled_indices] Ks = Ks[sampled_indices] # absolute to relative indices input_indices = compute_relative_inds( sampled_indices, np.array(split_dict["train_ids"]), ) anchor_indices = np.arange( sampled_indices.shape[0], sampled_indices.shape[0] + num_anchors, ).tolist() # the order has no meaning here elif task == "img2vid": num_targets = len(all_imgs_path) - num_inputs num_anchors = infer_prior_stats( T, num_inputs, num_total_frames=num_targets, version_dict=version_dict, ) input_indices = split_dict["train_ids"] anchor_indices = infer_prior_inds( c2ws, num_prior_frames=num_anchors, input_frame_indices=input_indices, options=options, ).tolist() num_anchors = len(anchor_indices) anchor_c2ws = c2ws[anchor_indices, :3] anchor_Ks = Ks[anchor_indices] elif task == "img2trajvid": num_anchors = infer_prior_stats( T, num_inputs, num_total_frames=num_targets, version_dict=version_dict, ) target_c2ws = c2ws[split_dict["test_ids"], :3] target_Ks = Ks[split_dict["test_ids"]] anchor_c2ws = target_c2ws[ np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) ] anchor_Ks = target_Ks[ np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) ] sampled_indices = split_dict["train_ids"] + split_dict["test_ids"] all_imgs_path = [all_imgs_path[i] for i in sampled_indices] c2ws = c2ws[sampled_indices] Ks = Ks[sampled_indices] input_indices = np.arange(num_inputs).tolist() anchor_indices = np.linspace( num_inputs, num_inputs + num_targets - 1, num_anchors ).tolist() else: raise ValueError(f"Unknown task: {task}") return ( all_imgs_path, num_inputs, num_targets, input_indices, anchor_indices, torch.tensor(c2ws[:, :3]).float(), torch.tensor(Ks).float(), (torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None), (torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None), ) def main( data_path, data_items=None, task="img2img", save_subdir="", H=None, W=None, T=None, use_traj_prior=False, **overwrite_options, ): if H is not None: VERSION_DICT["H"] = H if W is not None: VERSION_DICT["W"] = W if T is not None: VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T options = VERSION_DICT["options"] options["chunk_strategy"] = "nearest-gt" options["video_save_fps"] = 30.0 options["beta_linear_start"] = 5e-6 options["log_snr_shift"] = 2.4 options["guider_types"] = 1 options["cfg"] = 2.0 options["camera_scale"] = 2.0 options["num_steps"] = 50 options["cfg_min"] = 1.2 options["encoding_t"] = 1 options["decoding_t"] = 1 options["num_inputs"] = None options["seed"] = 23 options.update(overwrite_options) num_inputs = options["num_inputs"] seed = options["seed"] if data_items is not None: if not isinstance(data_items, (list, tuple)): data_items = data_items.split(",") scenes = [os.path.join(data_path, item) for item in data_items] else: scenes = glob.glob(osp.join(data_path, "*")) for scene in tqdm(scenes): save_path_scene = os.path.join( WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0] ) if options.get("skip_saved", False) and os.path.exists( os.path.join(save_path_scene, "transforms.json") ): print(f"Skipping {scene} as it is already sampled.") continue # parse_task -> infer_prior_stats modifies VERSION_DICT["T"] in-place. ( all_imgs_path, num_inputs, num_targets, input_indices, anchor_indices, c2ws, Ks, anchor_c2ws, anchor_Ks, ) = parse_task( task, scene, num_inputs, VERSION_DICT["T"], VERSION_DICT, ) assert num_inputs is not None # Create image conditioning. image_cond = { "img": all_imgs_path, "input_indices": input_indices, "prior_indices": anchor_indices, } # Create camera conditioning. camera_cond = { "c2w": c2ws.clone(), "K": Ks.clone(), "input_indices": list(range(num_inputs + num_targets)), } # run_one_scene -> transform_img_and_K modifies VERSION_DICT["H"] and VERSION_DICT["W"] in-place. video_path_generator = run_one_scene( task, VERSION_DICT, # H, W maybe updated in run_one_scene model=MODEL, ae=AE, conditioner=CONDITIONER, denoiser=DENOISER, image_cond=image_cond, camera_cond=camera_cond, save_path=save_path_scene, use_traj_prior=use_traj_prior, traj_prior_Ks=anchor_Ks, traj_prior_c2ws=anchor_c2ws, seed=seed, # to ensure sampled video can be reproduced in regardless of start and i ) for _ in video_path_generator: pass # Convert from OpenCV to OpenGL camera format. c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float() img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png"))) if len(img_paths) != len(c2ws): input_img_paths = sorted( glob.glob(osp.join(save_path_scene, "input", "*.png")) ) assert len(img_paths) == num_targets assert len(input_img_paths) == num_inputs assert c2ws.shape[0] == num_inputs + num_targets target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices] img_paths = [ input_img_paths[input_indices.index(i)] if i in input_indices else img_paths[target_indices.index(i)] for i in range(c2ws.shape[0]) ] create_transforms_simple( save_path=save_path_scene, img_paths=img_paths, img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat( num_inputs + num_targets, 0 ), c2ws=c2ws, Ks=Ks, ) if __name__ == "__main__": fire.Fire(main)