OzzyGT HF staff commited on
Commit
337a4ef
1 Parent(s): 8763f8f

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +639 -0
pipeline.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
21
+ from diffusers.loaders import (
22
+ FromSingleFileMixin,
23
+ IPAdapterMixin,
24
+ StableDiffusionXLLoraLoaderMixin,
25
+ TextualInversionLoaderMixin,
26
+ )
27
+ from diffusers.models import (
28
+ AutoencoderKL,
29
+ ControlNetModel,
30
+ ImageProjection,
31
+ UNet2DConditionModel,
32
+ )
33
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
34
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
35
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
36
+ StableDiffusionXLPipelineOutput,
37
+ )
38
+ from diffusers.schedulers import KarrasDiffusionSchedulers
39
+ from diffusers.utils.torch_utils import (
40
+ is_compiled_module,
41
+ is_torch_version,
42
+ randn_tensor,
43
+ )
44
+ from transformers import (
45
+ CLIPImageProcessor,
46
+ CLIPTextModel,
47
+ CLIPTextModelWithProjection,
48
+ CLIPTokenizer,
49
+ CLIPVisionModelWithProjection,
50
+ )
51
+
52
+
53
+ def retrieve_timesteps(
54
+ scheduler,
55
+ num_inference_steps: Optional[int] = None,
56
+ device: Optional[Union[str, torch.device]] = None,
57
+ **kwargs,
58
+ ):
59
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
60
+ timesteps = scheduler.timesteps
61
+
62
+ return timesteps, num_inference_steps
63
+
64
+
65
+ class StableDiffusionXLRecolorPipeline(
66
+ DiffusionPipeline,
67
+ StableDiffusionMixin,
68
+ TextualInversionLoaderMixin,
69
+ StableDiffusionXLLoraLoaderMixin,
70
+ IPAdapterMixin,
71
+ FromSingleFileMixin,
72
+ ):
73
+ # leave controlnet out on purpose because it iterates with unet
74
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
75
+ _optional_components = [
76
+ "tokenizer",
77
+ "tokenizer_2",
78
+ "text_encoder",
79
+ "text_encoder_2",
80
+ "feature_extractor",
81
+ "image_encoder",
82
+ ]
83
+ _callback_tensor_inputs = [
84
+ "latents",
85
+ "prompt_embeds",
86
+ "negative_prompt_embeds",
87
+ "add_text_embeds",
88
+ "add_time_ids",
89
+ "negative_pooled_prompt_embeds",
90
+ "negative_add_time_ids",
91
+ ]
92
+
93
+ def __init__(
94
+ self,
95
+ vae: AutoencoderKL,
96
+ text_encoder: CLIPTextModel,
97
+ text_encoder_2: CLIPTextModelWithProjection,
98
+ tokenizer: CLIPTokenizer,
99
+ tokenizer_2: CLIPTokenizer,
100
+ unet: UNet2DConditionModel,
101
+ controlnet: Union[
102
+ ControlNetModel,
103
+ List[ControlNetModel],
104
+ Tuple[ControlNetModel],
105
+ MultiControlNetModel,
106
+ ],
107
+ scheduler: KarrasDiffusionSchedulers,
108
+ force_zeros_for_empty_prompt: bool = True,
109
+ add_watermarker: Optional[bool] = None,
110
+ feature_extractor: CLIPImageProcessor = None,
111
+ image_encoder: CLIPVisionModelWithProjection = None,
112
+ ):
113
+ super().__init__()
114
+
115
+ if isinstance(controlnet, (list, tuple)):
116
+ controlnet = MultiControlNetModel(controlnet)
117
+
118
+ self.register_modules(
119
+ vae=vae,
120
+ text_encoder=text_encoder,
121
+ text_encoder_2=text_encoder_2,
122
+ tokenizer=tokenizer,
123
+ tokenizer_2=tokenizer_2,
124
+ unet=unet,
125
+ controlnet=controlnet,
126
+ scheduler=scheduler,
127
+ feature_extractor=feature_extractor,
128
+ image_encoder=image_encoder,
129
+ )
130
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
131
+ self.image_processor = VaeImageProcessor(
132
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
133
+ )
134
+ self.control_image_processor = VaeImageProcessor(
135
+ vae_scale_factor=self.vae_scale_factor,
136
+ do_convert_rgb=True,
137
+ do_normalize=False,
138
+ )
139
+ self.register_to_config(
140
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
141
+ )
142
+
143
+ def encode_prompt(
144
+ self,
145
+ prompt: str,
146
+ negative_prompt: Optional[str] = None,
147
+ device: Optional[torch.device] = None,
148
+ do_classifier_free_guidance: bool = True,
149
+ ):
150
+ device = device or self._execution_device
151
+ prompt = [prompt] if isinstance(prompt, str) else prompt
152
+
153
+ if prompt is not None:
154
+ batch_size = len(prompt)
155
+
156
+ # Define tokenizers and text encoders
157
+ tokenizers = (
158
+ [self.tokenizer, self.tokenizer_2]
159
+ if self.tokenizer is not None
160
+ else [self.tokenizer_2]
161
+ )
162
+ text_encoders = (
163
+ [self.text_encoder, self.text_encoder_2]
164
+ if self.text_encoder is not None
165
+ else [self.text_encoder_2]
166
+ )
167
+
168
+ prompt_2 = prompt
169
+
170
+ # textual inversion: process multi-vector tokens if necessary
171
+ prompt_embeds_list = []
172
+ prompts = [prompt, prompt_2]
173
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
174
+ text_inputs = tokenizer(
175
+ prompt,
176
+ padding="max_length",
177
+ max_length=tokenizer.model_max_length,
178
+ truncation=True,
179
+ return_tensors="pt",
180
+ )
181
+
182
+ text_input_ids = text_inputs.input_ids
183
+
184
+ prompt_embeds = text_encoder(
185
+ text_input_ids.to(device), output_hidden_states=True
186
+ )
187
+
188
+ # We are only ALWAYS interested in the pooled output of the final text encoder
189
+ pooled_prompt_embeds = prompt_embeds[0]
190
+ prompt_embeds = prompt_embeds.hidden_states[-2]
191
+ prompt_embeds_list.append(prompt_embeds)
192
+
193
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
194
+
195
+ # get unconditional embeddings for classifier free guidance
196
+ negative_prompt_embeds = None
197
+ negative_pooled_prompt_embeds = None
198
+
199
+ if do_classifier_free_guidance:
200
+ negative_prompt = negative_prompt or ""
201
+
202
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
203
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
204
+
205
+ # normalize str to list
206
+ negative_prompt = [negative_prompt]
207
+ negative_prompt_2 = negative_prompt
208
+
209
+ uncond_tokens: List[str]
210
+ uncond_tokens = [negative_prompt, negative_prompt_2]
211
+
212
+ negative_prompt_embeds_list = []
213
+ for negative_prompt, tokenizer, text_encoder in zip(
214
+ uncond_tokens, tokenizers, text_encoders
215
+ ):
216
+ max_length = prompt_embeds.shape[1]
217
+ uncond_input = tokenizer(
218
+ negative_prompt,
219
+ padding="max_length",
220
+ max_length=max_length,
221
+ truncation=True,
222
+ return_tensors="pt",
223
+ )
224
+
225
+ negative_prompt_embeds = text_encoder(
226
+ uncond_input.input_ids.to(device),
227
+ output_hidden_states=True,
228
+ )
229
+ # We are only ALWAYS interested in the pooled output of the final text encoder
230
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
231
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
232
+
233
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
234
+
235
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
236
+
237
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
238
+
239
+ bs_embed, seq_len, _ = prompt_embeds.shape
240
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
241
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
242
+
243
+ if do_classifier_free_guidance:
244
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
245
+ seq_len = negative_prompt_embeds.shape[1]
246
+
247
+ negative_prompt_embeds = negative_prompt_embeds.to(
248
+ dtype=self.text_encoder_2.dtype, device=device
249
+ )
250
+
251
+ negative_prompt_embeds = negative_prompt_embeds.view(
252
+ batch_size, seq_len, -1
253
+ )
254
+
255
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
256
+
257
+ if do_classifier_free_guidance:
258
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
259
+ bs_embed, -1
260
+ )
261
+
262
+ return (
263
+ prompt_embeds,
264
+ negative_prompt_embeds,
265
+ pooled_prompt_embeds,
266
+ negative_pooled_prompt_embeds,
267
+ )
268
+
269
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
270
+ def encode_image(
271
+ self, image, device, num_images_per_prompt, output_hidden_states=None
272
+ ):
273
+ dtype = next(self.image_encoder.parameters()).dtype
274
+
275
+ if not isinstance(image, torch.Tensor):
276
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
277
+
278
+ image = image.to(device=device, dtype=dtype)
279
+ if output_hidden_states:
280
+ image_enc_hidden_states = self.image_encoder(
281
+ image, output_hidden_states=True
282
+ ).hidden_states[-2]
283
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
284
+ num_images_per_prompt, dim=0
285
+ )
286
+ uncond_image_enc_hidden_states = self.image_encoder(
287
+ torch.zeros_like(image), output_hidden_states=True
288
+ ).hidden_states[-2]
289
+ uncond_image_enc_hidden_states = (
290
+ uncond_image_enc_hidden_states.repeat_interleave(
291
+ num_images_per_prompt, dim=0
292
+ )
293
+ )
294
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
295
+ else:
296
+ image_embeds = self.image_encoder(image).image_embeds
297
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
298
+ uncond_image_embeds = torch.zeros_like(image_embeds)
299
+
300
+ return image_embeds, uncond_image_embeds
301
+
302
+ def prepare_ip_adapter_image_embeds(
303
+ self,
304
+ ip_adapter_image,
305
+ device,
306
+ do_classifier_free_guidance,
307
+ ):
308
+ image_embeds = []
309
+ if do_classifier_free_guidance:
310
+ negative_image_embeds = []
311
+
312
+ if not isinstance(ip_adapter_image, list):
313
+ ip_adapter_image = [ip_adapter_image]
314
+
315
+ if len(ip_adapter_image) != len(
316
+ self.unet.encoder_hid_proj.image_projection_layers
317
+ ):
318
+ raise ValueError(
319
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
320
+ )
321
+
322
+ for single_ip_adapter_image, image_proj_layer in zip(
323
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
324
+ ):
325
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
326
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
327
+ single_ip_adapter_image, device, 1, output_hidden_state
328
+ )
329
+
330
+ image_embeds.append(single_image_embeds[None, :])
331
+ if do_classifier_free_guidance:
332
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
333
+
334
+ ip_adapter_image_embeds = []
335
+
336
+ for i, single_image_embeds in enumerate(image_embeds):
337
+ if do_classifier_free_guidance:
338
+ single_image_embeds = torch.cat(
339
+ [negative_image_embeds[i], single_image_embeds], dim=0
340
+ )
341
+
342
+ single_image_embeds = single_image_embeds.to(device=device)
343
+ ip_adapter_image_embeds.append(single_image_embeds)
344
+
345
+ return ip_adapter_image_embeds
346
+
347
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
348
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
349
+
350
+ image_batch_size = image.shape[0]
351
+
352
+ image = image.repeat_interleave(image_batch_size, dim=0)
353
+ image = image.to(device=device, dtype=dtype)
354
+
355
+ if do_classifier_free_guidance:
356
+ image = torch.cat([image] * 2)
357
+
358
+ return image
359
+
360
+ def prepare_latents(
361
+ self, batch_size, num_channels_latents, height, width, dtype, device
362
+ ):
363
+ shape = (
364
+ batch_size,
365
+ num_channels_latents,
366
+ int(height) // self.vae_scale_factor,
367
+ int(width) // self.vae_scale_factor,
368
+ )
369
+
370
+ latents = randn_tensor(shape, device=device, dtype=dtype)
371
+
372
+ # scale the initial noise by the standard deviation required by the scheduler
373
+ latents = latents * self.scheduler.init_noise_sigma
374
+ return latents
375
+
376
+ @property
377
+ def guidance_scale(self):
378
+ return self._guidance_scale
379
+
380
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
381
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
382
+ # corresponds to doing no classifier free guidance.
383
+ @property
384
+ def do_classifier_free_guidance(self):
385
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
386
+
387
+ @property
388
+ def denoising_end(self):
389
+ return self._denoising_end
390
+
391
+ @property
392
+ def num_timesteps(self):
393
+ return self._num_timesteps
394
+
395
+ @torch.no_grad()
396
+ def __call__(
397
+ self,
398
+ image: PipelineImageInput = None,
399
+ num_inference_steps: int = 8,
400
+ guidance_scale: float = 2.0,
401
+ prompt_embeds: Optional[torch.Tensor] = None,
402
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
403
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
404
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
405
+ ip_adapter_image: Optional[PipelineImageInput] = None,
406
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
407
+ control_guidance_start: Union[float, List[float]] = 0.0,
408
+ control_guidance_end: Union[float, List[float]] = 1.0,
409
+ **kwargs,
410
+ ):
411
+ controlnet = self.controlnet
412
+
413
+ # align format for control guidance
414
+ if not isinstance(control_guidance_start, list) and isinstance(
415
+ control_guidance_end, list
416
+ ):
417
+ control_guidance_start = len(control_guidance_end) * [
418
+ control_guidance_start
419
+ ]
420
+ elif not isinstance(control_guidance_end, list) and isinstance(
421
+ control_guidance_start, list
422
+ ):
423
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
424
+ elif not isinstance(control_guidance_start, list) and not isinstance(
425
+ control_guidance_end, list
426
+ ):
427
+ mult = (
428
+ len(controlnet.nets)
429
+ if isinstance(controlnet, MultiControlNetModel)
430
+ else 1
431
+ )
432
+ control_guidance_start, control_guidance_end = (
433
+ mult * [control_guidance_start],
434
+ mult * [control_guidance_end],
435
+ )
436
+
437
+ self._guidance_scale = guidance_scale
438
+
439
+ # 2. Define call parameters
440
+ batch_size = 1
441
+ device = self._execution_device
442
+
443
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
444
+ controlnet_conditioning_scale, float
445
+ ):
446
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
447
+ controlnet.nets
448
+ )
449
+
450
+ # 3.2 Encode ip_adapter_image
451
+ if ip_adapter_image is not None:
452
+ image_embeds = self.prepare_ip_adapter_image_embeds(
453
+ ip_adapter_image,
454
+ device,
455
+ self.do_classifier_free_guidance,
456
+ )
457
+
458
+ # 4. Prepare image
459
+ if isinstance(controlnet, ControlNetModel):
460
+ image = self.prepare_image(
461
+ image=image,
462
+ device=device,
463
+ dtype=controlnet.dtype,
464
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
465
+ )
466
+ height, width = image.shape[-2:]
467
+ elif isinstance(controlnet, MultiControlNetModel):
468
+ images = []
469
+
470
+ for image_ in image:
471
+ image_ = self.prepare_image(
472
+ image=image_,
473
+ device=device,
474
+ dtype=controlnet.dtype,
475
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
476
+ )
477
+
478
+ images.append(image_)
479
+
480
+ image = images
481
+ height, width = image[0].shape[-2:]
482
+ else:
483
+ assert False
484
+
485
+ # 5. Prepare timesteps
486
+ timesteps, num_inference_steps = retrieve_timesteps(
487
+ self.scheduler, num_inference_steps, device
488
+ )
489
+ self._num_timesteps = len(timesteps)
490
+
491
+ # 6. Prepare latent variables
492
+ num_channels_latents = self.unet.config.in_channels
493
+ latents = self.prepare_latents(
494
+ batch_size,
495
+ num_channels_latents,
496
+ height,
497
+ width,
498
+ prompt_embeds.dtype,
499
+ device,
500
+ )
501
+
502
+ # 7.1 Create tensor stating which controlnets to keep
503
+ controlnet_keep = []
504
+ for i in range(len(timesteps)):
505
+ keeps = [
506
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
507
+ for s, e in zip(control_guidance_start, control_guidance_end)
508
+ ]
509
+ controlnet_keep.append(
510
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
511
+ )
512
+
513
+ # 7.2 Prepare added time ids & embeddings
514
+ add_text_embeds = pooled_prompt_embeds
515
+
516
+ add_time_ids = negative_add_time_ids = torch.tensor(
517
+ image[0].shape[-2:] + torch.Size([0, 0]) + image[0].shape[-2:]
518
+ ).unsqueeze(0)
519
+
520
+ negative_add_time_ids = add_time_ids
521
+
522
+ if self.do_classifier_free_guidance:
523
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
524
+ add_text_embeds = torch.cat(
525
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
526
+ )
527
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
528
+
529
+ prompt_embeds = prompt_embeds.to(device)
530
+ add_text_embeds = add_text_embeds.to(device)
531
+ add_time_ids = add_time_ids.to(device)
532
+
533
+ added_cond_kwargs = {
534
+ "text_embeds": add_text_embeds,
535
+ "time_ids": add_time_ids,
536
+ }
537
+
538
+ # 8. Denoising loop
539
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
540
+
541
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
542
+ for i, t in enumerate(timesteps):
543
+ # expand the latents if we are doing classifier free guidance
544
+ latent_model_input = (
545
+ torch.cat([latents] * 2)
546
+ if self.do_classifier_free_guidance
547
+ else latents
548
+ )
549
+ latent_model_input = self.scheduler.scale_model_input(
550
+ latent_model_input, t
551
+ )
552
+
553
+ # controlnet(s) inference
554
+ control_model_input = latent_model_input
555
+ controlnet_prompt_embeds = prompt_embeds
556
+ controlnet_added_cond_kwargs = added_cond_kwargs
557
+
558
+ if isinstance(controlnet_keep[i], list):
559
+ cond_scale = [
560
+ c * s
561
+ for c, s in zip(
562
+ controlnet_conditioning_scale, controlnet_keep[i]
563
+ )
564
+ ]
565
+ else:
566
+ controlnet_cond_scale = controlnet_conditioning_scale
567
+ if isinstance(controlnet_cond_scale, list):
568
+ controlnet_cond_scale = controlnet_cond_scale[0]
569
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
570
+
571
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
572
+ control_model_input,
573
+ t,
574
+ encoder_hidden_states=controlnet_prompt_embeds,
575
+ controlnet_cond=image,
576
+ conditioning_scale=cond_scale,
577
+ guess_mode=False,
578
+ added_cond_kwargs=controlnet_added_cond_kwargs,
579
+ return_dict=False,
580
+ )
581
+
582
+ if ip_adapter_image is not None:
583
+ added_cond_kwargs["image_embeds"] = image_embeds
584
+
585
+ # predict the noise residual
586
+ noise_pred = self.unet(
587
+ latent_model_input,
588
+ t,
589
+ encoder_hidden_states=prompt_embeds,
590
+ timestep_cond=None,
591
+ cross_attention_kwargs={},
592
+ down_block_additional_residuals=down_block_res_samples,
593
+ mid_block_additional_residual=mid_block_res_sample,
594
+ added_cond_kwargs=added_cond_kwargs,
595
+ return_dict=False,
596
+ )[0]
597
+
598
+ # perform guidance
599
+ if self.do_classifier_free_guidance:
600
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
601
+ noise_pred = noise_pred_uncond + guidance_scale * (
602
+ noise_pred_text - noise_pred_uncond
603
+ )
604
+
605
+ # compute the previous noisy sample x_t -> x_t-1
606
+ latents = self.scheduler.step(
607
+ noise_pred, t, latents, return_dict=False
608
+ )[0]
609
+
610
+ if i == 2:
611
+ prompt_embeds = prompt_embeds[-1:]
612
+ add_text_embeds = add_text_embeds[-1:]
613
+ add_time_ids = add_time_ids[-1:]
614
+
615
+ added_cond_kwargs = {
616
+ "text_embeds": add_text_embeds,
617
+ "time_ids": add_time_ids,
618
+ }
619
+
620
+ controlnet_prompt_embeds = prompt_embeds
621
+ controlnet_added_cond_kwargs = added_cond_kwargs
622
+
623
+ image = [single_image[-1:] for single_image in image]
624
+ self._guidance_scale = 0.0
625
+
626
+ # call the callback, if provided
627
+ if i == len(timesteps) - 1 or (
628
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
629
+ ):
630
+ progress_bar.update()
631
+
632
+ latents = latents / self.vae.config.scaling_factor
633
+ image = self.vae.decode(latents, return_dict=False)[0]
634
+ image = self.image_processor.postprocess(image)[0]
635
+
636
+ # Offload all models
637
+ self.maybe_free_model_hooks()
638
+
639
+ return StableDiffusionXLPipelineOutput(images=image)