SkyReels_L / app.py
1inkusFace's picture
Update app.py
0730b0c verified
import spaces
import gradio as gr
import argparse
import sys
import time
import os
import random
from skyreelsinfer.offload import Offload, OffloadConfig
from skyreelsinfer.pipelines import SkyreelsVideoPipeline
from skyreelsinfer import TaskType
#from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils import load_image
from PIL import Image
import numpy as np
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
torch.backends.cuda.preferred_blas_library="cublas"
torch.backends.cuda.preferred_linalg_library="cusolver"
torch.set_float32_matmul_precision("highest")
torch.backends.cuda.enable_cudnn_sdp(False) # Still a good idea to keep it.
os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
os.environ["SAFETENSORS_FAST_GPU"] = "1"
os.putenv("TOKENIZERS_PARALLELISM","False")
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
base_model_id = "hunyuanvideo-community/HunyuanVideo"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
compiler_transformer=False,
)
def init_predictor():
global pipe
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
# subfolder="transformer",
torch_dtype=torch.bfloat16,
#device="cpu",
).to("cuda").eval()
#quantize_(text_encoder, float8_weight_only(), device="cpu")
#text_encoder.to("cpu")
#torch.cuda.empty_cache()
#quantize_(transformer, float8_weight_only(), device="cpu")
#transformer.to("cuda")
#torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
) #.to("cpu")
pipe.vae.to('cpu')
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
@spaces.GPU(duration=90)
def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
if segment==1:
random.seed(time.time())
seed = int(random.randrange(4294967294))
#Offload.offload(
# pipeline=pipe,
# config=offload_config,
#)
pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device
)
pipe.text_encoder.to("cpu")
pipe.text_encoder_2.to("cpu")
#pipe.trasformer.to('cuda')
torch.cuda.empty_cache()
generator = torch.Generator(device='cuda').manual_seed(seed)
transformer_dtype = pipe.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
negative_attention_mask = negative_attention_mask.to(transformer_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
all_timesteps_cpu = timesteps.cpu()
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
num_channels_latents = pipe.transformer.config.in_channels
num_channels_latents = int(num_channels_latents / 2)
image = Image.open(image).convert('RGB')
image.resize((size,size), Image.LANCZOS)
pipe.vae.to("cuda")
with torch.no_grad():
image = pipe.video_processor.preprocess(image, height=size, width=size).to(
device, dtype=prompt_embeds.dtype
)
num_latent_frames = (frames - 1) // pipe.vae_scale_factor_temporal + 1
latents = pipe.prepare_latents(
batch_size=1, num_channels_latents=num_channels_latents, height=size, width=size, num_frames=frames,
dtype=torch.float32, device=device, generator=generator, latents=None,
)
image_latents = pipe.image_latents(
image, 1, size, size, device, torch.float32, num_channels_latents, num_latent_frames
)
image_latents = image_latents.to("cuda", pipe.transformer.dtype)
pipe.vae.to("cpu")
torch.cuda.empty_cache()
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
else:
pipe.vae.to("cpu")
torch.cuda.empty_cache()
transformer_dtype = pipe.transformer.dtype
state_file = f"SkyReel_{segment-1}_{seed}.pt"
state = torch.load(state_file, weights_only=False)
generator = torch.Generator(device='cuda').manual_seed(seed)
latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
guidance_scale = state["guidance_scale"]
all_timesteps_cpu = state["all_timesteps"]
size = state["height"]
size = state["width"]
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
if segment==9:
pipe.transformer.to('cpu')
torch.cuda.empty_cache()
pipe.vae.to("cuda")
latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
#with torch.no_grad():
video = pipe.vae.decode(latents, return_dict=False)[0]
video = pipe.video_processor.postprocess_video(video)
# return HunyuanVideoPipelineOutput(frames=video)
save_dir = f"./"
video_out_file = f"{save_dir}/{seed}.mp4"
print(f"generate video, local path: {video_out_file}")
export_to_video(output, video_out_file, fps=24)
return video_out_file, seed
else:
segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
latents = latents.to(transformer_dtype)
latent_model_input = torch.cat([latents] * 2)
latent_image_input = (torch.cat([image_latents] * 2))
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
with torch.no_grad():
noise_pred = pipe.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
# attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
intermediate_latents_cpu = latents.detach().cpu()
original_prompt_embeds_cpu = prompt_embeds.cpu()
original_image_latents_cpu = image_latents.cpu()
original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
original_prompt_attention_mask_cpu = prompt_attention_mask.cpu()
timesteps = pipe.scheduler.timesteps
all_timesteps_cpu = timesteps.cpu() # Move to CPU
state = {
"intermediate_latents": intermediate_latents_cpu,
"all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler
"prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds
"image_latents": original_image_latents_cpu,
"pooled_prompt_embeds": original_pooled_prompt_embeds_cpu,
"prompt_attention_mask": original_prompt_attention_mask_cpu,
"guidance_scale": guidance_scale,
"seed": seed,
"prompt": prompt, # Save originals for reference/verification
"negative_prompt": negative_prompt,
"height": size, # Save dimensions used
"width": size
}
state_file = f"SkyReel_{segment}_{seed}.pt"
torch.save(state, state_file)
return None, seed
def update_ranges(total_steps):
"""Calculates and updates the ranges for the 8 slave sliders."""
step_size = total_steps // 8 # Calculate the size of each segment
ranges = []
for i in range(8):
lower_bound = i * step_size
ranges.append([lower_bound]) # Add the range to the list
return ranges
with gr.Blocks() as demo:
with gr.Row():
image = gr.Image(label="Upload Image", type="filepath")
prompt = gr.Textbox(label="Input Prompt")
run_button_1 = gr.Button("Run Segment 1", scale=0)
run_button_2 = gr.Button("Run Segment 2", scale=0)
run_button_3 = gr.Button("Run Segment 3", scale=0)
run_button_4 = gr.Button("Run Segment 4", scale=0)
run_button_5 = gr.Button("Run Segment 5", scale=0)
run_button_6 = gr.Button("Run Segment 6", scale=0)
run_button_7 = gr.Button("Run Segment 7", scale=0)
run_button_8 = gr.Button("Run Segment 8", scale=0)
run_button_9 = gr.Button("Run Decode Video", scale=0)
result = gr.Gallery(label="Result", columns=1, show_label=False)
seed = gr.Number(value=1, label="Seed")
size = gr.Slider(
label="Size",
minimum=256,
maximum=1024,
step=16,
value=368,
)
frames = gr.Slider(
label="Number of Frames",
minimum=16,
maximum=256,
step=8,
value=64,
)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=96,
step=1,
value=25,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=16.0,
step=.1,
value=6.0,
)
submit_button = gr.Button("Generate Video")
output_video = gr.Video(label="Generated Video")
range_sliders = []
for i in range(8):
slider = gr.Slider(
minimum=1,
maximum=250,
value=[i * (steps.value // 8)],
step=1,
label=f"Range {i + 1}",
)
range_sliders.append(slider)
steps.change(
update_ranges,
inputs=steps,
outputs=range_sliders,
)
gr.on(
triggers=[
run_button_1.click,
],
fn=generate,
inputs=[
gr.Number(value=1),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_2.click,
],
fn=generate,
inputs=[
gr.Number(value=2),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_3.click,
],
fn=generate,
inputs=[
gr.Number(value=3),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_4.click,
],
fn=generate,
inputs=[
gr.Number(value=4),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_5.click,
],
fn=generate,
inputs=[
gr.Number(value=5),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_6.click,
],
fn=generate,
inputs=[
gr.Number(value=6),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_7.click,
],
fn=generate,
inputs=[
gr.Number(value=7),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_8.click,
],
fn=generate,
inputs=[
gr.Number(value=8),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
gr.on(
triggers=[
run_button_9.click,
],
fn=generate,
inputs=[
gr.Number(value=9),
image,
prompt,
size,
guidance_scale,
steps,
frames,
seed,
],
outputs=[result, seed],
)
if __name__ == "__main__":
init_predictor()
demo.launch()