Spaces:
Running
on
Zero
Running
on
Zero
Update skyreelsinfer/pipelines/pipeline_skyreels_video.py
Browse files
skyreelsinfer/pipelines/pipeline_skyreels_video.py
CHANGED
@@ -7,12 +7,12 @@ from typing import Union
|
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
-
from diffusers import
|
11 |
-
from diffusers.pipelines.hunyuan_video.
|
12 |
-
from diffusers.pipelines.hunyuan_video.
|
13 |
-
from diffusers.pipelines.hunyuan_video.
|
14 |
-
from diffusers.pipelines.hunyuan_video.
|
15 |
-
from diffusers.pipelines.hunyuan_video.
|
16 |
from PIL import Image
|
17 |
#import gc
|
18 |
|
@@ -46,7 +46,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
46 |
return noise_cfg
|
47 |
|
48 |
|
49 |
-
class SkyreelsVideoPipeline(
|
50 |
"""
|
51 |
support i2v and t2v
|
52 |
support true_cfg
|
@@ -300,8 +300,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
|
|
300 |
device, dtype=prompt_embeds.dtype
|
301 |
)
|
302 |
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
303 |
-
latents
|
304 |
-
image,
|
305 |
batch_size * num_videos_per_prompt,
|
306 |
num_channels_latents,
|
307 |
height,
|
@@ -316,7 +315,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
|
|
316 |
self.text_encoder.to("cpu")
|
317 |
torch.cuda.empty_cache()
|
318 |
torch.cuda.reset_peak_memory_stats()
|
319 |
-
|
320 |
# add image latents
|
321 |
if image is not None:
|
322 |
image_latents = self.image_latents(
|
@@ -326,7 +325,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
|
|
326 |
image_latents = image_latents.to(transformer_dtype)
|
327 |
else:
|
328 |
image_latents = None
|
329 |
-
|
330 |
# 6. Prepare guidance condition
|
331 |
if self.do_classifier_free_guidance:
|
332 |
guidance = (
|
@@ -361,7 +360,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
|
|
361 |
latent_image_input = (
|
362 |
torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
|
363 |
)
|
364 |
-
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
|
365 |
timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
|
366 |
if cfg_for and self.do_classifier_free_guidance:
|
367 |
noise_pred_list = []
|
@@ -434,4 +433,4 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
|
|
434 |
if not return_dict:
|
435 |
return (video,)
|
436 |
|
437 |
-
return HunyuanVideoPipelineOutput(frames=video)
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
+
from diffusers import HunyuanVideoPipeline
|
11 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
|
12 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import HunyuanVideoPipelineOutput
|
13 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import MultiPipelineCallbacks
|
14 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import PipelineCallback
|
15 |
+
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps
|
16 |
from PIL import Image
|
17 |
#import gc
|
18 |
|
|
|
46 |
return noise_cfg
|
47 |
|
48 |
|
49 |
+
class SkyreelsVideoPipeline(HunyuanVideoPipeline):
|
50 |
"""
|
51 |
support i2v and t2v
|
52 |
support true_cfg
|
|
|
300 |
device, dtype=prompt_embeds.dtype
|
301 |
)
|
302 |
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
303 |
+
latents = self.prepare_latents(
|
|
|
304 |
batch_size * num_videos_per_prompt,
|
305 |
num_channels_latents,
|
306 |
height,
|
|
|
315 |
self.text_encoder.to("cpu")
|
316 |
torch.cuda.empty_cache()
|
317 |
torch.cuda.reset_peak_memory_stats()
|
318 |
+
|
319 |
# add image latents
|
320 |
if image is not None:
|
321 |
image_latents = self.image_latents(
|
|
|
325 |
image_latents = image_latents.to(transformer_dtype)
|
326 |
else:
|
327 |
image_latents = None
|
328 |
+
|
329 |
# 6. Prepare guidance condition
|
330 |
if self.do_classifier_free_guidance:
|
331 |
guidance = (
|
|
|
360 |
latent_image_input = (
|
361 |
torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
|
362 |
)
|
363 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
|
364 |
timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
|
365 |
if cfg_for and self.do_classifier_free_guidance:
|
366 |
noise_pred_list = []
|
|
|
433 |
if not return_dict:
|
434 |
return (video,)
|
435 |
|
436 |
+
return HunyuanVideoPipelineOutput(frames=video)
|