|
import copy |
|
import json |
|
import os |
|
import os.path as osp |
|
import queue |
|
import secrets |
|
import threading |
|
import time |
|
from datetime import datetime |
|
from glob import glob |
|
from pathlib import Path |
|
from typing import Literal |
|
|
|
import gradio as gr |
|
import httpx |
|
import imageio.v3 as iio |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import tyro |
|
import viser |
|
import viser.transforms as vt |
|
from einops import rearrange |
|
from gradio import networking |
|
from gradio.context import LocalContext |
|
from gradio.tunneling import CERTIFICATE_PATH, Tunnel |
|
|
|
from seva.eval import ( |
|
IS_TORCH_NIGHTLY, |
|
chunk_input_and_test, |
|
create_transforms_simple, |
|
infer_prior_stats, |
|
run_one_scene, |
|
transform_img_and_K, |
|
) |
|
from seva.geometry import ( |
|
DEFAULT_FOV_RAD, |
|
get_default_intrinsics, |
|
get_preset_pose_fov, |
|
normalize_scene, |
|
) |
|
from seva.gui import define_gui |
|
from seva.model import SGMWrapper |
|
from seva.modules.autoencoder import AutoEncoder |
|
from seva.modules.conditioner import CLIPConditioner |
|
from seva.modules.preprocessor import Dust3rPipeline |
|
from seva.sampling import DDPMDiscretization, DiscreteDenoiser |
|
from seva.utils import load_model |
|
|
|
device = "cuda:0" |
|
|
|
|
|
|
|
WORK_DIR = "work_dirs/demo_gr" |
|
MAX_SESSIONS = 1 |
|
ADVANCE_EXAMPLE_MAP = [ |
|
( |
|
"assets/advance/blue-car.jpg", |
|
["assets/advance/blue-car.jpg"], |
|
), |
|
( |
|
"assets/advance/garden-4_0.jpg", |
|
[ |
|
"assets/advance/garden-4_0.jpg", |
|
"assets/advance/garden-4_1.jpg", |
|
"assets/advance/garden-4_2.jpg", |
|
"assets/advance/garden-4_3.jpg", |
|
], |
|
), |
|
( |
|
"assets/advance/vgg-lab-4_0.png", |
|
[ |
|
"assets/advance/vgg-lab-4_0.png", |
|
"assets/advance/vgg-lab-4_1.png", |
|
"assets/advance/vgg-lab-4_2.png", |
|
"assets/advance/vgg-lab-4_3.png", |
|
], |
|
), |
|
( |
|
"assets/advance/telebooth-2_0.jpg", |
|
[ |
|
"assets/advance/telebooth-2_0.jpg", |
|
"assets/advance/telebooth-2_1.jpg", |
|
], |
|
), |
|
( |
|
"assets/advance/backyard-7_0.jpg", |
|
[ |
|
"assets/advance/backyard-7_0.jpg", |
|
"assets/advance/backyard-7_1.jpg", |
|
"assets/advance/backyard-7_2.jpg", |
|
"assets/advance/backyard-7_3.jpg", |
|
"assets/advance/backyard-7_4.jpg", |
|
"assets/advance/backyard-7_5.jpg", |
|
"assets/advance/backyard-7_6.jpg", |
|
], |
|
), |
|
] |
|
|
|
if IS_TORCH_NIGHTLY: |
|
COMPILE = True |
|
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" |
|
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" |
|
else: |
|
COMPILE = False |
|
|
|
|
|
DUST3R = Dust3rPipeline(device=device) |
|
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) |
|
AE = AutoEncoder(chunk_size=1).to(device) |
|
CONDITIONER = CLIPConditioner().to(device) |
|
DISCRETIZATION = DDPMDiscretization() |
|
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) |
|
VERSION_DICT = { |
|
"H": 576, |
|
"W": 576, |
|
"T": 21, |
|
"C": 4, |
|
"f": 8, |
|
"options": {}, |
|
} |
|
SERVERS = {} |
|
ABORT_EVENTS = {} |
|
|
|
if COMPILE: |
|
MODEL = torch.compile(MODEL) |
|
CONDITIONER = torch.compile(CONDITIONER) |
|
AE = torch.compile(AE) |
|
|
|
|
|
class SevaRenderer(object): |
|
def __init__(self, server: viser.ViserServer): |
|
self.server = server |
|
self.gui_state = None |
|
|
|
def preprocess( |
|
self, input_img_path_or_tuples: list[tuple[str, None]] | str |
|
) -> tuple[dict, dict, dict]: |
|
|
|
|
|
|
|
shorter: int = 576 |
|
|
|
shorter = round(shorter / 64) * 64 |
|
|
|
if isinstance(input_img_path_or_tuples, str): |
|
|
|
input_imgs = torch.as_tensor( |
|
iio.imread(input_img_path_or_tuples) / 255.0, dtype=torch.float32 |
|
)[None, ..., :3] |
|
input_imgs = transform_img_and_K( |
|
input_imgs.permute(0, 3, 1, 2), |
|
shorter, |
|
K=None, |
|
size_stride=64, |
|
)[0].permute(0, 2, 3, 1) |
|
input_Ks = get_default_intrinsics( |
|
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1] |
|
) |
|
input_c2ws = torch.eye(4)[None] |
|
|
|
|
|
time.sleep(0.1) |
|
return ( |
|
{ |
|
"input_imgs": input_imgs, |
|
"input_Ks": input_Ks, |
|
"input_c2ws": input_c2ws, |
|
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]), |
|
"points": [np.zeros((0, 3))], |
|
"point_colors": [np.zeros((0, 3))], |
|
"scene_scale": 1.0, |
|
}, |
|
gr.update(visible=False), |
|
gr.update(), |
|
) |
|
else: |
|
|
|
img_paths = [p for (p, _) in input_img_path_or_tuples] |
|
( |
|
input_imgs, |
|
input_Ks, |
|
input_c2ws, |
|
points, |
|
point_colors, |
|
) = DUST3R.infer_cameras_and_points(img_paths) |
|
num_inputs = len(img_paths) |
|
if num_inputs == 1: |
|
input_imgs, input_Ks, input_c2ws, points, point_colors = ( |
|
input_imgs[:1], |
|
input_Ks[:1], |
|
input_c2ws[:1], |
|
points[:1], |
|
point_colors[:1], |
|
) |
|
input_imgs = [img[..., :3] for img in input_imgs] |
|
|
|
point_chunks = [p.shape[0] for p in points] |
|
point_indices = np.cumsum(point_chunks)[:-1] |
|
input_c2ws, points, _ = normalize_scene( |
|
input_c2ws, |
|
np.concatenate(points, 0), |
|
camera_center_method="poses", |
|
) |
|
points = np.split(points, point_indices, 0) |
|
|
|
scene_scale = np.median( |
|
np.ptp(np.concatenate([input_c2ws[:, :3, 3], *points], 0), -1) |
|
) |
|
input_c2ws[:, :3, 3] /= scene_scale |
|
points = [point / scene_scale for point in points] |
|
input_imgs = [ |
|
torch.as_tensor(img / 255.0, dtype=torch.float32) for img in input_imgs |
|
] |
|
input_Ks = torch.as_tensor(input_Ks) |
|
input_c2ws = torch.as_tensor(input_c2ws) |
|
new_input_imgs, new_input_Ks = [], [] |
|
for img, K in zip(input_imgs, input_Ks): |
|
img = rearrange(img, "h w c -> 1 c h w") |
|
|
|
|
|
img, K = transform_img_and_K(img, shorter, K=K[None], size_stride=64) |
|
assert isinstance(K, torch.Tensor) |
|
K = K / K.new_tensor([img.shape[-1], img.shape[-2], 1])[:, None] |
|
new_input_imgs.append(img) |
|
new_input_Ks.append(K) |
|
input_imgs = torch.cat(new_input_imgs, 0) |
|
input_imgs = rearrange(input_imgs, "b c h w -> b h w c")[..., :3] |
|
input_Ks = torch.cat(new_input_Ks, 0) |
|
return ( |
|
{ |
|
"input_imgs": input_imgs, |
|
"input_Ks": input_Ks, |
|
"input_c2ws": input_c2ws, |
|
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]), |
|
"points": points, |
|
"point_colors": point_colors, |
|
"scene_scale": scene_scale, |
|
}, |
|
gr.update(visible=False), |
|
gr.update() |
|
if num_inputs <= 10 |
|
else gr.update(choices=["interp"], value="interp"), |
|
) |
|
|
|
def visualize_scene(self, preprocessed: dict): |
|
server = self.server |
|
server.scene.reset() |
|
server.gui.reset() |
|
set_bkgd_color(server) |
|
|
|
( |
|
input_imgs, |
|
input_Ks, |
|
input_c2ws, |
|
input_wh, |
|
points, |
|
point_colors, |
|
scene_scale, |
|
) = ( |
|
preprocessed["input_imgs"], |
|
preprocessed["input_Ks"], |
|
preprocessed["input_c2ws"], |
|
preprocessed["input_wh"], |
|
preprocessed["points"], |
|
preprocessed["point_colors"], |
|
preprocessed["scene_scale"], |
|
) |
|
W, H = input_wh |
|
|
|
server.scene.set_up_direction(-input_c2ws[..., :3, 1].mean(0).numpy()) |
|
|
|
|
|
assert input_imgs[0].shape[:2] == (H, W) |
|
if H > W: |
|
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 0, 0].item())) |
|
else: |
|
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 1, 1].item())) |
|
init_fov_deg = float(init_fov / np.pi * 180.0) |
|
|
|
frustum_nodes, pcd_nodes = [], [] |
|
for i in range(len(input_imgs)): |
|
K = input_Ks[i] |
|
frustum = server.scene.add_camera_frustum( |
|
f"/scene_assets/cameras/{i}", |
|
fov=2 * np.arctan(1 / (2 * K[1, 1].item())), |
|
aspect=W / H, |
|
scale=0.1 * scene_scale, |
|
image=(input_imgs[i].numpy() * 255.0).astype(np.uint8), |
|
wxyz=vt.SO3.from_matrix(input_c2ws[i, :3, :3].numpy()).wxyz, |
|
position=input_c2ws[i, :3, 3].numpy(), |
|
) |
|
|
|
def get_handler(frustum): |
|
def handler(event: viser.GuiEvent) -> None: |
|
assert event.client_id is not None |
|
client = server.get_clients()[event.client_id] |
|
with client.atomic(): |
|
client.camera.position = frustum.position |
|
client.camera.wxyz = frustum.wxyz |
|
|
|
|
|
look_direction = vt.SO3(frustum.wxyz).as_matrix()[:, 2] |
|
position_origin = -frustum.position |
|
client.camera.look_at = ( |
|
frustum.position |
|
+ np.dot(look_direction, position_origin) |
|
/ np.linalg.norm(position_origin) |
|
* look_direction |
|
) |
|
|
|
return handler |
|
|
|
frustum.on_click(get_handler(frustum)) |
|
frustum_nodes.append(frustum) |
|
|
|
pcd = server.scene.add_point_cloud( |
|
f"/scene_assets/points/{i}", |
|
points[i], |
|
point_colors[i], |
|
point_size=0.01 * scene_scale, |
|
point_shape="circle", |
|
) |
|
pcd_nodes.append(pcd) |
|
|
|
with server.gui.add_folder("Scene scale", expand_by_default=False, order=200): |
|
camera_scale_slider = server.gui.add_slider( |
|
"Log camera scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1 |
|
) |
|
|
|
@camera_scale_slider.on_update |
|
def _(_) -> None: |
|
for i in range(len(frustum_nodes)): |
|
frustum_nodes[i].scale = ( |
|
0.1 * scene_scale * 10**camera_scale_slider.value |
|
) |
|
|
|
point_scale_slider = server.gui.add_slider( |
|
"Log point scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1 |
|
) |
|
|
|
@point_scale_slider.on_update |
|
def _(_) -> None: |
|
for i in range(len(pcd_nodes)): |
|
pcd_nodes[i].point_size = ( |
|
0.01 * scene_scale * 10**point_scale_slider.value |
|
) |
|
|
|
self.gui_state = define_gui( |
|
server, |
|
init_fov=init_fov_deg, |
|
img_wh=input_wh, |
|
scene_scale=scene_scale, |
|
) |
|
|
|
def get_target_c2ws_and_Ks_from_gui(self, preprocessed: dict): |
|
input_wh = preprocessed["input_wh"] |
|
W, H = input_wh |
|
gui_state = self.gui_state |
|
assert gui_state is not None and gui_state.camera_traj_list is not None |
|
target_c2ws, target_Ks = [], [] |
|
for item in gui_state.camera_traj_list: |
|
target_c2ws.append(item["w2c"]) |
|
assert item["img_wh"] == input_wh |
|
K = np.array(item["K"]).reshape(3, 3) / np.array([W, H, 1])[:, None] |
|
target_Ks.append(K) |
|
target_c2ws = torch.as_tensor( |
|
np.linalg.inv(np.array(target_c2ws).reshape(-1, 4, 4)) |
|
) |
|
target_Ks = torch.as_tensor(np.array(target_Ks).reshape(-1, 3, 3)) |
|
return target_c2ws, target_Ks |
|
|
|
def get_target_c2ws_and_Ks_from_preset( |
|
self, |
|
preprocessed: dict, |
|
preset_traj: Literal[ |
|
"orbit", |
|
"spiral", |
|
"lemniscate", |
|
"zoom-in", |
|
"zoom-out", |
|
"dolly zoom-in", |
|
"dolly zoom-out", |
|
"move-forward", |
|
"move-backward", |
|
"move-up", |
|
"move-down", |
|
"move-left", |
|
"move-right", |
|
], |
|
num_frames: int, |
|
zoom_factor: float | None, |
|
): |
|
img_wh = preprocessed["input_wh"] |
|
start_c2w = preprocessed["input_c2ws"][0] |
|
start_w2c = torch.linalg.inv(start_c2w) |
|
look_at = torch.tensor([0, 0, 10]) |
|
start_fov = DEFAULT_FOV_RAD |
|
target_c2ws, target_fovs = get_preset_pose_fov( |
|
preset_traj, |
|
num_frames, |
|
start_w2c, |
|
look_at, |
|
-start_c2w[:3, 1], |
|
start_fov, |
|
spiral_radii=[1.0, 1.0, 0.5], |
|
zoom_factor=zoom_factor, |
|
) |
|
target_c2ws = torch.as_tensor(target_c2ws) |
|
target_fovs = torch.as_tensor(target_fovs) |
|
target_Ks = get_default_intrinsics( |
|
target_fovs, |
|
aspect_ratio=img_wh[0] / img_wh[1], |
|
) |
|
return target_c2ws, target_Ks |
|
|
|
def export_output_data(self, preprocessed: dict, output_dir: str): |
|
input_imgs, input_Ks, input_c2ws, input_wh = ( |
|
preprocessed["input_imgs"], |
|
preprocessed["input_Ks"], |
|
preprocessed["input_c2ws"], |
|
preprocessed["input_wh"], |
|
) |
|
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed) |
|
|
|
num_inputs = len(input_imgs) |
|
num_targets = len(target_c2ws) |
|
|
|
input_imgs = (input_imgs.cpu().numpy() * 255.0).astype(np.uint8) |
|
input_c2ws = input_c2ws.cpu().numpy() |
|
input_Ks = input_Ks.cpu().numpy() |
|
target_c2ws = target_c2ws.cpu().numpy() |
|
target_Ks = target_Ks.cpu().numpy() |
|
img_whs = np.array(input_wh)[None].repeat(len(input_imgs) + len(target_Ks), 0) |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
img_paths = [] |
|
for i, img in enumerate(input_imgs): |
|
iio.imwrite(img_path := osp.join(output_dir, f"{i:03d}.png"), img) |
|
img_paths.append(img_path) |
|
for i in range(num_targets): |
|
iio.imwrite( |
|
img_path := osp.join(output_dir, f"{i + num_inputs:03d}.png"), |
|
np.zeros((input_wh[1], input_wh[0], 3), dtype=np.uint8), |
|
) |
|
img_paths.append(img_path) |
|
|
|
|
|
all_c2ws = np.concatenate([input_c2ws, target_c2ws]) |
|
all_Ks = np.concatenate([input_Ks, target_Ks]) |
|
all_c2ws = all_c2ws @ np.diag([1, -1, -1, 1]) |
|
create_transforms_simple(output_dir, img_paths, img_whs, all_c2ws, all_Ks) |
|
split_dict = { |
|
"train_ids": list(range(num_inputs)), |
|
"test_ids": list(range(num_inputs, num_inputs + num_targets)), |
|
} |
|
with open( |
|
osp.join(output_dir, f"train_test_split_{num_inputs}.json"), "w" |
|
) as f: |
|
json.dump(split_dict, f, indent=4) |
|
gr.Info(f"Output data saved to {output_dir}", duration=1) |
|
|
|
def render( |
|
self, |
|
preprocessed: dict, |
|
session_hash: str, |
|
seed: int, |
|
chunk_strategy: str, |
|
cfg: float, |
|
preset_traj: Literal[ |
|
"orbit", |
|
"spiral", |
|
"lemniscate", |
|
"zoom-in", |
|
"zoom-out", |
|
"dolly zoom-in", |
|
"dolly zoom-out", |
|
"move-forward", |
|
"move-backward", |
|
"move-up", |
|
"move-down", |
|
"move-left", |
|
"move-right", |
|
] |
|
| None, |
|
num_frames: int | None, |
|
zoom_factor: float | None, |
|
camera_scale: float, |
|
): |
|
render_name = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
render_dir = osp.join(WORK_DIR, render_name) |
|
|
|
input_imgs, input_Ks, input_c2ws, (W, H) = ( |
|
preprocessed["input_imgs"], |
|
preprocessed["input_Ks"], |
|
preprocessed["input_c2ws"], |
|
preprocessed["input_wh"], |
|
) |
|
num_inputs = len(input_imgs) |
|
if preset_traj is None: |
|
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed) |
|
else: |
|
assert num_frames is not None |
|
assert num_inputs == 1 |
|
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype) |
|
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset( |
|
preprocessed, preset_traj, num_frames, zoom_factor |
|
) |
|
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0) |
|
all_Ks = ( |
|
torch.cat([input_Ks, target_Ks], 0) |
|
* input_Ks.new_tensor([W, H, 1])[:, None] |
|
) |
|
num_targets = len(target_c2ws) |
|
input_indices = list(range(num_inputs)) |
|
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist() |
|
|
|
T = VERSION_DICT["T"] |
|
version_dict = copy.deepcopy(VERSION_DICT) |
|
num_anchors = infer_prior_stats( |
|
T, |
|
num_inputs, |
|
num_total_frames=num_targets, |
|
version_dict=version_dict, |
|
) |
|
|
|
T = version_dict["T"] |
|
assert isinstance(num_anchors, int) |
|
anchor_indices = np.linspace( |
|
num_inputs, |
|
num_inputs + num_targets - 1, |
|
num_anchors, |
|
).tolist() |
|
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]] |
|
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]] |
|
|
|
all_imgs_np = ( |
|
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy() |
|
* 255.0 |
|
).astype(np.uint8) |
|
image_cond = { |
|
"img": all_imgs_np, |
|
"input_indices": input_indices, |
|
"prior_indices": anchor_indices, |
|
} |
|
|
|
camera_cond = { |
|
"c2w": all_c2ws, |
|
"K": all_Ks, |
|
"input_indices": list(range(num_inputs + num_targets)), |
|
} |
|
|
|
num_steps = 50 |
|
options_ori = VERSION_DICT["options"] |
|
options = copy.deepcopy(options_ori) |
|
options["chunk_strategy"] = chunk_strategy |
|
options["video_save_fps"] = 30.0 |
|
options["beta_linear_start"] = 5e-6 |
|
options["log_snr_shift"] = 2.4 |
|
options["guider_types"] = [1, 2] |
|
options["cfg"] = [ |
|
float(cfg), |
|
3.0 if num_inputs >= 9 else 2.0, |
|
] |
|
options["camera_scale"] = camera_scale |
|
options["num_steps"] = num_steps |
|
options["cfg_min"] = 1.2 |
|
options["encoding_t"] = 1 |
|
options["decoding_t"] = 1 |
|
assert session_hash in ABORT_EVENTS |
|
abort_event = ABORT_EVENTS[session_hash] |
|
abort_event.clear() |
|
options["abort_event"] = abort_event |
|
task = "img2trajvid" |
|
|
|
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T |
|
chunk_strategy_first_pass = options.get( |
|
"chunk_strategy_first_pass", "gt-nearest" |
|
) |
|
num_chunks_0 = len( |
|
chunk_input_and_test( |
|
T_first_pass, |
|
input_c2ws, |
|
anchor_c2ws, |
|
input_indices, |
|
image_cond["prior_indices"], |
|
options={**options, "sampler_verbose": False}, |
|
task=task, |
|
chunk_strategy=chunk_strategy_first_pass, |
|
gt_input_inds=list(range(input_c2ws.shape[0])), |
|
)[1] |
|
) |
|
|
|
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist() |
|
anchor_indices = np.array(input_indices + anchor_indices)[ |
|
anchor_argsort |
|
].tolist() |
|
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])] |
|
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[ |
|
anchor_argsort |
|
] |
|
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T |
|
chunk_strategy = options.get("chunk_strategy", "nearest") |
|
num_chunks_1 = len( |
|
chunk_input_and_test( |
|
T_second_pass, |
|
anchor_c2ws_second_pass, |
|
target_c2ws, |
|
anchor_indices, |
|
target_indices, |
|
options={**options, "sampler_verbose": False}, |
|
task=task, |
|
chunk_strategy=chunk_strategy, |
|
gt_input_inds=gt_input_inds, |
|
)[1] |
|
) |
|
second_pass_pbar = gr.Progress().tqdm( |
|
iterable=None, |
|
desc="Second pass sampling", |
|
total=num_chunks_1 * num_steps, |
|
) |
|
first_pass_pbar = gr.Progress().tqdm( |
|
iterable=None, |
|
desc="First pass sampling", |
|
total=num_chunks_0 * num_steps, |
|
) |
|
video_path_generator = run_one_scene( |
|
task=task, |
|
version_dict={ |
|
"H": H, |
|
"W": W, |
|
"T": T, |
|
"C": VERSION_DICT["C"], |
|
"f": VERSION_DICT["f"], |
|
"options": options, |
|
}, |
|
model=MODEL, |
|
ae=AE, |
|
conditioner=CONDITIONER, |
|
denoiser=DENOISER, |
|
image_cond=image_cond, |
|
camera_cond=camera_cond, |
|
save_path=render_dir, |
|
use_traj_prior=True, |
|
traj_prior_c2ws=anchor_c2ws, |
|
traj_prior_Ks=anchor_Ks, |
|
seed=seed, |
|
gradio=True, |
|
first_pass_pbar=first_pass_pbar, |
|
second_pass_pbar=second_pass_pbar, |
|
abort_event=abort_event, |
|
) |
|
output_queue = queue.Queue() |
|
|
|
blocks = LocalContext.blocks.get() |
|
event_id = LocalContext.event_id.get() |
|
|
|
def worker(): |
|
|
|
|
|
LocalContext.blocks.set(blocks) |
|
LocalContext.event_id.set(event_id) |
|
for i, video_path in enumerate(video_path_generator): |
|
if i == 0: |
|
output_queue.put( |
|
( |
|
video_path, |
|
gr.update(), |
|
gr.update(), |
|
gr.update(), |
|
) |
|
) |
|
elif i == 1: |
|
output_queue.put( |
|
( |
|
video_path, |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
) |
|
else: |
|
gr.Error("More than two passes during rendering.") |
|
|
|
thread = threading.Thread(target=worker, daemon=True) |
|
thread.start() |
|
|
|
while thread.is_alive() or not output_queue.empty(): |
|
if abort_event.is_set(): |
|
thread.join() |
|
abort_event.clear() |
|
yield ( |
|
gr.update(), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
time.sleep(0.1) |
|
while not output_queue.empty(): |
|
yield output_queue.get() |
|
|
|
|
|
|
|
|
|
def setup_tunnel( |
|
local_host: str, local_port: int, share_token: str, share_server_address: str | None |
|
) -> tuple[str, Tunnel]: |
|
share_server_address = ( |
|
networking.GRADIO_SHARE_SERVER_ADDRESS |
|
if share_server_address is None |
|
else share_server_address |
|
) |
|
if share_server_address is None: |
|
try: |
|
response = httpx.get(networking.GRADIO_API_SERVER, timeout=30) |
|
payload = response.json()[0] |
|
remote_host, remote_port = payload["host"], int(payload["port"]) |
|
certificate = payload["root_ca"] |
|
Path(CERTIFICATE_PATH).parent.mkdir(parents=True, exist_ok=True) |
|
with open(CERTIFICATE_PATH, "w") as f: |
|
f.write(certificate) |
|
except Exception as e: |
|
raise RuntimeError( |
|
"Could not get share link from Gradio API Server." |
|
) from e |
|
else: |
|
remote_host, remote_port = share_server_address.split(":") |
|
remote_port = int(remote_port) |
|
tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token) |
|
address = tunnel.start_tunnel() |
|
return address, tunnel |
|
|
|
|
|
def set_bkgd_color(server: viser.ViserServer | viser.ClientHandle): |
|
server.scene.set_background_image(np.array([[[39, 39, 42]]], dtype=np.uint8)) |
|
|
|
|
|
def start_server_and_abort_event(request: gr.Request): |
|
server = viser.ViserServer() |
|
|
|
@server.on_client_connect |
|
def _(client: viser.ClientHandle): |
|
|
|
client.gui.configure_theme( |
|
dark_mode=True, |
|
show_share_button=False, |
|
control_layout="collapsible", |
|
) |
|
set_bkgd_color(client) |
|
|
|
print(f"Starting server {server.get_port()}") |
|
server_url, tunnel = setup_tunnel( |
|
local_host=server.get_host(), |
|
local_port=server.get_port(), |
|
share_token=secrets.token_urlsafe(32), |
|
share_server_address=None, |
|
) |
|
SERVERS[request.session_hash] = (server, tunnel) |
|
if server_url is None: |
|
raise gr.Error( |
|
"Failed to get a viewport URL. Please check your network connection." |
|
) |
|
|
|
time.sleep(1) |
|
|
|
ABORT_EVENTS[request.session_hash] = threading.Event() |
|
|
|
return ( |
|
SevaRenderer(server), |
|
gr.HTML( |
|
f'<iframe src="{server_url}" style="display: block; margin: auto; width: 100%; height: min(60vh, 600px);" frameborder="0"></iframe>', |
|
container=True, |
|
), |
|
request.session_hash, |
|
) |
|
|
|
|
|
def stop_server_and_abort_event(request: gr.Request): |
|
if request.session_hash in SERVERS: |
|
print(f"Stopping server {request.session_hash}") |
|
server, tunnel = SERVERS.pop(request.session_hash) |
|
server.stop() |
|
tunnel.kill() |
|
|
|
if request.session_hash in ABORT_EVENTS: |
|
print(f"Setting abort event {request.session_hash}") |
|
ABORT_EVENTS[request.session_hash].set() |
|
|
|
time.sleep(5) |
|
ABORT_EVENTS.pop(request.session_hash) |
|
|
|
|
|
def set_abort_event(request: gr.Request): |
|
if request.session_hash in ABORT_EVENTS: |
|
print(f"Setting abort event {request.session_hash}") |
|
ABORT_EVENTS[request.session_hash].set() |
|
|
|
|
|
def get_advance_examples(selection: gr.SelectData): |
|
index = selection.index |
|
return ( |
|
gr.Gallery(ADVANCE_EXAMPLE_MAP[index][1], visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.Gallery(visible=False), |
|
) |
|
|
|
|
|
def get_preamble(): |
|
gr.Markdown(""" |
|
# Stable Virtual Camera |
|
<span style="display: flex; flex-wrap: wrap; gap: 5px;"> |
|
<a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a> |
|
<a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a> |
|
<a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a> |
|
<a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> |
|
<a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a> |
|
<a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a> |
|
</span> |
|
|
|
Welcome to the demo of <strong>Stable Virtual Camera (Seva)</strong>! Given any number of input views and their cameras, this demo will allow you to generate novel views of a scene at any target camera of interest. |
|
|
|
We provide two ways to use our demo (selected by the tab below, documented [here](https://github.com/Stability-AI/stable-virtual-camera/blob/main/docs/GR_USAGE.md)): |
|
1. **[Basic](https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705)**: Given a single image, you can generate a video following one of our preset camera trajectories. |
|
2. **[Advanced](https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba)**: Given any number of input images, you can generate a video following any camera trajectory of your choice by our key-frame-based interface. |
|
|
|
> This is a research preview and comes with a few [limitations](https://stable-virtual-camera.github.io/#limitations): |
|
> - Limited quality in certain subjects due to training data, including humans, animals, and dynamic textures. |
|
> - Limited quality in some highly ambiguous scenes and camera trajectories, including extreme views and collision into objects. |
|
""") |
|
|
|
|
|
|
|
_APP_JS = """ |
|
function refresh() { |
|
const url = new URL(window.location); |
|
if (url.searchParams.get('__theme') !== 'dark') { |
|
url.searchParams.set('__theme', 'dark'); |
|
} |
|
} |
|
""" |
|
|
|
|
|
def main(server_port: int | None = None, share: bool = True): |
|
with gr.Blocks(js=_APP_JS) as app: |
|
renderer = gr.State() |
|
session_hash = gr.State() |
|
_ = get_preamble() |
|
with gr.Tabs(): |
|
with gr.Tab("Basic"): |
|
render_btn = gr.Button("Render video", interactive=False, render=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
|
|
preprocess_btn = gr.Button("Preprocess images", interactive=False) |
|
preprocess_progress = gr.Textbox( |
|
label="", |
|
visible=False, |
|
interactive=False, |
|
) |
|
with gr.Group(): |
|
input_imgs = gr.Image( |
|
type="filepath", |
|
label="Input", |
|
height=200, |
|
) |
|
_ = gr.Examples( |
|
examples=sorted(glob("assets/basic/*")), |
|
inputs=[input_imgs], |
|
label="Example", |
|
) |
|
chunk_strategy = gr.Dropdown( |
|
["interp", "interp-gt"], |
|
label="Chunk strategy", |
|
render=False, |
|
) |
|
preprocessed = gr.State() |
|
|
|
input_imgs.change( |
|
lambda img: gr.update(interactive=bool(img)), |
|
inputs=input_imgs, |
|
outputs=preprocess_btn, |
|
) |
|
preprocess_btn.click( |
|
lambda r, *args: [ |
|
*r.preprocess(*args), |
|
gr.update(interactive=True), |
|
], |
|
inputs=[renderer, input_imgs], |
|
outputs=[ |
|
preprocessed, |
|
preprocess_progress, |
|
chunk_strategy, |
|
render_btn, |
|
], |
|
show_progress_on=[preprocess_progress], |
|
concurrency_limit=1, |
|
concurrency_id="gpu_queue", |
|
) |
|
preprocess_btn.click( |
|
lambda: gr.update(visible=True), |
|
outputs=[preprocess_progress], |
|
) |
|
with gr.Row(): |
|
preset_traj = gr.Dropdown( |
|
choices=[ |
|
"orbit", |
|
"spiral", |
|
"lemniscate", |
|
"zoom-in", |
|
"zoom-out", |
|
"dolly zoom-in", |
|
"dolly zoom-out", |
|
"move-forward", |
|
"move-backward", |
|
"move-up", |
|
"move-down", |
|
"move-left", |
|
"move-right", |
|
], |
|
label="Preset trajectory", |
|
value="orbit", |
|
) |
|
num_frames = gr.Slider(30, 150, 80, label="#Frames") |
|
zoom_factor = gr.Slider( |
|
step=0.01, label="Zoom factor", visible=False |
|
) |
|
with gr.Row(): |
|
seed = gr.Number(value=23, label="Random seed") |
|
chunk_strategy.render() |
|
cfg = gr.Slider(1.0, 7.0, value=4.0, label="CFG value") |
|
with gr.Row(): |
|
camera_scale = gr.Slider( |
|
0.1, |
|
15.0, |
|
value=2.0, |
|
label="Camera scale", |
|
) |
|
|
|
def default_cfg_preset_traj(traj): |
|
|
|
|
|
if traj in ["zoom-out", "move-down"]: |
|
value = 5.0 |
|
elif traj in [ |
|
"orbit", |
|
"dolly zoom-out", |
|
"move-backward", |
|
"move-up", |
|
"move-left", |
|
"move-right", |
|
]: |
|
value = 4.0 |
|
else: |
|
value = 3.0 |
|
return value |
|
|
|
preset_traj.change( |
|
default_cfg_preset_traj, |
|
inputs=[preset_traj], |
|
outputs=[cfg], |
|
) |
|
preset_traj.change( |
|
lambda traj: gr.update( |
|
value=( |
|
10.0 if "dolly" in traj or "pan" in traj else 2.0 |
|
) |
|
), |
|
inputs=[preset_traj], |
|
outputs=[camera_scale], |
|
) |
|
|
|
def zoom_factor_preset_traj(traj): |
|
visible = traj in [ |
|
"zoom-in", |
|
"zoom-out", |
|
"dolly zoom-in", |
|
"dolly zoom-out", |
|
] |
|
is_zoomin = traj.endswith("zoom-in") |
|
if is_zoomin: |
|
minimum = 0.1 |
|
maximum = 0.5 |
|
value = 0.28 |
|
else: |
|
minimum = 1.2 |
|
maximum = 3 |
|
value = 1.5 |
|
return gr.update( |
|
visible=visible, |
|
minimum=minimum, |
|
maximum=maximum, |
|
value=value, |
|
) |
|
|
|
preset_traj.change( |
|
zoom_factor_preset_traj, |
|
inputs=[preset_traj], |
|
outputs=[zoom_factor], |
|
) |
|
with gr.Column(): |
|
with gr.Group(): |
|
abort_btn = gr.Button("Abort rendering", visible=False) |
|
render_btn.render() |
|
render_progress = gr.Textbox( |
|
label="", visible=False, interactive=False |
|
) |
|
output_video = gr.Video( |
|
label="Output", interactive=False, autoplay=True, loop=True |
|
) |
|
render_btn.click( |
|
lambda r, *args: (yield from r.render(*args)), |
|
inputs=[ |
|
renderer, |
|
preprocessed, |
|
session_hash, |
|
seed, |
|
chunk_strategy, |
|
cfg, |
|
preset_traj, |
|
num_frames, |
|
zoom_factor, |
|
camera_scale, |
|
], |
|
outputs=[ |
|
output_video, |
|
render_btn, |
|
abort_btn, |
|
render_progress, |
|
], |
|
show_progress_on=[render_progress], |
|
concurrency_id="gpu_queue", |
|
) |
|
render_btn.click( |
|
lambda: [ |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
], |
|
outputs=[render_btn, abort_btn, render_progress], |
|
) |
|
abort_btn.click(set_abort_event) |
|
with gr.Tab("Advanced"): |
|
render_btn = gr.Button("Render video", interactive=False, render=False) |
|
viewport = gr.HTML(container=True, render=False) |
|
gr.Timer(0.1).tick( |
|
lambda renderer: gr.update( |
|
interactive=renderer is not None |
|
and renderer.gui_state is not None |
|
and renderer.gui_state.camera_traj_list is not None |
|
), |
|
inputs=[renderer], |
|
outputs=[render_btn], |
|
) |
|
with gr.Row(): |
|
viewport.render() |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
|
|
preprocess_btn = gr.Button("Preprocess images", interactive=False) |
|
preprocess_progress = gr.Textbox( |
|
label="", |
|
visible=False, |
|
interactive=False, |
|
) |
|
with gr.Group(): |
|
input_imgs = gr.Gallery( |
|
interactive=True, |
|
label="Input", |
|
columns=4, |
|
height=200, |
|
) |
|
|
|
|
|
example_imgs = gr.Gallery( |
|
[e[0] for e in ADVANCE_EXAMPLE_MAP], |
|
allow_preview=False, |
|
preview=False, |
|
label="Example", |
|
columns=20, |
|
rows=1, |
|
height=115, |
|
) |
|
example_imgs_expander = gr.Gallery( |
|
visible=False, |
|
interactive=False, |
|
label="Example", |
|
preview=True, |
|
columns=20, |
|
rows=1, |
|
) |
|
chunk_strategy = gr.Dropdown( |
|
["interp-gt", "interp"], |
|
label="Chunk strategy", |
|
value="interp-gt", |
|
render=False, |
|
) |
|
with gr.Row(): |
|
example_imgs_backer = gr.Button( |
|
"Go back", visible=False |
|
) |
|
example_imgs_confirmer = gr.Button( |
|
"Confirm", visible=False |
|
) |
|
example_imgs.select( |
|
get_advance_examples, |
|
outputs=[ |
|
example_imgs_expander, |
|
example_imgs_confirmer, |
|
example_imgs_backer, |
|
example_imgs, |
|
], |
|
) |
|
example_imgs_confirmer.click( |
|
lambda x: ( |
|
x, |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(interactive=bool(x)) |
|
), |
|
inputs=[example_imgs_expander], |
|
outputs=[ |
|
input_imgs, |
|
example_imgs_expander, |
|
example_imgs_confirmer, |
|
example_imgs_backer, |
|
example_imgs, |
|
preprocess_btn |
|
], |
|
) |
|
example_imgs_backer.click( |
|
lambda: ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
), |
|
outputs=[ |
|
example_imgs_expander, |
|
example_imgs_confirmer, |
|
example_imgs_backer, |
|
example_imgs, |
|
], |
|
) |
|
preprocessed = gr.State() |
|
preprocess_btn.click( |
|
lambda r, *args: r.preprocess(*args), |
|
inputs=[renderer, input_imgs], |
|
outputs=[ |
|
preprocessed, |
|
preprocess_progress, |
|
chunk_strategy, |
|
], |
|
show_progress_on=[preprocess_progress], |
|
concurrency_id="gpu_queue", |
|
) |
|
preprocess_btn.click( |
|
lambda: gr.update(visible=True), |
|
outputs=[preprocess_progress], |
|
) |
|
preprocessed.change( |
|
lambda r, *args: r.visualize_scene(*args), |
|
inputs=[renderer, preprocessed], |
|
) |
|
with gr.Row(): |
|
seed = gr.Number(value=23, label="Random seed") |
|
chunk_strategy.render() |
|
cfg = gr.Slider(1.0, 7.0, value=3.0, label="CFG value") |
|
with gr.Row(): |
|
camera_scale = gr.Slider( |
|
0.1, |
|
15.0, |
|
value=2.0, |
|
label="Camera scale (useful for single-view input)", |
|
) |
|
with gr.Group(): |
|
output_data_dir = gr.Textbox(label="Output data directory") |
|
output_data_btn = gr.Button("Export output data") |
|
output_data_btn.click( |
|
lambda r, *args: r.export_output_data(*args), |
|
inputs=[renderer, preprocessed, output_data_dir], |
|
) |
|
with gr.Column(): |
|
with gr.Group(): |
|
abort_btn = gr.Button("Abort rendering", visible=False) |
|
render_btn.render() |
|
render_progress = gr.Textbox( |
|
label="", visible=False, interactive=False |
|
) |
|
output_video = gr.Video( |
|
label="Output", interactive=False, autoplay=True, loop=True |
|
) |
|
render_btn.click( |
|
lambda r, *args: (yield from r.render(*args)), |
|
inputs=[ |
|
renderer, |
|
preprocessed, |
|
session_hash, |
|
seed, |
|
chunk_strategy, |
|
cfg, |
|
gr.State(), |
|
gr.State(), |
|
gr.State(), |
|
camera_scale, |
|
], |
|
outputs=[ |
|
output_video, |
|
render_btn, |
|
abort_btn, |
|
render_progress, |
|
], |
|
show_progress_on=[render_progress], |
|
concurrency_id="gpu_queue", |
|
) |
|
render_btn.click( |
|
lambda: [ |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
], |
|
outputs=[render_btn, abort_btn, render_progress], |
|
) |
|
abort_btn.click(set_abort_event) |
|
|
|
|
|
app.load( |
|
start_server_and_abort_event, |
|
outputs=[renderer, viewport, session_hash], |
|
) |
|
app.unload(stop_server_and_abort_event) |
|
|
|
app.queue(max_size=5).launch( |
|
share=share, |
|
server_port=server_port, |
|
show_error=True, |
|
allowed_paths=[WORK_DIR], |
|
|
|
ssr_mode=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(main) |
|
|