bghira commited on
Commit
c5d814d
1 Parent(s): 6f61d4c

Create custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +942 -0
custom_pipeline.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ T5EncoderModel,
24
+ T5TokenizerFast,
25
+ )
26
+
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.loaders import FluxLoraLoaderMixin
29
+ from diffusers.models.autoencoders import AutoencoderKL
30
+ from diffusers.models.transformers import FluxTransformer2DModel
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
+ from diffusers.utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import randn_tensor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+
43
+
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import FluxPipeline
59
+
60
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
61
+ >>> pipe.to("cuda")
62
+ >>> prompt = "A cat holding a sign that says hello world"
63
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
64
+ >>> # Refer to the pipeline documentation for more details.
65
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
66
+ >>> image.save("flux.png")
67
+ ```
68
+ """
69
+
70
+
71
+ def calculate_shift(
72
+ image_seq_len,
73
+ base_seq_len: int = 256,
74
+ max_seq_len: int = 4096,
75
+ base_shift: float = 0.5,
76
+ max_shift: float = 1.16,
77
+ ):
78
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
79
+ b = base_shift - m * base_seq_len
80
+ mu = image_seq_len * m + b
81
+ return mu
82
+
83
+
84
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
85
+ def retrieve_timesteps(
86
+ scheduler,
87
+ num_inference_steps: Optional[int] = None,
88
+ device: Optional[Union[str, torch.device]] = None,
89
+ timesteps: Optional[List[int]] = None,
90
+ sigmas: Optional[List[float]] = None,
91
+ **kwargs,
92
+ ):
93
+ """
94
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
95
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
96
+
97
+ Args:
98
+ scheduler (`SchedulerMixin`):
99
+ The scheduler to get timesteps from.
100
+ num_inference_steps (`int`):
101
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
102
+ must be `None`.
103
+ device (`str` or `torch.device`, *optional*):
104
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
105
+ timesteps (`List[int]`, *optional*):
106
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
107
+ `num_inference_steps` and `sigmas` must be `None`.
108
+ sigmas (`List[float]`, *optional*):
109
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
110
+ `num_inference_steps` and `timesteps` must be `None`.
111
+
112
+ Returns:
113
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
114
+ second element is the number of inference steps.
115
+ """
116
+ if timesteps is not None and sigmas is not None:
117
+ raise ValueError(
118
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
119
+ )
120
+ if timesteps is not None:
121
+ accepts_timesteps = "timesteps" in set(
122
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
123
+ )
124
+ if not accepts_timesteps:
125
+ raise ValueError(
126
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
127
+ f" timestep schedules. Please check whether you are using the correct scheduler."
128
+ )
129
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ num_inference_steps = len(timesteps)
132
+ elif sigmas is not None:
133
+ accept_sigmas = "sigmas" in set(
134
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
135
+ )
136
+ if not accept_sigmas:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ else:
145
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ return timesteps, num_inference_steps
148
+
149
+
150
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
151
+ r"""
152
+ The Flux pipeline for text-to-image generation.
153
+
154
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
155
+
156
+ Args:
157
+ transformer ([`FluxTransformer2DModel`]):
158
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
159
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
160
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
161
+ vae ([`AutoencoderKL`]):
162
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
163
+ text_encoder ([`CLIPTextModelWithProjection`]):
164
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
165
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
166
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
167
+ as its dimension.
168
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
169
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
170
+ specifically the
171
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
172
+ variant.
173
+ tokenizer (`CLIPTokenizer`):
174
+ Tokenizer of class
175
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
176
+ tokenizer_2 (`CLIPTokenizer`):
177
+ Second Tokenizer of class
178
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
179
+ """
180
+
181
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
182
+ _optional_components = []
183
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
184
+
185
+ def __init__(
186
+ self,
187
+ scheduler: FlowMatchEulerDiscreteScheduler,
188
+ vae: AutoencoderKL,
189
+ text_encoder: CLIPTextModel,
190
+ tokenizer: CLIPTokenizer,
191
+ text_encoder_2: T5EncoderModel,
192
+ tokenizer_2: T5TokenizerFast,
193
+ transformer: FluxTransformer2DModel,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.register_modules(
198
+ vae=vae,
199
+ text_encoder=text_encoder,
200
+ text_encoder_2=text_encoder_2,
201
+ tokenizer=tokenizer,
202
+ tokenizer_2=tokenizer_2,
203
+ transformer=transformer,
204
+ scheduler=scheduler,
205
+ )
206
+ self.vae_scale_factor = (
207
+ 2 ** (len(self.vae.config.block_out_channels))
208
+ if hasattr(self, "vae") and self.vae is not None
209
+ else 16
210
+ )
211
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
212
+ self.tokenizer_max_length = (
213
+ self.tokenizer.model_max_length
214
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
215
+ else 77
216
+ )
217
+ self.default_sample_size = 64
218
+
219
+ def _get_t5_prompt_embeds(
220
+ self,
221
+ prompt: Union[str, List[str]] = None,
222
+ num_images_per_prompt: int = 1,
223
+ max_sequence_length: int = 512,
224
+ device: Optional[torch.device] = None,
225
+ dtype: Optional[torch.dtype] = None,
226
+ ):
227
+ device = device or self._execution_device
228
+ dtype = dtype or self.text_encoder.dtype
229
+
230
+ prompt = [prompt] if isinstance(prompt, str) else prompt
231
+ batch_size = len(prompt)
232
+
233
+ text_inputs = self.tokenizer_2(
234
+ prompt,
235
+ padding="max_length",
236
+ max_length=max_sequence_length,
237
+ truncation=True,
238
+ return_length=False,
239
+ return_overflowing_tokens=False,
240
+ return_tensors="pt",
241
+ )
242
+ prompt_attention_mask = text_inputs.attention_mask
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer_2(
245
+ prompt, padding="longest", return_tensors="pt"
246
+ ).input_ids
247
+
248
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
249
+ text_input_ids, untruncated_ids
250
+ ):
251
+ removed_text = self.tokenizer_2.batch_decode(
252
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
253
+ )
254
+ # logger.warning(
255
+ # "The following part of your input was truncated because `max_sequence_length` is set to "
256
+ # f" {max_sequence_length} tokens: {removed_text}"
257
+ # )
258
+
259
+ prompt_embeds = self.text_encoder_2(
260
+ text_input_ids.to(device), output_hidden_states=False
261
+ )[0]
262
+
263
+ dtype = self.text_encoder_2.dtype
264
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
265
+
266
+ _, seq_len, _ = prompt_embeds.shape
267
+
268
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ batch_size * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ return prompt_embeds, prompt_attention_mask
275
+
276
+ def _get_clip_prompt_embeds(
277
+ self,
278
+ prompt: Union[str, List[str]],
279
+ num_images_per_prompt: int = 1,
280
+ device: Optional[torch.device] = None,
281
+ ):
282
+ device = device or self._execution_device
283
+
284
+ prompt = [prompt] if isinstance(prompt, str) else prompt
285
+ batch_size = len(prompt)
286
+
287
+ text_inputs = self.tokenizer(
288
+ prompt,
289
+ padding="max_length",
290
+ max_length=self.tokenizer_max_length,
291
+ truncation=True,
292
+ return_overflowing_tokens=False,
293
+ return_length=False,
294
+ return_tensors="pt",
295
+ )
296
+
297
+ text_input_ids = text_inputs.input_ids
298
+ untruncated_ids = self.tokenizer(
299
+ prompt, padding="longest", return_tensors="pt"
300
+ ).input_ids
301
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
302
+ text_input_ids, untruncated_ids
303
+ ):
304
+ removed_text = self.tokenizer.batch_decode(
305
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
306
+ )
307
+ # logger.warning(
308
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
309
+ # f" {self.tokenizer_max_length} tokens: {removed_text}"
310
+ # )
311
+ prompt_embeds = self.text_encoder(
312
+ text_input_ids.to(device), output_hidden_states=False
313
+ )
314
+
315
+ # Use pooled output of CLIPTextModel
316
+ prompt_embeds = prompt_embeds.pooler_output
317
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
318
+
319
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
320
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
322
+
323
+ return prompt_embeds
324
+
325
+ def encode_prompt(
326
+ self,
327
+ prompt: Union[str, List[str]],
328
+ prompt_2: Union[str, List[str]],
329
+ device: Optional[torch.device] = None,
330
+ num_images_per_prompt: int = 1,
331
+ prompt_embeds: Optional[torch.FloatTensor] = None,
332
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ max_sequence_length: int = 512,
334
+ lora_scale: Optional[float] = None,
335
+ ):
336
+ r"""
337
+
338
+ Args:
339
+ prompt (`str` or `List[str]`, *optional*):
340
+ prompt to be encoded
341
+ prompt_2 (`str` or `List[str]`, *optional*):
342
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
343
+ used in all text-encoders
344
+ device: (`torch.device`):
345
+ torch device
346
+ num_images_per_prompt (`int`):
347
+ number of images that should be generated per prompt
348
+ prompt_embeds (`torch.FloatTensor`, *optional*):
349
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
350
+ provided, text embeddings will be generated from `prompt` input argument.
351
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
352
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
353
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
354
+ clip_skip (`int`, *optional*):
355
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
356
+ the output of the pre-final layer will be used for computing the prompt embeddings.
357
+ lora_scale (`float`, *optional*):
358
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
359
+ """
360
+ device = device or self._execution_device
361
+
362
+ # set lora scale so that monkey patched LoRA
363
+ # function of text encoder can correctly access it
364
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
365
+ self._lora_scale = lora_scale
366
+
367
+ # dynamically adjust the LoRA scale
368
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
369
+ scale_lora_layers(self.text_encoder, lora_scale)
370
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
371
+ scale_lora_layers(self.text_encoder_2, lora_scale)
372
+
373
+ prompt = [prompt] if isinstance(prompt, str) else prompt
374
+ if prompt is not None:
375
+ batch_size = len(prompt)
376
+ else:
377
+ batch_size = prompt_embeds.shape[0]
378
+
379
+ prompt_attention_mask = None
380
+ if prompt_embeds is None:
381
+ prompt_2 = prompt_2 or prompt
382
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
383
+
384
+ # We only use the pooled prompt output from the CLIPTextModel
385
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
386
+ prompt=prompt,
387
+ device=device,
388
+ num_images_per_prompt=num_images_per_prompt,
389
+ )
390
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
391
+ prompt=prompt_2,
392
+ num_images_per_prompt=num_images_per_prompt,
393
+ max_sequence_length=max_sequence_length,
394
+ device=device,
395
+ )
396
+
397
+ if self.text_encoder is not None:
398
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
399
+ # Retrieve the original scale by scaling back the LoRA layers
400
+ unscale_lora_layers(self.text_encoder, lora_scale)
401
+
402
+ if self.text_encoder_2 is not None:
403
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
404
+ # Retrieve the original scale by scaling back the LoRA layers
405
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
406
+
407
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(
408
+ device=device, dtype=prompt_embeds.dtype
409
+ )
410
+
411
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
412
+
413
+ def check_inputs(
414
+ self,
415
+ prompt,
416
+ prompt_2,
417
+ height,
418
+ width,
419
+ prompt_embeds=None,
420
+ pooled_prompt_embeds=None,
421
+ callback_on_step_end_tensor_inputs=None,
422
+ max_sequence_length=None,
423
+ ):
424
+ if height % 8 != 0 or width % 8 != 0:
425
+ raise ValueError(
426
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
427
+ )
428
+
429
+ if callback_on_step_end_tensor_inputs is not None and not all(
430
+ k in self._callback_tensor_inputs
431
+ for k in callback_on_step_end_tensor_inputs
432
+ ):
433
+ raise ValueError(
434
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
435
+ )
436
+
437
+ if prompt is not None and prompt_embeds is not None:
438
+ raise ValueError(
439
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
440
+ " only forward one of the two."
441
+ )
442
+ elif prompt_2 is not None and prompt_embeds is not None:
443
+ raise ValueError(
444
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
445
+ " only forward one of the two."
446
+ )
447
+ elif prompt is None and prompt_embeds is None:
448
+ raise ValueError(
449
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
450
+ )
451
+ elif prompt is not None and (
452
+ not isinstance(prompt, str) and not isinstance(prompt, list)
453
+ ):
454
+ raise ValueError(
455
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
456
+ )
457
+ elif prompt_2 is not None and (
458
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
459
+ ):
460
+ raise ValueError(
461
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
462
+ )
463
+
464
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
465
+ raise ValueError(
466
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
467
+ )
468
+
469
+ if max_sequence_length is not None and max_sequence_length > 512:
470
+ raise ValueError(
471
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
472
+ )
473
+
474
+ @staticmethod
475
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
476
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
477
+ latent_image_ids[..., 1] = (
478
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
479
+ )
480
+ latent_image_ids[..., 2] = (
481
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
482
+ )
483
+
484
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
485
+ latent_image_ids.shape
486
+ )
487
+
488
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
489
+ latent_image_ids = latent_image_ids.reshape(
490
+ batch_size,
491
+ latent_image_id_height * latent_image_id_width,
492
+ latent_image_id_channels,
493
+ )
494
+
495
+ return latent_image_ids.to(device=device, dtype=dtype)
496
+
497
+ @staticmethod
498
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
499
+ latents = latents.view(
500
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
501
+ )
502
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
503
+ latents = latents.reshape(
504
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
505
+ )
506
+
507
+ return latents
508
+
509
+ @staticmethod
510
+ def _unpack_latents(latents, height, width, vae_scale_factor):
511
+ batch_size, num_patches, channels = latents.shape
512
+
513
+ height = height // vae_scale_factor
514
+ width = width // vae_scale_factor
515
+
516
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
517
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
518
+
519
+ latents = latents.reshape(
520
+ batch_size, channels // (2 * 2), height * 2, width * 2
521
+ )
522
+
523
+ return latents
524
+
525
+ def prepare_latents(
526
+ self,
527
+ batch_size,
528
+ num_channels_latents,
529
+ height,
530
+ width,
531
+ dtype,
532
+ device,
533
+ generator,
534
+ latents=None,
535
+ ):
536
+ height = 2 * (int(height) // self.vae_scale_factor)
537
+ width = 2 * (int(width) // self.vae_scale_factor)
538
+
539
+ shape = (batch_size, num_channels_latents, height, width)
540
+
541
+ if latents is not None:
542
+ latent_image_ids = self._prepare_latent_image_ids(
543
+ batch_size, height, width, device, dtype
544
+ )
545
+ return latents.to(device=device, dtype=dtype), latent_image_ids
546
+
547
+ if isinstance(generator, list) and len(generator) != batch_size:
548
+ raise ValueError(
549
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
550
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
551
+ )
552
+
553
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
554
+ latents = self._pack_latents(
555
+ latents, batch_size, num_channels_latents, height, width
556
+ )
557
+
558
+ latent_image_ids = self._prepare_latent_image_ids(
559
+ batch_size, height, width, device, dtype
560
+ )
561
+
562
+ return latents, latent_image_ids
563
+
564
+ @property
565
+ def guidance_scale(self):
566
+ return self._guidance_scale
567
+
568
+ @property
569
+ def joint_attention_kwargs(self):
570
+ return self._joint_attention_kwargs
571
+
572
+ @property
573
+ def num_timesteps(self):
574
+ return self._num_timesteps
575
+
576
+ @property
577
+ def interrupt(self):
578
+ return self._interrupt
579
+
580
+ @torch.no_grad()
581
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
582
+ def __call__(
583
+ self,
584
+ prompt: Union[str, List[str]] = None,
585
+ prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
586
+ negative_mask: Optional[
587
+ Union[torch.FloatTensor, List[torch.FloatTensor]]
588
+ ] = None,
589
+ prompt_2: Optional[Union[str, List[str]]] = None,
590
+ height: Optional[int] = None,
591
+ width: Optional[int] = None,
592
+ num_inference_steps: int = 28,
593
+ timesteps: List[int] = None,
594
+ guidance_scale: float = 3.5,
595
+ num_images_per_prompt: Optional[int] = 1,
596
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
597
+ latents: Optional[torch.FloatTensor] = None,
598
+ prompt_embeds: Optional[torch.FloatTensor] = None,
599
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
600
+ output_type: Optional[str] = "pil",
601
+ return_dict: bool = True,
602
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
603
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
604
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
605
+ max_sequence_length: int = 512,
606
+ guidance_scale_real: float = 1.0,
607
+ negative_prompt: Union[str, List[str]] = "",
608
+ negative_prompt_2: Union[str, List[str]] = "",
609
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
611
+ no_cfg_until_timestep: int = 2,
612
+ ):
613
+ r"""
614
+ Function invoked when calling the pipeline for generation.
615
+
616
+ Args:
617
+ prompt (`str` or `List[str]`, *optional*):
618
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
619
+ instead.
620
+ prompt_mask (`str` or `List[str]`, *optional*):
621
+ The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
622
+ instead.
623
+ prompt_2 (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
625
+ will be used instead
626
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
627
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
628
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
629
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
630
+ num_inference_steps (`int`, *optional*, defaults to 50):
631
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
632
+ expense of slower inference.
633
+ timesteps (`List[int]`, *optional*):
634
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
635
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
636
+ passed will be used. Must be in descending order.
637
+ guidance_scale (`float`, *optional*, defaults to 7.0):
638
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
639
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
640
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
641
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
642
+ usually at the expense of lower image quality.
643
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
644
+ The number of images to generate per prompt.
645
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
646
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
647
+ to make generation deterministic.
648
+ latents (`torch.FloatTensor`, *optional*):
649
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
650
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
651
+ tensor will ge generated by sampling using the supplied random `generator`.
652
+ prompt_embeds (`torch.FloatTensor`, *optional*):
653
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
654
+ provided, text embeddings will be generated from `prompt` input argument.
655
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
656
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
657
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
658
+ output_type (`str`, *optional*, defaults to `"pil"`):
659
+ The output format of the generate image. Choose between
660
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
661
+ return_dict (`bool`, *optional*, defaults to `True`):
662
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
663
+ joint_attention_kwargs (`dict`, *optional*):
664
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
665
+ `self.processor` in
666
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
667
+ callback_on_step_end (`Callable`, *optional*):
668
+ A function that calls at the end of each denoising steps during the inference. The function is called
669
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
670
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
671
+ `callback_on_step_end_tensor_inputs`.
672
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
673
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
674
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
675
+ `._callback_tensor_inputs` attribute of your pipeline class.
676
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
677
+
678
+ Examples:
679
+
680
+ Returns:
681
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
682
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
683
+ images.
684
+ """
685
+
686
+ height = height or self.default_sample_size * self.vae_scale_factor
687
+ width = width or self.default_sample_size * self.vae_scale_factor
688
+
689
+ # 1. Check inputs. Raise error if not correct
690
+ self.check_inputs(
691
+ prompt,
692
+ prompt_2,
693
+ height,
694
+ width,
695
+ prompt_embeds=prompt_embeds,
696
+ pooled_prompt_embeds=pooled_prompt_embeds,
697
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
698
+ max_sequence_length=max_sequence_length,
699
+ )
700
+
701
+ self._guidance_scale = guidance_scale
702
+ self._guidance_scale_real = guidance_scale_real
703
+ self._joint_attention_kwargs = joint_attention_kwargs
704
+ self._interrupt = False
705
+
706
+ # 2. Define call parameters
707
+ if prompt is not None and isinstance(prompt, str):
708
+ batch_size = 1
709
+ elif prompt is not None and isinstance(prompt, list):
710
+ batch_size = len(prompt)
711
+ else:
712
+ batch_size = prompt_embeds.shape[0]
713
+
714
+ device = self._execution_device
715
+
716
+ lora_scale = (
717
+ self.joint_attention_kwargs.get("scale", None)
718
+ if self.joint_attention_kwargs is not None
719
+ else None
720
+ )
721
+ (
722
+ prompt_embeds,
723
+ pooled_prompt_embeds,
724
+ text_ids,
725
+ _,
726
+ ) = self.encode_prompt(
727
+ prompt=prompt,
728
+ prompt_2=prompt_2,
729
+ prompt_embeds=prompt_embeds,
730
+ pooled_prompt_embeds=pooled_prompt_embeds,
731
+ device=device,
732
+ num_images_per_prompt=num_images_per_prompt,
733
+ max_sequence_length=max_sequence_length,
734
+ lora_scale=lora_scale,
735
+ )
736
+
737
+ if negative_prompt_2 == "" and negative_prompt != "":
738
+ negative_prompt_2 = negative_prompt
739
+
740
+ negative_text_ids = text_ids
741
+ if guidance_scale_real > 1.0 and (
742
+ negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
743
+ ):
744
+ (
745
+ negative_prompt_embeds,
746
+ negative_pooled_prompt_embeds,
747
+ negative_text_ids,
748
+ _,
749
+ ) = self.encode_prompt(
750
+ prompt=negative_prompt,
751
+ prompt_2=negative_prompt_2,
752
+ prompt_embeds=None,
753
+ pooled_prompt_embeds=None,
754
+ device=device,
755
+ num_images_per_prompt=num_images_per_prompt,
756
+ max_sequence_length=max_sequence_length,
757
+ lora_scale=lora_scale,
758
+ )
759
+
760
+ # 4. Prepare latent variables
761
+ num_channels_latents = self.transformer.config.in_channels // 4
762
+ latents, latent_image_ids = self.prepare_latents(
763
+ batch_size * num_images_per_prompt,
764
+ num_channels_latents,
765
+ height,
766
+ width,
767
+ prompt_embeds.dtype,
768
+ device,
769
+ generator,
770
+ latents,
771
+ )
772
+
773
+ # 5. Prepare timesteps
774
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
775
+ image_seq_len = latents.shape[1]
776
+ mu = calculate_shift(
777
+ image_seq_len,
778
+ self.scheduler.config.base_image_seq_len,
779
+ self.scheduler.config.max_image_seq_len,
780
+ self.scheduler.config.base_shift,
781
+ self.scheduler.config.max_shift,
782
+ )
783
+ timesteps, num_inference_steps = retrieve_timesteps(
784
+ self.scheduler,
785
+ num_inference_steps,
786
+ device,
787
+ timesteps,
788
+ sigmas,
789
+ mu=mu,
790
+ )
791
+ num_warmup_steps = max(
792
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
793
+ )
794
+ self._num_timesteps = len(timesteps)
795
+
796
+ latents = latents.to(self.transformer.device)
797
+ latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
798
+ timesteps = timesteps.to(self.transformer.device)
799
+ text_ids = text_ids.to(self.transformer.device)[0]
800
+
801
+ # 6. Denoising loop
802
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
803
+ for i, t in enumerate(timesteps):
804
+ if self.interrupt:
805
+ continue
806
+
807
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
808
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
809
+
810
+ # handle guidance
811
+ if self.transformer.config.guidance_embeds:
812
+ guidance = torch.tensor(
813
+ [guidance_scale], device=self.transformer.device
814
+ )
815
+ guidance = guidance.expand(latents.shape[0])
816
+ else:
817
+ guidance = None
818
+
819
+ extra_transformer_args = {}
820
+ if prompt_mask is not None:
821
+ extra_transformer_args["attention_mask"] = prompt_mask.to(
822
+ device=self.transformer.device
823
+ )
824
+
825
+ noise_pred = self.transformer(
826
+ hidden_states=latents.to(
827
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
828
+ ),
829
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
830
+ timestep=timestep / 1000,
831
+ guidance=guidance,
832
+ pooled_projections=pooled_prompt_embeds.to(
833
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
834
+ ),
835
+ encoder_hidden_states=prompt_embeds.to(
836
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
837
+ ),
838
+ txt_ids=text_ids,
839
+ img_ids=latent_image_ids,
840
+ joint_attention_kwargs=self.joint_attention_kwargs,
841
+ return_dict=False,
842
+ **extra_transformer_args,
843
+ )[0]
844
+
845
+ # TODO optionally use batch prediction to speed this up.
846
+ if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
847
+ noise_pred_uncond = self.transformer(
848
+ hidden_states=latents.to(
849
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
850
+ ),
851
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
852
+ timestep=timestep / 1000,
853
+ guidance=guidance,
854
+ pooled_projections=negative_pooled_prompt_embeds.to(
855
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
856
+ ),
857
+ encoder_hidden_states=negative_prompt_embeds.to(
858
+ device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
859
+ ),
860
+ txt_ids=negative_text_ids.to(device=self.transformer.device),
861
+ img_ids=latent_image_ids.to(device=self.transformer.device),
862
+ joint_attention_kwargs=self.joint_attention_kwargs,
863
+ return_dict=False,
864
+ )[0]
865
+
866
+ noise_pred = noise_pred_uncond + guidance_scale_real * (
867
+ noise_pred - noise_pred_uncond
868
+ )
869
+
870
+ # compute the previous noisy sample x_t -> x_t-1
871
+ latents_dtype = latents.dtype
872
+ latents = self.scheduler.step(
873
+ noise_pred, t, latents, return_dict=False
874
+ )[0]
875
+
876
+ if latents.dtype != latents_dtype:
877
+ if torch.backends.mps.is_available():
878
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
879
+ latents = latents.to(latents_dtype)
880
+
881
+ if callback_on_step_end is not None:
882
+ callback_kwargs = {}
883
+ for k in callback_on_step_end_tensor_inputs:
884
+ callback_kwargs[k] = locals()[k]
885
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
886
+
887
+ latents = callback_outputs.pop("latents", latents)
888
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
889
+
890
+ # call the callback, if provided
891
+ if i == len(timesteps) - 1 or (
892
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
893
+ ):
894
+ progress_bar.update()
895
+
896
+ if XLA_AVAILABLE:
897
+ xm.mark_step()
898
+
899
+ if output_type == "latent":
900
+ image = latents
901
+
902
+ else:
903
+ latents = self._unpack_latents(
904
+ latents, height, width, self.vae_scale_factor
905
+ )
906
+ latents = (
907
+ latents / self.vae.config.scaling_factor
908
+ ) + self.vae.config.shift_factor
909
+
910
+ image = self.vae.decode(
911
+ latents.to(device=self.vae.device, dtype=self.vae.dtype),
912
+ return_dict=False,
913
+ )[0]
914
+ image = self.image_processor.postprocess(image, output_type=output_type)
915
+
916
+ # Offload all models
917
+ self.maybe_free_model_hooks()
918
+
919
+ if not return_dict:
920
+ return (image,)
921
+
922
+ return FluxPipelineOutput(images=image)
923
+
924
+
925
+ from dataclasses import dataclass
926
+ from typing import List, Union
927
+ import PIL.Image
928
+ from diffusers.utils import BaseOutput
929
+
930
+
931
+ @dataclass
932
+ class FluxPipelineOutput(BaseOutput):
933
+ """
934
+ Output class for Stable Diffusion pipelines.
935
+
936
+ Args:
937
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
938
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
939
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
940
+ """
941
+
942
+ images: Union[List[PIL.Image.Image], np.ndarray]