CameraCtrl-svd / app.py
hao he
Modify the model to svd
5687730
raw
history blame
31.5 kB
import spaces
import argparse
import torch
import tempfile
import os
import cv2
import numpy as np
import gradio as gr
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
import matplotlib as mpl
from omegaconf import OmegaConf
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from inference_cameractrl import get_relative_pose, ray_condition, get_pipeline
from cameractrl.utils.util import save_videos_grid
cv2.setNumThreads(1)
mpl.use('agg')
#### Description ####
title = r"""<h1 align="center">CameraCtrl: Enabling Camera Control for Video Diffusion Models</h1>"""
subtitle = r"""<h2 align="center">CameraCtrl Image2Video with <a href='https://arxiv.org/abs/2311.15127' target='_blank'> <b>Stable Video Diffusion (SVD)</b> </a> <a href='https://huggingface.co/stabilityai/stable-video-diffusion-img2vid' target='_blank'> <b> model </b> </a> </h2>"""
description = r"""
<b>Official Gradio demo</b> for <a href='https://github.com/hehao13/CameraCtrl' target='_blank'><b>CameraCtrl: Enabling Camera Control for Video Diffusion Models</b></a>.<br>
CameraCtrl is capable of precisely controlling the camera trajectory during the video generation process.<br>
Note that, with SVD, CameraCtrl only support Image2Video now.<br>
"""
closing_words = r"""
---
If you are interested in this demo or CameraCtrl is helpful for you, please give us a ⭐ of the <a href='https://github.com/hehao13/CameraCtrl' target='_blank'> CameraCtrl</a> Github Repo !
[![GitHub Stars](https://img.shields.io/github/stars/hehao13/CameraCtrl
)](https://github.com/hehao13/CameraCtrl)
---
📝 **Citation**
<br>
If you find our paper or code is useful for your research, please consider citing:
```bibtex
@article{he2024cameractrl,
title={CameraCtrl: Enabling Camera Control for Text-to-Video Generation},
author={Hao He and Yinghao Xu and Yuwei Guo and Gordon Wetzstein and Bo Dai and Hongsheng Li and Ceyuan Yang},
journal={arXiv preprint arXiv:2404.02101},
year={2024}
}
```
📧 **Contact**
<br>
If you have any questions, please feel free to contact me at <b>[email protected]</b>.
**Acknowledgement**
<br>
We thank <a href='https://wzhouxiff.github.io/projects/MotionCtrl/' target='_blank'><b>MotionCtrl</b></a> and <a href='https://huggingface.co/spaces/lllyasviel/IC-Light' target='_blank'><b>IC-Light</b></a> for their gradio codes.<br>
"""
RESIZE_MODES = ['Resize then Center Crop', 'Directly resize']
CAMERA_TRAJECTORY_MODES = ["Provided Camera Trajectories", "Custom Camera Trajectories"]
height = 320
width = 576
num_frames = 14
device = "cuda" if torch.cuda.is_available() else "cpu"
config = "configs/train_cameractrl/svd_320_576_cameractrl.yaml"
model_id = "stabilityai/stable-video-diffusion-img2vid"
ckpt = "checkpoints/CameraCtrl_svdxt.ckpt"
if not os.path.exists(ckpt):
os.makedirs("checkpoints", exist_ok=True)
os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_SVD_ckpts/resolve/main/CameraCtrl_svd.ckpt?download=true")
os.system("mv CameraCtrl_svd.ckpt?download=true checkpoints/CameraCtrl_svdxt.ckpt")
model_config = OmegaConf.load(config)
pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
ckpt, True, device)
examples = [
[
"assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png",
"assets/pose_files/0bf152ef84195293.txt",
"Trajectory 1"
],
[
"assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png",
"assets/pose_files/0c9b371cc6225682.txt",
"Trajectory 2"
],
[
"assets/example_condition_images/Rocky_coastline_with_crashing_waves..png",
"assets/pose_files/0c11dbe781b1c11c.txt",
"Trajectory 3"
],
[
"assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png",
"assets/pose_files/0f47577ab3441480.txt",
"Trajectory 4"
],
[
"assets/example_condition_images/An_exploding_cheese_house..png",
"assets/pose_files/0f47577ab3441480.txt",
"Trajectory 4"
],
[
"assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png",
"assets/pose_files/0f68374b76390082.txt",
"Trajectory 5"
],
[
"assets/example_condition_images/Leaves_are_falling_from_trees..png",
"assets/pose_files/2c80f9eb0d3b2bb4.txt",
"Trajectory 6"
],
[
"assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png",
"assets/pose_files/2f25826f0d0ef09a.txt",
"Trajectory 7"
],
[
"assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png",
"assets/pose_files/3f79dc32d575bcdc.txt",
"Trajectory 8"
],
[
"assets/example_condition_images/A_car_running_on_Mars..png",
"assets/pose_files/4a2d6753676df096.txt",
"Trajectory 9"
],
]
class Camera(object):
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
class CameraPoseVisualizer:
def __init__(self, xlim, ylim, zlim):
self.fig = plt.figure(figsize=(18, 7))
self.ax = self.fig.add_subplot(projection='3d')
self.plotly_data = None # plotly data traces
self.ax.set_aspect("auto")
self.ax.set_xlim(xlim)
self.ax.set_ylim(ylim)
self.ax.set_zlim(zlim)
self.ax.set_xlabel('x')
self.ax.set_ylabel('y')
self.ax.set_zlabel('z')
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9 / 16, base_xval=1, zval=3):
vertex_std = np.array([[0, 0, 0, 1],
[base_xval, -base_xval * hw_ratio, zval, 1],
[base_xval, base_xval * hw_ratio, zval, 1],
[-base_xval, base_xval * hw_ratio, zval, 1],
[-base_xval, -base_xval * hw_ratio, zval, 1]])
vertex_transformed = vertex_std @ extrinsic.T
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
vertex_transformed[4, :-1]]]
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
self.ax.add_collection3d(
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
def colorbar(self, max_frame_length):
cmap = mpl.cm.rainbow
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical',
label='Frame Indexes')
def show(self):
plt.title('Camera Trajectory')
plt.show()
def get_c2w(w2cs):
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
camera_positions = np.asarray([c2w[:3, 3] for c2w in ret_poses]) # [n_frame, 3]
position_distances = [camera_positions[i] - camera_positions[i - 1] for i in range(1, len(camera_positions))]
xyz_max = np.max(camera_positions, axis=0)
xyz_min = np.min(camera_positions, axis=0)
xyz_ranges = xyz_max - xyz_min # [3, ]
max_range = np.max(xyz_ranges)
expected_xyz_ranges = 1
scale_ratio = expected_xyz_ranges / max_range
scaled_position_distances = [dis * scale_ratio for dis in position_distances] # [n_frame - 1]
scaled_camera_positions = [camera_positions[0], ]
scaled_camera_positions.extend([camera_positions[0] + np.sum(np.asarray(scaled_position_distances[:i]), axis=0)
for i in range(1, len(camera_positions))])
ret_poses = [np.concatenate(
(np.concatenate((ori_pose[:3, :3], cam_position[:, None]), axis=1), np.asarray([0, 0, 0, 1])[None]), axis=0)
for ori_pose, cam_position in zip(ret_poses, scaled_camera_positions)]
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
ret_poses = [transform_matrix @ x for x in ret_poses]
return np.array(ret_poses, dtype=np.float32)
def visualize_trajectory(trajectory_file):
with open(trajectory_file, 'r') as f:
poses = f.readlines()
w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]]
num_frames = len(w2cs)
last_row = np.zeros((1, 4))
last_row[0, -1] = 1.0
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
c2ws = get_c2w(w2cs)
visualizer = CameraPoseVisualizer([-1.2, 1.2], [-1.2, 1.2], [-1.2, 1.2])
for frame_idx, c2w in enumerate(c2ws):
visualizer.extrinsic2pyramid(c2w, frame_idx / num_frames, hw_ratio=9 / 16, base_xval=0.02, zval=0.1)
visualizer.colorbar(num_frames)
return visualizer.fig
vis_traj = visualize_trajectory('assets/pose_files/0bf152ef84195293.txt')
@torch.inference_mode()
def process_input_image(input_image, resize_mode):
global height, width
expected_hw_ratio = height / width
inp_w, inp_h = input_image.size
inp_hw_ratio = inp_h / inp_w
if inp_hw_ratio > expected_hw_ratio:
resized_height = inp_hw_ratio * width
resized_width = width
else:
resized_height = height
resized_width = height / inp_hw_ratio
resized_image = F.resize(input_image, size=[resized_height, resized_width])
if resize_mode == RESIZE_MODES[0]:
return_image = F.center_crop(resized_image, output_size=[height, width])
else:
return_image = resized_image
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), gr.update(
visible=True), gr.update(visible=True), gr.update(visible=True)
def update_camera_trajectories(trajectory_mode):
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
elif trajectory_mode == CAMERA_TRAJECTORY_MODES[1]:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def update_camera_args(trajectory_mode, provided_camera_trajectory, customized_trajectory_file):
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
res = "Provided " + str(provided_camera_trajectory)
else:
if customized_trajectory_file is None:
res = " "
else:
res = f"Customized trajectory file {customized_trajectory_file.name.split('/')[-1]}"
return res
def update_camera_args_reset():
return " "
def update_trajectory_vis_plot(camera_trajectory_args, provided_camera_trajectory, customized_trajectory_file):
if 'Provided' in camera_trajectory_args:
if provided_camera_trajectory == "Trajectory 1":
trajectory_file_path = "assets/pose_files/0bf152ef84195293.txt"
elif provided_camera_trajectory == "Trajectory 2":
trajectory_file_path = "assets/pose_files/0c9b371cc6225682.txt"
elif provided_camera_trajectory == "Trajectory 3":
trajectory_file_path = "assets/pose_files/0c11dbe781b1c11c.txt"
elif provided_camera_trajectory == "Trajectory 4":
trajectory_file_path = "assets/pose_files/0f47577ab3441480.txt"
elif provided_camera_trajectory == "Trajectory 5":
trajectory_file_path = "assets/pose_files/0f68374b76390082.txt"
elif provided_camera_trajectory == "Trajectory 6":
trajectory_file_path = "assets/pose_files/2c80f9eb0d3b2bb4.txt"
elif provided_camera_trajectory == "Trajectory 7":
trajectory_file_path = "assets/pose_files/2f25826f0d0ef09a.txt"
elif provided_camera_trajectory == "Trajectory 8":
trajectory_file_path = "assets/pose_files/3f79dc32d575bcdc.txt"
else:
trajectory_file_path = "assets/pose_files/4a2d6753676df096.txt"
else:
trajectory_file_path = customized_trajectory_file.name
vis_traj = visualize_trajectory(trajectory_file_path)
return gr.update(visible=True), vis_traj, gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), trajectory_file_path
def update_set_button():
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def update_buttons_for_example(example_image, example_traj_path, provided_traj_name):
global height, width
return_image = example_image
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True)
# @torch.inference_mode()
# @spaces.GPU(duration=150)
# def sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator):
# res = pipeline(
# image=condition_image,
# pose_embedding=plucker_embedding,
# height=height,
# width=width,
# num_frames=num_frames,
# num_inference_steps=num_inference_step,
# min_guidance_scale=min_guidance_scale,
# max_guidance_scale=max_guidance_scale,
# fps=fps_id,
# do_image_process=True,
# generator=generator,
# output_type='pt'
# ).frames[0].transpose(0, 1).cpu()
#
# temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
# save_videos_grid(res[None], temporal_video_path, rescale=False)
# return temporal_video_path
@spaces.GPU(duration=80)
def sample_video(condition_image, trajectory_file, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, seed):
global height, width, num_frames, device, pipeline
with open(trajectory_file, 'r') as f:
poses = f.readlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
cam_params = [[float(x) for x in pose] for pose in poses]
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = cam_params[0].fy / cam_params[0].fx
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device='cpu') # b f h w 6
plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
generator = torch.Generator(device=device)
generator.manual_seed(int(seed))
with torch.no_grad():
sample = pipeline(
image=condition_image,
pose_embedding=plucker_embedding,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_step,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
fps=fps_id,
do_image_process=True,
generator=generator,
output_type='pt'
).frames[0].transpose(0, 1).cpu()
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
save_videos_grid(sample[None], temporal_video_path, rescale=False)
return temporal_video_path
# return sample(condition_image, plucker_embedding, height, width, num_frames, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, generator)
def main(args):
demo = gr.Blocks().queue()
with demo:
gr.Markdown(title)
gr.Markdown(subtitle)
gr.Markdown(description)
with gr.Column():
# step1: Input condition image
step1_title = gr.Markdown("---\n## Step 1: Input an Image", show_label=False, visible=True)
step1_dec = gr.Markdown(f"\n 1. Upload an Image by `Drag` or Click `Upload Image`; \
\n 2. Click `{RESIZE_MODES[0]}` or `{RESIZE_MODES[1]}` to select the image resize mode. \
\n - `{RESIZE_MODES[0]}`: First resize the input image, then center crop it into the resolution of 320 x 576. \
\n - `{RESIZE_MODES[1]}`: Only resize the input image, and keep the original aspect ratio.",
show_label=False, visible=True)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
input_image = gr.Image(type='pil', interactive=True, elem_id='condition_image',
elem_classes='image',
visible=True)
with gr.Row():
resize_crop_button = gr.Button(RESIZE_MODES[0], visible=True)
directly_resize_button = gr.Button(RESIZE_MODES[1], visible=True)
with gr.Column(scale=2):
processed_image = gr.Image(type='pil', interactive=False, elem_id='processed_image',
elem_classes='image', visible=False)
# step2: Select camera trajectory
step2_camera_trajectory = gr.Markdown("---\n## Step 2: Select the camera trajectory", show_label=False,
visible=False)
step2_camera_trajectory_des = gr.Markdown(f"\n - `{CAMERA_TRAJECTORY_MODES[0]}`: Including 9 camera trajectories extracted from the test set of RealEstate10K dataset, each has 25 frames. \
\n - `{CAMERA_TRAJECTORY_MODES[1]}`: You can provide the customized camera trajectories in the txt file.",
show_label=False, visible=False)
with gr.Row(equal_height=True):
provide_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[0], visible=False)
customized_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[1], visible=False)
with gr.Row():
with gr.Column():
provided_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[0]}", show_label=False,
visible=False)
provided_camera_trajectory_des = gr.Markdown(f"\n 1. Click one of the provide camera trajectories, such as `Trajectory 1`; \
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
show_label=False, visible=False)
customized_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[1]}",
show_label=False,
visible=False)
customized_run_status = gr.Markdown(f"\n 1. Input the txt file containing camera trajectory. \
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
show_label=False, visible=False)
with gr.Row():
provided_trajectories = gr.Dropdown(
["Trajectory 1", "Trajectory 2", "Trajectory 3", "Trajectory 4", "Trajectory 5",
"Trajectory 6", "Trajectory 7", "Trajectory 8", "Trajectory 9"],
label="Provided Trajectories", interactive=True, visible=False)
with gr.Row():
customized_camera_trajectory_file = gr.File(
label="Upload customized camera trajectory (in .txt format).", visible=False, interactive=True)
with gr.Row():
camera_args = gr.Textbox(value=" ", label="Camera Trajectory Name", visible=False)
camera_trajectory_path = gr.Textbox(value=" ", visible=False)
with gr.Row():
camera_trajectory_vis = gr.Button(value="Visualize Camera Trajectory", visible=False)
camera_trajectory_reset = gr.Button(value="Reset Camera Trajectory", visible=False)
with gr.Column():
vis_camera_trajectory = gr.Plot(vis_traj, label='Camera Trajectory', visible=False)
# step3: Set inference parameters
with gr.Row():
with gr.Column():
step3_title = gr.Markdown(f"---\n## Step3: Setting the inference hyper-parameters.", visible=False)
step3_des = gr.Markdown(
f"\n 1. Set the mumber of inference step; \
\n 2. Set the seed; \
\n 3. Set the minimum guidance scale and the maximum guidance scale; \
\n 4. Set the fps; \
\n - Please refer to the SVD paper for the meaning of the last three parameter",
visible=False)
with gr.Row():
with gr.Column():
num_inference_steps = gr.Number(value=25, label='Number Inference Steps', step=1, interactive=True,
visible=False)
with gr.Column():
seed = gr.Number(value=42, label='Seed', minimum=1, interactive=True, visible=False, step=1)
with gr.Column():
min_guidance_scale = gr.Number(value=1.0, label='Minimum Guidance Scale', minimum=1.0, step=0.5,
interactive=True, visible=False)
with gr.Column():
max_guidance_scale = gr.Number(value=3.0, label='Maximum Guidance Scale', minimum=1.0, step=0.5,
interactive=True, visible=False)
with gr.Column():
fps = gr.Number(value=7, label='FPS', minimum=1, step=1, interactive=True, visible=False)
with gr.Column():
_ = gr.Button("Seed", visible=False)
with gr.Column():
_ = gr.Button("Seed", visible=False)
with gr.Column():
_ = gr.Button("Seed", visible=False)
with gr.Row():
with gr.Column():
_ = gr.Button("Set", visible=False)
with gr.Column():
set_button = gr.Button("Set", visible=False)
with gr.Column():
_ = gr.Button("Set", visible=False)
# step 4: Generate video
with gr.Row():
with gr.Column():
step4_title = gr.Markdown("---\n## Step4 Generating video", show_label=False, visible=False)
step4_des = gr.Markdown(f"\n - Click the `Start generation !` button to generate the video.; \
\n - If the content of generated video is not very aligned with the condition image, try to increase the `Minimum Guidance Scale` and `Maximum Guidance Scale`. \
\n - If the generated videos are distored, try to increase `FPS`.",
visible=False)
start_button = gr.Button(value="Start generation !", visible=False)
with gr.Column():
generate_video = gr.Video(value=None, label="Generate Video", visible=False)
resize_crop_button.click(fn=process_input_image, inputs=[input_image, resize_crop_button],
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
provide_trajectory_button, customized_trajectory_button])
directly_resize_button.click(fn=process_input_image, inputs=[input_image, directly_resize_button],
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
provide_trajectory_button, customized_trajectory_button])
provide_trajectory_button.click(fn=update_camera_trajectories, inputs=[provide_trajectory_button],
outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
provided_trajectories,
customized_camera_trajectory, customized_run_status,
customized_camera_trajectory_file,
camera_args, camera_trajectory_vis, camera_trajectory_reset])
customized_trajectory_button.click(fn=update_camera_trajectories, inputs=[customized_trajectory_button],
outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
provided_trajectories,
customized_camera_trajectory, customized_run_status,
customized_camera_trajectory_file,
camera_args, camera_trajectory_vis, camera_trajectory_reset])
provided_trajectories.change(fn=update_camera_args, inputs=[provide_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
outputs=[camera_args])
customized_camera_trajectory_file.change(fn=update_camera_args, inputs=[customized_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
outputs=[camera_args])
camera_trajectory_reset.click(fn=update_camera_args_reset, inputs=None, outputs=[camera_args])
camera_trajectory_vis.click(fn=update_trajectory_vis_plot, inputs=[camera_args, provided_trajectories, customized_camera_trajectory_file],
outputs=[vis_camera_trajectory, vis_camera_trajectory, step3_title, step3_des,
num_inference_steps, min_guidance_scale, max_guidance_scale, fps,
seed, set_button, camera_trajectory_path])
set_button.click(fn=update_set_button, inputs=None, outputs=[step4_title, step4_des, start_button, generate_video])
start_button.click(fn=sample_video, inputs=[processed_image, camera_trajectory_path, num_inference_steps,
min_guidance_scale, max_guidance_scale, fps, seed],
outputs=[generate_video])
# set example
gr.Markdown("## Examples")
gr.Markdown("\n Choosing the one of the following examples to get a quick start, by selecting an example, "
"we will set the condition image and camera trajectory automatically. "
"Then, you can click the `Visualize Camera Trajectory` button to visualize the camera trajectory.")
gr.Examples(
fn=update_buttons_for_example,
run_on_click=True,
cache_examples=False,
examples=examples,
inputs=[input_image, camera_args, provided_trajectories],
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button,
customized_trajectory_button,
provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories,
customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file,
camera_args, camera_trajectory_vis, camera_trajectory_reset]
)
with gr.Row():
gr.Markdown(closing_words)
demo.launch(**args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--listen', default='0.0.0.0')
parser.add_argument('--broswer', action='store_true')
parser.add_argument('--share', action='store_true')
args = parser.parse_args()
launch_kwargs = {'server_name': args.listen,
'inbrowser': args.broswer,
'share': args.share}
main(launch_kwargs)