import os import sys import gradio as gr import torch import subprocess import argparse import glob project_root = os.path.dirname(os.path.abspath(__file__)) os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio") sys.path.append(project_root) HERE_PATH = os.path.normpath(os.path.dirname(__file__)) sys.path.insert(0, HERE_PATH) from huggingface_hub import hf_hub_download hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/') # Parse command line arguments parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI") parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") parser.add_argument("--share", action="store_true", help="Share the web UI") parser.add_argument("--gpu", type=int, default=0, help="GPU device ID") parser.add_argument("--model_path", type=str, default="EXCAI/Diffusion-As-Shader", help="Path to model checkpoint") parser.add_argument("--output_dir", type=str, default="tmp", help="Output directory") args = parser.parse_args() # Use the original GPU ID throughout the entire code for consistency GPU_ID = args.gpu # Set environment variables - this used to remap the GPU, but we're removing this for consistency # Instead, we'll pass the original GPU ID to all commands # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Commented out to ensure consistent GPU ID usage # Check if CUDA is available CUDA_AVAILABLE = torch.cuda.is_available() if CUDA_AVAILABLE: GPU_COUNT = torch.cuda.device_count() GPU_NAMES = [f"{i}: {torch.cuda.get_device_name(i)}" for i in range(GPU_COUNT)] else: GPU_COUNT = 0 GPU_NAMES = ["CPU (CUDA not available)"] GPU_ID = "CPU" DEFAULT_MODEL_PATH = args.model_path OUTPUT_DIR = args.output_dir # Create necessary directories os.makedirs("outputs", exist_ok=True) # Create project tmp directory instead of using system temp os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True) os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True) def save_uploaded_file(file): if file is None: return None # Use project tmp directory instead of system temp temp_dir = os.path.join(project_root, "tmp") if hasattr(file, 'name'): filename = file.name else: # Generate a unique filename if name attribute is missing import uuid ext = ".tmp" if hasattr(file, 'content_type'): if "image" in file.content_type: ext = ".png" elif "video" in file.content_type: ext = ".mp4" filename = f"{uuid.uuid4()}{ext}" temp_path = os.path.join(temp_dir, filename) try: # Check if file is a FileStorage object or already a path if hasattr(file, 'save'): file.save(temp_path) elif isinstance(file, str): # It's already a path return file else: # Try to read and save the file with open(temp_path, 'wb') as f: f.write(file.read() if hasattr(file, 'read') else file) except Exception as e: print(f"Error saving file: {e}") return None return temp_path def create_run_command(args): """Create command based on input parameters""" cmd = ["python", "demo.py"] if "prompt" not in args or args["prompt"] is None or args["prompt"] == "": args["prompt"] = "" if "checkpoint_path" not in args or args["checkpoint_path"] is None or args["checkpoint_path"] == "": args["checkpoint_path"] = DEFAULT_MODEL_PATH # 添加调试输出 print(f"DEBUG: Command args: {args}") for key, value in args.items(): if value is not None: # Handle boolean values correctly - for repaint, we need to pass true/false if isinstance(value, bool): cmd.append(f"--{key}") cmd.append(str(value).lower()) # Convert True/False to true/false else: cmd.append(f"--{key}") cmd.append(str(value)) return cmd def run_process(cmd): """Run command and return output""" print(f"Running command: {' '.join(cmd)}") process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True ) output = [] for line in iter(process.stdout.readline, ""): print(line, end="") output.append(line) if not line: break process.stdout.close() return_code = process.wait() if return_code: stderr = process.stderr.read() print(f"Error: {stderr}") raise subprocess.CalledProcessError(return_code, cmd, output="\n".join(output), stderr=stderr) return "\n".join(output) # Process functions for each tab def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image): """Process video motion transfer task""" try: # Save uploaded files input_video_path = save_uploaded_file(source) if input_video_path is None: return None print(f"DEBUG: Repaint option: {mt_repaint_option}") print(f"DEBUG: Repaint image: {mt_repaint_image}") args = { "input_path": input_video_path, "prompt": f"\"{prompt}\"", "checkpoint_path": DEFAULT_MODEL_PATH, "output_dir": OUTPUT_DIR, "gpu": GPU_ID } # Priority: Custom Image > Yes > No if mt_repaint_image is not None: # Custom image takes precedence if provided repaint_path = save_uploaded_file(mt_repaint_image) print(f"DEBUG: Repaint path: {repaint_path}") args["repaint"] = repaint_path elif mt_repaint_option == "Yes": # Otherwise use Yes/No selection args["repaint"] = "true" # Create and run command cmd = create_run_command(args) output = run_process(cmd) # Find generated video files output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) if output_files: # Sort by modification time, return the latest file latest_file = max(output_files, key=os.path.getmtime) return latest_file else: return None except Exception as e: import traceback print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") return None def process_camera_control(source, prompt, camera_motion, tracking_method): """Process camera control task""" try: # Save uploaded files input_media_path = save_uploaded_file(source) if input_media_path is None: return None print(f"DEBUG: Camera motion: '{camera_motion}'") print(f"DEBUG: Tracking method: '{tracking_method}'") args = { "input_path": input_media_path, "prompt": prompt, "checkpoint_path": DEFAULT_MODEL_PATH, "output_dir": OUTPUT_DIR, "gpu": GPU_ID, "tracking_method": tracking_method } if camera_motion and camera_motion.strip(): args["camera_motion"] = camera_motion # Create and run command cmd = create_run_command(args) output = run_process(cmd) # Find generated video files output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) if output_files: # Sort by modification time, return the latest file latest_file = max(output_files, key=os.path.getmtime) return latest_file else: return None except Exception as e: import traceback print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") return None def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method): """Process object manipulation task""" try: # Save uploaded files input_image_path = save_uploaded_file(source) if input_image_path is None: return None object_mask_path = save_uploaded_file(object_mask) args = { "input_path": input_image_path, "prompt": prompt, "checkpoint_path": DEFAULT_MODEL_PATH, "output_dir": OUTPUT_DIR, "gpu": GPU_ID, "object_motion": object_motion, "object_mask": object_mask_path, "tracking_method": tracking_method } # Create and run command cmd = create_run_command(args) output = run_process(cmd) # Find generated video files output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) if output_files: # Sort by modification time, return the latest file latest_file = max(output_files, key=os.path.getmtime) return latest_file else: return None except Exception as e: import traceback print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") return None def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image): """Process mesh animation task""" try: # Save uploaded files input_video_path = save_uploaded_file(source) if input_video_path is None: return None tracking_video_path = save_uploaded_file(tracking_video) if tracking_video_path is None: return None args = { "input_path": input_video_path, "prompt": prompt, "checkpoint_path": DEFAULT_MODEL_PATH, "output_dir": OUTPUT_DIR, "gpu": GPU_ID, "tracking_path": tracking_video_path } # Priority: Custom Image > Yes > No if ma_repaint_image is not None: # Custom image takes precedence if provided repaint_path = save_uploaded_file(ma_repaint_image) args["repaint"] = repaint_path elif ma_repaint_option == "Yes": # Otherwise use Yes/No selection args["repaint"] = "true" # Create and run command cmd = create_run_command(args) output = run_process(cmd) # Find generated video files output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")) if output_files: # Sort by modification time, return the latest file latest_file = max(output_files, key=os.path.getmtime) return latest_file else: return None except Exception as e: import traceback print(f"Processing failed: {str(e)}\n{traceback.format_exc()}") return None # Create Gradio interface with updated layout with gr.Blocks(title="Diffusion as Shader") as demo: gr.Markdown("# Diffusion as Shader Web UI") gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)") with gr.Row(): left_column = gr.Column(scale=1) right_column = gr.Column(scale=1) with right_column: output_video = gr.Video(label="Generated Video") with left_column: source = gr.File(label="Source", file_types=["image", "video"]) common_prompt = gr.Textbox(label="Prompt", lines=2) gr.Markdown(f"**Using GPU: {GPU_ID}**") with gr.Tabs() as task_tabs: # Motion Transfer tab with gr.TabItem("Motion Transfer"): gr.Markdown("## Motion Transfer") # Simplified controls - Radio buttons for Yes/No and separate file upload with gr.Row(): mt_repaint_option = gr.Radio( label="Repaint First Frame", choices=["No", "Yes"], value="No" ) gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") # Custom image uploader (always visible) mt_repaint_image = gr.File( label="Custom Repaint Image", file_types=["image"] ) # Add run button for Motion Transfer tab mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg") # Connect to process function mt_run_btn.click( fn=process_motion_transfer, inputs=[ source, common_prompt, mt_repaint_option, mt_repaint_image ], outputs=[output_video] ) # Camera Control tab with gr.TabItem("Camera Control"): gr.Markdown("## Camera Control") cc_camera_motion = gr.Textbox( label="Current Camera Motion Sequence", placeholder="Your camera motion sequence will appear here...", interactive=False ) # Use tabs for different motion types with gr.Tabs() as cc_motion_tabs: # Translation tab with gr.TabItem("Translation (trans)"): with gr.Row(): cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement") cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement") cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)") with gr.Row(): cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) cc_trans_note = gr.Markdown(""" **Translation Notes:** - Positive X: Move right, Negative X: Move left - Positive Y: Move down, Negative Y: Move up - Positive Z: Zoom in, Negative Z: Zoom out """) # Add translation button in the Translation tab cc_add_trans = gr.Button("Add Camera Translation", variant="secondary") # Function to add translation motion def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end): # Format: trans dx dy dz [start_frame end_frame] frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else "" new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}" # Append to existing motion string with semicolon separator if needed if current_motion and current_motion.strip(): updated_motion = f"{current_motion}; {new_motion}" else: updated_motion = new_motion return updated_motion # Connect translation button cc_add_trans.click( fn=add_translation_motion, inputs=[ cc_camera_motion, cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end ], outputs=[cc_camera_motion] ) # Rotation tab with gr.TabItem("Rotation (rot)"): with gr.Row(): cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis") cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)") with gr.Row(): cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0) cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0) cc_rot_note = gr.Markdown(""" **Rotation Notes:** - X-axis rotation: Tilt camera up/down - Y-axis rotation: Pan camera left/right - Z-axis rotation: Roll camera """) # Add rotation button in the Rotation tab cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary") # Function to add rotation motion def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end): # Format: rot axis angle [start_frame end_frame] frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else "" new_motion = f"rot {rot_axis} {rot_angle}{frame_range}" # Append to existing motion string with semicolon separator if needed if current_motion and current_motion.strip(): updated_motion = f"{current_motion}; {new_motion}" else: updated_motion = new_motion return updated_motion # Connect rotation button cc_add_rot.click( fn=add_rotation_motion, inputs=[ cc_camera_motion, cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end ], outputs=[cc_camera_motion] ) # Add a clear button to reset the motion sequence cc_clear_motion = gr.Button("Clear All Motions", variant="stop") def clear_camera_motion(): return "" cc_clear_motion.click( fn=clear_camera_motion, inputs=[], outputs=[cc_camera_motion] ) cc_tracking_method = gr.Radio( label="Tracking Method", choices=["spatracker", "moge"], value="moge" ) # Add run button for Camera Control tab cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg") # Connect to process function cc_run_btn.click( fn=process_camera_control, inputs=[ source, common_prompt, cc_camera_motion, cc_tracking_method ], outputs=[output_video] ) # Object Manipulation tab with gr.TabItem("Object Manipulation"): gr.Markdown("## Object Manipulation") om_object_mask = gr.File( label="Object Mask Image", file_types=["image"] ) gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate") om_object_motion = gr.Dropdown( label="Object Motion Type", choices=["up", "down", "left", "right", "front", "back", "rot"], value="up" ) om_tracking_method = gr.Radio( label="Tracking Method", choices=["spatracker", "moge"], value="moge" ) # Add run button for Object Manipulation tab om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg") # Connect to process function om_run_btn.click( fn=process_object_manipulation, inputs=[ source, common_prompt, om_object_motion, om_object_mask, om_tracking_method ], outputs=[output_video] ) # Animating meshes to video tab with gr.TabItem("Animating meshes to video"): gr.Markdown("## Mesh Animation to Video") gr.Markdown(""" Note: Currently only supports tracking videos generated with Blender (version > 4.0). Please run the script `scripts/blender.py` in your Blender project to generate tracking videos. """) ma_tracking_video = gr.File( label="Tracking Video", file_types=["video"] ) gr.Markdown("Tracking video needs to be generated from Blender") # Simplified controls - Radio buttons for Yes/No and separate file upload with gr.Row(): ma_repaint_option = gr.Radio( label="Repaint First Frame", choices=["No", "Yes"], value="No" ) gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.") # Custom image uploader (always visible) ma_repaint_image = gr.File( label="Custom Repaint Image", file_types=["image"] ) # Add run button for Mesh Animation tab ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg") # Connect to process function ma_run_btn.click( fn=process_mesh_animation, inputs=[ source, common_prompt, ma_tracking_video, ma_repaint_option, ma_repaint_image ], outputs=[output_video] ) # Launch interface if __name__ == "__main__": print(f"Using GPU: {GPU_ID}") print(f"Web UI will start on port {args.port}") if args.share: print("Creating public link for remote access") # Launch interface demo.launch(share=args.share, server_port=args.port)