Spaces:
Paused
Paused
File size: 6,271 Bytes
10b581c ef5add3 9dab6c2 10b581c 9dab6c2 10b581c 70f2266 10b581c 5694315 10b581c 5694315 9dab6c2 5694315 9dab6c2 5694315 ef5add3 5694315 9dab6c2 45a9d7f dfaa5fc 5694315 cc5ea83 10b581c cc5ea83 10b581c cc5ea83 10b581c bf99f41 5694315 192f60f 5694315 374b969 10b581c 5694315 10b581c 5694315 f57c553 10b581c 5694315 f57c553 10b581c 5694315 10b581c 5694315 10b581c 5694315 ef5add3 10b581c 5694315 66d038e 10b581c 6a76f54 10b581c 192f60f 10b581c ecf6d80 6a76f54 e65bce3 6a76f54 10b581c ecf6d80 10b581c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import os
import torch
import gc
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import T5EncoderModel, T5Tokenizer
from datetime import datetime
import random
from huggingface_hub import hf_hub_download
# Ensure 'checkpoint' directory exists
os.makedirs("checkpoints", exist_ok=True)
# Download LoRA weights
hf_hub_download(
repo_id="wenqsun/DimensionX",
filename="orbit_left_lora_weights.safetensors",
local_dir="checkpoints"
)
hf_hub_download(
repo_id="wenqsun/DimensionX",
filename="orbit_up_lora_weights.safetensors",
local_dir="checkpoints"
)
# Load models in the global scope
model_id = "THUDM/CogVideoX-5b-I2V"
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16).to("cpu")
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cpu")
vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).to("cpu")
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
def find_and_move_object_to_cpu():
for obj in gc.get_objects():
try:
if isinstance(obj, torch.nn.Module):
if any(param.is_cuda for param in obj.parameters()):
obj.to('cpu')
if any(buf.is_cuda for buf in obj.buffers()):
obj.to('cpu')
except Exception as e:
pass
def clear_gpu():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
lora_path = "checkpoints/"
if orbit_type == "Left":
weight_name = "orbit_left_lora_weights.safetensors"
elif orbit_type == "Up":
weight_name = "orbit_up_lora_weights.safetensors"
lora_rank = 256
pipe.unload_lora_weights()
# Generate a timestamp for adapter_name
adapter_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Load LoRA weights on CPU, move to GPU afterward
pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=f"{adapter_timestamp}")
pipe.fuse_lora(lora_scale=1 / lora_rank)
# Move the pipeline to GPU for inference
pipe.to("cuda")
# Set the inference prompt
prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
image = load_image(image_path)
seed = random.randint(0, 2**8 - 1)
video = pipe(
image,
prompt,
num_inference_steps=50,
guidance_scale=7.0,
use_dynamic_cfg=True,
generator=torch.Generator(device="cpu").manual_seed(seed)
)
# Generate and save output video
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
# Move objects to CPU and clear GPU memory immediately after inference
find_and_move_object_to_cpu()
clear_gpu()
return f"output_{timestamp}.mp4"
# Set up Gradio UI
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# DimensionX")
gr.Markdown("### Create Any 3D and 4D Scenes from a Single Image with Controllable Video Diffusion")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/wenqsun/DimensionX">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://chenshuo20.github.io/DimensionX/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/abs/2411.04928">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/DimensionX?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Image Input", type="filepath")
prompt = gr.Textbox(label="Prompt")
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
submit_btn = gr.Button("Submit")
with gr.Column():
video_out = gr.Video(label="Video output")
examples = gr.Examples(
examples = [
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background.",
"Left",
"./examples/output_astronaut_left.mp4"
],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background.",
"Up",
"./examples/output_astronaut_up.mp4"
]
],
inputs=[image_in, prompt, orbit_type, video_out]
)
submit_btn.click(
fn=infer,
inputs=[image_in, prompt, orbit_type],
outputs=[video_out]
)
demo.queue().launch(show_error=True, show_api=False) |