SkyReels_L / app.py
1inkusFace's picture
Update app.py
f9a089d verified
raw
history blame
10.6 kB
import spaces
import gradio as gr
import argparse
import sys
import os
import random
import subprocess
from PIL import Image
import numpy as np
# Removed environment-specific lines
from diffusers.utils import export_to_video
from diffusers.utils import load_image
import torch
import logging
from collections import OrderedDict
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("highest")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger(__name__)
# --- Dummy Classes (Keep for standalone execution) ---
class OffloadConfig:
def __init__(
self,
high_cpu_memory: bool = False,
parameters_level: bool = False,
compiler_transformer: bool = False,
compiler_cache: str = "",
):
self.high_cpu_memory = high_cpu_memory
self.parameters_level = parameters_level
self.compiler_transformer = compiler_transformer
self.compiler_cache = compiler_cache
class TaskType: # Keep here for infer
T2V = 0
I2V = 1
class LlamaModel:
@staticmethod
def from_pretrained(*args, **kwargs):
return LlamaModel()
def to(self, device):
return self
class HunyuanVideoTransformer3DModel:
@staticmethod
def from_pretrained(*args, **kwargs):
return HunyuanVideoTransformer3DModel()
def to(self, device):
return self
class SkyreelsVideoPipeline:
@staticmethod
def from_pretrained(*args, **kwargs):
return SkyreelsVideoPipeline()
def to(self, device):
return self
def __call__(self, *args, **kwargs):
num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
height = kwargs.get("height", 512)
width = kwargs.get("width", 512)
if "image" in kwargs: # I2V
image = kwargs["image"]
# Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
# Create video by repeating the image
frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
# Correct shape: (1, C, T, H, W) - NO PERMUTE HERE
else: # T2V
frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct!
return type("obj", (object,), {"frames": frames})() # No longer a list!
def __init__(self):
super().__init__()
self._modules = OrderedDict()
self.vae = self.VAE()
self._modules["vae"] = self.vae
def named_children(self):
return self._modules.items()
class VAE:
def enable_tiling(self):
pass
def quantize_(*args, **kwargs):
return
def float8_weight_only():
return
# --- End Dummy Classes ---
class SkyReelsVideoSingleGpuInfer:
def _load_model(
self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
):
logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
text_encoder = LlamaModel.from_pretrained(
base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device="cpu"
).to("cpu")
if quant_model:
quantize_(text_encoder, float8_weight_only())
text_encoder.to("cpu")
torch.cuda.empty_cache()
quantize_(transformer, float8_weight_only())
transformer.to("cpu")
torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
).to("cpu")
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
return pipe
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
enable_cfg_parallel: bool = True,
):
self.task_type = task_type
self.model_id = model_id
self.quant_model = quant_model
self.is_offload = is_offload
self.offload_config = offload_config
self.enable_cfg_parallel = enable_cfg_parallel
self.pipe = None
self.is_initialized = False
self.gpu_device = None
def initialize(self):
"""Initializes the model and moves it to the GPU."""
if self.is_initialized:
return
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Cannot initialize model.")
self.gpu_device = "cuda:0"
self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
if self.is_offload:
pass
else:
self.pipe.to(self.gpu_device)
if self.offload_config.compiler_transformer:
torch._dynamo.config.suppress_errors = True
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
self.pipe.transformer = torch.compile(
self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
)
if self.offload_config.compiler_transformer:
self.warm_up()
self.is_initialized = True
def warm_up(self):
if not self.is_initialized:
raise RuntimeError("Model must be initialized before warm-up.")
init_kwargs = {
"prompt": "A woman is dancing in a room",
"height": 544,
"width": 960,
"guidance_scale": 6,
"num_inference_steps": 1,
"negative_prompt": "bad quality",
"num_frames": 16,
"generator": torch.Generator(self.gpu_device).manual_seed(42),
"embedded_guidance_scale": 1.0,
}
if self.task_type == TaskType.I2V:
init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
self.pipe(**init_kwargs)
logger.info("Warm-up complete.")
def infer(self, **kwargs):
"""Handles inference requests."""
if not self.is_initialized:
self.initialize()
if "seed" in kwargs:
kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
del kwargs["seed"]
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
result = self.pipe(**kwargs).frames # Return the tensor directly
return result
_predictor = None
@spaces.GPU(duration=90)
def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
"""Generates a video based on the given prompt and seed.
Args:
prompt: The text prompt to guide video generation.
seed: The random seed for reproducibility.
image: Optional path to an image for Image-to-Video.
Returns:
A tuple containing the path to the generated video and the parameters used.
"""
global _predictor
if seed == -1:
random.seed()
seed = int(random.randrange(4294967294))
if image is None:
task_type = TaskType.T2V
model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
kwargs = {
"prompt": prompt,
"height": 512,
"width": 512,
"num_frames": 16,
"num_inference_steps": 30,
"seed": seed,
"guidance_scale": 7.5,
"negative_prompt": "bad quality, worst quality",
}
else:
task_type = TaskType.I2V
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
kwargs = {
"prompt": prompt,
"image": load_image(image),
"height": 512,
"width": 512,
"num_frames": 97,
"num_inference_steps": 30,
"seed": seed,
"guidance_scale": 6.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "Aerial view, low quality, bad hands",
"cfg_for": False,
}
if _predictor is None:
_predictor = SkyReelsVideoSingleGpuInfer(
task_type=task_type,
model_id=model_id,
quant_model=True,
is_offload=True,
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
compiler_transformer=False,
),
)
_predictor.initialize()
logger.info("Predictor initialized")
with torch.no_grad():
output = _predictor.infer(**kwargs)
'''
output = (output.numpy() * 255).astype(np.uint8)
# Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
output = output.transpose(0, 2, 3, 4, 1)
output = output[0] # Remove batch dimension: (T, H, W, C)
'''
save_dir = f"./result"
os.makedirs(save_dir, exist_ok=True)
video_out_file = f"{save_dir}/{seed}.mp4"
print(f"generate video, local path: {video_out_file}")
export_to_video(output, video_out_file, fps=24)
return video_out_file, kwargs
def create_gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload Image", type="filepath")
prompt = gr.Textbox(label="Input Prompt")
seed = gr.Number(label="Random Seed", value=-1)
with gr.Column():
submit_button = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
output_params = gr.Textbox(label="Output Parameters")
submit_button.click(
fn=generate_video,
inputs=[prompt, seed, image],
outputs=[output_video, output_params],
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue().launch()