import spaces import gradio as gr import argparse import sys import time import os import random from skyreelsinfer.offload import Offload, OffloadConfig from skyreelsinfer.pipelines import SkyreelsVideoPipeline from skyreelsinfer import TaskType #from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer from diffusers import HunyuanVideoTransformer3DModel from diffusers.utils import export_to_video from diffusers.utils import load_image from PIL import Image import numpy as np from torchao.quantization import float8_weight_only from torchao.quantization import quantize_ from transformers import LlamaModel import torch torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False torch.backends.cuda.preferred_blas_library="cublas" torch.backends.cuda.preferred_linalg_library="cusolver" torch.set_float32_matmul_precision("highest") torch.backends.cuda.enable_cudnn_sdp(False) # Still a good idea to keep it. os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1") os.environ["SAFETENSORS_FAST_GPU"] = "1" os.putenv("TOKENIZERS_PARALLELISM","False") model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" base_model_id = "hunyuanvideo-community/HunyuanVideo" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") offload_config=OffloadConfig( high_cpu_memory=True, parameters_level=True, compiler_transformer=False, ) def init_predictor(): global pipe text_encoder = LlamaModel.from_pretrained( base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, ).to("cpu") transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, # subfolder="transformer", torch_dtype=torch.bfloat16, #device="cpu", ).to("cuda").eval() #quantize_(text_encoder, float8_weight_only(), device="cpu") #text_encoder.to("cpu") #torch.cuda.empty_cache() #quantize_(transformer, float8_weight_only(), device="cpu") #transformer.to("cuda") #torch.cuda.empty_cache() pipe = SkyreelsVideoPipeline.from_pretrained( base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16, ) #.to("cpu") pipe.vae.to('cpu') pipe.vae.enable_tiling() torch.cuda.empty_cache() negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" @spaces.GPU(duration=90) def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ): if segment==1: random.seed(time.time()) seed = int(random.randrange(4294967294)) #Offload.offload( # pipeline=pipe, # config=offload_config, #) pipe.text_encoder.to("cuda") pipe.text_encoder_2.to("cuda") with torch.no_grad(): prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt( prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device ) pipe.text_encoder.to("cpu") pipe.text_encoder_2.to("cpu") #pipe.trasformer.to('cuda') torch.cuda.empty_cache() generator = torch.Generator(device='cuda').manual_seed(seed) transformer_dtype = pipe.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) negative_attention_mask = negative_attention_mask.to(transformer_dtype) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask]) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps all_timesteps_cpu = timesteps.cpu() timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda") num_channels_latents = pipe.transformer.config.in_channels num_channels_latents = int(num_channels_latents / 2) image = Image.open(image).convert('RGB') image.resize((size,size), Image.LANCZOS) pipe.vae.to("cuda") with torch.no_grad(): image = pipe.video_processor.preprocess(image, height=size, width=size).to( device, dtype=prompt_embeds.dtype ) num_latent_frames = (frames - 1) // pipe.vae_scale_factor_temporal + 1 latents = pipe.prepare_latents( batch_size=1, num_channels_latents=num_channels_latents, height=size, width=size, num_frames=frames, dtype=torch.float32, device=device, generator=generator, latents=None, ) image_latents = pipe.image_latents( image, 1, size, size, device, torch.float32, num_channels_latents, num_latent_frames ) image_latents = image_latents.to("cuda", pipe.transformer.dtype) pipe.vae.to("cpu") torch.cuda.empty_cache() guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 else: pipe.vae.to("cpu") torch.cuda.empty_cache() transformer_dtype = pipe.transformer.dtype state_file = f"SkyReel_{segment-1}_{seed}.pt" state = torch.load(state_file, weights_only=False) generator = torch.Generator(device='cuda').manual_seed(seed) latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16) guidance_scale = state["guidance_scale"] all_timesteps_cpu = state["all_timesteps"] size = state["height"] size = state["width"] pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device) timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16) pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16) image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16) if segment==9: pipe.transformer.to('cpu') torch.cuda.empty_cache() pipe.vae.to("cuda") latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor #with torch.no_grad(): video = pipe.vae.decode(latents, return_dict=False)[0] video = pipe.video_processor.postprocess_video(video) # return HunyuanVideoPipelineOutput(frames=video) save_dir = f"./" video_out_file = f"{save_dir}/{seed}.mp4" print(f"generate video, local path: {video_out_file}") export_to_video(output, video_out_file, fps=24) return video_out_file, seed else: segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda") guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 for i, t in enumerate(pipe.progress_bar(segment_timesteps)): latents = latents.to(transformer_dtype) latent_model_input = torch.cat([latents] * 2) latent_image_input = (torch.cat([image_latents] * 2)) latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32) with torch.no_grad(): noise_pred = pipe.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, # attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] intermediate_latents_cpu = latents.detach().cpu() original_prompt_embeds_cpu = prompt_embeds.cpu() original_image_latents_cpu = image_latents.cpu() original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu() original_prompt_attention_mask_cpu = prompt_attention_mask.cpu() timesteps = pipe.scheduler.timesteps all_timesteps_cpu = timesteps.cpu() # Move to CPU state = { "intermediate_latents": intermediate_latents_cpu, "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler "prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds "image_latents": original_image_latents_cpu, "pooled_prompt_embeds": original_pooled_prompt_embeds_cpu, "prompt_attention_mask": original_prompt_attention_mask_cpu, "guidance_scale": guidance_scale, "seed": seed, "prompt": prompt, # Save originals for reference/verification "negative_prompt": negative_prompt, "height": size, # Save dimensions used "width": size } state_file = f"SkyReel_{segment}_{seed}.pt" torch.save(state, state_file) return None, seed def update_ranges(total_steps): """Calculates and updates the ranges for the 8 slave sliders.""" step_size = total_steps // 8 # Calculate the size of each segment ranges = [] for i in range(8): lower_bound = i * step_size ranges.append([lower_bound]) # Add the range to the list return ranges with gr.Blocks() as demo: with gr.Row(): image = gr.Image(label="Upload Image", type="filepath") prompt = gr.Textbox(label="Input Prompt") run_button_1 = gr.Button("Run Segment 1", scale=0) run_button_2 = gr.Button("Run Segment 2", scale=0) run_button_3 = gr.Button("Run Segment 3", scale=0) run_button_4 = gr.Button("Run Segment 4", scale=0) run_button_5 = gr.Button("Run Segment 5", scale=0) run_button_6 = gr.Button("Run Segment 6", scale=0) run_button_7 = gr.Button("Run Segment 7", scale=0) run_button_8 = gr.Button("Run Segment 8", scale=0) run_button_9 = gr.Button("Run Decode Video", scale=0) result = gr.Gallery(label="Result", columns=1, show_label=False) seed = gr.Number(value=1, label="Seed") size = gr.Slider( label="Size", minimum=256, maximum=1024, step=16, value=368, ) frames = gr.Slider( label="Number of Frames", minimum=16, maximum=256, step=8, value=64, ) steps = gr.Slider( label="Number of Steps", minimum=1, maximum=96, step=1, value=25, ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=16.0, step=.1, value=6.0, ) submit_button = gr.Button("Generate Video") output_video = gr.Video(label="Generated Video") range_sliders = [] for i in range(8): slider = gr.Slider( minimum=1, maximum=250, value=[i * (steps.value // 8)], step=1, label=f"Range {i + 1}", ) range_sliders.append(slider) steps.change( update_ranges, inputs=steps, outputs=range_sliders, ) gr.on( triggers=[ run_button_1.click, ], fn=generate, inputs=[ gr.Number(value=1), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_2.click, ], fn=generate, inputs=[ gr.Number(value=2), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_3.click, ], fn=generate, inputs=[ gr.Number(value=3), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_4.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_5.click, ], fn=generate, inputs=[ gr.Number(value=5), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_6.click, ], fn=generate, inputs=[ gr.Number(value=6), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_7.click, ], fn=generate, inputs=[ gr.Number(value=7), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_8.click, ], fn=generate, inputs=[ gr.Number(value=8), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_9.click, ], fn=generate, inputs=[ gr.Number(value=9), image, prompt, size, guidance_scale, steps, frames, seed, ], outputs=[result, seed], ) if __name__ == "__main__": init_predictor() demo.launch()