Spaces:
Paused
Paused
import spaces | |
import gradio as gr | |
import argparse | |
import sys | |
import os | |
import random | |
import subprocess | |
from PIL import Image | |
import numpy as np | |
# Removed environment-specific lines | |
from diffusers.utils import export_to_video | |
from diffusers.utils import load_image | |
import torch | |
import logging | |
from collections import OrderedDict | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = False | |
torch.set_float32_matmul_precision("highest") | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
logger = logging.getLogger(__name__) | |
# --- Dummy Classes (Keep for standalone execution) --- | |
class OffloadConfig: | |
def __init__( | |
self, | |
high_cpu_memory: bool = False, | |
parameters_level: bool = False, | |
compiler_transformer: bool = False, | |
compiler_cache: str = "", | |
): | |
self.high_cpu_memory = high_cpu_memory | |
self.parameters_level = parameters_level | |
self.compiler_transformer = compiler_transformer | |
self.compiler_cache = compiler_cache | |
class TaskType: # Keep here for infer | |
T2V = 0 | |
I2V = 1 | |
class LlamaModel: | |
def from_pretrained(*args, **kwargs): | |
return LlamaModel() | |
def to(self, device): | |
return self | |
class HunyuanVideoTransformer3DModel: | |
def from_pretrained(*args, **kwargs): | |
return HunyuanVideoTransformer3DModel() | |
def to(self, device): | |
return self | |
class SkyreelsVideoPipeline: | |
def from_pretrained(*args, **kwargs): | |
return SkyreelsVideoPipeline() | |
def to(self, device): | |
return self | |
def __call__(self, *args, **kwargs): | |
num_frames = kwargs.get("num_frames", 16) # Default to 16 frames | |
height = kwargs.get("height", 512) | |
width = kwargs.get("width", 512) | |
if "image" in kwargs: # I2V | |
image = kwargs["image"] | |
# Convert PIL Image to PyTorch tensor (and normalize to [0, 1]) | |
image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W) | |
# Create video by repeating the image | |
frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W) | |
frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise | |
# Correct shape: (1, C, T, H, W) - NO PERMUTE HERE | |
else: # T2V | |
frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct! | |
return type("obj", (object,), {"frames": frames})() # No longer a list! | |
def __init__(self): | |
super().__init__() | |
self._modules = OrderedDict() | |
self.vae = self.VAE() | |
self._modules["vae"] = self.vae | |
def named_children(self): | |
return self._modules.items() | |
class VAE: | |
def enable_tiling(self): | |
pass | |
def quantize_(*args, **kwargs): | |
return | |
def float8_weight_only(): | |
return | |
# --- End Dummy Classes --- | |
class SkyReelsVideoSingleGpuInfer: | |
def _load_model( | |
self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True | |
): | |
logger.info(f"load model model_id:{model_id} quan_model:{quant_model}") | |
text_encoder = LlamaModel.from_pretrained( | |
base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
).to("cpu") | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, device="cpu" | |
).to("cpu") | |
if quant_model: | |
quantize_(text_encoder, float8_weight_only()) | |
text_encoder.to("cpu") | |
torch.cuda.empty_cache() | |
quantize_(transformer, float8_weight_only()) | |
transformer.to("cpu") | |
torch.cuda.empty_cache() | |
pipe = SkyreelsVideoPipeline.from_pretrained( | |
base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16 | |
).to("cpu") | |
pipe.vae.enable_tiling() | |
torch.cuda.empty_cache() | |
return pipe | |
def __init__( | |
self, | |
task_type: TaskType, | |
model_id: str, | |
quant_model: bool = True, | |
is_offload: bool = True, | |
offload_config: OffloadConfig = OffloadConfig(), | |
enable_cfg_parallel: bool = True, | |
): | |
self.task_type = task_type | |
self.model_id = model_id | |
self.quant_model = quant_model | |
self.is_offload = is_offload | |
self.offload_config = offload_config | |
self.enable_cfg_parallel = enable_cfg_parallel | |
self.pipe = None | |
self.is_initialized = False | |
self.gpu_device = None | |
def initialize(self): | |
"""Initializes the model and moves it to the GPU.""" | |
if self.is_initialized: | |
return | |
if not torch.cuda.is_available(): | |
raise RuntimeError("CUDA is not available. Cannot initialize model.") | |
self.gpu_device = "cuda:0" | |
self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model) | |
if self.is_offload: | |
pass | |
else: | |
self.pipe.to(self.gpu_device) | |
if self.offload_config.compiler_transformer: | |
torch._dynamo.config.suppress_errors = True | |
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}" | |
self.pipe.transformer = torch.compile( | |
self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True | |
) | |
if self.offload_config.compiler_transformer: | |
self.warm_up() | |
self.is_initialized = True | |
def warm_up(self): | |
if not self.is_initialized: | |
raise RuntimeError("Model must be initialized before warm-up.") | |
init_kwargs = { | |
"prompt": "A woman is dancing in a room", | |
"height": 544, | |
"width": 960, | |
"guidance_scale": 6, | |
"num_inference_steps": 1, | |
"negative_prompt": "bad quality", | |
"num_frames": 16, | |
"generator": torch.Generator(self.gpu_device).manual_seed(42), | |
"embedded_guidance_scale": 1.0, | |
} | |
if self.task_type == TaskType.I2V: | |
init_kwargs["image"] = Image.new("RGB", (544, 960), color="black") | |
self.pipe(**init_kwargs) | |
logger.info("Warm-up complete.") | |
def infer(self, **kwargs): | |
"""Handles inference requests.""" | |
if not self.is_initialized: | |
self.initialize() | |
if "seed" in kwargs: | |
kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"]) | |
del kwargs["seed"] | |
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V | |
result = self.pipe(**kwargs).frames # Return the tensor directly | |
return result | |
_predictor = None | |
def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]: | |
"""Generates a video based on the given prompt and seed. | |
Args: | |
prompt: The text prompt to guide video generation. | |
seed: The random seed for reproducibility. | |
image: Optional path to an image for Image-to-Video. | |
Returns: | |
A tuple containing the path to the generated video and the parameters used. | |
""" | |
global _predictor | |
if seed == -1: | |
random.seed() | |
seed = int(random.randrange(4294967294)) | |
if image is None: | |
task_type = TaskType.T2V | |
model_id = "Skywork/SkyReels-V1-Hunyuan-T2V" | |
kwargs = { | |
"prompt": prompt, | |
"height": 512, | |
"width": 512, | |
"num_frames": 16, | |
"num_inference_steps": 30, | |
"seed": seed, | |
"guidance_scale": 7.5, | |
"negative_prompt": "bad quality, worst quality", | |
} | |
else: | |
task_type = TaskType.I2V | |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" | |
kwargs = { | |
"prompt": prompt, | |
"image": load_image(image), | |
"height": 512, | |
"width": 512, | |
"num_frames": 97, | |
"num_inference_steps": 30, | |
"seed": seed, | |
"guidance_scale": 6.0, | |
"embedded_guidance_scale": 1.0, | |
"negative_prompt": "Aerial view, low quality, bad hands", | |
"cfg_for": False, | |
} | |
if _predictor is None: | |
_predictor = SkyReelsVideoSingleGpuInfer( | |
task_type=task_type, | |
model_id=model_id, | |
quant_model=True, | |
is_offload=True, | |
offload_config=OffloadConfig( | |
high_cpu_memory=True, | |
parameters_level=True, | |
compiler_transformer=False, | |
), | |
) | |
_predictor.initialize() | |
logger.info("Predictor initialized") | |
with torch.no_grad(): | |
output = _predictor.infer(**kwargs) | |
''' | |
output = (output.numpy() * 255).astype(np.uint8) | |
# Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C) | |
output = output.transpose(0, 2, 3, 4, 1) | |
output = output[0] # Remove batch dimension: (T, H, W, C) | |
''' | |
save_dir = f"./result" | |
os.makedirs(save_dir, exist_ok=True) | |
video_out_file = f"{save_dir}/{seed}.mp4" | |
print(f"generate video, local path: {video_out_file}") | |
export_to_video(output, video_out_file, fps=24) | |
return video_out_file, kwargs | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Upload Image", type="filepath") | |
prompt = gr.Textbox(label="Input Prompt") | |
seed = gr.Number(label="Random Seed", value=-1) | |
with gr.Column(): | |
submit_button = gr.Button("Generate Video") | |
output_video = gr.Video(label="Generated Video") | |
output_params = gr.Textbox(label="Output Parameters") | |
submit_button.click( | |
fn=generate_video, | |
inputs=[prompt, seed, image], | |
outputs=[output_video, output_params], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.queue().launch() |