1inkusFace commited on
Commit
ed78ea9
·
verified ·
1 Parent(s): 2ddb684

Update skyreelsinfer/pipelines/pipeline_skyreels_video.py

Browse files
skyreelsinfer/pipelines/pipeline_skyreels_video.py CHANGED
@@ -7,12 +7,12 @@ from typing import Union
7
 
8
  import numpy as np
9
  import torch
10
- from diffusers import HunyuanSkyreelsImageToVideoPipeline
11
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_skyreels_image2video import DEFAULT_PROMPT_TEMPLATE
12
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_skyreels_image2video import HunyuanVideoPipelineOutput
13
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_skyreels_image2video import MultiPipelineCallbacks
14
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_skyreels_image2video import PipelineCallback
15
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_skyreels_image2video import retrieve_timesteps
16
  from PIL import Image
17
  #import gc
18
 
@@ -46,7 +46,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
46
  return noise_cfg
47
 
48
 
49
- class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
50
  """
51
  support i2v and t2v
52
  support true_cfg
@@ -300,8 +300,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
300
  device, dtype=prompt_embeds.dtype
301
  )
302
  num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
303
- latents, image_latents = self.prepare_latents(
304
- image,
305
  batch_size * num_videos_per_prompt,
306
  num_channels_latents,
307
  height,
@@ -316,7 +315,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
316
  self.text_encoder.to("cpu")
317
  torch.cuda.empty_cache()
318
  torch.cuda.reset_peak_memory_stats()
319
- '''
320
  # add image latents
321
  if image is not None:
322
  image_latents = self.image_latents(
@@ -326,7 +325,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
326
  image_latents = image_latents.to(transformer_dtype)
327
  else:
328
  image_latents = None
329
- '''
330
  # 6. Prepare guidance condition
331
  if self.do_classifier_free_guidance:
332
  guidance = (
@@ -361,7 +360,7 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
361
  latent_image_input = (
362
  torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
363
  )
364
- latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1).to('cuda',torch.bfloat16)
365
  timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
366
  if cfg_for and self.do_classifier_free_guidance:
367
  noise_pred_list = []
@@ -434,4 +433,4 @@ class SkyreelsVideoPipeline(HunyuanSkyreelsImageToVideoPipeline):
434
  if not return_dict:
435
  return (video,)
436
 
437
- return HunyuanVideoPipelineOutput(frames=video)
 
7
 
8
  import numpy as np
9
  import torch
10
+ from diffusers import HunyuanVideoPipeline
11
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
12
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import HunyuanVideoPipelineOutput
13
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import MultiPipelineCallbacks
14
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import PipelineCallback
15
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps
16
  from PIL import Image
17
  #import gc
18
 
 
46
  return noise_cfg
47
 
48
 
49
+ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
50
  """
51
  support i2v and t2v
52
  support true_cfg
 
300
  device, dtype=prompt_embeds.dtype
301
  )
302
  num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
303
+ latents = self.prepare_latents(
 
304
  batch_size * num_videos_per_prompt,
305
  num_channels_latents,
306
  height,
 
315
  self.text_encoder.to("cpu")
316
  torch.cuda.empty_cache()
317
  torch.cuda.reset_peak_memory_stats()
318
+
319
  # add image latents
320
  if image is not None:
321
  image_latents = self.image_latents(
 
325
  image_latents = image_latents.to(transformer_dtype)
326
  else:
327
  image_latents = None
328
+
329
  # 6. Prepare guidance condition
330
  if self.do_classifier_free_guidance:
331
  guidance = (
 
360
  latent_image_input = (
361
  torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
362
  )
363
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
364
  timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
365
  if cfg_for and self.do_classifier_free_guidance:
366
  noise_pred_list = []
 
433
  if not return_dict:
434
  return (video,)
435
 
436
+ return HunyuanVideoPipelineOutput(frames=video)