1inkusFace commited on
Commit
526cdb6
·
verified ·
1 Parent(s): 5f4f000

Update skyreelsinfer/pipelines/pipeline_skyreels_video.py

Browse files
skyreelsinfer/pipelines/pipeline_skyreels_video.py CHANGED
@@ -1,425 +1,430 @@
1
- from typing import Any
2
- from typing import Callable
3
- from typing import Dict
4
- from typing import List
5
- from typing import Optional
6
- from typing import Union
7
-
8
- import numpy as np
9
- import torch
10
- from diffusers import HunyuanVideoPipeline
11
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
12
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import HunyuanVideoPipelineOutput
13
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import MultiPipelineCallbacks
14
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import PipelineCallback
15
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps
16
- from PIL import Image
17
-
18
-
19
- def resizecrop(image, th, tw):
20
- w, h = image.size
21
- if h / w > th / tw:
22
- new_w = int(w)
23
- new_h = int(new_w * th / tw)
24
- else:
25
- new_h = int(h)
26
- new_w = int(new_h * tw / th)
27
- left = (w - new_w) / 2
28
- top = (h - new_h) / 2
29
- right = (w + new_w) / 2
30
- bottom = (h + new_h) / 2
31
- image = image.crop((left, top, right, bottom))
32
- return image
33
-
34
-
35
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
36
- """
37
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
38
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
39
- """
40
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
41
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
42
- # rescale the results from guidance (fixes overexposure)
43
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
44
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
45
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
46
- return noise_cfg
47
-
48
-
49
- class SkyreelsVideoPipeline(HunyuanVideoPipeline):
50
- """
51
- support i2v and t2v
52
- support true_cfg
53
- """
54
-
55
- @property
56
- def guidance_rescale(self):
57
- return self._guidance_rescale
58
-
59
- @property
60
- def clip_skip(self):
61
- return self._clip_skip
62
-
63
- @property
64
- def do_classifier_free_guidance(self):
65
- # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
66
- return self._guidance_scale > 1
67
-
68
- def encode_prompt(
69
- self,
70
- prompt: Union[str, List[str]],
71
- do_classifier_free_guidance: bool,
72
- negative_prompt: str = "",
73
- prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
74
- num_videos_per_prompt: int = 1,
75
- prompt_embeds: Optional[torch.Tensor] = None,
76
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
77
- prompt_attention_mask: Optional[torch.Tensor] = None,
78
- negative_prompt_embeds: Optional[torch.Tensor] = None,
79
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
80
- negative_attention_mask: Optional[torch.Tensor] = None,
81
- device: Optional[torch.device] = None,
82
- dtype: Optional[torch.dtype] = None,
83
- max_sequence_length: int = 256,
84
- ):
85
- num_hidden_layers_to_skip = self.clip_skip if self.clip_skip is not None else 0
86
- print(f"num_hidden_layers_to_skip: {num_hidden_layers_to_skip}")
87
- if prompt_embeds is None:
88
- prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
89
- prompt,
90
- prompt_template,
91
- num_videos_per_prompt,
92
- device=device,
93
- dtype=dtype,
94
- num_hidden_layers_to_skip=num_hidden_layers_to_skip,
95
- max_sequence_length=max_sequence_length,
96
- )
97
- if negative_prompt_embeds is None and do_classifier_free_guidance:
98
- negative_prompt_embeds, negative_attention_mask = self._get_llama_prompt_embeds(
99
- negative_prompt,
100
- prompt_template,
101
- num_videos_per_prompt,
102
- device=device,
103
- dtype=dtype,
104
- num_hidden_layers_to_skip=num_hidden_layers_to_skip,
105
- max_sequence_length=max_sequence_length,
106
- )
107
- if self.text_encoder_2 is not None and pooled_prompt_embeds is None:
108
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
109
- prompt,
110
- num_videos_per_prompt,
111
- device=device,
112
- dtype=dtype,
113
- max_sequence_length=77,
114
- )
115
- if negative_pooled_prompt_embeds is None and do_classifier_free_guidance:
116
- negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
117
- negative_prompt,
118
- num_videos_per_prompt,
119
- device=device,
120
- dtype=dtype,
121
- max_sequence_length=77,
122
- )
123
- return (
124
- prompt_embeds,
125
- prompt_attention_mask,
126
- negative_prompt_embeds,
127
- negative_attention_mask,
128
- pooled_prompt_embeds,
129
- negative_pooled_prompt_embeds,
130
- )
131
-
132
- def image_latents(
133
- self,
134
- initial_image,
135
- batch_size,
136
- height,
137
- width,
138
- device,
139
- dtype,
140
- num_channels_latents,
141
- video_length,
142
- ):
143
- initial_image = initial_image.unsqueeze(2)
144
- image_latents = self.vae.encode(initial_image).latent_dist.sample()
145
- if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor:
146
- image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
147
- else:
148
- image_latents = image_latents * self.vae.config.scaling_factor
149
- padding_shape = (
150
- batch_size,
151
- num_channels_latents,
152
- video_length - 1,
153
- int(height) // self.vae_scale_factor_spatial,
154
- int(width) // self.vae_scale_factor_spatial,
155
- )
156
- latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
157
- image_latents = torch.cat([image_latents, latent_padding], dim=2)
158
- return image_latents
159
-
160
- @torch.no_grad()
161
- def __call__(
162
- self,
163
- prompt: str,
164
- negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
165
- height: int = 720,
166
- width: int = 1280,
167
- num_frames: int = 129,
168
- num_inference_steps: int = 50,
169
- sigmas: List[float] = None,
170
- guidance_scale: float = 1.0,
171
- num_videos_per_prompt: Optional[int] = 1,
172
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
- latents: Optional[torch.Tensor] = None,
174
- prompt_embeds: Optional[torch.Tensor] = None,
175
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
176
- prompt_attention_mask: Optional[torch.Tensor] = None,
177
- negative_prompt_embeds: Optional[torch.Tensor] = None,
178
- negative_attention_mask: Optional[torch.Tensor] = None,
179
- output_type: Optional[str] = "pil",
180
- return_dict: bool = True,
181
- attention_kwargs: Optional[Dict[str, Any]] = None,
182
- guidance_rescale: float = 0.0,
183
- clip_skip: Optional[int] = 2,
184
- callback_on_step_end: Optional[
185
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
186
- ] = None,
187
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
188
- prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
189
- max_sequence_length: int = 256,
190
- embedded_guidance_scale: Optional[float] = 6.0,
191
- image: Optional[Union[torch.Tensor, Image.Image]] = None,
192
- cfg_for: bool = False,
193
- ):
194
- if hasattr(self, "text_encoder_to_gpu"):
195
- self.text_encoder_to_gpu()
196
-
197
- if image is not None and isinstance(image, Image.Image):
198
- image = resizecrop(image, height, width)
199
-
200
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
201
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
202
-
203
- # 1. Check inputs. Raise error if not correct
204
- self.check_inputs(
205
- prompt,
206
- None,
207
- height,
208
- width,
209
- prompt_embeds,
210
- callback_on_step_end_tensor_inputs,
211
- prompt_template,
212
- )
213
- # add negative prompt check
214
- if negative_prompt is not None and negative_prompt_embeds is not None:
215
- raise ValueError(
216
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
217
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
218
- )
219
-
220
- if prompt_embeds is not None and negative_prompt_embeds is not None:
221
- if prompt_embeds.shape != negative_prompt_embeds.shape:
222
- raise ValueError(
223
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
224
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
225
- f" {negative_prompt_embeds.shape}."
226
- )
227
-
228
- self._guidance_scale = guidance_scale
229
- self._guidance_rescale = guidance_rescale
230
- self._clip_skip = clip_skip
231
- self._attention_kwargs = attention_kwargs
232
- self._interrupt = False
233
-
234
- device = self._execution_device
235
-
236
- # 2. Define call parameters
237
- if prompt is not None and isinstance(prompt, str):
238
- batch_size = 1
239
- elif prompt is not None and isinstance(prompt, list):
240
- batch_size = len(prompt)
241
- else:
242
- batch_size = prompt_embeds.shape[0]
243
-
244
- # 3. Encode input prompt
245
- (
246
- prompt_embeds,
247
- prompt_attention_mask,
248
- negative_prompt_embeds,
249
- negative_attention_mask,
250
- pooled_prompt_embeds,
251
- negative_pooled_prompt_embeds,
252
- ) = self.encode_prompt(
253
- prompt=prompt,
254
- do_classifier_free_guidance=self.do_classifier_free_guidance,
255
- negative_prompt=negative_prompt,
256
- prompt_template=prompt_template,
257
- num_videos_per_prompt=num_videos_per_prompt,
258
- prompt_embeds=prompt_embeds,
259
- prompt_attention_mask=prompt_attention_mask,
260
- negative_prompt_embeds=negative_prompt_embeds,
261
- negative_attention_mask=negative_attention_mask,
262
- device=device,
263
- max_sequence_length=max_sequence_length,
264
- )
265
-
266
- transformer_dtype = self.transformer.dtype
267
- prompt_embeds = prompt_embeds.to(transformer_dtype)
268
- prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
269
- if pooled_prompt_embeds is not None:
270
- pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
271
-
272
- ## Embeddings are concatenated to form a batch.
273
- if self.do_classifier_free_guidance:
274
- negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
275
- negative_attention_mask = negative_attention_mask.to(transformer_dtype)
276
- if negative_pooled_prompt_embeds is not None:
277
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
278
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
279
- if prompt_attention_mask is not None:
280
- prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask])
281
- if pooled_prompt_embeds is not None:
282
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
283
-
284
- # 4. Prepare timesteps
285
- sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
286
- timesteps, num_inference_steps = retrieve_timesteps(
287
- self.scheduler,
288
- num_inference_steps,
289
- device,
290
- sigmas=sigmas,
291
- )
292
-
293
- # 5. Prepare latent variables
294
- num_channels_latents = self.transformer.config.in_channels
295
- if image is not None:
296
- num_channels_latents = int(num_channels_latents / 2)
297
- image = self.video_processor.preprocess(image, height=height, width=width).to(
298
- device, dtype=prompt_embeds.dtype
299
- )
300
- num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
301
- latents = self.prepare_latents(
302
- batch_size * num_videos_per_prompt,
303
- num_channels_latents,
304
- height,
305
- width,
306
- num_latent_frames,
307
- torch.float32,
308
- device,
309
- generator,
310
- latents,
311
- )
312
- # add image latents
313
- if image is not None:
314
- image_latents = self.image_latents(
315
- image, batch_size, height, width, device, torch.float32, num_channels_latents, num_latent_frames
316
- )
317
-
318
- image_latents = image_latents.to(transformer_dtype)
319
- else:
320
- image_latents = None
321
-
322
- # 6. Prepare guidance condition
323
- if self.do_classifier_free_guidance:
324
- guidance = (
325
- torch.tensor([embedded_guidance_scale] * latents.shape[0] * 2, dtype=transformer_dtype, device=device)
326
- * 1000.0
327
- )
328
- else:
329
- guidance = (
330
- torch.tensor([embedded_guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device)
331
- * 1000.0
332
- )
333
-
334
- # 7. Denoising loop
335
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
336
- self._num_timesteps = len(timesteps)
337
-
338
- if hasattr(self, "text_encoder_to_cpu"):
339
- self.text_encoder_to_cpu()
340
-
341
- with self.progress_bar(total=num_inference_steps) as progress_bar:
342
- for i, t in enumerate(timesteps):
343
- if self.interrupt:
344
- continue
345
-
346
- latents = latents.to(transformer_dtype)
347
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
348
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
349
- # timestep = t.expand(latents.shape[0]).to(latents.dtype)
350
- if image_latents is not None:
351
- latent_image_input = (
352
- torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
353
- )
354
- latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
355
- timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
356
- if cfg_for and self.do_classifier_free_guidance:
357
- noise_pred_list = []
358
- for idx in range(latent_model_input.shape[0]):
359
- noise_pred_uncond = self.transformer(
360
- hidden_states=latent_model_input[idx].unsqueeze(0),
361
- timestep=timestep[idx].unsqueeze(0),
362
- encoder_hidden_states=prompt_embeds[idx].unsqueeze(0),
363
- encoder_attention_mask=prompt_attention_mask[idx].unsqueeze(0),
364
- pooled_projections=pooled_prompt_embeds[idx].unsqueeze(0),
365
- guidance=guidance[idx].unsqueeze(0),
366
- attention_kwargs=attention_kwargs,
367
- return_dict=False,
368
- )[0]
369
- noise_pred_list.append(noise_pred_uncond)
370
- noise_pred = torch.cat(noise_pred_list, dim=0)
371
- else:
372
- noise_pred = self.transformer(
373
- hidden_states=latent_model_input,
374
- timestep=timestep,
375
- encoder_hidden_states=prompt_embeds,
376
- encoder_attention_mask=prompt_attention_mask,
377
- pooled_projections=pooled_prompt_embeds,
378
- guidance=guidance,
379
- attention_kwargs=attention_kwargs,
380
- return_dict=False,
381
- )[0]
382
-
383
- # perform guidance
384
- if self.do_classifier_free_guidance:
385
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
386
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
387
-
388
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
389
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
390
- noise_pred = rescale_noise_cfg(
391
- noise_pred,
392
- noise_pred_text,
393
- guidance_rescale=self.guidance_rescale,
394
- )
395
-
396
- # compute the previous noisy sample x_t -> x_t-1
397
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
398
-
399
- if callback_on_step_end is not None:
400
- callback_kwargs = {}
401
- for k in callback_on_step_end_tensor_inputs:
402
- callback_kwargs[k] = locals()[k]
403
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
404
-
405
- latents = callback_outputs.pop("latents", latents)
406
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
407
-
408
- # call the callback, if provided
409
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
410
- progress_bar.update()
411
-
412
- if not output_type == "latent":
413
- latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
414
- video = self.vae.decode(latents, return_dict=False)[0]
415
- video = self.video_processor.postprocess_video(video, output_type=output_type)
416
- else:
417
- video = latents
418
-
419
- # Offload all models
420
- self.maybe_free_model_hooks()
421
-
422
- if not return_dict:
423
- return (video,)
424
-
425
- return HunyuanVideoPipelineOutput(frames=video)
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Optional
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers import HunyuanVideoPipeline
11
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
12
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import HunyuanVideoPipelineOutput
13
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import MultiPipelineCallbacks
14
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import PipelineCallback
15
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps
16
+ from PIL import Image
17
+
18
+
19
+ def resizecrop(image, th, tw):
20
+ w, h = image.size
21
+ if h / w > th / tw:
22
+ new_w = int(w)
23
+ new_h = int(new_w * th / tw)
24
+ else:
25
+ new_h = int(h)
26
+ new_w = int(new_h * tw / th)
27
+ left = (w - new_w) / 2
28
+ top = (h - new_h) / 2
29
+ right = (w + new_w) / 2
30
+ bottom = (h + new_h) / 2
31
+ image = image.crop((left, top, right, bottom))
32
+ return image
33
+
34
+
35
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
36
+ """
37
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
38
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
39
+ """
40
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
41
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
42
+ # rescale the results from guidance (fixes overexposure)
43
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
44
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
45
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
46
+ return noise_cfg
47
+
48
+
49
+ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
50
+ """
51
+ support i2v and t2v
52
+ support true_cfg
53
+ """
54
+
55
+ @property
56
+ def guidance_rescale(self):
57
+ return self._guidance_rescale
58
+
59
+ @property
60
+ def clip_skip(self):
61
+ return self._clip_skip
62
+
63
+ @property
64
+ def do_classifier_free_guidance(self):
65
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
66
+ return self._guidance_scale > 1
67
+
68
+ def encode_prompt(
69
+ self,
70
+ prompt: Union[str, List[str]],
71
+ do_classifier_free_guidance: bool,
72
+ negative_prompt: str = "",
73
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
74
+ num_videos_per_prompt: int = 1,
75
+ prompt_embeds: Optional[torch.Tensor] = None,
76
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
77
+ prompt_attention_mask: Optional[torch.Tensor] = None,
78
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
79
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
80
+ negative_attention_mask: Optional[torch.Tensor] = None,
81
+ device: Optional[torch.device] = None,
82
+ dtype: Optional[torch.dtype] = None,
83
+ max_sequence_length: int = 256,
84
+ ):
85
+ num_hidden_layers_to_skip = self.clip_skip if self.clip_skip is not None else 0
86
+ print(f"num_hidden_layers_to_skip: {num_hidden_layers_to_skip}")
87
+ if prompt_embeds is None:
88
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
89
+ prompt,
90
+ prompt_template,
91
+ num_videos_per_prompt,
92
+ device=device,
93
+ dtype=dtype,
94
+ num_hidden_layers_to_skip=num_hidden_layers_to_skip,
95
+ max_sequence_length=max_sequence_length,
96
+ )
97
+ if negative_prompt_embeds is None and do_classifier_free_guidance:
98
+ negative_prompt_embeds, negative_attention_mask = self._get_llama_prompt_embeds(
99
+ negative_prompt,
100
+ prompt_template,
101
+ num_videos_per_prompt,
102
+ device=device,
103
+ dtype=dtype,
104
+ num_hidden_layers_to_skip=num_hidden_layers_to_skip,
105
+ max_sequence_length=max_sequence_length,
106
+ )
107
+ if self.text_encoder_2 is not None and pooled_prompt_embeds is None:
108
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
109
+ prompt,
110
+ num_videos_per_prompt,
111
+ device=device,
112
+ dtype=dtype,
113
+ max_sequence_length=77,
114
+ )
115
+ if negative_pooled_prompt_embeds is None and do_classifier_free_guidance:
116
+ negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
117
+ negative_prompt,
118
+ num_videos_per_prompt,
119
+ device=device,
120
+ dtype=dtype,
121
+ max_sequence_length=77,
122
+ )
123
+ return (
124
+ prompt_embeds,
125
+ prompt_attention_mask,
126
+ negative_prompt_embeds,
127
+ negative_attention_mask,
128
+ pooled_prompt_embeds,
129
+ negative_pooled_prompt_embeds,
130
+ )
131
+
132
+ def image_latents(
133
+ self,
134
+ initial_image,
135
+ batch_size,
136
+ height,
137
+ width,
138
+ device,
139
+ dtype,
140
+ num_channels_latents,
141
+ video_length,
142
+ ):
143
+ initial_image = initial_image.unsqueeze(2)
144
+ image_latents = self.vae.encode(initial_image).latent_dist.sample()
145
+ if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor:
146
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
147
+ else:
148
+ image_latents = image_latents * self.vae.config.scaling_factor
149
+ padding_shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ video_length - 1,
153
+ int(height) // self.vae_scale_factor_spatial,
154
+ int(width) // self.vae_scale_factor_spatial,
155
+ )
156
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
157
+ image_latents = torch.cat([image_latents, latent_padding], dim=2)
158
+ return image_latents
159
+
160
+ @torch.no_grad()
161
+ def __call__(
162
+ self,
163
+ prompt: str,
164
+ negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
165
+ height: int = 720,
166
+ width: int = 1280,
167
+ num_frames: int = 129,
168
+ num_inference_steps: int = 50,
169
+ sigmas: List[float] = None,
170
+ guidance_scale: float = 1.0,
171
+ num_videos_per_prompt: Optional[int] = 1,
172
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
+ latents: Optional[torch.Tensor] = None,
174
+ prompt_embeds: Optional[torch.Tensor] = None,
175
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
176
+ prompt_attention_mask: Optional[torch.Tensor] = None,
177
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
178
+ negative_attention_mask: Optional[torch.Tensor] = None,
179
+ output_type: Optional[str] = "pil",
180
+ return_dict: bool = True,
181
+ attention_kwargs: Optional[Dict[str, Any]] = None,
182
+ guidance_rescale: float = 0.0,
183
+ clip_skip: Optional[int] = 2,
184
+ callback_on_step_end: Optional[
185
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
186
+ ] = None,
187
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
188
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
189
+ max_sequence_length: int = 256,
190
+ embedded_guidance_scale: Optional[float] = 6.0,
191
+ image: Optional[Union[torch.Tensor, Image.Image]] = None,
192
+ cfg_for: bool = False,
193
+ ):
194
+ if hasattr(self, "text_encoder_to_gpu"):
195
+ self.text_encoder_to_gpu()
196
+
197
+ if image is not None and isinstance(image, Image.Image):
198
+ image = resizecrop(image, height, width)
199
+
200
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
201
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
202
+
203
+ # 1. Check inputs. Raise error if not correct
204
+ self.check_inputs(
205
+ prompt,
206
+ None,
207
+ height,
208
+ width,
209
+ prompt_embeds,
210
+ callback_on_step_end_tensor_inputs,
211
+ prompt_template,
212
+ )
213
+ # add negative prompt check
214
+ if negative_prompt is not None and negative_prompt_embeds is not None:
215
+ raise ValueError(
216
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
217
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
218
+ )
219
+
220
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
221
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
222
+ raise ValueError(
223
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
224
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
225
+ f" {negative_prompt_embeds.shape}."
226
+ )
227
+
228
+ self._guidance_scale = guidance_scale
229
+ self._guidance_rescale = guidance_rescale
230
+ self._clip_skip = clip_skip
231
+ self._attention_kwargs = attention_kwargs
232
+ self._interrupt = False
233
+
234
+ device = self._execution_device
235
+
236
+ # 2. Define call parameters
237
+ if prompt is not None and isinstance(prompt, str):
238
+ batch_size = 1
239
+ elif prompt is not None and isinstance(prompt, list):
240
+ batch_size = len(prompt)
241
+ else:
242
+ batch_size = prompt_embeds.shape[0]
243
+ pipe.text_encoder.to("cuda")
244
+
245
+ # 3. Encode input prompt
246
+ (
247
+ prompt_embeds,
248
+ prompt_attention_mask,
249
+ negative_prompt_embeds,
250
+ negative_attention_mask,
251
+ pooled_prompt_embeds,
252
+ negative_pooled_prompt_embeds,
253
+ ) = self.encode_prompt(
254
+ prompt=prompt,
255
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
256
+ negative_prompt=negative_prompt,
257
+ prompt_template=prompt_template,
258
+ num_videos_per_prompt=num_videos_per_prompt,
259
+ prompt_embeds=prompt_embeds,
260
+ prompt_attention_mask=prompt_attention_mask,
261
+ negative_prompt_embeds=negative_prompt_embeds,
262
+ negative_attention_mask=negative_attention_mask,
263
+ device=device,
264
+ max_sequence_length=max_sequence_length,
265
+ )
266
+
267
+ transformer_dtype = self.transformer.dtype
268
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
269
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
270
+ if pooled_prompt_embeds is not None:
271
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
272
+
273
+ ## Embeddings are concatenated to form a batch.
274
+ if self.do_classifier_free_guidance:
275
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
276
+ negative_attention_mask = negative_attention_mask.to(transformer_dtype)
277
+ if negative_pooled_prompt_embeds is not None:
278
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
279
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
280
+ if prompt_attention_mask is not None:
281
+ prompt_attention_mask = torch.cat([negative_attention_mask, prompt_attention_mask])
282
+ if pooled_prompt_embeds is not None:
283
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
284
+
285
+ # 4. Prepare timesteps
286
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
287
+ timesteps, num_inference_steps = retrieve_timesteps(
288
+ self.scheduler,
289
+ num_inference_steps,
290
+ device,
291
+ sigmas=sigmas,
292
+ )
293
+
294
+ # 5. Prepare latent variables
295
+ num_channels_latents = self.transformer.config.in_channels
296
+ if image is not None:
297
+ num_channels_latents = int(num_channels_latents / 2)
298
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
299
+ device, dtype=prompt_embeds.dtype
300
+ )
301
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
302
+ latents = self.prepare_latents(
303
+ batch_size * num_videos_per_prompt,
304
+ num_channels_latents,
305
+ height,
306
+ width,
307
+ num_latent_frames,
308
+ torch.float32,
309
+ device,
310
+ generator,
311
+ latents,
312
+ )
313
+ # add image latents
314
+ if image is not None:
315
+ image_latents = self.image_latents(
316
+ image, batch_size, height, width, device, torch.float32, num_channels_latents, num_latent_frames
317
+ )
318
+
319
+ image_latents = image_latents.to(transformer_dtype)
320
+ else:
321
+ image_latents = None
322
+
323
+ # 6. Prepare guidance condition
324
+ if self.do_classifier_free_guidance:
325
+ guidance = (
326
+ torch.tensor([embedded_guidance_scale] * latents.shape[0] * 2, dtype=transformer_dtype, device=device)
327
+ * 1000.0
328
+ )
329
+ else:
330
+ guidance = (
331
+ torch.tensor([embedded_guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device)
332
+ * 1000.0
333
+ )
334
+
335
+ # 7. Denoising loop
336
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
337
+ self._num_timesteps = len(timesteps)
338
+
339
+ if hasattr(self, "text_encoder_to_cpu"):
340
+ self.text_encoder_to_cpu()
341
+ pipe.text_encoder.to("cpu")
342
+ pipe.vae.to("cpu")
343
+ torch.cuda.empty_cache()
344
+
345
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
346
+ for i, t in enumerate(timesteps):
347
+ if self.interrupt:
348
+ continue
349
+
350
+ latents = latents.to(transformer_dtype)
351
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
352
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
353
+ # timestep = t.expand(latents.shape[0]).to(latents.dtype)
354
+ if image_latents is not None:
355
+ latent_image_input = (
356
+ torch.cat([image_latents] * 2) if self.do_classifier_free_guidance else image_latents
357
+ )
358
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
359
+ timestep = t.repeat(latent_model_input.shape[0]).to(torch.float32)
360
+ if cfg_for and self.do_classifier_free_guidance:
361
+ noise_pred_list = []
362
+ for idx in range(latent_model_input.shape[0]):
363
+ noise_pred_uncond = self.transformer(
364
+ hidden_states=latent_model_input[idx].unsqueeze(0),
365
+ timestep=timestep[idx].unsqueeze(0),
366
+ encoder_hidden_states=prompt_embeds[idx].unsqueeze(0),
367
+ encoder_attention_mask=prompt_attention_mask[idx].unsqueeze(0),
368
+ pooled_projections=pooled_prompt_embeds[idx].unsqueeze(0),
369
+ guidance=guidance[idx].unsqueeze(0),
370
+ attention_kwargs=attention_kwargs,
371
+ return_dict=False,
372
+ )[0]
373
+ noise_pred_list.append(noise_pred_uncond)
374
+ noise_pred = torch.cat(noise_pred_list, dim=0)
375
+ else:
376
+ noise_pred = self.transformer(
377
+ hidden_states=latent_model_input,
378
+ timestep=timestep,
379
+ encoder_hidden_states=prompt_embeds,
380
+ encoder_attention_mask=prompt_attention_mask,
381
+ pooled_projections=pooled_prompt_embeds,
382
+ guidance=guidance,
383
+ attention_kwargs=attention_kwargs,
384
+ return_dict=False,
385
+ )[0]
386
+
387
+ # perform guidance
388
+ if self.do_classifier_free_guidance:
389
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
390
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
391
+
392
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
393
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
394
+ noise_pred = rescale_noise_cfg(
395
+ noise_pred,
396
+ noise_pred_text,
397
+ guidance_rescale=self.guidance_rescale,
398
+ )
399
+
400
+ # compute the previous noisy sample x_t -> x_t-1
401
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
402
+
403
+ if callback_on_step_end is not None:
404
+ callback_kwargs = {}
405
+ for k in callback_on_step_end_tensor_inputs:
406
+ callback_kwargs[k] = locals()[k]
407
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
408
+
409
+ latents = callback_outputs.pop("latents", latents)
410
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
411
+
412
+ # call the callback, if provided
413
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
414
+ progress_bar.update()
415
+
416
+ if not output_type == "latent":
417
+ pipe.vae.to("cuda")
418
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
419
+ video = self.vae.decode(latents, return_dict=False)[0]
420
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
421
+ else:
422
+ video = latents
423
+
424
+ # Offload all models
425
+ self.maybe_free_model_hooks()
426
+
427
+ if not return_dict:
428
+ return (video,)
429
+
430
+ return HunyuanVideoPipelineOutput(frames=video)