Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import gradio as gr | |
import torch | |
import argparse | |
from PIL import Image | |
import numpy as np | |
import torchvision.transforms as transforms | |
from moviepy.editor import VideoFileClip | |
from diffusers.utils import load_image, load_video | |
import spaces | |
project_root = os.path.dirname(os.path.abspath(__file__)) | |
os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio") | |
sys.path.append(project_root) | |
try: | |
sys.path.append(os.path.join(project_root, "submodules/MoGe")) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
except: | |
print("Warning: MoGe not found, motion transfer will not be applied") | |
HERE_PATH = os.path.normpath(os.path.dirname(__file__)) | |
sys.path.insert(0, HERE_PATH) | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/') | |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator | |
from submodules.MoGe.moge.model import MoGeModel | |
# Parse command line arguments | |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI") | |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") | |
parser.add_argument("--share", action="store_true", help="Share the web UI") | |
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID") | |
parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint") | |
parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory") | |
args = parser.parse_args() | |
# Use the original GPU ID throughout the entire code for consistency | |
GPU_ID = args.gpu | |
DEFAULT_MODEL_PATH = args.model_path | |
OUTPUT_DIR = args.output_dir | |
# Create necessary directories | |
os.makedirs("outputs", exist_ok=True) | |
# Create project tmp directory instead of using system temp | |
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True) | |
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True) | |
def load_media(media_path, max_frames=49, transform=None): | |
"""Load video or image frames and convert to tensor | |
Args: | |
media_path (str): Path to video or image file | |
max_frames (int): Maximum number of frames to load | |
transform (callable): Transform to apply to frames | |
Returns: | |
Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag | |
""" | |
if transform is None: | |
transform = transforms.Compose([ | |
transforms.Resize((480, 720)), | |
transforms.ToTensor() | |
]) | |
# Determine if input is video or image based on extension | |
ext = os.path.splitext(media_path)[1].lower() | |
is_video = ext in ['.mp4', '.avi', '.mov'] | |
if is_video: | |
frames = load_video(media_path) | |
fps = len(frames) / VideoFileClip(media_path).duration | |
else: | |
# Handle image as single frame | |
image = load_image(media_path) | |
frames = [image] | |
fps = 8 # Default fps for images | |
# Ensure we have exactly max_frames | |
if len(frames) > max_frames: | |
frames = frames[:max_frames] | |
elif len(frames) < max_frames: | |
last_frame = frames[-1] | |
while len(frames) < max_frames: | |
frames.append(last_frame.copy()) | |
# Convert frames to tensor | |
video_tensor = torch.stack([transform(frame) for frame in frames]) | |
return video_tensor, fps, is_video | |
def save_uploaded_file(file): | |
if file is None: | |
return None | |
# Use project tmp directory instead of system temp | |
temp_dir = os.path.join(project_root, "tmp") | |
if hasattr(file, 'name'): | |
filename = file.name | |
else: | |
# Generate a unique filename if name attribute is missing | |
import uuid | |
ext = ".tmp" | |
if hasattr(file, 'content_type'): | |
if "image" in file.content_type: | |
ext = ".png" | |
elif "video" in file.content_type: | |
ext = ".mp4" | |
filename = f"{uuid.uuid4()}{ext}" | |
temp_path = os.path.join(temp_dir, filename) | |
try: | |
# Check if file is a FileStorage object or already a path | |
if hasattr(file, 'save'): | |
file.save(temp_path) | |
elif isinstance(file, str): | |
# It's already a path | |
return file | |
else: | |
# Try to read and save the file | |
with open(temp_path, 'wb') as f: | |
f.write(file.read() if hasattr(file, 'read') else file) | |
except Exception as e: | |
print(f"Error saving file: {e}") | |
return None | |
return temp_path | |
das_pipeline = None | |
moge_model = None | |
def get_das_pipeline(): | |
global das_pipeline | |
if das_pipeline is None: | |
das_pipeline = DiffusionAsShaderPipeline(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
return das_pipeline | |
def get_moge_model(): | |
global moge_model | |
if moge_model is None: | |
das = get_das_pipeline() | |
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device) | |
return moge_model | |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image): | |
"""Process video motion transfer task""" | |
try: | |
# Save uploaded files | |
input_video_path = save_uploaded_file(source) | |
if input_video_path is None: | |
return None | |
print(f"DEBUG: Repaint option: {mt_repaint_option}") | |
print(f"DEBUG: Repaint image: {mt_repaint_image}") | |
das = get_das_pipeline() | |
video_tensor, fps, is_video = load_media(input_video_path) | |
if not is_video: | |
tracking_method = "moge" | |
print("Image input detected, using MoGe for tracking video generation.") | |
else: | |
tracking_method = "spatracker" | |
repaint_img_tensor = None | |
if mt_repaint_image is not None: | |
repaint_path = save_uploaded_file(mt_repaint_image) | |
repaint_img_tensor, _, _ = load_media(repaint_path) | |
repaint_img_tensor = repaint_img_tensor[0] | |
elif mt_repaint_option == "Yes": | |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
repaint_img_tensor = repainter.repaint( | |
video_tensor[0], | |
prompt=prompt, | |
depth_path=None | |
) | |
tracking_tensor = None | |
if tracking_method == "moge": | |
moge = get_moge_model() | |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
H, W = infer_result["points"].shape[0:2] | |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
cam_motion = CameraMotionGenerator(None) | |
cam_motion.set_intr(infer_result["intrinsics"]) | |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
_, tracking_tensor = das.visualize_tracking_moge( | |
pred_tracks.cpu().numpy(), | |
infer_result["mask"].cpu().numpy() | |
) | |
print('Export tracking video via MoGe') | |
else: | |
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) | |
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts) | |
print('Export tracking video via SpaTracker') | |
output_path = das.apply_tracking( | |
video_tensor=video_tensor, | |
fps=8, | |
tracking_tensor=tracking_tensor, | |
img_cond_tensor=repaint_img_tensor, | |
prompt=prompt, | |
checkpoint_path=DEFAULT_MODEL_PATH | |
) | |
return output_path | |
except Exception as e: | |
import traceback | |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
return None | |
def process_camera_control(source, prompt, camera_motion, tracking_method): | |
"""Process camera control task""" | |
try: | |
# Save uploaded files | |
input_media_path = save_uploaded_file(source) | |
if input_media_path is None: | |
return None | |
print(f"DEBUG: Camera motion: '{camera_motion}'") | |
print(f"DEBUG: Tracking method: '{tracking_method}'") | |
das = get_das_pipeline() | |
video_tensor, fps, is_video = load_media(input_media_path) | |
if not is_video and tracking_method == "spatracker": | |
tracking_method = "moge" | |
print("Image input detected with spatracker selected, switching to MoGe") | |
cam_motion = CameraMotionGenerator(camera_motion) | |
repaint_img_tensor = None | |
tracking_tensor = None | |
if tracking_method == "moge": | |
moge = get_moge_model() | |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
H, W = infer_result["points"].shape[0:2] | |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
cam_motion.set_intr(infer_result["intrinsics"]) | |
if camera_motion: | |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4] | |
print("Camera motion applied") | |
else: | |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
_, tracking_tensor = das.visualize_tracking_moge( | |
pred_tracks.cpu().numpy(), | |
infer_result["mask"].cpu().numpy() | |
) | |
print('Export tracking video via MoGe') | |
else: | |
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) | |
if camera_motion: | |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4] | |
pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses) | |
print("Camera motion applied") | |
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts) | |
print('Export tracking video via SpaTracker') | |
output_path = das.apply_tracking( | |
video_tensor=video_tensor, | |
fps=8, | |
tracking_tensor=tracking_tensor, | |
img_cond_tensor=repaint_img_tensor, | |
prompt=prompt, | |
checkpoint_path=DEFAULT_MODEL_PATH | |
) | |
return output_path | |
except Exception as e: | |
import traceback | |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
return None | |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method): | |
"""Process object manipulation task""" | |
try: | |
# Save uploaded files | |
input_image_path = save_uploaded_file(source) | |
if input_image_path is None: | |
return None | |
object_mask_path = save_uploaded_file(object_mask) | |
if object_mask_path is None: | |
print("Object mask not provided") | |
return None | |
das = get_das_pipeline() | |
video_tensor, fps, is_video = load_media(input_image_path) | |
if not is_video and tracking_method == "spatracker": | |
tracking_method = "moge" | |
print("Image input detected with spatracker selected, switching to MoGe") | |
mask_image = Image.open(object_mask_path).convert('L') | |
mask_image = transforms.Resize((480, 720))(mask_image) | |
mask = torch.from_numpy(np.array(mask_image) > 127) | |
motion_generator = ObjectMotionGenerator(device=das.device) | |
repaint_img_tensor = None | |
tracking_tensor = None | |
if tracking_method == "moge": | |
moge = get_moge_model() | |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1] | |
H, W = infer_result["points"].shape[0:2] | |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3] | |
pred_tracks = motion_generator.apply_motion( | |
pred_tracks=pred_tracks, | |
mask=mask, | |
motion_type=object_motion, | |
distance=50, | |
num_frames=49, | |
tracking_method="moge" | |
) | |
print(f"Object motion '{object_motion}' applied using provided mask") | |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1) | |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3) | |
cam_motion = CameraMotionGenerator(None) | |
cam_motion.set_intr(infer_result["intrinsics"]) | |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3] | |
_, tracking_tensor = das.visualize_tracking_moge( | |
pred_tracks.cpu().numpy(), | |
infer_result["mask"].cpu().numpy() | |
) | |
print('Export tracking video via MoGe') | |
else: | |
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) | |
pred_tracks = motion_generator.apply_motion( | |
pred_tracks=pred_tracks.squeeze(), | |
mask=mask, | |
motion_type=object_motion, | |
distance=50, | |
num_frames=49, | |
tracking_method="spatracker" | |
).unsqueeze(0) | |
print(f"Object motion '{object_motion}' applied using provided mask") | |
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts) | |
print('Export tracking video via SpaTracker') | |
output_path = das.apply_tracking( | |
video_tensor=video_tensor, | |
fps=8, | |
tracking_tensor=tracking_tensor, | |
img_cond_tensor=repaint_img_tensor, | |
prompt=prompt, | |
checkpoint_path=DEFAULT_MODEL_PATH | |
) | |
return output_path | |
except Exception as e: | |
import traceback | |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
return None | |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image): | |
"""Process mesh animation task""" | |
try: | |
# Save uploaded files | |
input_video_path = save_uploaded_file(source) | |
if input_video_path is None: | |
return None | |
tracking_video_path = save_uploaded_file(tracking_video) | |
if tracking_video_path is None: | |
return None | |
das = get_das_pipeline() | |
video_tensor, fps, is_video = load_media(input_video_path) | |
tracking_tensor, tracking_fps, _ = load_media(tracking_video_path) | |
repaint_img_tensor = None | |
if ma_repaint_image is not None: | |
repaint_path = save_uploaded_file(ma_repaint_image) | |
repaint_img_tensor, _, _ = load_media(repaint_path) | |
repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧 | |
elif ma_repaint_option == "Yes": | |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR) | |
repaint_img_tensor = repainter.repaint( | |
video_tensor[0], | |
prompt=prompt, | |
depth_path=None | |
) | |
output_path = das.apply_tracking( | |
video_tensor=video_tensor, | |
fps=8, | |
tracking_tensor=tracking_tensor, | |
img_cond_tensor=repaint_img_tensor, | |
prompt=prompt, | |
checkpoint_path=DEFAULT_MODEL_PATH | |
) | |
return output_path | |
except Exception as e: | |
import traceback | |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") | |
return None | |
# Create Gradio interface with updated layout | |
with gr.Blocks(title="Diffusion as Shader") as demo: | |
gr.Markdown("# Diffusion as Shader Web UI") | |
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)") | |
with gr.Row(): | |
left_column = gr.Column(scale=1) | |
right_column = gr.Column(scale=1) | |
with right_column: | |
output_video = gr.Video(label="Generated Video") | |
with left_column: | |
source = gr.File(label="Source", file_types=["image", "video"]) | |
common_prompt = gr.Textbox(label="Prompt", lines=2) | |
gr.Markdown(f"**Using GPU: {GPU_ID}**") | |
with gr.Tabs() as task_tabs: | |
# Motion Transfer tab | |
with gr.TabItem("Motion Transfer"): | |
gr.Markdown("## Motion Transfer") | |
# Simplified controls - Radio buttons for Yes/No and separate file upload | |
with gr.Row(): | |
mt_repaint_option = gr.Radio( | |
label="Repaint First Frame", | |
choices=["No", "Yes"], | |
value="No" | |
) | |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") | |
# Custom image uploader (always visible) | |
mt_repaint_image = gr.File( | |
label="Custom Repaint Image", | |
file_types=["image"] | |
) | |
# Add run button for Motion Transfer tab | |
mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg") | |
# Connect to process function | |
mt_run_btn.click( | |
fn=process_motion_transfer, | |
inputs=[ | |
source, common_prompt, | |
mt_repaint_option, mt_repaint_image | |
], | |
outputs=[output_video] | |
) | |
# Camera Control tab | |
with gr.TabItem("Camera Control"): | |
gr.Markdown("## Camera Control") | |
cc_camera_motion = gr.Textbox( | |
label="Current Camera Motion Sequence", | |
placeholder="Your camera motion sequence will appear here...", | |
interactive=False | |
) | |
# Use tabs for different motion types | |
with gr.Tabs() as cc_motion_tabs: | |
# Translation tab | |
with gr.TabItem("Translation (trans)"): | |
with gr.Row(): | |
cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement") | |
cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement") | |
cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)") | |
with gr.Row(): | |
cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) | |
cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) | |
cc_trans_note = gr.Markdown(""" | |
**Translation Notes:** | |
- Positive X: Move right, Negative X: Move left | |
- Positive Y: Move down, Negative Y: Move up | |
- Positive Z: Zoom in, Negative Z: Zoom out | |
""") | |
# Add translation button in the Translation tab | |
cc_add_trans = gr.Button("Add Camera Translation", variant="secondary") | |
# Function to add translation motion | |
def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end): | |
# Format: trans dx dy dz [start_frame end_frame] | |
frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else "" | |
new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}" | |
# Append to existing motion string with semicolon separator if needed | |
if current_motion and current_motion.strip(): | |
updated_motion = f"{current_motion}; {new_motion}" | |
else: | |
updated_motion = new_motion | |
return updated_motion | |
# Connect translation button | |
cc_add_trans.click( | |
fn=add_translation_motion, | |
inputs=[ | |
cc_camera_motion, | |
cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end | |
], | |
outputs=[cc_camera_motion] | |
) | |
# Rotation tab | |
with gr.TabItem("Rotation (rot)"): | |
with gr.Row(): | |
cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis") | |
cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)") | |
with gr.Row(): | |
cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) | |
cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) | |
cc_rot_note = gr.Markdown(""" | |
**Rotation Notes:** | |
- X-axis rotation: Tilt camera up/down | |
- Y-axis rotation: Pan camera left/right | |
- Z-axis rotation: Roll camera | |
""") | |
# Add rotation button in the Rotation tab | |
cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary") | |
# Function to add rotation motion | |
def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end): | |
# Format: rot axis angle [start_frame end_frame] | |
frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else "" | |
new_motion = f"rot {rot_axis} {rot_angle}{frame_range}" | |
# Append to existing motion string with semicolon separator if needed | |
if current_motion and current_motion.strip(): | |
updated_motion = f"{current_motion}; {new_motion}" | |
else: | |
updated_motion = new_motion | |
return updated_motion | |
# Connect rotation button | |
cc_add_rot.click( | |
fn=add_rotation_motion, | |
inputs=[ | |
cc_camera_motion, | |
cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end | |
], | |
outputs=[cc_camera_motion] | |
) | |
# Add a clear button to reset the motion sequence | |
cc_clear_motion = gr.Button("Clear All Motions", variant="stop") | |
def clear_camera_motion(): | |
return "" | |
cc_clear_motion.click( | |
fn=clear_camera_motion, | |
inputs=[], | |
outputs=[cc_camera_motion] | |
) | |
cc_tracking_method = gr.Radio( | |
label="Tracking Method", | |
choices=["spatracker", "moge"], | |
value="moge" | |
) | |
# Add run button for Camera Control tab | |
cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg") | |
# Connect to process function | |
cc_run_btn.click( | |
fn=process_camera_control, | |
inputs=[ | |
source, common_prompt, | |
cc_camera_motion, cc_tracking_method | |
], | |
outputs=[output_video] | |
) | |
# Object Manipulation tab | |
with gr.TabItem("Object Manipulation"): | |
gr.Markdown("## Object Manipulation") | |
om_object_mask = gr.File( | |
label="Object Mask Image", | |
file_types=["image"] | |
) | |
gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate") | |
om_object_motion = gr.Dropdown( | |
label="Object Motion Type", | |
choices=["up", "down", "left", "right", "front", "back", "rot"], | |
value="up" | |
) | |
om_tracking_method = gr.Radio( | |
label="Tracking Method", | |
choices=["spatracker", "moge"], | |
value="moge" | |
) | |
# Add run button for Object Manipulation tab | |
om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg") | |
# Connect to process function | |
om_run_btn.click( | |
fn=process_object_manipulation, | |
inputs=[ | |
source, common_prompt, | |
om_object_motion, om_object_mask, om_tracking_method | |
], | |
outputs=[output_video] | |
) | |
# Animating meshes to video tab | |
with gr.TabItem("Animating meshes to video"): | |
gr.Markdown("## Mesh Animation to Video") | |
gr.Markdown(""" | |
Note: Currently only supports tracking videos generated with Blender (version > 4.0). | |
Please run the script `scripts/blender.py` in your Blender project to generate tracking videos. | |
""") | |
ma_tracking_video = gr.File( | |
label="Tracking Video", | |
file_types=["video"] | |
) | |
gr.Markdown("Tracking video needs to be generated from Blender") | |
# Simplified controls - Radio buttons for Yes/No and separate file upload | |
with gr.Row(): | |
ma_repaint_option = gr.Radio( | |
label="Repaint First Frame", | |
choices=["No", "Yes"], | |
value="No" | |
) | |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") | |
# Custom image uploader (always visible) | |
ma_repaint_image = gr.File( | |
label="Custom Repaint Image", | |
file_types=["image"] | |
) | |
# Add run button for Mesh Animation tab | |
ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg") | |
# Connect to process function | |
ma_run_btn.click( | |
fn=process_mesh_animation, | |
inputs=[ | |
source, common_prompt, | |
ma_tracking_video, ma_repaint_option, ma_repaint_image | |
], | |
outputs=[output_video] | |
) | |
# Launch interface | |
if __name__ == "__main__": | |
print(f"Using GPU: {GPU_ID}") | |
print(f"Web UI will start on port {args.port}") | |
if args.share: | |
print("Creating public link for remote access") | |
# Launch interface | |
demo.launch(share=args.share, server_port=args.port) |