rashidvyro commited on
Commit
aadb993
·
1 Parent(s): 3e442e7

Delete stable_diffusion_custom_v4_1.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_custom_v4_1.py +0 -795
stable_diffusion_custom_v4_1.py DELETED
@@ -1,795 +0,0 @@
1
- import random
2
- from diffusers import StableDiffusionPipeline
3
- # from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
4
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput, AutoencoderKL, CLIPTextModel, CLIPTokenizer, UNet2DConditionModel, KarrasDiffusionSchedulers, StableDiffusionSafetyChecker, CLIPImageProcessor
5
- from compel import Compel
6
- from onediff.utils.tokenizer import TextualInversionLoaderMixin, MultiTokenCLIPTokenizer
7
- import torch
8
- from typing import Any, Callable, Dict, List, Optional, Union
9
- from dynamicprompts.generators import RandomPromptGenerator
10
- import time
11
- from compel import Compel
12
- from onediff.utils.prompt_parser import ScheduledPromptConditioning
13
- from onediff.utils.prompt_parser import get_learned_conditioning_prompt_schedules
14
- from dynamicprompts.generators import RandomPromptGenerator
15
- import tqdm
16
- from cachetools import LRUCache
17
- from onediff.utils.image_processor import VaeImageProcessor
18
-
19
-
20
- class CustomStableDiffusionPipeline4_1(TextualInversionLoaderMixin, StableDiffusionPipeline):
21
- def __init__(
22
- self,
23
- vae: AutoencoderKL,
24
- text_encoder: CLIPTextModel,
25
- tokenizer: CLIPTokenizer,
26
- unet: UNet2DConditionModel,
27
- scheduler: KarrasDiffusionSchedulers,
28
- safety_checker: StableDiffusionSafetyChecker,
29
- feature_extractor: CLIPImageProcessor,
30
- requires_safety_checker: bool = True,
31
- prompt_cache_size: int = 1024,
32
- prompt_cache_ttl: int = 60 * 2,
33
- ) -> None:
34
- super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler,
35
- safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker)
36
-
37
- self.vae_scale_factor = 2 ** (
38
- len(self.vae.config.block_out_channels) - 1)
39
- self.image_processor = VaeImageProcessor(
40
- vae_scale_factor=self.vae_scale_factor)
41
- self.register_to_config(
42
- requires_safety_checker=requires_safety_checker)
43
-
44
- self.compel = Compel(tokenizer=self.tokenizer,
45
- text_encoder=self.text_encoder, truncate_long_prompts=False)
46
- self.cache = LRUCache(maxsize=prompt_cache_size)
47
-
48
- self.cached_uc = [None, None]
49
- self.cached_c = [None, None]
50
-
51
- self.prompt_handler = None
52
-
53
- def build_scheduled_cond(self, prompt, steps, key):
54
- prompt_schedule = get_learned_conditioning_prompt_schedules([prompt], steps)[
55
- 0]
56
-
57
- cached = self.cache.get(key, None)
58
- if cached is not None:
59
- return cached
60
-
61
- texts = [x[1] for x in prompt_schedule]
62
- conds = [self.compel.build_conditioning_tensor(
63
- text).to('cpu') for text in texts]
64
-
65
- cond_schedule = []
66
- for i, s in enumerate(prompt_schedule):
67
- cond_schedule.append(ScheduledPromptConditioning(s[0], conds[i]))
68
-
69
- self.cache[key] = cond_schedule
70
- return cond_schedule
71
-
72
- def initialize_magic_prompt_cache(self, pos_prompt_template: str, plain_prompt_template: str, neg_prompt_template: str, num_to_generate: int, steps: int):
73
- r"""
74
- Initializes the magic prompt cache for the forward pass.
75
- Must be called immedaitely after Compel is loaded and embeds are initalized.
76
- """
77
- rpg = RandomPromptGenerator(ignore_whitespace=True, seed=555)
78
- positive_prompts = rpg.generate(
79
- template=pos_prompt_template, num_images=num_to_generate)
80
- scheduled_conds = []
81
- with torch.no_grad():
82
- cache = {}
83
- for i in tqdm.tqdm(range(len(positive_prompts))):
84
- scheduled_conds.append(self.build_scheduled_cond(
85
- positive_prompts[i], steps, cache))
86
-
87
- plain_scheduled_cond = self.build_scheduled_cond(
88
- plain_prompt_template, steps, cache)
89
-
90
- scheduled_uncond = self.build_scheduled_cond(
91
- neg_prompt_template, steps, cache)
92
-
93
- self.scheduled_conds = scheduled_conds
94
- self.plain_scheduled_cond = plain_scheduled_cond
95
- self.scheduled_uncond = scheduled_uncond
96
-
97
- def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
98
- r"""
99
- Encodes the prompt into text encoder hidden states.
100
-
101
- Args:
102
- prompt (`str` or `list(int)`):
103
- prompt to be encoded
104
- device: (`torch.device`):
105
- torch device
106
- num_images_per_prompt (`int`):
107
- number of images that should be generated per prompt
108
- do_classifier_free_guidance (`bool`):
109
- whether to use classifier free guidance or not
110
- negative_prompt (`str` or `List[str]`):
111
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
112
- if `guidance_scale` is less than `1`).
113
- """
114
- batch_size = len(prompt) if isinstance(prompt, list) else 1
115
-
116
- text_inputs = self.tokenizer(
117
- prompt,
118
- padding="max_length",
119
- max_length=self.tokenizer.model_max_length,
120
- truncation=True,
121
- return_tensors="np",
122
- )
123
- text_input_ids = text_inputs.input_ids
124
- text_input_ids = torch.from_numpy(text_input_ids)
125
- untruncated_ids = self.tokenizer(
126
- prompt, padding="max_length", return_tensors="np").input_ids
127
- untruncated_ids = torch.from_numpy(untruncated_ids)
128
-
129
- if (
130
- text_input_ids.shape == untruncated_ids.shape
131
- and text_input_ids.numel() == untruncated_ids.numel()
132
- and not torch.equal(text_input_ids, untruncated_ids)
133
- ):
134
- removed_text = self.tokenizer.batch_decode(
135
- untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
136
- logger.warning(
137
- "The following part of your input was truncated because CLIP can only handle sequences up to"
138
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
139
- )
140
-
141
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
142
- attention_mask = text_inputs.attention_mask.to(device)
143
- else:
144
- attention_mask = None
145
-
146
- text_embeddings = self.text_encoder(
147
- text_input_ids.to(device), attention_mask=attention_mask)
148
- text_embeddings = text_embeddings[0]
149
-
150
- # duplicate text embeddings for each generation per prompt, using mps friendly method
151
- bs_embed, seq_len, _ = text_embeddings.shape
152
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
153
- text_embeddings = text_embeddings.view(
154
- bs_embed * num_images_per_prompt, seq_len, -1)
155
-
156
- # get unconditional embeddings for classifier free guidance
157
- if do_classifier_free_guidance:
158
- uncond_tokens: List[str]
159
- if negative_prompt is None:
160
- uncond_tokens = [""] * batch_size
161
- elif type(prompt) is not type(negative_prompt):
162
- raise TypeError(
163
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
164
- f" {type(prompt)}."
165
- )
166
- elif isinstance(negative_prompt, str):
167
- uncond_tokens = [negative_prompt]
168
- elif batch_size != len(negative_prompt):
169
- raise ValueError(
170
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
171
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
172
- " the batch size of `prompt`."
173
- )
174
- else:
175
- uncond_tokens = negative_prompt
176
-
177
- max_length = text_input_ids.shape[-1]
178
- uncond_input = self.tokenizer(
179
- uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np",
180
- )
181
-
182
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
183
- attention_mask = torch.from_numpy(
184
- uncond_input.attention_mask).to(device)
185
- else:
186
- attention_mask = None
187
-
188
- uncond_embeddings = self.text_encoder(
189
- torch.from_numpy(uncond_input.input_ids).to(device), attention_mask=attention_mask,
190
- )
191
- uncond_embeddings = uncond_embeddings[0]
192
-
193
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
194
- seq_len = uncond_embeddings.shape[1]
195
- uncond_embeddings = uncond_embeddings.repeat(
196
- 1, num_images_per_prompt, 1)
197
- uncond_embeddings = uncond_embeddings.view(
198
- batch_size * num_images_per_prompt, seq_len, -1)
199
-
200
- # For classifier free guidance, we need to do two forward passes.
201
- # Here we concatenate the unconditional and text embeddings into a single batch
202
- # to avoid doing two forward passes
203
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
204
-
205
- return text_embeddings
206
-
207
- def _encode_promptv2(
208
- self,
209
- prompt,
210
- device,
211
- num_images_per_prompt,
212
- do_classifier_free_guidance,
213
- negative_prompt=None,
214
- prompt_embeds: Optional[torch.FloatTensor] = None,
215
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
216
- ):
217
-
218
- if prompt is not None and isinstance(prompt, str):
219
- batch_size = 1
220
- elif prompt is not None and isinstance(prompt, list):
221
- batch_size = len(prompt)
222
- else:
223
- batch_size = prompt_embeds.shape[0]
224
-
225
- if prompt_embeds is None:
226
- text_inputs = self.tokenizer(
227
- prompt,
228
- padding="max_length",
229
- max_length=self.tokenizer.model_max_length,
230
- truncation=True,
231
- return_tensors="pt",
232
- )
233
- text_input_ids = text_inputs.input_ids
234
- untruncated_ids = self.tokenizer(
235
- prompt, padding="longest", return_tensors="pt").input_ids
236
-
237
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
238
- text_input_ids, untruncated_ids
239
- ):
240
- removed_text = self.tokenizer.batch_decode(
241
- untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
242
- )
243
-
244
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
245
- attention_mask = text_inputs.attention_mask.to(device)
246
- else:
247
- attention_mask = None
248
-
249
- prompt_embeds = self.text_encoder(
250
- text_input_ids.to(device),
251
- attention_mask=attention_mask,
252
- )
253
- prompt_embeds = prompt_embeds[0]
254
-
255
- prompt_embeds = prompt_embeds.to(
256
- dtype=self.text_encoder.dtype, device=device)
257
-
258
- bs_embed, seq_len, _ = prompt_embeds.shape
259
- # duplicate text embeddings for each generation per prompt, using mps friendly method
260
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
261
- prompt_embeds = prompt_embeds.view(
262
- bs_embed * num_images_per_prompt, seq_len, -1)
263
-
264
- # get unconditional embeddings for classifier free guidance
265
- if do_classifier_free_guidance and negative_prompt_embeds is None:
266
- uncond_tokens: List[str]
267
- if negative_prompt is None:
268
- uncond_tokens = [""] * batch_size
269
- elif type(prompt) is not type(negative_prompt):
270
- raise TypeError(
271
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
272
- f" {type(prompt)}."
273
- )
274
- elif isinstance(negative_prompt, str):
275
- uncond_tokens = [negative_prompt]
276
- elif batch_size != len(negative_prompt):
277
- raise ValueError(
278
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
279
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
280
- " the batch size of `prompt`."
281
- )
282
- else:
283
- uncond_tokens = negative_prompt
284
-
285
- max_length = prompt_embeds.shape[1]
286
- uncond_input = self.tokenizer(
287
- uncond_tokens,
288
- padding="max_length",
289
- max_length=max_length,
290
- truncation=True,
291
- return_tensors="pt",
292
- )
293
-
294
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
295
- attention_mask = uncond_input.attention_mask.to(device)
296
- else:
297
- attention_mask = None
298
-
299
- negative_prompt_embeds = self.text_encoder(
300
- uncond_input.input_ids.to(device),
301
- attention_mask=attention_mask,
302
- )
303
- negative_prompt_embeds = negative_prompt_embeds[0]
304
-
305
- if do_classifier_free_guidance:
306
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
307
- seq_len = negative_prompt_embeds.shape[1]
308
-
309
- negative_prompt_embeds = negative_prompt_embeds.to(
310
- dtype=self.text_encoder.dtype, device=device)
311
-
312
- negative_prompt_embeds = negative_prompt_embeds.repeat(
313
- 1, num_images_per_prompt, 1)
314
- negative_prompt_embeds = negative_prompt_embeds.view(
315
- batch_size * num_images_per_prompt, seq_len, -1)
316
-
317
- negative_prompt_embeds, prompt_embeds = self.compel.pad_conditioning_tensors_to_same_length(
318
- [negative_prompt_embeds, prompt_embeds])
319
- # For classifier free guidance, we need to do two forward passes.
320
- # Here we concatenate the unconditional and text embeddings into a single batch
321
- # to avoid doing two forward passes
322
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
323
-
324
- return prompt_embeds
325
-
326
- def _pyramid_noise_like(self, noise, device, seed, iterations=6, discount=0.4):
327
- gen = torch.manual_seed(seed)
328
- # EDIT: w and h get over-written, rename for a different variant!
329
- b, c, w, h = noise.shape
330
- u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
331
- for i in range(iterations):
332
- r = random.random() * 2 + 2 # Rather than always going 2x,
333
- wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
334
- noise += u(torch.randn(b, c, wn, hn,
335
- generator=gen).to(device)) * discount**i
336
- if wn == 1 or hn == 1:
337
- break # Lowest resolution is 1x1
338
- return noise / noise.std() # Scaled back to roughly unit variance
339
-
340
- @torch.no_grad()
341
- def inferV4(
342
- self,
343
- prompt: Union[str, List[str]],
344
- height: Optional[int] = None,
345
- width: Optional[int] = None,
346
- num_inference_steps: int = 50,
347
- guidance_scale: float = 7.5,
348
- negative_prompt: Optional[Union[str, List[str]]] = None,
349
- num_images_per_prompt: Optional[int] = 1,
350
- eta: float = 0.0,
351
- generator: Optional[torch.Generator] = None,
352
- latents: Optional[torch.FloatTensor] = None,
353
- output_type: Optional[str] = "pil",
354
- return_dict: bool = True,
355
- callback: Optional[Callable[[
356
- int, int, torch.FloatTensor], None]] = None,
357
- callback_steps: Optional[int] = 1,
358
- compile_unet: bool = True,
359
- compile_vae: bool = True,
360
- compile_tenc: bool = True,
361
- max_tokens=0,
362
- seed=-1,
363
- flags=[],
364
- og_prompt=None,
365
- og_neg_prompt=None,
366
- disc=0.4,
367
- iter=6,
368
- pyramid=0, # disabled by default unless specified
369
- ):
370
- r"""
371
- Function invoked when calling the pipeline for generation.
372
-
373
- Args:
374
- prompt (`str` or `List[str]`):
375
- The prompt or prompts to guide the image generation.
376
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
377
- The height in pixels of the generated image.
378
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
379
- The width in pixels of the generated image.
380
- num_inference_steps (`int`, *optional*, defaults to 50):
381
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
382
- expense of slower inference.
383
- guidance_scale (`float`, *optional*, defaults to 7.5):
384
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
385
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
386
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
387
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
388
- usually at the expense of lower image quality.
389
- negative_prompt (`str` or `List[str]`, *optional*):
390
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
391
- if `guidance_scale` is less than `1`).
392
- num_images_per_prompt (`int`, *optional*, defaults to 1):
393
- The number of images to generate per prompt.
394
- eta (`float`, *optional*, defaults to 0.0):
395
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
396
- [`schedulers.DDIMScheduler`], will be ignored for others.
397
- generator (`torch.Generator`, *optional*):
398
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
399
- deterministic.
400
- latents (`torch.FloatTensor`, *optional*):
401
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
402
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
403
- tensor will ge generated by sampling using the supplied random `generator`.
404
- output_type (`str`, *optional*, defaults to `"pil"`):
405
- The output format of the generate image. Choose between
406
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
407
- return_dict (`bool`, *optional*, defaults to `True`):
408
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
409
- plain tuple.
410
- callback (`Callable`, *optional*):
411
- A function that will be called every `callback_steps` steps during inference. The function will be
412
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
413
- callback_steps (`int`, *optional*, defaults to 1):
414
- The frequency at which the `callback` function will be called. If not specified, the callback will be
415
- called at every step.
416
-
417
- Returns:
418
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
419
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
420
- When returning a tuple, the first element is a list with the generated images, and the second element is a
421
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
422
- (nsfw) content, according to the `safety_checker`.
423
- """
424
- # 0. Default height and width to unet
425
-
426
- height = height or self.unet.config.sample_size * self.vae_scale_factor
427
- width = width or self.unet.config.sample_size * self.vae_scale_factor
428
-
429
- self.check_inputs(prompt, height, width, callback_steps)
430
- if negative_prompt == None:
431
- negative_prompt = ['']
432
- # 2. Define call parameters
433
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
434
- device = self._execution_device
435
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
436
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
437
- # corresponds to doing no classifier free guidance.
438
- do_classifier_free_guidance = guidance_scale > 1.0
439
-
440
- # # 3. Encode input prompt
441
-
442
- self.scheduler.set_timesteps(num_inference_steps, device=device)
443
- timesteps = self.scheduler.timesteps
444
-
445
- # Cache key for flags
446
- plain = "plain" in flags
447
- flair = None
448
- for flag in flags:
449
- if "flair" in flag:
450
- flair = flag
451
- break
452
-
453
- with torch.no_grad():
454
- c_time = time.time()
455
- user_cond = self.build_scheduled_cond(
456
- prompt[0], num_inference_steps, ('pos', og_prompt, seed, plain, flair))
457
- c_time = time.time()
458
- user_uncond = self.build_scheduled_cond(
459
- negative_prompt[0], num_inference_steps, ('neg', negative_prompt[0], 0))
460
-
461
- c = []
462
- c.extend(user_cond)
463
- uc = []
464
- uc.extend(user_uncond)
465
- max_token_count = 0
466
-
467
- for cond in uc:
468
- if cond.cond.shape[1] > max_token_count:
469
- max_token_count = cond.cond.shape[1]
470
- for cond in c:
471
- if cond.cond.shape[1] > max_token_count:
472
- max_token_count = cond.cond.shape[1]
473
-
474
- def pad_tensor(conditionings: List[ScheduledPromptConditioning], max_token_count: int) -> List[ScheduledPromptConditioning]:
475
-
476
- c0_shape = conditionings[0].cond.shape
477
- if not all([len(c.cond.shape) == len(c0_shape) for c in conditionings]):
478
- raise ValueError(
479
- "Conditioning tensors must all have either 2 dimensions (unbatched) or 3 dimensions (batched)")
480
-
481
- if len(c0_shape) == 2:
482
- # need to be unsqueezed
483
- for c in conditionings:
484
- c.cond = c.cond.unsqueeze(0)
485
- c0_shape = conditionings[0].cond.shape
486
- if len(c0_shape) != 3:
487
- raise ValueError(
488
- f"All conditioning tensors must have the same number of dimensions (2 or 3)")
489
-
490
- if not all([c.cond.shape[0] == c0_shape[0] and c.cond.shape[2] == c0_shape[2] for c in conditionings]):
491
- raise ValueError(
492
- f"All conditioning tensors must have the same batch size ({c0_shape[0]}) and number of embeddings per token ({c0_shape[1]}")
493
-
494
- # if necessary, pad shorter tensors out with an emptystring tensor
495
- empty_z = torch.cat(
496
- [self.compel.build_conditioning_tensor("")] * c0_shape[0])
497
- for i, c in enumerate(conditionings):
498
- cond = c.cond.to(self.device)
499
- while cond.shape[1] < max_token_count:
500
- cond = torch.cat([cond, empty_z], dim=1)
501
- conditionings[i] = ScheduledPromptConditioning(
502
- c.end_at_step, cond)
503
- return conditionings
504
-
505
- uc = pad_tensor(uc, max_token_count)
506
- c = pad_tensor(c, max_token_count)
507
-
508
- next_uc = uc.pop(0)
509
- next_c = c.pop(0)
510
- prompt_embeds = None
511
- new_embeds = True
512
- embed_per_step = []
513
- for i in range(len(timesteps)):
514
- if i > next_uc.end_at_step:
515
- next_uc = uc.pop(0)
516
- new_embeds = True
517
- if i > next_c.end_at_step:
518
- next_c = c.pop(0)
519
- new_embeds = True
520
-
521
- if new_embeds:
522
- negative_prompt_embeds, prompt_embeds = self.compel.pad_conditioning_tensors_to_same_length([
523
- next_uc.cond, next_c.cond])
524
- prompt_embeds = torch.cat(
525
- [negative_prompt_embeds, prompt_embeds])
526
- new_embeds = False
527
-
528
- embed_per_step.append(prompt_embeds)
529
-
530
- # 5. Prepare latent variables
531
- num_channels_latents = self.unet.in_channels
532
- latents = self.prepare_latents(
533
- batch_size * num_images_per_prompt,
534
- num_channels_latents,
535
- height,
536
- width,
537
- prompt_embeds.dtype,
538
- device,
539
- generator,
540
- latents,
541
- )
542
-
543
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
544
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
545
-
546
- # 7. Denoising loop
547
- num_warmup_steps = len(timesteps) - \
548
- num_inference_steps * self.scheduler.order
549
- with self.progress_bar(total=num_inference_steps) as progress_bar:
550
- for i, t in enumerate(timesteps):
551
- # expand the latents if we are doing classifier free guidance
552
- latent_model_input = torch.cat(
553
- [latents] * 2) if do_classifier_free_guidance else latents
554
- latent_model_input = self.scheduler.scale_model_input(
555
- latent_model_input, t)
556
-
557
- prompt_embeds = embed_per_step[i]
558
- # predict the noise residual
559
-
560
- noise_pred = self.unet(
561
- latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
562
-
563
- # perform guidance
564
- if do_classifier_free_guidance:
565
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
566
- noise_pred = noise_pred_uncond + guidance_scale * \
567
- (noise_pred_text - noise_pred_uncond)
568
-
569
- if (i < pyramid*num_inference_steps):
570
- noise_pred = self._pyramid_noise_like(
571
- noise_pred, device, seed, iterations=iter, discount=disc)
572
-
573
- # compute the previous noisy sample x_t -> x_t-1
574
- latents = self.scheduler.step(
575
- noise_pred, t, latents, **extra_step_kwargs).prev_sample
576
-
577
- # call the callback, if provided
578
- if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
579
- progress_bar.update()
580
- if callback is not None and i % callback_steps == 0:
581
- callback(i, t, latents)
582
-
583
- if not output_type == "latent":
584
- image = self.vae.decode(
585
- latents / self.vae.config.scaling_factor, return_dict=False)[0]
586
- image, has_nsfw_concept = self.run_safety_checker(
587
- image, device, prompt_embeds.dtype)
588
- else:
589
- image = latents
590
- has_nsfw_concept = None
591
-
592
- if has_nsfw_concept is None:
593
- do_denormalize = [True] * image.shape[0]
594
- else:
595
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
596
-
597
- image = self.image_processor.postprocess(
598
- image, output_type=output_type, do_denormalize=do_denormalize)
599
-
600
- # Offload last model to CPU
601
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
602
- self.final_offload_hook.offload()
603
-
604
- if not return_dict:
605
- return (image, has_nsfw_concept)
606
-
607
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
608
-
609
- @torch.no_grad()
610
- def inferPipe(
611
- self,
612
- prompt: Union[str, List[str]] = None,
613
- height: Optional[int] = None,
614
- width: Optional[int] = None,
615
- num_inference_steps: int = 50,
616
- guidance_scale: float = 7.5,
617
- negative_prompt: Optional[Union[str, List[str]]] = None,
618
- num_images_per_prompt: Optional[int] = 1,
619
- eta: float = 0.0,
620
- generator: Optional[Union[torch.Generator,
621
- List[torch.Generator]]] = None,
622
- latents: Optional[torch.FloatTensor] = None,
623
- prompt_embeds: Optional[torch.FloatTensor] = None,
624
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
625
- output_type: Optional[str] = "pil",
626
- return_dict: bool = True,
627
- callback: Optional[Callable[[
628
- int, int, torch.FloatTensor], None]] = None,
629
- callback_steps: int = 1,
630
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
631
- ):
632
- r"""
633
- Function invoked when calling the pipeline for generation.
634
-
635
- Args:
636
- prompt (`str` or `List[str]`, *optional*):
637
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
638
- instead.
639
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
640
- The height in pixels of the generated image.
641
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
642
- The width in pixels of the generated image.
643
- num_inference_steps (`int`, *optional*, defaults to 50):
644
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
645
- expense of slower inference.
646
- guidance_scale (`float`, *optional*, defaults to 7.5):
647
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
648
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
649
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
650
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
651
- usually at the expense of lower image quality.
652
- negative_prompt (`str` or `List[str]`, *optional*):
653
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
654
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
655
- less than `1`).
656
- num_images_per_prompt (`int`, *optional*, defaults to 1):
657
- The number of images to generate per prompt.
658
- eta (`float`, *optional*, defaults to 0.0):
659
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
660
- [`schedulers.DDIMScheduler`], will be ignored for others.
661
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
662
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
663
- to make generation deterministic.
664
- latents (`torch.FloatTensor`, *optional*):
665
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
666
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
667
- tensor will ge generated by sampling using the supplied random `generator`.
668
- prompt_embeds (`torch.FloatTensor`, *optional*):
669
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
670
- provided, text embeddings will be generated from `prompt` input argument.
671
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
672
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
673
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
674
- argument.
675
- output_type (`str`, *optional*, defaults to `"pil"`):
676
- The output format of the generate image. Choose between
677
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
678
- return_dict (`bool`, *optional*, defaults to `True`):
679
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
680
- plain tuple.
681
- callback (`Callable`, *optional*):
682
- A function that will be called every `callback_steps` steps during inference. The function will be
683
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
684
- callback_steps (`int`, *optional*, defaults to 1):
685
- The frequency at which the `callback` function will be called. If not specified, the callback will be
686
- called at every step.
687
- cross_attention_kwargs (`dict`, *optional*):
688
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
689
- `self.processor` in
690
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
691
-
692
- Examples:
693
-
694
- Returns:
695
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
696
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
697
- When returning a tuple, the first element is a list with the generated images, and the second element is a
698
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
699
- (nsfw) content, according to the `safety_checker`.
700
- """
701
- # 0. Default height and width to unet
702
- height = height or self.unet.config.sample_size * self.vae_scale_factor
703
- width = width or self.unet.config.sample_size * self.vae_scale_factor
704
-
705
- # 1. Check inputs. Raise error if not correct
706
- self.check_inputs(prompt, height, width, callback_steps)
707
-
708
- # 2. Define call parameters
709
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
710
- device = self._execution_device
711
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
712
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
713
- # corresponds to doing no classifier free guidance.
714
- do_classifier_free_guidance = guidance_scale > 1.0
715
-
716
- # 3. Encode input prompt
717
- text_embeddings = self._encode_prompt(
718
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
719
- )
720
-
721
- # 4. Prepare timesteps
722
- self.scheduler.set_timesteps(num_inference_steps)
723
- timesteps = self.scheduler.timesteps
724
-
725
- # 5. Prepare latent variables
726
- num_channels_latents = self.unet.in_channels
727
- latents = self.prepare_latents(
728
- batch_size * num_images_per_prompt,
729
- num_channels_latents,
730
- height,
731
- width,
732
- text_embeddings.dtype,
733
- device,
734
- generator,
735
- latents,
736
- )
737
-
738
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
739
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
740
-
741
- # 7. Denoising loop
742
- num_warmup_steps = len(timesteps) - \
743
- num_inference_steps * self.scheduler.order
744
- with self.progress_bar(total=num_inference_steps) as progress_bar:
745
- for i, t in enumerate(timesteps):
746
- # expand the latents if we are doing classifier free guidance
747
- latent_model_input = torch.cat(
748
- [latents] * 2) if do_classifier_free_guidance else latents
749
- latent_model_input = self.scheduler.scale_model_input(
750
- latent_model_input, t)
751
-
752
- noise_pred = self.unet(
753
- latent_model_input, t, encoder_hidden_states=text_embeddings).sample
754
-
755
- # perform guidance
756
- if do_classifier_free_guidance:
757
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
758
- noise_pred = noise_pred_uncond + guidance_scale * \
759
- (noise_pred_text - noise_pred_uncond)
760
-
761
- # compute the previous noisy sample x_t -> x_t-1
762
- latents = self.scheduler.step(
763
- noise_pred, t, latents, **extra_step_kwargs).prev_sample
764
-
765
- # call the callback, if provided
766
- if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
767
- progress_bar.update()
768
- if callback is not None and i % callback_steps == 0:
769
- callback(i, t, latents)
770
-
771
- if not output_type == "latent":
772
- image = self.vae.decode(
773
- latents / self.vae.config.scaling_factor, return_dict=False)[0]
774
- image, has_nsfw_concept = self.run_safety_checker(
775
- image, device, text_embeddings.dtype)
776
- else:
777
- image = latents
778
- has_nsfw_concept = None
779
-
780
- if has_nsfw_concept is None:
781
- do_denormalize = [True] * image.shape[0]
782
- else:
783
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
784
-
785
- image = self.image_processor.postprocess(
786
- image, output_type=output_type, do_denormalize=do_denormalize)
787
-
788
- # Offload last model to CPU
789
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
790
- self.final_offload_hook.offload()
791
-
792
- if not return_dict:
793
- return (image, has_nsfw_concept)
794
-
795
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)