Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import argparse | |
import sys | |
import time | |
import os | |
import random | |
from skyreelsinfer.offload import Offload, OffloadConfig | |
from skyreelsinfer.pipelines import SkyreelsVideoPipeline | |
from skyreelsinfer import TaskType | |
#from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer | |
from diffusers import HunyuanVideoTransformer3DModel | |
from diffusers.utils import export_to_video | |
from diffusers.utils import load_image | |
from PIL import Image | |
import numpy as np | |
from torchao.quantization import float8_weight_only | |
from torchao.quantization import quantize_ | |
from transformers import LlamaModel | |
import torch | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.preferred_blas_library="cublas" | |
torch.backends.cuda.preferred_linalg_library="cusolver" | |
torch.set_float32_matmul_precision("highest") | |
torch.backends.cuda.enable_cudnn_sdp(False) # Still a good idea to keep it. | |
os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1") | |
os.environ["SAFETENSORS_FAST_GPU"] = "1" | |
os.putenv("TOKENIZERS_PARALLELISM","False") | |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" | |
base_model_id = "hunyuanvideo-community/HunyuanVideo" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
offload_config=OffloadConfig( | |
high_cpu_memory=True, | |
parameters_level=True, | |
compiler_transformer=False, | |
) | |
def init_predictor(): | |
global pipe | |
text_encoder = LlamaModel.from_pretrained( | |
base_model_id, | |
subfolder="text_encoder", | |
torch_dtype=torch.bfloat16, | |
).to("cpu") | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
model_id, | |
# subfolder="transformer", | |
torch_dtype=torch.bfloat16, | |
#device="cpu", | |
).to("cuda").eval() | |
#quantize_(text_encoder, float8_weight_only(), device="cpu") | |
#text_encoder.to("cpu") | |
#torch.cuda.empty_cache() | |
#quantize_(transformer, float8_weight_only(), device="cpu") | |
#transformer.to("cuda") | |
#torch.cuda.empty_cache() | |
pipe = SkyreelsVideoPipeline.from_pretrained( | |
base_model_id, | |
transformer=transformer, | |
text_encoder=text_encoder, | |
torch_dtype=torch.bfloat16, | |
) #.to("cpu") | |
pipe.vae.to('cpu') | |
pipe.vae.enable_tiling() | |
torch.cuda.empty_cache() | |
negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" | |
def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ): | |
if segment==1: | |
random.seed(time.time()) | |
seed = int(random.randrange(4294967294)) | |
#Offload.offload( | |
# pipeline=pipe, | |
# config=offload_config, | |
#) | |
pipe.text_encoder.to("cuda") | |
pipe.text_encoder_2.to("cuda") | |
with torch.no_grad(): | |
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt( | |
prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device | |
) | |
pipe.text_encoder.to("cpu") | |
pipe.text_encoder_2.to("cpu") | |
#pipe.trasformer.to('cuda') | |
torch.cuda.empty_cache() | |
generator = torch.Generator(device='cuda').manual_seed(seed) | |
transformer_dtype = pipe.transformer.dtype | |
prompt_embeds = prompt_embeds.to(transformer_dtype) | |
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) | |
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) | |
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) | |
negative_attention_mask = negative_attention_mask.to(transformer_dtype) | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask]) | |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) | |
pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = pipe.scheduler.timesteps | |
all_timesteps_cpu = timesteps.cpu() | |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) | |
segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda") | |
num_channels_latents = pipe.transformer.config.in_channels | |
num_channels_latents = int(num_channels_latents / 2) | |
image = Image.open(image).convert('RGB') | |
image.resize((size,size), Image.LANCZOS) | |
pipe.vae.to("cuda") | |
with torch.no_grad(): | |
image = pipe.video_processor.preprocess(image, height=size, width=size).to( | |
device, dtype=prompt_embeds.dtype | |
) | |
num_latent_frames = (frames - 1) // pipe.vae_scale_factor_temporal + 1 | |
latents = pipe.prepare_latents( | |
batch_size=1, num_channels_latents=num_channels_latents, height=size, width=size, num_frames=frames, | |
dtype=torch.float32, device=device, generator=generator, latents=None, | |
) | |
image_latents = pipe.image_latents( | |
image, 1, size, size, device, torch.float32, num_channels_latents, num_latent_frames | |
) | |
image_latents = image_latents.to("cuda", pipe.transformer.dtype) | |
pipe.vae.to("cpu") | |
torch.cuda.empty_cache() | |
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
else: | |
pipe.vae.to("cpu") | |
torch.cuda.empty_cache() | |
transformer_dtype = pipe.transformer.dtype | |
state_file = f"SkyReel_{segment-1}_{seed}.pt" | |
state = torch.load(state_file, weights_only=False) | |
generator = torch.Generator(device='cuda').manual_seed(seed) | |
latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16) | |
guidance_scale = state["guidance_scale"] | |
all_timesteps_cpu = state["all_timesteps"] | |
size = state["height"] | |
size = state["width"] | |
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device) | |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) | |
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16) | |
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) | |
prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16) | |
image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16) | |
if segment==9: | |
pipe.transformer.to('cpu') | |
torch.cuda.empty_cache() | |
pipe.vae.to("cuda") | |
latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor | |
#with torch.no_grad(): | |
video = pipe.vae.decode(latents, return_dict=False)[0] | |
video = pipe.video_processor.postprocess_video(video) | |
# return HunyuanVideoPipelineOutput(frames=video) | |
save_dir = f"./" | |
video_out_file = f"{save_dir}/{seed}.mp4" | |
print(f"generate video, local path: {video_out_file}") | |
export_to_video(output, video_out_file, fps=24) | |
return video_out_file, seed | |
else: | |
segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda") | |
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
for i, t in enumerate(pipe.progress_bar(segment_timesteps)): | |
latents = latents.to(transformer_dtype) | |
latent_model_input = torch.cat([latents] * 2) | |
latent_image_input = (torch.cat([image_latents] * 2)) | |
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) | |
timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32) | |
with torch.no_grad(): | |
noise_pred = pipe.transformer( | |
hidden_states=latent_model_input, | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds, | |
encoder_attention_mask=prompt_attention_mask, | |
pooled_projections=pooled_prompt_embeds, | |
guidance=guidance, | |
# attention_kwargs=attention_kwargs, | |
return_dict=False, | |
)[0] | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
intermediate_latents_cpu = latents.detach().cpu() | |
original_prompt_embeds_cpu = prompt_embeds.cpu() | |
original_image_latents_cpu = image_latents.cpu() | |
original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu() | |
original_prompt_attention_mask_cpu = prompt_attention_mask.cpu() | |
timesteps = pipe.scheduler.timesteps | |
all_timesteps_cpu = timesteps.cpu() # Move to CPU | |
state = { | |
"intermediate_latents": intermediate_latents_cpu, | |
"all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler | |
"prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds | |
"image_latents": original_image_latents_cpu, | |
"pooled_prompt_embeds": original_pooled_prompt_embeds_cpu, | |
"prompt_attention_mask": original_prompt_attention_mask_cpu, | |
"guidance_scale": guidance_scale, | |
"seed": seed, | |
"prompt": prompt, # Save originals for reference/verification | |
"negative_prompt": negative_prompt, | |
"height": size, # Save dimensions used | |
"width": size | |
} | |
state_file = f"SkyReel_{segment}_{seed}.pt" | |
torch.save(state, state_file) | |
return None, seed | |
def update_ranges(total_steps): | |
"""Calculates and updates the ranges for the 8 slave sliders.""" | |
step_size = total_steps // 8 # Calculate the size of each segment | |
ranges = [] | |
for i in range(8): | |
lower_bound = i * step_size | |
ranges.append([lower_bound]) # Add the range to the list | |
return ranges | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image = gr.Image(label="Upload Image", type="filepath") | |
prompt = gr.Textbox(label="Input Prompt") | |
run_button_1 = gr.Button("Run Segment 1", scale=0) | |
run_button_2 = gr.Button("Run Segment 2", scale=0) | |
run_button_3 = gr.Button("Run Segment 3", scale=0) | |
run_button_4 = gr.Button("Run Segment 4", scale=0) | |
run_button_5 = gr.Button("Run Segment 5", scale=0) | |
run_button_6 = gr.Button("Run Segment 6", scale=0) | |
run_button_7 = gr.Button("Run Segment 7", scale=0) | |
run_button_8 = gr.Button("Run Segment 8", scale=0) | |
run_button_9 = gr.Button("Run Decode Video", scale=0) | |
result = gr.Gallery(label="Result", columns=1, show_label=False) | |
seed = gr.Number(value=1, label="Seed") | |
size = gr.Slider( | |
label="Size", | |
minimum=256, | |
maximum=1024, | |
step=16, | |
value=368, | |
) | |
frames = gr.Slider( | |
label="Number of Frames", | |
minimum=16, | |
maximum=256, | |
step=8, | |
value=64, | |
) | |
steps = gr.Slider( | |
label="Number of Steps", | |
minimum=1, | |
maximum=96, | |
step=1, | |
value=25, | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=16.0, | |
step=.1, | |
value=6.0, | |
) | |
submit_button = gr.Button("Generate Video") | |
output_video = gr.Video(label="Generated Video") | |
range_sliders = [] | |
for i in range(8): | |
slider = gr.Slider( | |
minimum=1, | |
maximum=250, | |
value=[i * (steps.value // 8)], | |
step=1, | |
label=f"Range {i + 1}", | |
) | |
range_sliders.append(slider) | |
steps.change( | |
update_ranges, | |
inputs=steps, | |
outputs=range_sliders, | |
) | |
gr.on( | |
triggers=[ | |
run_button_1.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=1), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_2.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=2), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_3.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=3), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_4.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=4), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_5.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=5), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_6.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=6), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_7.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=7), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_8.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=8), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
gr.on( | |
triggers=[ | |
run_button_9.click, | |
], | |
fn=generate, | |
inputs=[ | |
gr.Number(value=9), | |
image, | |
prompt, | |
size, | |
guidance_scale, | |
steps, | |
frames, | |
seed, | |
], | |
outputs=[result, seed], | |
) | |
if __name__ == "__main__": | |
init_predictor() | |
demo.launch() |