Linoy Tsaban commited on
Commit
b7b2a49
·
1 Parent(s): 39f441b

Delete modified_pipeline_semantic_stable_diffusion.py

Browse files
modified_pipeline_semantic_stable_diffusion.py DELETED
@@ -1,1312 +0,0 @@
1
-
2
- import inspect
3
- import warnings
4
- from itertools import repeat
5
- from typing import Callable, List, Optional, Union
6
-
7
- import torch
8
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
-
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
- from diffusers.models.attention_processor import AttnProcessor, Attention
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import KarrasDiffusionSchedulers
15
- from diffusers.utils import logging
16
- from diffusers.utils.torch_utils import randn_tensor
17
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
19
- # from . import SemanticStableDiffusionPipelineOutput
20
-
21
- import numpy as np
22
- from PIL import Image
23
- from tqdm import tqdm
24
- import torch.nn.functional as F
25
- import math
26
- from collections.abc import Iterable
27
-
28
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
-
30
- class AttentionStore():
31
- @staticmethod
32
- def get_empty_store():
33
- return {"down_cross": [], "mid_cross": [], "up_cross": [],
34
- "down_self": [], "mid_self": [], "up_self": []}
35
-
36
- def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP):
37
- # attn.shape = batch_size * head_size, seq_len query, seq_len_key
38
- bs = 2 + int(PnP) + editing_prompts
39
- source_batch_size = int(attn.shape[0] // bs)
40
- skip = 2 if PnP else 1 # skip PnP & unconditional
41
- self.forward(
42
- attn[skip*source_batch_size:],
43
- is_cross,
44
- place_in_unet)
45
-
46
- def forward(self, attn, is_cross: bool, place_in_unet: str):
47
- key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
48
- if attn.shape[1] <= 32 ** 2: # avoid memory overhead
49
- self.step_store[key].append(attn)
50
-
51
- def between_steps(self, store_step=True):
52
- if store_step:
53
- if self.average:
54
- if len(self.attention_store) == 0:
55
- self.attention_store = self.step_store
56
- else:
57
- for key in self.attention_store:
58
- for i in range(len(self.attention_store[key])):
59
- self.attention_store[key][i] += self.step_store[key][i]
60
- else:
61
- if len(self.attention_store) == 0:
62
- self.attention_store = [self.step_store]
63
- else:
64
- self.attention_store.append(self.step_store)
65
-
66
- self.cur_step += 1
67
- self.step_store = self.get_empty_store()
68
-
69
- def get_attention(self, step: int):
70
- if self.average:
71
- attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
72
- else:
73
- assert(step is not None)
74
- attention = self.attention_store[step]
75
- return attention
76
-
77
- def aggregate_attention(self, attention_maps, prompts, res: int,
78
- from_where: List[str], is_cross: bool, select: int
79
- ):
80
- out = []
81
- num_pixels = res ** 2
82
- for location in from_where:
83
- for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
84
- if item.shape[1] == num_pixels:
85
- cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
86
- out.append(cross_maps)
87
- out = torch.cat(out, dim=0)
88
- # average over heads
89
- out = out.sum(0) / out.shape[0]
90
- return out
91
-
92
- def __init__(self, average: bool):
93
- self.step_store = self.get_empty_store()
94
- self.attention_store = []
95
- self.cur_step = 0
96
- self.average = average
97
-
98
- class CrossAttnProcessor:
99
-
100
- def __init__(self, attention_store, place_in_unet, PnP, editing_prompts):
101
- self.attnstore = attention_store
102
- self.place_in_unet = place_in_unet
103
- self.editing_prompts = editing_prompts
104
- self.PnP = PnP
105
-
106
- def __call__(
107
- self,
108
- attn: Attention,
109
- hidden_states,
110
- encoder_hidden_states=None,
111
- attention_mask=None,
112
- temb=None,
113
- ):
114
- assert(not attn.residual_connection)
115
- assert(attn.spatial_norm is None)
116
- assert(attn.group_norm is None)
117
- assert(hidden_states.ndim != 4)
118
- assert(encoder_hidden_states is not None) # is cross
119
-
120
- batch_size, sequence_length, _ = (
121
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
122
- )
123
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
124
-
125
- query = attn.to_q(hidden_states)
126
-
127
- if encoder_hidden_states is None:
128
- encoder_hidden_states = hidden_states
129
- elif attn.norm_cross:
130
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
131
-
132
- key = attn.to_k(encoder_hidden_states)
133
- value = attn.to_v(encoder_hidden_states)
134
-
135
- query = attn.head_to_batch_dim(query)
136
- key = attn.head_to_batch_dim(key)
137
- value = attn.head_to_batch_dim(value)
138
-
139
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
140
- self.attnstore(attention_probs,
141
- is_cross=True,
142
- place_in_unet=self.place_in_unet,
143
- editing_prompts=self.editing_prompts,
144
- PnP=self.PnP)
145
-
146
- hidden_states = torch.bmm(attention_probs, value)
147
- hidden_states = attn.batch_to_head_dim(hidden_states)
148
-
149
- # linear proj
150
- hidden_states = attn.to_out[0](hidden_states)
151
- # dropout
152
- hidden_states = attn.to_out[1](hidden_states)
153
-
154
- hidden_states = hidden_states / attn.rescale_output_factor
155
- return hidden_states
156
-
157
- # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing
158
- class GaussianSmoothing():
159
-
160
- def __init__(self, device):
161
- kernel_size = [3, 3]
162
- sigma = [0.5, 0.5]
163
-
164
- # The gaussian kernel is the product of the gaussian function of each dimension.
165
- kernel = 1
166
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
167
- for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
168
- mean = (size - 1) / 2
169
- kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
170
-
171
- # Make sure sum of values in gaussian kernel equals 1.
172
- kernel = kernel / torch.sum(kernel)
173
-
174
- # Reshape to depthwise convolutional weight
175
- kernel = kernel.view(1, 1, *kernel.size())
176
- kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1))
177
-
178
- self.weight = kernel.to(device)
179
-
180
- def __call__(self, input):
181
- """
182
- Arguments:
183
- Apply gaussian filter to input.
184
- input (torch.Tensor): Input to apply gaussian filter on.
185
- Returns:
186
- filtered (torch.Tensor): Filtered output.
187
- """
188
- return F.conv2d(input, weight=self.weight.to(input.dtype))
189
-
190
- def load_512(image_path, size, left=0, right=0, top=0, bottom=0, device=None, dtype=None):
191
- def pre_process(im, size, left=0, right=0, top=0, bottom=0):
192
- if type(im) is str:
193
- image = np.array(Image.open(im).convert('RGB'))[:, :, :3]
194
- elif isinstance(im, Image.Image):
195
- image = np.array((im).convert('RGB'))[:, :, :3]
196
- else:
197
- image = im
198
- h, w, c = image.shape
199
- left = min(left, w - 1)
200
- right = min(right, w - left - 1)
201
- top = min(top, h - left - 1)
202
- bottom = min(bottom, h - top - 1)
203
- image = image[top:h - bottom, left:w - right]
204
- h, w, c = image.shape
205
- if h < w:
206
- offset = (w - h) // 2
207
- image = image[:, offset:offset + h]
208
- elif w < h:
209
- offset = (h - w) // 2
210
- image = image[offset:offset + w]
211
- image = np.array(Image.fromarray(image).resize((size, size)))
212
- image = torch.from_numpy(image).float().permute(2, 0, 1)
213
- return image
214
-
215
- tmps = []
216
- if isinstance(image_path, list):
217
- for item in image_path:
218
- tmps.append(pre_process(item, size, left, right, top, bottom))
219
- else:
220
- tmps.append(pre_process(image_path, size, left, right, top, bottom))
221
- image = torch.stack(tmps) / 127.5 - 1
222
-
223
- image = image.to(device=device, dtype=dtype)
224
- return image
225
-
226
- # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing
227
- def reset_dpm(scheduler):
228
- if isinstance(scheduler, DPMSolverMultistepSchedulerInject):
229
- scheduler.model_outputs = [
230
- None,
231
- ] * scheduler.config.solver_order
232
- scheduler.lower_order_nums = 0
233
-
234
- class SemanticStableDiffusionPipeline(DiffusionPipeline):
235
- r"""
236
- Pipeline for text-to-image generation with latent editing.
237
-
238
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
239
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
240
-
241
- This model builds on the implementation of ['StableDiffusionPipeline']
242
-
243
- Args:
244
- vae ([`AutoencoderKL`]):
245
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
246
- text_encoder ([`CLIPTextModel`]):
247
- Frozen text-encoder. Stable Diffusion uses the text portion of
248
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
249
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
250
- tokenizer (`CLIPTokenizer`):
251
- Tokenizer of class
252
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
253
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
254
- scheduler ([`SchedulerMixin`]):
255
- A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
256
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
257
- safety_checker ([`Q16SafetyChecker`]):
258
- Classification module that estimates whether generated images could be considered offensive or harmful.
259
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
260
- feature_extractor ([`CLIPImageProcessor`]):
261
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
262
- """
263
-
264
- _optional_components = ["safety_checker", "feature_extractor"]
265
-
266
- def __init__(
267
- self,
268
- vae: AutoencoderKL,
269
- text_encoder: CLIPTextModel,
270
- tokenizer: CLIPTokenizer,
271
- unet: UNet2DConditionModel,
272
- scheduler: KarrasDiffusionSchedulers,
273
- safety_checker: StableDiffusionSafetyChecker,
274
- feature_extractor: CLIPImageProcessor,
275
- requires_safety_checker: bool = True,
276
- ):
277
- super().__init__()
278
-
279
- if safety_checker is None and requires_safety_checker:
280
- logger.warning(
281
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
282
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
283
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
284
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
285
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
286
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
287
- )
288
-
289
- if safety_checker is not None and feature_extractor is None:
290
- raise ValueError(
291
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
292
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
293
- )
294
-
295
- self.register_modules(
296
- vae=vae,
297
- text_encoder=text_encoder,
298
- tokenizer=tokenizer,
299
- unet=unet,
300
- scheduler=scheduler,
301
- safety_checker=safety_checker,
302
- feature_extractor=feature_extractor,
303
- )
304
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
305
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
306
- self.register_to_config(requires_safety_checker=requires_safety_checker)
307
-
308
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
309
- def run_safety_checker(self, image, device, dtype):
310
- if self.safety_checker is None:
311
- has_nsfw_concept = None
312
- else:
313
- if torch.is_tensor(image):
314
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
315
- else:
316
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
317
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
318
- image, has_nsfw_concept = self.safety_checker(
319
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
320
- )
321
- return image, has_nsfw_concept
322
-
323
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
324
- def decode_latents(self, latents):
325
- warnings.warn(
326
- "The decode_latents method is deprecated and will be removed in a future version. Please"
327
- " use VaeImageProcessor instead",
328
- FutureWarning,
329
- )
330
- latents = 1 / self.vae.config.scaling_factor * latents
331
- image = self.vae.decode(latents, return_dict=False)[0]
332
- image = (image / 2 + 0.5).clamp(0, 1)
333
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
334
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
335
- return image
336
-
337
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
338
- def prepare_extra_step_kwargs(self, generator, eta):
339
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
340
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
341
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
342
- # and should be between [0, 1]
343
-
344
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
345
- extra_step_kwargs = {}
346
- if accepts_eta:
347
- extra_step_kwargs["eta"] = eta
348
-
349
- # check if the scheduler accepts generator
350
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
351
- if accepts_generator:
352
- extra_step_kwargs["generator"] = generator
353
- return extra_step_kwargs
354
-
355
-
356
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
357
- def check_inputs(
358
- self,
359
- prompt,
360
- height,
361
- width,
362
- callback_steps,
363
- negative_prompt=None,
364
- prompt_embeds=None,
365
- negative_prompt_embeds=None,
366
- ):
367
- if height % 8 != 0 or width % 8 != 0:
368
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
369
-
370
- if (callback_steps is None) or (
371
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
372
- ):
373
- raise ValueError(
374
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
375
- f" {type(callback_steps)}."
376
- )
377
-
378
- if prompt is not None and prompt_embeds is not None:
379
- raise ValueError(
380
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
381
- " only forward one of the two."
382
- )
383
- elif prompt is None and prompt_embeds is None:
384
- raise ValueError(
385
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
386
- )
387
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
388
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
389
-
390
- if negative_prompt is not None and negative_prompt_embeds is not None:
391
- raise ValueError(
392
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
393
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
394
- )
395
-
396
- if prompt_embeds is not None and negative_prompt_embeds is not None:
397
- if prompt_embeds.shape != negative_prompt_embeds.shape:
398
- raise ValueError(
399
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
400
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
401
- f" {negative_prompt_embeds.shape}."
402
- )
403
-
404
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
405
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
406
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
407
- if isinstance(generator, list) and len(generator) != batch_size:
408
- raise ValueError(
409
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
410
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
411
- )
412
-
413
- if latents is None:
414
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415
- else:
416
- latents = latents.to(device)
417
-
418
- # scale the initial noise by the standard deviation required by the scheduler
419
- latents = latents * self.scheduler.init_noise_sigma
420
- return latents
421
-
422
- def prepare_unet(self, attention_store, PnP: bool):
423
- attn_procs = {}
424
- for name in self.unet.attn_processors.keys():
425
- if name.startswith("mid_block"):
426
- place_in_unet = "mid"
427
- elif name.startswith("up_blocks"):
428
- place_in_unet = "up"
429
- elif name.startswith("down_blocks"):
430
- place_in_unet = "down"
431
- else:
432
- continue
433
-
434
- if "attn2" in name:
435
- attn_procs[name] = CrossAttnProcessor(
436
- attention_store=attention_store,
437
- place_in_unet=place_in_unet,
438
- PnP=PnP,
439
- editing_prompts=self.enabled_editing_prompts)
440
- else:
441
- attn_procs[name] = AttnProcessor()
442
-
443
- self.unet.set_attn_processor(attn_procs)
444
-
445
- @torch.no_grad()
446
- def __call__(
447
- self,
448
- prompt: Union[str, List[str]],
449
- height: Optional[int] = None,
450
- width: Optional[int] = None,
451
- num_inference_steps: int = 50,
452
- guidance_scale: float = 7.5,
453
- negative_prompt: Optional[Union[str, List[str]]] = None,
454
- num_images_per_prompt: int = 1,
455
- eta: float = 0.0,
456
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
457
- latents: Optional[torch.FloatTensor] = None,
458
- output_type: Optional[str] = "pil",
459
- return_dict: bool = True,
460
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
461
- callback_steps: int = 1,
462
- editing_prompt: Optional[Union[str, List[str]]] = None,
463
- editing_prompt_embeddings: Optional[torch.Tensor] = None,
464
- reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
465
- edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
466
- edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
467
- edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
468
- edit_threshold: Optional[Union[float, List[float]]] = 0.9,
469
- edit_momentum_scale: Optional[float] = 0.1,
470
- edit_mom_beta: Optional[float] = 0.4,
471
- edit_weights: Optional[List[float]] = None,
472
- sem_guidance: Optional[List[torch.Tensor]] = None,
473
-
474
- # masking
475
- use_cross_attn_mask: bool = False,
476
- use_intersect_mask: bool = True,
477
- edit_tokens_for_attn_map: List[str] = None,
478
-
479
- # Attention store (just for visualization purposes)
480
- attn_store_steps: Optional[List[int]] = [],
481
- store_averaged_over_steps: bool = True,
482
-
483
- # DDPM additions
484
- use_ddpm: bool = False,
485
- wts: Optional[List[torch.Tensor]] = None,
486
- zs: Optional[List[torch.Tensor]] = None
487
- ):
488
- r"""
489
- Function invoked when calling the pipeline for generation.
490
-
491
- Args:
492
- prompt (`str` or `List[str]`):
493
- The prompt or prompts to guide the image generation.
494
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
495
- The height in pixels of the generated image.
496
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
497
- The width in pixels of the generated image.
498
- num_inference_steps (`int`, *optional*, defaults to 50):
499
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
500
- expense of slower inference.
501
- guidance_scale (`float`, *optional*, defaults to 7.5):
502
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
503
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
504
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
505
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
506
- usually at the expense of lower image quality.
507
- negative_prompt (`str` or `List[str]`, *optional*):
508
- The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
509
- if `guidance_scale` is less than `1`).
510
- num_images_per_prompt (`int`, *optional*, defaults to 1):
511
- The number of images to generate per prompt.
512
- eta (`float`, *optional*, defaults to 0.0):
513
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
514
- [`schedulers.DDIMScheduler`], will be ignored for others.
515
- generator (`torch.Generator`, *optional*):
516
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
517
- to make generation deterministic.
518
- latents (`torch.FloatTensor`, *optional*):
519
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
520
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
521
- tensor will ge generated by sampling using the supplied random `generator`.
522
- output_type (`str`, *optional*, defaults to `"pil"`):
523
- The output format of the generate image. Choose between
524
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
525
- return_dict (`bool`, *optional*, defaults to `True`):
526
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
527
- plain tuple.
528
- callback (`Callable`, *optional*):
529
- A function that will be called every `callback_steps` steps during inference. The function will be
530
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
531
- callback_steps (`int`, *optional*, defaults to 1):
532
- The frequency at which the `callback` function will be called. If not specified, the callback will be
533
- called at every step.
534
- editing_prompt (`str` or `List[str]`, *optional*):
535
- The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting
536
- `editing_prompt = None`. Guidance direction of prompt should be specified via
537
- `reverse_editing_direction`.
538
- editing_prompt_embeddings (`torch.Tensor>`, *optional*):
539
- Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be
540
- specified via `reverse_editing_direction`.
541
- reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
542
- Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
543
- edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
544
- Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`.
545
- `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA
546
- Paper](https://arxiv.org/pdf/2301.12247.pdf).
547
- edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
548
- Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum
549
- will still be calculated for those steps and applied once all warmup periods are over.
550
- `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf).
551
- edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
552
- Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied.
553
- edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
554
- Threshold of semantic guidance.
555
- edit_momentum_scale (`float`, *optional*, defaults to 0.1):
556
- Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0
557
- momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
558
- than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are
559
- finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA
560
- Paper](https://arxiv.org/pdf/2301.12247.pdf).
561
- edit_mom_beta (`float`, *optional*, defaults to 0.4):
562
- Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous
563
- momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
564
- than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA
565
- Paper](https://arxiv.org/pdf/2301.12247.pdf).
566
- edit_weights (`List[float]`, *optional*, defaults to `None`):
567
- Indicates how much each individual concept should influence the overall guidance. If no weights are
568
- provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA
569
- Paper](https://arxiv.org/pdf/2301.12247.pdf).
570
- sem_guidance (`List[torch.Tensor]`, *optional*):
571
- List of pre-generated guidance vectors to be applied at generation. Length of the list has to
572
- correspond to `num_inference_steps`.
573
-
574
- Returns:
575
- [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`:
576
- [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True,
577
- otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the
578
- second element is a list of `bool`s denoting whether the corresponding generated image likely represents
579
- "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
580
- """
581
- if use_intersect_mask:
582
- use_cross_attn_mask = True
583
-
584
- if use_cross_attn_mask:
585
- self.smoothing = GaussianSmoothing(self.device)
586
-
587
- # 0. Default height and width to unet
588
- height = height or self.unet.config.sample_size * self.vae_scale_factor
589
- width = width or self.unet.config.sample_size * self.vae_scale_factor
590
-
591
- # 1. Check inputs. Raise error if not correct
592
- self.check_inputs(prompt, height, width, callback_steps)
593
-
594
- if use_ddpm:
595
- reset_dpm(self.scheduler)
596
-
597
- # 2. Define call parameters
598
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
599
-
600
- if editing_prompt:
601
- enable_edit_guidance = True
602
- if isinstance(editing_prompt, str):
603
- editing_prompt = [editing_prompt]
604
- self.enabled_editing_prompts = len(editing_prompt)
605
- elif editing_prompt_embeddings is not None:
606
- enable_edit_guidance = True
607
- self.enabled_editing_prompts = editing_prompt_embeddings.shape[0]
608
- else:
609
- self.enabled_editing_prompts = 0
610
- enable_edit_guidance = False
611
-
612
- # get prompt text embeddings
613
- text_inputs = self.tokenizer(
614
- prompt,
615
- padding="max_length",
616
- max_length=self.tokenizer.model_max_length,
617
- truncation=True,
618
- return_tensors="pt",
619
- )
620
- text_input_ids = text_inputs.input_ids
621
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
622
-
623
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
624
- text_input_ids, untruncated_ids
625
- ):
626
- removed_text = self.tokenizer.batch_decode(
627
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
628
- )
629
- logger.warning(
630
- "The following part of your input was truncated because CLIP can only handle sequences up to"
631
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
632
- )
633
-
634
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
635
-
636
- # duplicate text embeddings for each generation per prompt, using mps friendly method
637
- bs_embed, seq_len, _ = text_embeddings.shape
638
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
639
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
640
-
641
- if enable_edit_guidance:
642
- # get safety text embeddings
643
- if editing_prompt_embeddings is None:
644
- if edit_tokens_for_attn_map is not None:
645
- edit_tokens = [[word.replace("</w>", "") for word in self.tokenizer.tokenize(item)] for item in editing_prompt]
646
- #print(f"edit_tokens: {edit_tokens}")
647
-
648
- edit_concepts_input = self.tokenizer(
649
- [x for item in editing_prompt for x in repeat(item, batch_size)],
650
- padding="max_length",
651
- max_length=self.tokenizer.model_max_length,
652
- truncation=True,
653
- return_tensors="pt",
654
- return_length=True
655
- )
656
-
657
- num_edit_tokens = edit_concepts_input.length -2 # not counting startoftext and endoftext
658
- edit_concepts_input_ids = edit_concepts_input.input_ids
659
- untruncated_ids = self.tokenizer(
660
- [x for item in editing_prompt for x in repeat(item, batch_size)],
661
- padding="longest",
662
- return_tensors="pt").input_ids
663
-
664
- if untruncated_ids.shape[-1] >= edit_concepts_input_ids.shape[-1] and not torch.equal(
665
- edit_concepts_input_ids, untruncated_ids
666
- ):
667
- removed_text = self.tokenizer.batch_decode(
668
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
669
- )
670
- logger.warning(
671
- "The following part of your input was truncated because CLIP can only handle sequences up to"
672
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
673
- )
674
-
675
- edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
676
- else:
677
- edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
678
-
679
- # duplicate text embeddings for each generation per prompt, using mps friendly method
680
- bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
681
- edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1)
682
- edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1)
683
-
684
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
685
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
686
- # corresponds to doing no classifier free guidance.
687
- do_classifier_free_guidance = guidance_scale > 1.0
688
- # get unconditional embeddings for classifier free guidance
689
-
690
- if do_classifier_free_guidance:
691
- uncond_tokens: List[str]
692
- if negative_prompt is None:
693
- uncond_tokens = [""]
694
- elif type(prompt) is not type(negative_prompt):
695
- raise TypeError(
696
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
697
- f" {type(prompt)}."
698
- )
699
- elif isinstance(negative_prompt, str):
700
- uncond_tokens = [negative_prompt]
701
- elif batch_size != len(negative_prompt):
702
- raise ValueError(
703
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
704
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
705
- " the batch size of `prompt`."
706
- )
707
- else:
708
- uncond_tokens = negative_prompt
709
-
710
- max_length = text_input_ids.shape[-1]
711
- uncond_input = self.tokenizer(
712
- uncond_tokens,
713
- padding="max_length",
714
- max_length=max_length,
715
- truncation=True,
716
- return_tensors="pt",
717
- )
718
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
719
-
720
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
721
- seq_len = uncond_embeddings.shape[1]
722
- uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
723
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
724
-
725
- # For classifier free guidance, we need to do two forward passes.
726
- # Here we concatenate the unconditional and text embeddings into a single batch
727
- # to avoid doing two forward passes
728
- self.text_cross_attention_maps = [prompt] if isinstance(prompt, str) else prompt
729
- if enable_edit_guidance:
730
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
731
- self.text_cross_attention_maps += \
732
- ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
733
- else:
734
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
735
- # get the initial random noise unless the user supplied it
736
-
737
- # 4. Prepare timesteps
738
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
739
- timesteps = self.scheduler.timesteps
740
- if use_ddpm:
741
- t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
742
- timesteps = timesteps[-zs.shape[0]:]
743
-
744
- self.attention_store = AttentionStore(average=store_averaged_over_steps)
745
- self.prepare_unet(self.attention_store, False)
746
-
747
- # 5. Prepare latent variables
748
- num_channels_latents = self.unet.config.in_channels
749
- latents = self.prepare_latents(
750
- batch_size * num_images_per_prompt,
751
- num_channels_latents,
752
- height,
753
- width,
754
- text_embeddings.dtype,
755
- self.device,
756
- generator,
757
- latents,
758
- )
759
-
760
- # 6. Prepare extra step kwargs.
761
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
762
-
763
- # Initialize edit_momentum to None
764
- edit_momentum = None
765
-
766
- self.uncond_estimates = None
767
- self.text_estimates = None
768
- self.edit_estimates = None
769
- self.sem_guidance = None
770
-
771
- for i, t in enumerate(self.progress_bar(timesteps)):
772
- # expand the latents if we are doing classifier free guidance
773
- latent_model_input = (
774
- torch.cat([latents] * (2 + self.enabled_editing_prompts)) if do_classifier_free_guidance else latents
775
- )
776
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
777
-
778
- # predict the noise residual
779
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
780
-
781
- # perform guidance
782
- if do_classifier_free_guidance:
783
- noise_pred_out = noise_pred.chunk(2 + self.enabled_editing_prompts) # [b,4, 64, 64]
784
- noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
785
- noise_pred_edit_concepts = noise_pred_out[2:]
786
-
787
- # default text guidance
788
- noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
789
- # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0])
790
-
791
- if self.uncond_estimates is None:
792
- self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape))
793
- self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
794
-
795
- if self.text_estimates is None:
796
- self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
797
- self.text_estimates[i] = noise_pred_text.detach().cpu()
798
-
799
- if self.edit_estimates is None and enable_edit_guidance:
800
- self.edit_estimates = torch.zeros(
801
- (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
802
- )
803
-
804
- if self.sem_guidance is None:
805
- self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
806
-
807
- if edit_momentum is None:
808
- edit_momentum = torch.zeros_like(noise_guidance)
809
-
810
- if enable_edit_guidance:
811
- concept_weights = torch.zeros(
812
- (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
813
- device=self.device,
814
- dtype=noise_guidance.dtype,
815
- )
816
- noise_guidance_edit = torch.zeros(
817
- (len(noise_pred_edit_concepts), *noise_guidance.shape),
818
- device=self.device,
819
- dtype=noise_guidance.dtype,
820
- )
821
- # noise_guidance_edit = torch.zeros_like(noise_guidance)
822
- warmup_inds = []
823
- for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
824
- self.edit_estimates[i, c] = noise_pred_edit_concept
825
- if isinstance(edit_guidance_scale, list):
826
- edit_guidance_scale_c = edit_guidance_scale[c]
827
- else:
828
- edit_guidance_scale_c = edit_guidance_scale
829
-
830
- if isinstance(edit_threshold, list):
831
- edit_threshold_c = edit_threshold[c]
832
- else:
833
- edit_threshold_c = edit_threshold
834
- if isinstance(reverse_editing_direction, list):
835
- reverse_editing_direction_c = reverse_editing_direction[c]
836
- else:
837
- reverse_editing_direction_c = reverse_editing_direction
838
- if edit_weights:
839
- edit_weight_c = edit_weights[c]
840
- else:
841
- edit_weight_c = 1.0
842
- if isinstance(edit_warmup_steps, list):
843
- edit_warmup_steps_c = edit_warmup_steps[c]
844
- else:
845
- edit_warmup_steps_c = edit_warmup_steps
846
-
847
- if isinstance(edit_cooldown_steps, list):
848
- edit_cooldown_steps_c = edit_cooldown_steps[c]
849
- elif edit_cooldown_steps is None:
850
- edit_cooldown_steps_c = i + 1
851
- else:
852
- edit_cooldown_steps_c = edit_cooldown_steps
853
- if i >= edit_warmup_steps_c:
854
- warmup_inds.append(c)
855
- if i >= edit_cooldown_steps_c:
856
- noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
857
- continue
858
-
859
- noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
860
- # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
861
- tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
862
-
863
- tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
864
- if reverse_editing_direction_c:
865
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
866
- concept_weights[c, :] = tmp_weights
867
-
868
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
869
-
870
- if use_cross_attn_mask:
871
- out = self.attention_store.aggregate_attention(
872
- attention_maps=self.attention_store.step_store,
873
- prompts=self.text_cross_attention_maps,
874
- res=16,
875
- from_where=["up","down"],
876
- is_cross=True,
877
- select=self.text_cross_attention_maps.index(editing_prompt[c]),
878
- )
879
-
880
- attn_map = out[:, :, 1:] # 0 -> startoftext
881
- attn_map *= 100
882
- attn_map = torch.nn.functional.softmax(attn_map, dim=-1)
883
- attn_map = attn_map[:,:,:num_edit_tokens[c]] # -1 -> endoftext
884
-
885
- assert(attn_map.shape[2]==num_edit_tokens[c])
886
- if edit_tokens_for_attn_map is not None:
887
- # select attn_map for specified tokens
888
- token_idx = [edit_tokens[c].index(item) for item in edit_tokens_for_attn_map[c]]
889
- attn_map = attn_map[:,:,token_idx]
890
- assert(attn_map.shape[2] == len(edit_tokens_for_attn_map[c]))
891
-
892
- # average over tokens
893
- attn_map = torch.sum(attn_map, dim=2)
894
-
895
- # gaussian_smoothing
896
- attn_map = F.pad(attn_map.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect")
897
- attn_map = self.smoothing(attn_map).squeeze(0).squeeze(0)
898
-
899
- # torch.quantile function expects float32
900
- if attn_map.dtype == torch.float32:
901
- tmp = torch.quantile(
902
- attn_map.flatten(),
903
- edit_threshold_c
904
- )
905
- else:
906
- tmp = torch.quantile(
907
- attn_map.flatten().to(torch.float32),
908
- edit_threshold_c
909
- ).to(attn_map.dtype)
910
-
911
- attn_mask = torch.where(attn_map >= tmp, 1.0, 0.0)
912
-
913
- # resolution must match latent space dimension
914
- attn_mask = F.interpolate(
915
- attn_mask.unsqueeze(0).unsqueeze(0),
916
- noise_guidance_edit_tmp.shape[-2:] # 64,64
917
- )[0,0,:,:]
918
-
919
- if not use_intersect_mask:
920
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
921
-
922
- if use_intersect_mask:
923
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
924
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True)
925
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,4,1,1)
926
-
927
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
928
- tmp = torch.quantile(
929
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
930
- edit_threshold_c,
931
- dim=2,
932
- keepdim=False,
933
- )
934
- else:
935
- tmp = torch.quantile(
936
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
937
- edit_threshold_c,
938
- dim=2,
939
- keepdim=False,
940
- ).to(noise_guidance_edit_tmp_quantile.dtype)
941
-
942
- sega_mask = torch.where(
943
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
944
- torch.ones_like(noise_guidance_edit_tmp),
945
- torch.zeros_like(noise_guidance_edit_tmp),
946
- )
947
-
948
- intersect_mask = sega_mask * attn_mask
949
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
950
-
951
- elif not use_cross_attn_mask:
952
- # calculate quantile
953
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
954
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1, keepdim=True)
955
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1,4,1,1)
956
-
957
- # torch.quantile function expects float32
958
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
959
- tmp = torch.quantile(
960
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
961
- edit_threshold_c,
962
- dim=2,
963
- keepdim=False,
964
- )
965
- else:
966
- tmp = torch.quantile(
967
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
968
- edit_threshold_c,
969
- dim=2,
970
- keepdim=False,
971
- ).to(noise_guidance_edit_tmp_quantile.dtype)
972
-
973
- noise_guidance_edit_tmp = torch.where(
974
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
975
- noise_guidance_edit_tmp,
976
- torch.zeros_like(noise_guidance_edit_tmp),
977
- )
978
-
979
- noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
980
-
981
- # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
982
-
983
- warmup_inds = torch.tensor(warmup_inds).to(self.device)
984
- if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
985
- concept_weights = concept_weights.to("cpu") # Offload to cpu
986
- noise_guidance_edit = noise_guidance_edit.to("cpu")
987
-
988
- concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
989
- concept_weights_tmp = torch.where(
990
- concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
991
- )
992
- concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
993
- # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
994
-
995
- noise_guidance_edit_tmp = torch.index_select(
996
- noise_guidance_edit.to(self.device), 0, warmup_inds
997
- )
998
- noise_guidance_edit_tmp = torch.einsum(
999
- "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
1000
- )
1001
- noise_guidance_edit_tmp = noise_guidance_edit_tmp
1002
- noise_guidance = noise_guidance + noise_guidance_edit_tmp
1003
-
1004
- self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
1005
-
1006
- del noise_guidance_edit_tmp
1007
- del concept_weights_tmp
1008
- concept_weights = concept_weights.to(self.device)
1009
- noise_guidance_edit = noise_guidance_edit.to(self.device)
1010
-
1011
- concept_weights = torch.where(
1012
- concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
1013
- )
1014
-
1015
- concept_weights = torch.nan_to_num(concept_weights)
1016
-
1017
- noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
1018
-
1019
- noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
1020
-
1021
- edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
1022
-
1023
- if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
1024
- noise_guidance = noise_guidance + noise_guidance_edit
1025
- self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1026
-
1027
- if sem_guidance is not None:
1028
- edit_guidance = sem_guidance[i].to(self.device)
1029
- noise_guidance = noise_guidance + edit_guidance
1030
-
1031
- noise_pred = noise_pred_uncond + noise_guidance
1032
- ## ddpm ###########################################################
1033
- if use_ddpm:
1034
- idx = t_to_idx[int(t)]
1035
- latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx],
1036
- **extra_step_kwargs).prev_sample
1037
-
1038
- ## ddpm ##########################################################
1039
- # compute the previous noisy sample x_t -> x_t-1
1040
- else: #if not use_ddpm:
1041
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1042
-
1043
- # step callback
1044
- store_step = i in attn_store_steps
1045
- if store_step:
1046
- print("storing attention")
1047
- self.attention_store.between_steps(store_step)
1048
-
1049
- # call the callback, if provided
1050
- if callback is not None and i % callback_steps == 0:
1051
- callback(i, t, latents)
1052
-
1053
- # 8. Post-processing
1054
- if not output_type == "latent":
1055
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1056
- image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
1057
- else:
1058
- image = latents
1059
- has_nsfw_concept = None
1060
-
1061
- if has_nsfw_concept is None:
1062
- do_denormalize = [True] * image.shape[0]
1063
- else:
1064
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1065
-
1066
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1067
-
1068
- if not return_dict:
1069
- return (image, has_nsfw_concept)
1070
-
1071
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1072
-
1073
- def encode_text(self, prompts):
1074
- text_inputs = self.tokenizer(
1075
- prompts,
1076
- padding="max_length",
1077
- max_length=self.tokenizer.model_max_length,
1078
- return_tensors="pt",
1079
- )
1080
- text_input_ids = text_inputs.input_ids
1081
-
1082
- if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
1083
- removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
1084
- logger.warning(
1085
- "The following part of your input was truncated because CLIP can only handle sequences up to"
1086
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1087
- )
1088
- text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
1089
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
1090
-
1091
- return text_embeddings
1092
-
1093
- @torch.no_grad()
1094
- def invert(self,
1095
- image_path: str,
1096
- source_prompt: str = "",
1097
- source_guidance_scale=3.5,
1098
- num_inversion_steps: int = 30,
1099
- skip: float = 0.15,
1100
- eta: float = 1.0,
1101
- generator: Optional[torch.Generator] = None,
1102
- verbose=True,
1103
- ):
1104
- """
1105
- Inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
1106
- based on the code in https://github.com/inbarhub/DDPM_inversion
1107
-
1108
- returns:
1109
- zs - noise maps
1110
- xts - intermediate inverted latents
1111
- """
1112
-
1113
- # self.eta = eta
1114
- # assert (self.eta > 0)
1115
-
1116
- train_steps = self.scheduler.config.num_train_timesteps
1117
- timesteps = torch.from_numpy(
1118
- np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device)
1119
-
1120
-
1121
- num_inversion_steps = timesteps.shape[0]
1122
- self.scheduler.num_inference_steps = timesteps.shape[0]
1123
- self.scheduler.timesteps = timesteps
1124
-
1125
-
1126
- # 1. get embeddings
1127
-
1128
- uncond_embedding = self.encode_text("")
1129
-
1130
- # 2. encode image
1131
- x0 = self.encode_image(image_path, dtype=uncond_embedding.dtype)
1132
- batch_size = x0.shape[0]
1133
-
1134
- if not source_prompt == "":
1135
- text_embeddings = self.encode_text(source_prompt).repeat((batch_size, 1, 1))
1136
- uncond_embedding = uncond_embedding.repeat((batch_size, 1, 1))
1137
- # autoencoder reconstruction
1138
- # image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0]
1139
- # image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1140
-
1141
- # 3. find zs and xts
1142
- variance_noise_shape = (
1143
- num_inversion_steps,
1144
- batch_size,
1145
- self.unet.config.in_channels,
1146
- self.unet.sample_size,
1147
- self.unet.sample_size)
1148
-
1149
- # intermediate latents
1150
- t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1151
- xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype)
1152
-
1153
- for t in reversed(timesteps):
1154
- idx = num_inversion_steps-t_to_idx[int(t)] - 1
1155
- noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1156
- xts[idx] = self.scheduler.add_noise(x0, noise, t)
1157
- xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1158
-
1159
- reset_dpm(self.scheduler)
1160
- # noise maps
1161
- zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype)
1162
-
1163
- for t in self.progress_bar(timesteps, verbose=verbose):
1164
-
1165
- idx = num_inversion_steps-t_to_idx[int(t)]-1
1166
- # 1. predict noise residual
1167
- xt = xts[idx+1]
1168
-
1169
- noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample
1170
-
1171
- if not source_prompt == "":
1172
- noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample
1173
- noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred)
1174
-
1175
- xtm1 = xts[idx]
1176
- z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, eta)
1177
- zs[idx] = z
1178
-
1179
- # correction to avoid error accumulation
1180
- xts[idx] = xtm1_corrected
1181
-
1182
- # TODO: I don't think that the noise map for the last step should be discarded ?!
1183
- # if not zs is None:
1184
- # zs[-1] = torch.zeros_like(zs[-1])
1185
- # self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
1186
- zs = zs.flip(0)
1187
- # self.zs = zs
1188
-
1189
-
1190
- return zs, xts
1191
- # return zs, xts, image_rec
1192
-
1193
- @torch.no_grad()
1194
- def encode_image(self, image_path, dtype=None):
1195
- image = load_512(image_path,
1196
- size=self.unet.sample_size * self.vae_scale_factor,
1197
- device=self.device,
1198
- dtype=dtype)
1199
- x0 = self.vae.encode(image).latent_dist.mode()
1200
- x0 = self.vae.config.scaling_factor * x0
1201
- return x0
1202
-
1203
- def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1204
- # 1. get previous step value (=t-1)
1205
- prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
1206
-
1207
- # 2. compute alphas, betas
1208
- alpha_prod_t = scheduler.alphas_cumprod[timestep]
1209
- alpha_prod_t_prev = (
1210
- scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
1211
- )
1212
-
1213
- beta_prod_t = 1 - alpha_prod_t
1214
-
1215
- # 3. compute predicted original sample from predicted noise also called
1216
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1217
- pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
1218
-
1219
- # 4. Clip "predicted x_0"
1220
- if scheduler.config.clip_sample:
1221
- pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
1222
-
1223
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
1224
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
1225
- variance = scheduler._get_variance(timestep, prev_timestep)
1226
- std_dev_t = eta * variance ** (0.5)
1227
-
1228
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1229
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * noise_pred
1230
-
1231
- # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
1232
- mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
1233
- noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
1234
-
1235
- return noise, mu_xt + (eta * variance ** 0.5) * noise
1236
-
1237
- # Copied from pipelines.StableDiffusion.CycleDiffusionPipeline.compute_noise
1238
- def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1239
-
1240
- def first_order_update(model_output, timestep, prev_timestep, sample):
1241
- lambda_t, lambda_s = scheduler.lambda_t[prev_timestep], scheduler.lambda_t[timestep]
1242
- alpha_t, alpha_s = scheduler.alpha_t[prev_timestep], scheduler.alpha_t[timestep]
1243
- sigma_t, sigma_s = scheduler.sigma_t[prev_timestep], scheduler.sigma_t[timestep]
1244
- h = lambda_t - lambda_s
1245
-
1246
- mu_xt = (
1247
- (sigma_t / sigma_s * torch.exp(-h)) * sample
1248
- + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
1249
- )
1250
- sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1251
-
1252
- noise = (prev_latents - mu_xt) / sigma
1253
-
1254
- prev_sample = mu_xt + sigma * noise
1255
-
1256
- return noise, prev_sample
1257
- def second_order_update(model_output_list, timestep_list, prev_timestep, sample):
1258
- t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
1259
- m0, m1 = model_output_list[-1], model_output_list[-2]
1260
- lambda_t, lambda_s0, lambda_s1 = scheduler.lambda_t[t], scheduler.lambda_t[s0], scheduler.lambda_t[s1]
1261
- alpha_t, alpha_s0 = scheduler.alpha_t[t], scheduler.alpha_t[s0]
1262
- sigma_t, sigma_s0 = scheduler.sigma_t[t], scheduler.sigma_t[s0]
1263
- h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
1264
- r0 = h_0 / h
1265
- D0, D1 = m0, (1.0 / r0) * (m0 - m1)
1266
-
1267
- mu_xt = (
1268
- (sigma_t / sigma_s0 * torch.exp(-h)) * sample
1269
- + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
1270
- + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
1271
- )
1272
- sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1273
-
1274
- noise = (prev_latents - mu_xt) / sigma
1275
-
1276
- prev_sample = mu_xt + sigma * noise
1277
-
1278
- return noise, prev_sample
1279
-
1280
- step_index = (scheduler.timesteps == timestep).nonzero()
1281
- if len(step_index) == 0:
1282
- step_index = len(scheduler.timesteps) - 1
1283
- else:
1284
- step_index = step_index.item()
1285
-
1286
- prev_timestep = 0 if step_index == len(scheduler.timesteps) - 1 else scheduler.timesteps[step_index + 1]
1287
-
1288
- model_output = scheduler.convert_model_output(noise_pred, timestep, latents)
1289
-
1290
- for i in range(scheduler.config.solver_order - 1):
1291
- scheduler.model_outputs[i] = scheduler.model_outputs[i + 1]
1292
- scheduler.model_outputs[-1] = model_output
1293
-
1294
- if scheduler.lower_order_nums < 1:
1295
- noise, prev_sample = first_order_update(model_output, timestep, prev_timestep, latents)
1296
- else:
1297
- timestep_list = [scheduler.timesteps[step_index - 1], timestep]
1298
- noise, prev_sample = second_order_update(scheduler.model_outputs, timestep_list, prev_timestep, latents)
1299
-
1300
- if scheduler.lower_order_nums < scheduler.config.solver_order:
1301
- scheduler.lower_order_nums += 1
1302
-
1303
- return noise, prev_sample
1304
-
1305
- def compute_noise(scheduler, *args):
1306
- if isinstance(scheduler, DDIMScheduler):
1307
- return compute_noise_ddim(scheduler, *args)
1308
- elif isinstance(scheduler, DPMSolverMultistepSchedulerInject) and scheduler.config.algorithm_type == 'sde-dpmsolver++'\
1309
- and scheduler.config.solver_order == 2:
1310
- return compute_noise_sde_dpm_pp_2nd(scheduler, *args)
1311
- else:
1312
- raise NotImplementedError