|
import glob |
|
import os |
|
import os.path as osp |
|
|
|
import fire |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
from seva.data_io import get_parser |
|
from seva.eval import ( |
|
IS_TORCH_NIGHTLY, |
|
compute_relative_inds, |
|
create_transforms_simple, |
|
infer_prior_inds, |
|
infer_prior_stats, |
|
run_one_scene, |
|
) |
|
from seva.geometry import ( |
|
generate_interpolated_path, |
|
generate_spiral_path, |
|
get_arc_horizontal_w2cs, |
|
get_default_intrinsics, |
|
get_lookat, |
|
get_preset_pose_fov, |
|
) |
|
from seva.model import SGMWrapper |
|
from seva.modules.autoencoder import AutoEncoder |
|
from seva.modules.conditioner import CLIPConditioner |
|
from seva.sampling import DDPMDiscretization, DiscreteDenoiser |
|
from seva.utils import load_model |
|
|
|
device = "cuda:0" |
|
|
|
|
|
|
|
WORK_DIR = "work_dirs/demo" |
|
|
|
if IS_TORCH_NIGHTLY: |
|
COMPILE = True |
|
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" |
|
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" |
|
else: |
|
COMPILE = False |
|
|
|
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) |
|
AE = AutoEncoder(chunk_size=1).to(device) |
|
CONDITIONER = CLIPConditioner().to(device) |
|
DISCRETIZATION = DDPMDiscretization() |
|
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) |
|
VERSION_DICT = { |
|
"H": 576, |
|
"W": 576, |
|
"T": 21, |
|
"C": 4, |
|
"f": 8, |
|
"options": {}, |
|
} |
|
|
|
if COMPILE: |
|
MODEL = torch.compile(MODEL, dynamic=False) |
|
CONDITIONER = torch.compile(CONDITIONER, dynamic=False) |
|
AE = torch.compile(AE, dynamic=False) |
|
|
|
|
|
def parse_task( |
|
task, |
|
scene, |
|
num_inputs, |
|
T, |
|
version_dict, |
|
): |
|
options = version_dict["options"] |
|
|
|
anchor_indices = None |
|
anchor_c2ws = None |
|
anchor_Ks = None |
|
|
|
if task == "img2trajvid_s-prob": |
|
if num_inputs is not None: |
|
assert ( |
|
num_inputs == 1 |
|
), "Task `img2trajvid_s-prob` only support 1-view conditioning..." |
|
else: |
|
num_inputs = 1 |
|
num_targets = options.get("num_targets", T - 1) |
|
num_anchors = infer_prior_stats( |
|
T, |
|
num_inputs, |
|
num_total_frames=num_targets, |
|
version_dict=version_dict, |
|
) |
|
|
|
input_indices = [0] |
|
anchor_indices = np.linspace(1, num_targets, num_anchors).tolist() |
|
|
|
all_imgs_path = [scene] + [None] * num_targets |
|
|
|
c2ws, fovs = get_preset_pose_fov( |
|
option=options.get("traj_prior", "orbit"), |
|
num_frames=num_targets + 1, |
|
start_w2c=torch.eye(4), |
|
look_at=torch.Tensor([0, 0, 10]), |
|
) |
|
|
|
with Image.open(scene) as img: |
|
W, H = img.size |
|
aspect_ratio = W / H |
|
Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) |
|
Ks[:, :2] *= ( |
|
torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1) |
|
) |
|
Ks = Ks.numpy() |
|
|
|
anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]] |
|
anchor_Ks = Ks[[round(ind) for ind in anchor_indices]] |
|
|
|
else: |
|
parser = get_parser( |
|
parser_type="reconfusion", |
|
data_dir=scene, |
|
normalize=False, |
|
) |
|
all_imgs_path = parser.image_paths |
|
c2ws = parser.camtoworlds |
|
camera_ids = parser.camera_ids |
|
Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0) |
|
|
|
if num_inputs is None: |
|
assert len(parser.splits_per_num_input_frames.keys()) == 1 |
|
num_inputs = list(parser.splits_per_num_input_frames.keys())[0] |
|
split_dict = parser.splits_per_num_input_frames[num_inputs] |
|
elif isinstance(num_inputs, str): |
|
split_dict = parser.splits_per_num_input_frames[num_inputs] |
|
num_inputs = int(num_inputs.split("-")[0]) |
|
else: |
|
split_dict = parser.splits_per_num_input_frames[num_inputs] |
|
|
|
num_targets = len(split_dict["test_ids"]) |
|
|
|
if task == "img2img": |
|
|
|
|
|
num_anchors = infer_prior_stats( |
|
T, |
|
num_inputs, |
|
num_total_frames=num_targets, |
|
version_dict=version_dict, |
|
) |
|
|
|
sampled_indices = np.sort( |
|
np.array(split_dict["train_ids"] + split_dict["test_ids"]) |
|
) |
|
|
|
traj_prior = options.get("traj_prior", None) |
|
if traj_prior == "spiral": |
|
assert parser.bounds is not None |
|
anchor_c2ws = generate_spiral_path( |
|
c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]), |
|
parser.bounds[sampled_indices], |
|
n_frames=num_anchors + 1, |
|
n_rots=2, |
|
zrate=0.5, |
|
endpoint=False, |
|
)[1:] @ np.diagflat([1, -1, -1, 1]) |
|
elif traj_prior == "interpolated": |
|
assert num_inputs > 1 |
|
anchor_c2ws = generate_interpolated_path( |
|
c2ws[split_dict["train_ids"], :3], |
|
round((num_anchors + 1) / (num_inputs - 1)), |
|
endpoint=False, |
|
)[1 : num_anchors + 1] |
|
elif traj_prior == "orbit": |
|
c2ws_th = torch.as_tensor(c2ws) |
|
lookat = get_lookat( |
|
c2ws_th[sampled_indices, :3, 3], |
|
c2ws_th[sampled_indices, :3, 2], |
|
) |
|
anchor_c2ws = torch.linalg.inv( |
|
get_arc_horizontal_w2cs( |
|
torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]), |
|
lookat, |
|
-F.normalize( |
|
c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0), |
|
dim=-1, |
|
), |
|
num_frames=num_anchors + 1, |
|
endpoint=False, |
|
) |
|
).numpy()[1:, :3] |
|
else: |
|
anchor_c2ws = None |
|
|
|
|
|
all_imgs_path = [all_imgs_path[i] for i in sampled_indices] |
|
c2ws = c2ws[sampled_indices] |
|
Ks = Ks[sampled_indices] |
|
|
|
|
|
input_indices = compute_relative_inds( |
|
sampled_indices, |
|
np.array(split_dict["train_ids"]), |
|
) |
|
anchor_indices = np.arange( |
|
sampled_indices.shape[0], |
|
sampled_indices.shape[0] + num_anchors, |
|
).tolist() |
|
|
|
elif task == "img2vid": |
|
num_targets = len(all_imgs_path) - num_inputs |
|
num_anchors = infer_prior_stats( |
|
T, |
|
num_inputs, |
|
num_total_frames=num_targets, |
|
version_dict=version_dict, |
|
) |
|
|
|
input_indices = split_dict["train_ids"] |
|
anchor_indices = infer_prior_inds( |
|
c2ws, |
|
num_prior_frames=num_anchors, |
|
input_frame_indices=input_indices, |
|
options=options, |
|
).tolist() |
|
num_anchors = len(anchor_indices) |
|
anchor_c2ws = c2ws[anchor_indices, :3] |
|
anchor_Ks = Ks[anchor_indices] |
|
|
|
elif task == "img2trajvid": |
|
num_anchors = infer_prior_stats( |
|
T, |
|
num_inputs, |
|
num_total_frames=num_targets, |
|
version_dict=version_dict, |
|
) |
|
|
|
target_c2ws = c2ws[split_dict["test_ids"], :3] |
|
target_Ks = Ks[split_dict["test_ids"]] |
|
anchor_c2ws = target_c2ws[ |
|
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) |
|
] |
|
anchor_Ks = target_Ks[ |
|
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) |
|
] |
|
|
|
sampled_indices = split_dict["train_ids"] + split_dict["test_ids"] |
|
all_imgs_path = [all_imgs_path[i] for i in sampled_indices] |
|
c2ws = c2ws[sampled_indices] |
|
Ks = Ks[sampled_indices] |
|
|
|
input_indices = np.arange(num_inputs).tolist() |
|
anchor_indices = np.linspace( |
|
num_inputs, num_inputs + num_targets - 1, num_anchors |
|
).tolist() |
|
|
|
else: |
|
raise ValueError(f"Unknown task: {task}") |
|
|
|
return ( |
|
all_imgs_path, |
|
num_inputs, |
|
num_targets, |
|
input_indices, |
|
anchor_indices, |
|
torch.tensor(c2ws[:, :3]).float(), |
|
torch.tensor(Ks).float(), |
|
(torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None), |
|
(torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None), |
|
) |
|
|
|
|
|
def main( |
|
data_path, |
|
data_items=None, |
|
task="img2img", |
|
save_subdir="", |
|
H=None, |
|
W=None, |
|
T=None, |
|
use_traj_prior=False, |
|
**overwrite_options, |
|
): |
|
if H is not None: |
|
VERSION_DICT["H"] = H |
|
if W is not None: |
|
VERSION_DICT["W"] = W |
|
if T is not None: |
|
VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T |
|
|
|
options = VERSION_DICT["options"] |
|
options["chunk_strategy"] = "nearest-gt" |
|
options["video_save_fps"] = 30.0 |
|
options["beta_linear_start"] = 5e-6 |
|
options["log_snr_shift"] = 2.4 |
|
options["guider_types"] = 1 |
|
options["cfg"] = 2.0 |
|
options["camera_scale"] = 2.0 |
|
options["num_steps"] = 50 |
|
options["cfg_min"] = 1.2 |
|
options["encoding_t"] = 1 |
|
options["decoding_t"] = 1 |
|
options["num_inputs"] = None |
|
options["seed"] = 23 |
|
options.update(overwrite_options) |
|
|
|
num_inputs = options["num_inputs"] |
|
seed = options["seed"] |
|
|
|
if data_items is not None: |
|
if not isinstance(data_items, (list, tuple)): |
|
data_items = data_items.split(",") |
|
scenes = [os.path.join(data_path, item) for item in data_items] |
|
else: |
|
scenes = glob.glob(osp.join(data_path, "*")) |
|
|
|
for scene in tqdm(scenes): |
|
save_path_scene = os.path.join( |
|
WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0] |
|
) |
|
if options.get("skip_saved", False) and os.path.exists( |
|
os.path.join(save_path_scene, "transforms.json") |
|
): |
|
print(f"Skipping {scene} as it is already sampled.") |
|
continue |
|
|
|
|
|
( |
|
all_imgs_path, |
|
num_inputs, |
|
num_targets, |
|
input_indices, |
|
anchor_indices, |
|
c2ws, |
|
Ks, |
|
anchor_c2ws, |
|
anchor_Ks, |
|
) = parse_task( |
|
task, |
|
scene, |
|
num_inputs, |
|
VERSION_DICT["T"], |
|
VERSION_DICT, |
|
) |
|
assert num_inputs is not None |
|
|
|
image_cond = { |
|
"img": all_imgs_path, |
|
"input_indices": input_indices, |
|
"prior_indices": anchor_indices, |
|
} |
|
|
|
camera_cond = { |
|
"c2w": c2ws.clone(), |
|
"K": Ks.clone(), |
|
"input_indices": list(range(num_inputs + num_targets)), |
|
} |
|
|
|
video_path_generator = run_one_scene( |
|
task, |
|
VERSION_DICT, |
|
model=MODEL, |
|
ae=AE, |
|
conditioner=CONDITIONER, |
|
denoiser=DENOISER, |
|
image_cond=image_cond, |
|
camera_cond=camera_cond, |
|
save_path=save_path_scene, |
|
use_traj_prior=use_traj_prior, |
|
traj_prior_Ks=anchor_Ks, |
|
traj_prior_c2ws=anchor_c2ws, |
|
seed=seed, |
|
) |
|
for _ in video_path_generator: |
|
pass |
|
|
|
|
|
c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float() |
|
img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png"))) |
|
if len(img_paths) != len(c2ws): |
|
input_img_paths = sorted( |
|
glob.glob(osp.join(save_path_scene, "input", "*.png")) |
|
) |
|
assert len(img_paths) == num_targets |
|
assert len(input_img_paths) == num_inputs |
|
assert c2ws.shape[0] == num_inputs + num_targets |
|
target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices] |
|
img_paths = [ |
|
input_img_paths[input_indices.index(i)] |
|
if i in input_indices |
|
else img_paths[target_indices.index(i)] |
|
for i in range(c2ws.shape[0]) |
|
] |
|
create_transforms_simple( |
|
save_path=save_path_scene, |
|
img_paths=img_paths, |
|
img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat( |
|
num_inputs + num_targets, 0 |
|
), |
|
c2ws=c2ws, |
|
Ks=Ks, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|