junnyu commited on
Commit
eecac3d
·
1 Parent(s): a4a6c30

Delete webui_stable_diffusion_controlnet.py

Browse files
Files changed (1) hide show
  1. webui_stable_diffusion_controlnet.py +0 -1837
webui_stable_diffusion_controlnet.py DELETED
@@ -1,1837 +0,0 @@
1
- # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
- # Copyright 2023 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
17
- # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
- from ppdiffusers.utils import check_min_version
19
- check_min_version("0.14.1")
20
-
21
- import inspect
22
- from typing import Any, Callable, Dict, List, Optional, Union
23
-
24
- import paddle
25
- import paddle.nn as nn
26
- import PIL
27
- import PIL.Image
28
-
29
- from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
30
- from ppdiffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
31
- from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
32
- from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
- from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
34
- StableDiffusionSafetyChecker,
35
- )
36
- from ppdiffusers.schedulers import KarrasDiffusionSchedulers
37
- from ppdiffusers.utils import (
38
- PIL_INTERPOLATION,
39
- logging,
40
- randn_tensor,
41
- safetensors_load,
42
- torch_load,
43
- )
44
-
45
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
-
47
-
48
- class WebUIStableDiffusionControlNetPipeline(DiffusionPipeline):
49
- r"""
50
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
51
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
52
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
53
- Args:
54
- vae ([`AutoencoderKL`]):
55
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
56
- text_encoder ([`CLIPTextModel`]):
57
- Frozen text-encoder. Stable Diffusion uses the text portion of
58
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
59
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
60
- tokenizer (`CLIPTokenizer`):
61
- Tokenizer of class
62
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
63
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
64
- controlnet ([`ControlNetModel`]):
65
- Provides additional conditioning to the unet during the denoising process.
66
- scheduler ([`SchedulerMixin`]):
67
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
68
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
69
- safety_checker ([`StableDiffusionSafetyChecker`]):
70
- Classification module that estimates whether generated images could be considered offensive or harmful.
71
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
72
- feature_extractor ([`CLIPFeatureExtractor`]):
73
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
74
- """
75
- _optional_components = ["safety_checker", "feature_extractor"]
76
- enable_emphasis = True
77
- comma_padding_backtrack = 20
78
-
79
- def __init__(
80
- self,
81
- vae: AutoencoderKL,
82
- text_encoder: CLIPTextModel,
83
- tokenizer: CLIPTokenizer,
84
- unet: UNet2DConditionModel,
85
- controlnet: ControlNetModel,
86
- scheduler: KarrasDiffusionSchedulers,
87
- safety_checker: StableDiffusionSafetyChecker,
88
- feature_extractor: CLIPFeatureExtractor,
89
- requires_safety_checker: bool = True,
90
- ):
91
- super().__init__()
92
-
93
- if safety_checker is None and requires_safety_checker:
94
- logger.warning(
95
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
96
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
97
- " results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
98
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
99
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
100
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
101
- )
102
-
103
- if safety_checker is not None and feature_extractor is None:
104
- raise ValueError(
105
- f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
106
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
107
- )
108
-
109
- self.register_modules(
110
- vae=vae,
111
- text_encoder=text_encoder,
112
- tokenizer=tokenizer,
113
- unet=unet,
114
- controlnet=controlnet,
115
- scheduler=scheduler,
116
- safety_checker=safety_checker,
117
- feature_extractor=feature_extractor,
118
- )
119
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120
- self.register_to_config(requires_safety_checker=requires_safety_checker)
121
-
122
- # custom data
123
- clip_model = FrozenCLIPEmbedder(text_encoder, tokenizer)
124
- self.sj = StableDiffusionModelHijack(clip_model)
125
- self.orginal_scheduler_config = self.scheduler.config
126
- self.supported_scheduler = [
127
- "pndm",
128
- "lms",
129
- "euler",
130
- "euler-ancestral",
131
- "dpm-multi",
132
- "dpm-single",
133
- "unipc-multi",
134
- "ddim",
135
- "ddpm",
136
- "deis-multi",
137
- "heun",
138
- "kdpm2-ancestral",
139
- "kdpm2",
140
- ]
141
-
142
- def add_ti_embedding_dir(self, embeddings_dir):
143
- self.sj.embedding_db.add_embedding_dir(embeddings_dir)
144
- self.sj.embedding_db.load_textual_inversion_embeddings()
145
-
146
- def clear_ti_embedding(self):
147
- self.sj.embedding_db.clear_embedding_dirs()
148
- self.sj.embedding_db.load_textual_inversion_embeddings(True)
149
-
150
- def switch_scheduler(self, scheduler_type="ddim"):
151
- scheduler_type = scheduler_type.lower()
152
- from ppdiffusers import (
153
- DDIMScheduler,
154
- DDPMScheduler,
155
- DEISMultistepScheduler,
156
- DPMSolverMultistepScheduler,
157
- DPMSolverSinglestepScheduler,
158
- EulerAncestralDiscreteScheduler,
159
- EulerDiscreteScheduler,
160
- HeunDiscreteScheduler,
161
- KDPM2AncestralDiscreteScheduler,
162
- KDPM2DiscreteScheduler,
163
- LMSDiscreteScheduler,
164
- PNDMScheduler,
165
- UniPCMultistepScheduler,
166
- )
167
-
168
- if scheduler_type == "pndm":
169
- scheduler = PNDMScheduler.from_config(self.orginal_scheduler_config, skip_prk_steps=True)
170
- elif scheduler_type == "lms":
171
- scheduler = LMSDiscreteScheduler.from_config(self.orginal_scheduler_config)
172
- elif scheduler_type == "heun":
173
- scheduler = HeunDiscreteScheduler.from_config(self.orginal_scheduler_config)
174
- elif scheduler_type == "euler":
175
- scheduler = EulerDiscreteScheduler.from_config(self.orginal_scheduler_config)
176
- elif scheduler_type == "euler-ancestral":
177
- scheduler = EulerAncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
178
- elif scheduler_type == "dpm-multi":
179
- scheduler = DPMSolverMultistepScheduler.from_config(self.orginal_scheduler_config)
180
- elif scheduler_type == "dpm-single":
181
- scheduler = DPMSolverSinglestepScheduler.from_config(self.orginal_scheduler_config)
182
- elif scheduler_type == "kdpm2-ancestral":
183
- scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
184
- elif scheduler_type == "kdpm2":
185
- scheduler = KDPM2DiscreteScheduler.from_config(self.orginal_scheduler_config)
186
- elif scheduler_type == "unipc-multi":
187
- scheduler = UniPCMultistepScheduler.from_config(self.orginal_scheduler_config)
188
- elif scheduler_type == "ddim":
189
- scheduler = DDIMScheduler.from_config(
190
- self.orginal_scheduler_config,
191
- steps_offset=1,
192
- clip_sample=False,
193
- set_alpha_to_one=False,
194
- )
195
- elif scheduler_type == "ddpm":
196
- scheduler = DDPMScheduler.from_config(
197
- self.orginal_scheduler_config,
198
- )
199
- elif scheduler_type == "deis-multi":
200
- scheduler = DEISMultistepScheduler.from_config(
201
- self.orginal_scheduler_config,
202
- )
203
- else:
204
- raise ValueError(
205
- f"Scheduler of type {scheduler_type} doesn't exist! Please choose in {self.supported_scheduler}!"
206
- )
207
- self.scheduler = scheduler
208
-
209
- @paddle.no_grad()
210
- def _encode_prompt(
211
- self,
212
- prompt: str,
213
- do_classifier_free_guidance: float = 7.5,
214
- negative_prompt: str = None,
215
- num_inference_steps: int = 50,
216
- ):
217
- if do_classifier_free_guidance:
218
- assert isinstance(negative_prompt, str)
219
- negative_prompt = [negative_prompt]
220
- uc = get_learned_conditioning(self.sj.clip, negative_prompt, num_inference_steps)
221
- else:
222
- uc = None
223
-
224
- c = get_multicond_learned_conditioning(self.sj.clip, prompt, num_inference_steps)
225
- return c, uc
226
-
227
- def run_safety_checker(self, image, dtype):
228
- if self.safety_checker is not None:
229
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
230
- image, has_nsfw_concept = self.safety_checker(
231
- images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
232
- )
233
- else:
234
- has_nsfw_concept = None
235
- return image, has_nsfw_concept
236
-
237
- def decode_latents(self, latents):
238
- latents = 1 / self.vae.config.scaling_factor * latents
239
- image = self.vae.decode(latents).sample
240
- image = (image / 2 + 0.5).clip(0, 1)
241
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
242
- image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
243
- return image
244
-
245
- def prepare_extra_step_kwargs(self, generator, eta):
246
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
247
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
248
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
249
- # and should be between [0, 1]
250
-
251
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
252
- extra_step_kwargs = {}
253
- if accepts_eta:
254
- extra_step_kwargs["eta"] = eta
255
-
256
- # check if the scheduler accepts generator
257
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
258
- if accepts_generator:
259
- extra_step_kwargs["generator"] = generator
260
- return extra_step_kwargs
261
-
262
- def check_inputs(
263
- self,
264
- prompt,
265
- image,
266
- height,
267
- width,
268
- callback_steps,
269
- negative_prompt=None,
270
- controlnet_conditioning_scale=1.0,
271
- ):
272
- if height % 8 != 0 or width % 8 != 0:
273
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
274
-
275
- if (callback_steps is None) or (
276
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
277
- ):
278
- raise ValueError(
279
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
280
- f" {type(callback_steps)}."
281
- )
282
-
283
- if prompt is not None and not isinstance(prompt, str):
284
- raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}")
285
-
286
- if negative_prompt is not None and not isinstance(negative_prompt, str):
287
- raise ValueError(f"`negative_prompt` has to be of type `str` but is {type(negative_prompt)}")
288
-
289
- # Check `image`
290
-
291
- if isinstance(self.controlnet, ControlNetModel):
292
- self.check_image(image, prompt)
293
- else:
294
- assert False
295
-
296
- # Check `controlnet_conditioning_scale`
297
- if isinstance(self.controlnet, ControlNetModel):
298
- if not isinstance(controlnet_conditioning_scale, (float, list, tuple)):
299
- raise TypeError(
300
- "For single controlnet: `controlnet_conditioning_scale` must be type `float, list(float) or tuple(float)`."
301
- )
302
-
303
- def check_image(self, image, prompt):
304
- image_is_pil = isinstance(image, PIL.Image.Image)
305
- image_is_tensor = isinstance(image, paddle.Tensor)
306
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
307
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], paddle.Tensor)
308
-
309
- if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
310
- raise TypeError(
311
- "image must be one of PIL image, paddle tensor, list of PIL images, or list of paddle tensors"
312
- )
313
-
314
- if image_is_pil:
315
- image_batch_size = 1
316
- elif image_is_tensor:
317
- image_batch_size = image.shape[0]
318
- elif image_is_pil_list:
319
- image_batch_size = len(image)
320
- elif image_is_tensor_list:
321
- image_batch_size = len(image)
322
-
323
- if prompt is not None and isinstance(prompt, str):
324
- prompt_batch_size = 1
325
- elif prompt is not None and isinstance(prompt, list):
326
- prompt_batch_size = len(prompt)
327
-
328
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
329
- raise ValueError(
330
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
331
- )
332
-
333
- def prepare_image(self, image, width, height, dtype):
334
- if not isinstance(image, paddle.Tensor):
335
- if isinstance(image, PIL.Image.Image):
336
- image = [image]
337
-
338
- if isinstance(image[0], PIL.Image.Image):
339
- images = []
340
- for image_ in image:
341
- image_ = image_.convert("RGB")
342
- image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
343
- image_ = np.array(image_)
344
- image_ = image_[None, :]
345
- images.append(image_)
346
-
347
- image = np.concatenate(images, axis=0)
348
- image = np.array(image).astype(np.float32) / 255.0
349
- image = image.transpose(0, 3, 1, 2)
350
- image = paddle.to_tensor(image)
351
- elif isinstance(image[0], paddle.Tensor):
352
- image = paddle.concat(image, axis=0)
353
-
354
- image = image.cast(dtype)
355
- return image
356
-
357
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
358
- shape = [batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor]
359
- if isinstance(generator, list) and len(generator) != batch_size:
360
- raise ValueError(
361
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
362
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
363
- )
364
-
365
- if latents is None:
366
- latents = randn_tensor(shape, generator=generator, dtype=dtype)
367
-
368
- # scale the initial noise by the standard deviation required by the scheduler
369
- latents = latents * self.scheduler.init_noise_sigma
370
- return latents
371
-
372
- def _default_height_width(self, height, width, image):
373
- while isinstance(image, list):
374
- image = image[0]
375
-
376
- if height is None:
377
- if isinstance(image, PIL.Image.Image):
378
- height = image.height
379
- elif isinstance(image, paddle.Tensor):
380
- height = image.shape[3]
381
-
382
- height = (height // 8) * 8 # round down to nearest multiple of 8
383
-
384
- if width is None:
385
- if isinstance(image, PIL.Image.Image):
386
- width = image.width
387
- elif isinstance(image, paddle.Tensor):
388
- width = image.shape[2]
389
-
390
- width = (width // 8) * 8 # round down to nearest multiple of 8
391
-
392
- return height, width
393
-
394
- @paddle.no_grad()
395
- def __call__(
396
- self,
397
- prompt: str = None,
398
- image: PIL.Image.Image = None,
399
- height: Optional[int] = None,
400
- width: Optional[int] = None,
401
- num_inference_steps: int = 50,
402
- guidance_scale: float = 7.5,
403
- negative_prompt: str = None,
404
- eta: float = 0.0,
405
- generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
406
- latents: Optional[paddle.Tensor] = None,
407
- output_type: Optional[str] = "pil",
408
- return_dict: bool = True,
409
- callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
410
- callback_steps: Optional[int] = 1,
411
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
412
- clip_skip: int = 0,
413
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
414
- ):
415
- r"""
416
- Function invoked when calling the pipeline for generation.
417
-
418
- Args:
419
- prompt (`str`, *optional*):
420
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
421
- instead.
422
- image (`paddle.Tensor`, `PIL.Image.Image`):
423
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
424
- the type is specified as `paddle.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
425
- also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
426
- height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
427
- specified in init, images must be passed as a list such that each element of the list can be correctly
428
- batched for input to a single controlnet.
429
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
430
- The height in pixels of the generated image.
431
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
432
- The width in pixels of the generated image.
433
- num_inference_steps (`int`, *optional*, defaults to 50):
434
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
435
- expense of slower inference.
436
- guidance_scale (`float`, *optional*, defaults to 7.5):
437
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
438
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
439
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
440
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
441
- usually at the expense of lower image quality.
442
- negative_prompt (`str`, *optional*):
443
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
444
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
445
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
446
- eta (`float`, *optional*, defaults to 0.0):
447
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
448
- [`schedulers.DDIMScheduler`], will be ignored for others.
449
- generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
450
- One or a list of paddle generator(s) to make generation deterministic.
451
- latents (`paddle.Tensor`, *optional*):
452
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
453
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
454
- tensor will ge generated by sampling using the supplied random `generator`.
455
- output_type (`str`, *optional*, defaults to `"pil"`):
456
- The output format of the generate image. Choose between
457
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
458
- return_dict (`bool`, *optional*, defaults to `True`):
459
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
460
- plain tuple.
461
- callback (`Callable`, *optional*):
462
- A function that will be called every `callback_steps` steps during inference. The function will be
463
- called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
464
- callback_steps (`int`, *optional*, defaults to 1):
465
- The frequency at which the `callback` function will be called. If not specified, the callback will be
466
- called at every step.
467
- cross_attention_kwargs (`dict`, *optional*):
468
- A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
469
- `self.processor` in
470
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
471
- clip_skip (`int`, *optional*, defaults to 0):
472
- CLIP_stop_at_last_layers, if clip_skip < 1, we will use the last_hidden_state from text_encoder.
473
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
474
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
475
- to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
476
- corresponding scale as a list.
477
- Examples:
478
-
479
- Returns:
480
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
481
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
482
- When returning a tuple, the first element is a list with the generated images, and the second element is a
483
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
484
- (nsfw) content, according to the `safety_checker`.
485
- """
486
- # 0. Default height and width to unet
487
- height, width = self._default_height_width(height, width, image)
488
-
489
- # 1. Check inputs. Raise error if not correct
490
- self.check_inputs(
491
- prompt,
492
- image,
493
- height,
494
- width,
495
- callback_steps,
496
- negative_prompt,
497
- controlnet_conditioning_scale,
498
- )
499
-
500
- batch_size = 1
501
-
502
- image = self.prepare_image(
503
- image=image,
504
- width=width,
505
- height=height,
506
- dtype=self.controlnet.dtype,
507
- )
508
-
509
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
510
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
511
- # corresponds to doing no classifier free guidance.
512
- do_classifier_free_guidance = guidance_scale > 1.0
513
-
514
- prompts, extra_network_data = parse_prompts([prompt])
515
-
516
- self.sj.clip.CLIP_stop_at_last_layers = clip_skip
517
- # 3. Encode input prompt
518
- prompt_embeds, negative_prompt_embeds = self._encode_prompt(
519
- prompts,
520
- do_classifier_free_guidance,
521
- negative_prompt,
522
- num_inference_steps=num_inference_steps,
523
- )
524
-
525
- # 4. Prepare timesteps
526
- self.scheduler.set_timesteps(num_inference_steps)
527
- timesteps = self.scheduler.timesteps
528
-
529
- # 5. Prepare latent variables
530
- num_channels_latents = self.unet.in_channels
531
- latents = self.prepare_latents(
532
- batch_size,
533
- num_channels_latents,
534
- height,
535
- width,
536
- self.unet.dtype,
537
- generator,
538
- latents,
539
- )
540
-
541
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
542
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
543
-
544
- # 7. Denoising loop
545
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
546
- with self.progress_bar(total=num_inference_steps) as progress_bar:
547
- for i, t in enumerate(timesteps):
548
- step = i // self.scheduler.order
549
- do_batch = False
550
- conds_list, cond_tensor = reconstruct_multicond_batch(prompt_embeds, step)
551
- try:
552
- weight = conds_list[0][0][1]
553
- except Exception:
554
- weight = 1.0
555
- if do_classifier_free_guidance:
556
- uncond_tensor = reconstruct_cond_batch(negative_prompt_embeds, step)
557
- do_batch = cond_tensor.shape[1] == uncond_tensor.shape[1]
558
-
559
- # expand the latents if we are doing classifier free guidance
560
- latent_model_input = paddle.concat([latents] * 2) if do_batch else latents
561
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
562
-
563
- if do_batch:
564
- encoder_hidden_states = paddle.concat([uncond_tensor, cond_tensor])
565
- down_block_res_samples, mid_block_res_sample = self.controlnet(
566
- latent_model_input,
567
- t,
568
- encoder_hidden_states=encoder_hidden_states,
569
- controlnet_cond=paddle.concat([image, image]),
570
- conditioning_scale=controlnet_conditioning_scale,
571
- return_dict=False,
572
- )
573
- noise_pred = self.unet(
574
- latent_model_input,
575
- t,
576
- encoder_hidden_states=encoder_hidden_states,
577
- cross_attention_kwargs=cross_attention_kwargs,
578
- down_block_additional_residuals=down_block_res_samples,
579
- mid_block_additional_residual=mid_block_res_sample,
580
- ).sample
581
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
582
- noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred_text - noise_pred_uncond)
583
- else:
584
- down_block_res_samples, mid_block_res_sample = self.controlnet(
585
- latent_model_input,
586
- t,
587
- encoder_hidden_states=cond_tensor,
588
- controlnet_cond=image,
589
- conditioning_scale=controlnet_conditioning_scale,
590
- return_dict=False,
591
- )
592
- noise_pred = self.unet(
593
- latent_model_input,
594
- t,
595
- encoder_hidden_states=cond_tensor,
596
- cross_attention_kwargs=cross_attention_kwargs,
597
- down_block_additional_residuals=down_block_res_samples,
598
- mid_block_additional_residual=mid_block_res_sample,
599
- ).sample
600
-
601
- if do_classifier_free_guidance:
602
- down_block_res_samples, mid_block_res_sample = self.controlnet(
603
- latent_model_input,
604
- t,
605
- encoder_hidden_states=uncond_tensor,
606
- controlnet_cond=image,
607
- conditioning_scale=controlnet_conditioning_scale,
608
- return_dict=False,
609
- )
610
- noise_pred_uncond = self.unet(
611
- latent_model_input,
612
- t,
613
- encoder_hidden_states=uncond_tensor,
614
- cross_attention_kwargs=cross_attention_kwargs,
615
- down_block_additional_residuals=down_block_res_samples,
616
- mid_block_additional_residual=mid_block_res_sample,
617
- ).sample
618
- noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred - noise_pred_uncond)
619
-
620
- # compute the previous noisy sample x_t -> x_t-1
621
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
622
-
623
- # call the callback, if provided
624
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
625
- progress_bar.update()
626
- if callback is not None and i % callback_steps == 0:
627
- callback(i, t, latents)
628
-
629
- if output_type == "latent":
630
- image = latents
631
- has_nsfw_concept = None
632
- elif output_type == "pil":
633
- # 8. Post-processing
634
- image = self.decode_latents(latents)
635
-
636
- # 9. Run safety checker
637
- image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
638
-
639
- # 10. Convert to PIL
640
- image = self.numpy_to_pil(image)
641
- else:
642
- # 8. Post-processing
643
- image = self.decode_latents(latents)
644
-
645
- # 9. Run safety checker
646
- image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
647
-
648
- if not return_dict:
649
- return (image, has_nsfw_concept)
650
-
651
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
652
-
653
-
654
- # clip.py
655
- import math
656
- from collections import namedtuple
657
-
658
-
659
- class PromptChunk:
660
- """
661
- This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
662
- If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
663
- Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
664
- so just 75 tokens from prompt.
665
- """
666
-
667
- def __init__(self):
668
- self.tokens = []
669
- self.multipliers = []
670
- self.fixes = []
671
-
672
-
673
- PromptChunkFix = namedtuple("PromptChunkFix", ["offset", "embedding"])
674
- """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
675
- chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
676
- are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
677
-
678
-
679
- class FrozenCLIPEmbedder(nn.Layer):
680
- """Uses the CLIP transformer encoder for text (from huggingface)"""
681
-
682
- LAYERS = ["last", "pooled", "hidden"]
683
-
684
- def __init__(self, text_encoder, tokenizer, freeze=True, layer="last", layer_idx=None):
685
- super().__init__()
686
- assert layer in self.LAYERS
687
- self.tokenizer = tokenizer
688
- self.text_encoder = text_encoder
689
- if freeze:
690
- self.freeze()
691
- self.layer = layer
692
- self.layer_idx = layer_idx
693
- if layer == "hidden":
694
- assert layer_idx is not None
695
- assert 0 <= abs(layer_idx) <= 12
696
-
697
- def freeze(self):
698
- self.text_encoder.eval()
699
- for param in self.parameters():
700
- param.stop_gradient = False
701
-
702
- def forward(self, text):
703
- batch_encoding = self.tokenizer(
704
- text,
705
- truncation=True,
706
- max_length=self.tokenizer.model_max_length,
707
- padding="max_length",
708
- return_tensors="pd",
709
- )
710
- tokens = batch_encoding["input_ids"]
711
- outputs = self.text_encoder(input_ids=tokens, output_hidden_states=self.layer == "hidden", return_dict=True)
712
- if self.layer == "last":
713
- z = outputs.last_hidden_state
714
- elif self.layer == "pooled":
715
- z = outputs.pooler_output[:, None, :]
716
- else:
717
- z = outputs.hidden_states[self.layer_idx]
718
- return z
719
-
720
- def encode(self, text):
721
- return self(text)
722
-
723
-
724
- class FrozenCLIPEmbedderWithCustomWordsBase(nn.Layer):
725
- """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
726
- have unlimited prompt length and assign weights to tokens in prompt.
727
- """
728
-
729
- def __init__(self, wrapped, hijack):
730
- super().__init__()
731
-
732
- self.wrapped = wrapped
733
- """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
734
- depending on model."""
735
-
736
- self.hijack = hijack
737
- self.chunk_length = 75
738
-
739
- def empty_chunk(self):
740
- """creates an empty PromptChunk and returns it"""
741
-
742
- chunk = PromptChunk()
743
- chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
744
- chunk.multipliers = [1.0] * (self.chunk_length + 2)
745
- return chunk
746
-
747
- def get_target_prompt_token_count(self, token_count):
748
- """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
749
-
750
- return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
751
-
752
- def tokenize(self, texts):
753
- """Converts a batch of texts into a batch of token ids"""
754
-
755
- raise NotImplementedError
756
-
757
- def encode_with_text_encoder(self, tokens):
758
- """
759
- converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
760
- All python lists with tokens are assumed to have same length, usually 77.
761
- if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
762
- model - can be 768 and 1024.
763
- Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
764
- """
765
-
766
- raise NotImplementedError
767
-
768
- def encode_embedding_init_text(self, init_text, nvpt):
769
- """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
770
- transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
771
-
772
- raise NotImplementedError
773
-
774
- def tokenize_line(self, line):
775
- """
776
- this transforms a single prompt into a list of PromptChunk objects - as many as needed to
777
- represent the prompt.
778
- Returns the list and the total number of tokens in the prompt.
779
- """
780
-
781
- if WebUIStableDiffusionControlNetPipeline.enable_emphasis:
782
- parsed = parse_prompt_attention(line)
783
- else:
784
- parsed = [[line, 1.0]]
785
-
786
- tokenized = self.tokenize([text for text, _ in parsed])
787
-
788
- chunks = []
789
- chunk = PromptChunk()
790
- token_count = 0
791
- last_comma = -1
792
-
793
- def next_chunk(is_last=False):
794
- """puts current chunk into the list of results and produces the next one - empty;
795
- if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
796
- nonlocal token_count
797
- nonlocal last_comma
798
- nonlocal chunk
799
-
800
- if is_last:
801
- token_count += len(chunk.tokens)
802
- else:
803
- token_count += self.chunk_length
804
-
805
- to_add = self.chunk_length - len(chunk.tokens)
806
- if to_add > 0:
807
- chunk.tokens += [self.id_end] * to_add
808
- chunk.multipliers += [1.0] * to_add
809
-
810
- chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
811
- chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
812
-
813
- last_comma = -1
814
- chunks.append(chunk)
815
- chunk = PromptChunk()
816
-
817
- for tokens, (text, weight) in zip(tokenized, parsed):
818
- if text == "BREAK" and weight == -1:
819
- next_chunk()
820
- continue
821
-
822
- position = 0
823
- while position < len(tokens):
824
- token = tokens[position]
825
-
826
- if token == self.comma_token:
827
- last_comma = len(chunk.tokens)
828
-
829
- # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
830
- # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
831
- elif (
832
- WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack != 0
833
- and len(chunk.tokens) == self.chunk_length
834
- and last_comma != -1
835
- and len(chunk.tokens) - last_comma
836
- <= WebUIStableDiffusionControlNetPipeline.comma_padding_backtrack
837
- ):
838
- break_location = last_comma + 1
839
-
840
- reloc_tokens = chunk.tokens[break_location:]
841
- reloc_mults = chunk.multipliers[break_location:]
842
-
843
- chunk.tokens = chunk.tokens[:break_location]
844
- chunk.multipliers = chunk.multipliers[:break_location]
845
-
846
- next_chunk()
847
- chunk.tokens = reloc_tokens
848
- chunk.multipliers = reloc_mults
849
-
850
- if len(chunk.tokens) == self.chunk_length:
851
- next_chunk()
852
-
853
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(
854
- tokens, position
855
- )
856
- if embedding is None:
857
- chunk.tokens.append(token)
858
- chunk.multipliers.append(weight)
859
- position += 1
860
- continue
861
-
862
- emb_len = int(embedding.vec.shape[0])
863
- if len(chunk.tokens) + emb_len > self.chunk_length:
864
- next_chunk()
865
-
866
- chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
867
-
868
- chunk.tokens += [0] * emb_len
869
- chunk.multipliers += [weight] * emb_len
870
- position += embedding_length_in_tokens
871
-
872
- if len(chunk.tokens) > 0 or len(chunks) == 0:
873
- next_chunk(is_last=True)
874
-
875
- return chunks, token_count
876
-
877
- def process_texts(self, texts):
878
- """
879
- Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
880
- length, in tokens, of all texts.
881
- """
882
-
883
- token_count = 0
884
-
885
- cache = {}
886
- batch_chunks = []
887
- for line in texts:
888
- if line in cache:
889
- chunks = cache[line]
890
- else:
891
- chunks, current_token_count = self.tokenize_line(line)
892
- token_count = max(current_token_count, token_count)
893
-
894
- cache[line] = chunks
895
-
896
- batch_chunks.append(chunks)
897
-
898
- return batch_chunks, token_count
899
-
900
- def forward(self, texts):
901
- """
902
- Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
903
- Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
904
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
905
- An example shape returned by this function can be: (2, 77, 768).
906
- Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
907
- is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
908
- """
909
-
910
- batch_chunks, token_count = self.process_texts(texts)
911
-
912
- used_embeddings = {}
913
- chunk_count = max([len(x) for x in batch_chunks])
914
-
915
- zs = []
916
- for i in range(chunk_count):
917
- batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
918
-
919
- tokens = [x.tokens for x in batch_chunk]
920
- multipliers = [x.multipliers for x in batch_chunk]
921
- self.hijack.fixes = [x.fixes for x in batch_chunk]
922
-
923
- for fixes in self.hijack.fixes:
924
- for position, embedding in fixes:
925
- used_embeddings[embedding.name] = embedding
926
-
927
- z = self.process_tokens(tokens, multipliers)
928
- zs.append(z)
929
-
930
- if len(used_embeddings) > 0:
931
- embeddings_list = ", ".join(
932
- [f"{name} [{embedding.checksum()}]" for name, embedding in used_embeddings.items()]
933
- )
934
- self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
935
-
936
- return paddle.concat(zs, axis=1)
937
-
938
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
939
- """
940
- sends one single prompt chunk to be encoded by transformers neural network.
941
- remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
942
- there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
943
- Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
944
- corresponds to one token.
945
- """
946
- tokens = paddle.to_tensor(remade_batch_tokens)
947
-
948
- # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
949
- if self.id_end != self.id_pad:
950
- for batch_pos in range(len(remade_batch_tokens)):
951
- index = remade_batch_tokens[batch_pos].index(self.id_end)
952
- tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
953
-
954
- z = self.encode_with_text_encoder(tokens)
955
-
956
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
957
- batch_multipliers = paddle.to_tensor(batch_multipliers)
958
- original_mean = z.mean()
959
- z = z * batch_multipliers.reshape(
960
- batch_multipliers.shape
961
- + [
962
- 1,
963
- ]
964
- ).expand(z.shape)
965
- new_mean = z.mean()
966
- z = z * (original_mean / new_mean)
967
-
968
- return z
969
-
970
-
971
- class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
972
- def __init__(self, wrapped, hijack, CLIP_stop_at_last_layers=-1):
973
- super().__init__(wrapped, hijack)
974
- self.CLIP_stop_at_last_layers = CLIP_stop_at_last_layers
975
- self.tokenizer = wrapped.tokenizer
976
-
977
- vocab = self.tokenizer.get_vocab()
978
-
979
- self.comma_token = vocab.get(",</w>", None)
980
-
981
- self.token_mults = {}
982
- tokens_with_parens = [(k, v) for k, v in vocab.items() if "(" in k or ")" in k or "[" in k or "]" in k]
983
- for text, ident in tokens_with_parens:
984
- mult = 1.0
985
- for c in text:
986
- if c == "[":
987
- mult /= 1.1
988
- if c == "]":
989
- mult *= 1.1
990
- if c == "(":
991
- mult *= 1.1
992
- if c == ")":
993
- mult /= 1.1
994
-
995
- if mult != 1.0:
996
- self.token_mults[ident] = mult
997
-
998
- self.id_start = self.wrapped.tokenizer.bos_token_id
999
- self.id_end = self.wrapped.tokenizer.eos_token_id
1000
- self.id_pad = self.id_end
1001
-
1002
- def tokenize(self, texts):
1003
- tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
1004
-
1005
- return tokenized
1006
-
1007
- def encode_with_text_encoder(self, tokens):
1008
- output_hidden_states = self.CLIP_stop_at_last_layers > 1
1009
- outputs = self.wrapped.text_encoder(
1010
- input_ids=tokens, output_hidden_states=output_hidden_states, return_dict=True
1011
- )
1012
-
1013
- if output_hidden_states:
1014
- z = outputs.hidden_states[-self.CLIP_stop_at_last_layers]
1015
- z = self.wrapped.text_encoder.text_model.ln_final(z)
1016
- else:
1017
- z = outputs.last_hidden_state
1018
-
1019
- return z
1020
-
1021
- def encode_embedding_init_text(self, init_text, nvpt):
1022
- embedding_layer = self.wrapped.text_encoder.text_model
1023
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pd", add_special_tokens=False)[
1024
- "input_ids"
1025
- ]
1026
- embedded = embedding_layer.token_embedding.wrapped(ids).squeeze(0)
1027
-
1028
- return embedded
1029
-
1030
-
1031
- # extra_networks.py
1032
- import re
1033
- from collections import defaultdict
1034
-
1035
-
1036
- class ExtraNetworkParams:
1037
- def __init__(self, items=None):
1038
- self.items = items or []
1039
-
1040
-
1041
- re_extra_net = re.compile(r"<(\w+):([^>]+)>")
1042
-
1043
-
1044
- def parse_prompt(prompt):
1045
- res = defaultdict(list)
1046
-
1047
- def found(m):
1048
- name = m.group(1)
1049
- args = m.group(2)
1050
-
1051
- res[name].append(ExtraNetworkParams(items=args.split(":")))
1052
-
1053
- return ""
1054
-
1055
- prompt = re.sub(re_extra_net, found, prompt)
1056
-
1057
- return prompt, res
1058
-
1059
-
1060
- def parse_prompts(prompts):
1061
- res = []
1062
- extra_data = None
1063
-
1064
- for prompt in prompts:
1065
- updated_prompt, parsed_extra_data = parse_prompt(prompt)
1066
-
1067
- if extra_data is None:
1068
- extra_data = parsed_extra_data
1069
-
1070
- res.append(updated_prompt)
1071
-
1072
- return res, extra_data
1073
-
1074
-
1075
- # image_embeddings.py
1076
-
1077
- import base64
1078
- import json
1079
- import zlib
1080
-
1081
- import numpy as np
1082
- from PIL import Image
1083
-
1084
-
1085
- class EmbeddingDecoder(json.JSONDecoder):
1086
- def __init__(self, *args, **kwargs):
1087
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
1088
-
1089
- def object_hook(self, d):
1090
- if "TORCHTENSOR" in d:
1091
- return paddle.to_tensor(np.array(d["TORCHTENSOR"]))
1092
- return d
1093
-
1094
-
1095
- def embedding_from_b64(data):
1096
- d = base64.b64decode(data)
1097
- return json.loads(d, cls=EmbeddingDecoder)
1098
-
1099
-
1100
- def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
1101
- while True:
1102
- seed = (a * seed + c) % m
1103
- yield seed % 255
1104
-
1105
-
1106
- def xor_block(block):
1107
- g = lcg()
1108
- randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
1109
- return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
1110
-
1111
-
1112
- def crop_black(img, tol=0):
1113
- mask = (img > tol).all(2)
1114
- mask0, mask1 = mask.any(0), mask.any(1)
1115
- col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
1116
- row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
1117
- return img[row_start:row_end, col_start:col_end]
1118
-
1119
-
1120
- def extract_image_data_embed(image):
1121
- d = 3
1122
- outarr = (
1123
- crop_black(np.array(image.convert("RGB").getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8))
1124
- & 0x0F
1125
- )
1126
- black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
1127
- if black_cols[0].shape[0] < 2:
1128
- print("No Image data blocks found.")
1129
- return None
1130
-
1131
- data_block_lower = outarr[:, : black_cols[0].min(), :].astype(np.uint8)
1132
- data_block_upper = outarr[:, black_cols[0].max() + 1 :, :].astype(np.uint8)
1133
-
1134
- data_block_lower = xor_block(data_block_lower)
1135
- data_block_upper = xor_block(data_block_upper)
1136
-
1137
- data_block = (data_block_upper << 4) | (data_block_lower)
1138
- data_block = data_block.flatten().tobytes()
1139
-
1140
- data = zlib.decompress(data_block)
1141
- return json.loads(data, cls=EmbeddingDecoder)
1142
-
1143
-
1144
- # prompt_parser.py
1145
- import re
1146
- from collections import namedtuple
1147
- from typing import List
1148
-
1149
- import lark
1150
-
1151
- # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
1152
- # will be represented with prompt_schedule like this (assuming steps=100):
1153
- # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
1154
- # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
1155
- # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
1156
- # [75, 'fantasy landscape with a lake and an oak in background masterful']
1157
- # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
1158
-
1159
- schedule_parser = lark.Lark(
1160
- r"""
1161
- !start: (prompt | /[][():]/+)*
1162
- prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
1163
- !emphasized: "(" prompt ")"
1164
- | "(" prompt ":" prompt ")"
1165
- | "[" prompt "]"
1166
- scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
1167
- alternate: "[" prompt ("|" prompt)+ "]"
1168
- WHITESPACE: /\s+/
1169
- plain: /([^\\\[\]():|]|\\.)+/
1170
- %import common.SIGNED_NUMBER -> NUMBER
1171
- """
1172
- )
1173
-
1174
-
1175
- def get_learned_conditioning_prompt_schedules(prompts, steps):
1176
- """
1177
- >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
1178
- >>> g("test")
1179
- [[10, 'test']]
1180
- >>> g("a [b:3]")
1181
- [[3, 'a '], [10, 'a b']]
1182
- >>> g("a [b: 3]")
1183
- [[3, 'a '], [10, 'a b']]
1184
- >>> g("a [[[b]]:2]")
1185
- [[2, 'a '], [10, 'a [[b]]']]
1186
- >>> g("[(a:2):3]")
1187
- [[3, ''], [10, '(a:2)']]
1188
- >>> g("a [b : c : 1] d")
1189
- [[1, 'a b d'], [10, 'a c d']]
1190
- >>> g("a[b:[c:d:2]:1]e")
1191
- [[1, 'abe'], [2, 'ace'], [10, 'ade']]
1192
- >>> g("a [unbalanced")
1193
- [[10, 'a [unbalanced']]
1194
- >>> g("a [b:.5] c")
1195
- [[5, 'a c'], [10, 'a b c']]
1196
- >>> g("a [{b|d{:.5] c") # not handling this right now
1197
- [[5, 'a c'], [10, 'a {b|d{ c']]
1198
- >>> g("((a][:b:c [d:3]")
1199
- [[3, '((a][:b:c '], [10, '((a][:b:c d']]
1200
- >>> g("[a|(b:1.1)]")
1201
- [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
1202
- """
1203
-
1204
- def collect_steps(steps, tree):
1205
- l = [steps]
1206
-
1207
- class CollectSteps(lark.Visitor):
1208
- def scheduled(self, tree):
1209
- tree.children[-1] = float(tree.children[-1])
1210
- if tree.children[-1] < 1:
1211
- tree.children[-1] *= steps
1212
- tree.children[-1] = min(steps, int(tree.children[-1]))
1213
- l.append(tree.children[-1])
1214
-
1215
- def alternate(self, tree):
1216
- l.extend(range(1, steps + 1))
1217
-
1218
- CollectSteps().visit(tree)
1219
- return sorted(set(l))
1220
-
1221
- def at_step(step, tree):
1222
- class AtStep(lark.Transformer):
1223
- def scheduled(self, args):
1224
- before, after, _, when = args
1225
- yield before or () if step <= when else after
1226
-
1227
- def alternate(self, args):
1228
- yield next(args[(step - 1) % len(args)])
1229
-
1230
- def start(self, args):
1231
- def flatten(x):
1232
- if type(x) == str:
1233
- yield x
1234
- else:
1235
- for gen in x:
1236
- yield from flatten(gen)
1237
-
1238
- return "".join(flatten(args))
1239
-
1240
- def plain(self, args):
1241
- yield args[0].value
1242
-
1243
- def __default__(self, data, children, meta):
1244
- for child in children:
1245
- yield child
1246
-
1247
- return AtStep().transform(tree)
1248
-
1249
- def get_schedule(prompt):
1250
- try:
1251
- tree = schedule_parser.parse(prompt)
1252
- except lark.exceptions.LarkError:
1253
- if 0:
1254
- import traceback
1255
-
1256
- traceback.print_exc()
1257
- return [[steps, prompt]]
1258
- return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
1259
-
1260
- promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
1261
- return [promptdict[prompt] for prompt in prompts]
1262
-
1263
-
1264
- ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
1265
-
1266
-
1267
- def get_learned_conditioning(model, prompts, steps):
1268
- """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
1269
- and the sampling step at which this condition is to be replaced by the next one.
1270
-
1271
- Input:
1272
- (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
1273
-
1274
- Output:
1275
- [
1276
- [
1277
- ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
1278
- ],
1279
- [
1280
- ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
1281
- ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
1282
- ]
1283
- ]
1284
- """
1285
- res = []
1286
-
1287
- prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
1288
- cache = {}
1289
-
1290
- for prompt, prompt_schedule in zip(prompts, prompt_schedules):
1291
-
1292
- cached = cache.get(prompt, None)
1293
- if cached is not None:
1294
- res.append(cached)
1295
- continue
1296
-
1297
- texts = [x[1] for x in prompt_schedule]
1298
- conds = model(texts)
1299
-
1300
- cond_schedule = []
1301
- for i, (end_at_step, text) in enumerate(prompt_schedule):
1302
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
1303
-
1304
- cache[prompt] = cond_schedule
1305
- res.append(cond_schedule)
1306
-
1307
- return res
1308
-
1309
-
1310
- re_AND = re.compile(r"\bAND\b")
1311
- re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
1312
-
1313
-
1314
- def get_multicond_prompt_list(prompts):
1315
- res_indexes = []
1316
-
1317
- prompt_flat_list = []
1318
- prompt_indexes = {}
1319
-
1320
- for prompt in prompts:
1321
- subprompts = re_AND.split(prompt)
1322
-
1323
- indexes = []
1324
- for subprompt in subprompts:
1325
- match = re_weight.search(subprompt)
1326
-
1327
- text, weight = match.groups() if match is not None else (subprompt, 1.0)
1328
-
1329
- weight = float(weight) if weight is not None else 1.0
1330
-
1331
- index = prompt_indexes.get(text, None)
1332
- if index is None:
1333
- index = len(prompt_flat_list)
1334
- prompt_flat_list.append(text)
1335
- prompt_indexes[text] = index
1336
-
1337
- indexes.append((index, weight))
1338
-
1339
- res_indexes.append(indexes)
1340
-
1341
- return res_indexes, prompt_flat_list, prompt_indexes
1342
-
1343
-
1344
- class ComposableScheduledPromptConditioning:
1345
- def __init__(self, schedules, weight=1.0):
1346
- self.schedules: List[ScheduledPromptConditioning] = schedules
1347
- self.weight: float = weight
1348
-
1349
-
1350
- class MulticondLearnedConditioning:
1351
- def __init__(self, shape, batch):
1352
- self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
1353
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
1354
-
1355
-
1356
- def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
1357
- """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
1358
- For each prompt, the list is obtained by splitting the prompt using the AND separator.
1359
-
1360
- https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
1361
- """
1362
-
1363
- res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
1364
-
1365
- learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
1366
-
1367
- res = []
1368
- for indexes in res_indexes:
1369
- res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
1370
-
1371
- return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
1372
-
1373
-
1374
- def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
1375
- param = c[0][0].cond
1376
- res = paddle.zeros(
1377
- [
1378
- len(c),
1379
- ]
1380
- + param.shape,
1381
- dtype=param.dtype,
1382
- )
1383
- for i, cond_schedule in enumerate(c):
1384
- target_index = 0
1385
- for current, (end_at, cond) in enumerate(cond_schedule):
1386
- if current_step <= end_at:
1387
- target_index = current
1388
- break
1389
- res[i] = cond_schedule[target_index].cond
1390
-
1391
- return res
1392
-
1393
-
1394
- def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
1395
- param = c.batch[0][0].schedules[0].cond
1396
-
1397
- tensors = []
1398
- conds_list = []
1399
-
1400
- for batch_no, composable_prompts in enumerate(c.batch):
1401
- conds_for_batch = []
1402
-
1403
- for cond_index, composable_prompt in enumerate(composable_prompts):
1404
- target_index = 0
1405
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
1406
- if current_step <= end_at:
1407
- target_index = current
1408
- break
1409
-
1410
- conds_for_batch.append((len(tensors), composable_prompt.weight))
1411
- tensors.append(composable_prompt.schedules[target_index].cond)
1412
-
1413
- conds_list.append(conds_for_batch)
1414
-
1415
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
1416
- # and won't be able to torch.stack them. So this fixes that.
1417
- token_count = max([x.shape[0] for x in tensors])
1418
- for i in range(len(tensors)):
1419
- if tensors[i].shape[0] != token_count:
1420
- last_vector = tensors[i][-1:]
1421
- last_vector_repeated = last_vector.tile([token_count - tensors[i].shape[0], 1])
1422
- tensors[i] = paddle.concat([tensors[i], last_vector_repeated], axis=0)
1423
-
1424
- return conds_list, paddle.stack(tensors).cast(dtype=param.dtype)
1425
-
1426
-
1427
- re_attention = re.compile(
1428
- r"""
1429
- \\\(|
1430
- \\\)|
1431
- \\\[|
1432
- \\]|
1433
- \\\\|
1434
- \\|
1435
- \(|
1436
- \[|
1437
- :([+-]?[.\d]+)\)|
1438
- \)|
1439
- ]|
1440
- [^\\()\[\]:]+|
1441
- :
1442
- """,
1443
- re.X,
1444
- )
1445
-
1446
- re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
1447
-
1448
-
1449
- def parse_prompt_attention(text):
1450
- """
1451
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
1452
- Accepted tokens are:
1453
- (abc) - increases attention to abc by a multiplier of 1.1
1454
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
1455
- [abc] - decreases attention to abc by a multiplier of 1.1
1456
- \( - literal character '('
1457
- \[ - literal character '['
1458
- \) - literal character ')'
1459
- \] - literal character ']'
1460
- \\ - literal character '\'
1461
- anything else - just text
1462
-
1463
- >>> parse_prompt_attention('normal text')
1464
- [['normal text', 1.0]]
1465
- >>> parse_prompt_attention('an (important) word')
1466
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
1467
- >>> parse_prompt_attention('(unbalanced')
1468
- [['unbalanced', 1.1]]
1469
- >>> parse_prompt_attention('\(literal\]')
1470
- [['(literal]', 1.0]]
1471
- >>> parse_prompt_attention('(unnecessary)(parens)')
1472
- [['unnecessaryparens', 1.1]]
1473
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
1474
- [['a ', 1.0],
1475
- ['house', 1.5730000000000004],
1476
- [' ', 1.1],
1477
- ['on', 1.0],
1478
- [' a ', 1.1],
1479
- ['hill', 0.55],
1480
- [', sun, ', 1.1],
1481
- ['sky', 1.4641000000000006],
1482
- ['.', 1.1]]
1483
- """
1484
-
1485
- res = []
1486
- round_brackets = []
1487
- square_brackets = []
1488
-
1489
- round_bracket_multiplier = 1.1
1490
- square_bracket_multiplier = 1 / 1.1
1491
-
1492
- def multiply_range(start_position, multiplier):
1493
- for p in range(start_position, len(res)):
1494
- res[p][1] *= multiplier
1495
-
1496
- for m in re_attention.finditer(text):
1497
- text = m.group(0)
1498
- weight = m.group(1)
1499
-
1500
- if text.startswith("\\"):
1501
- res.append([text[1:], 1.0])
1502
- elif text == "(":
1503
- round_brackets.append(len(res))
1504
- elif text == "[":
1505
- square_brackets.append(len(res))
1506
- elif weight is not None and len(round_brackets) > 0:
1507
- multiply_range(round_brackets.pop(), float(weight))
1508
- elif text == ")" and len(round_brackets) > 0:
1509
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
1510
- elif text == "]" and len(square_brackets) > 0:
1511
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
1512
- else:
1513
- parts = re.split(re_break, text)
1514
- for i, part in enumerate(parts):
1515
- if i > 0:
1516
- res.append(["BREAK", -1])
1517
- res.append([part, 1.0])
1518
-
1519
- for pos in round_brackets:
1520
- multiply_range(pos, round_bracket_multiplier)
1521
-
1522
- for pos in square_brackets:
1523
- multiply_range(pos, square_bracket_multiplier)
1524
-
1525
- if len(res) == 0:
1526
- res = [["", 1.0]]
1527
-
1528
- # merge runs of identical weights
1529
- i = 0
1530
- while i + 1 < len(res):
1531
- if res[i][1] == res[i + 1][1]:
1532
- res[i][0] += res[i + 1][0]
1533
- res.pop(i + 1)
1534
- else:
1535
- i += 1
1536
-
1537
- return res
1538
-
1539
-
1540
- # sd_hijack.py
1541
-
1542
-
1543
- class StableDiffusionModelHijack:
1544
- fixes = None
1545
- comments = []
1546
- layers = None
1547
- circular_enabled = False
1548
-
1549
- def __init__(self, clip_model, embeddings_dir=None, CLIP_stop_at_last_layers=-1):
1550
- model_embeddings = clip_model.text_encoder.text_model
1551
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
1552
- clip_model = FrozenCLIPEmbedderWithCustomWords(
1553
- clip_model, self, CLIP_stop_at_last_layers=CLIP_stop_at_last_layers
1554
- )
1555
-
1556
- self.embedding_db = EmbeddingDatabase(clip_model)
1557
- self.embedding_db.add_embedding_dir(embeddings_dir)
1558
-
1559
- # hack this!
1560
- self.clip = clip_model
1561
-
1562
- def flatten(el):
1563
- flattened = [flatten(children) for children in el.children()]
1564
- res = [el]
1565
- for c in flattened:
1566
- res += c
1567
- return res
1568
-
1569
- self.layers = flatten(clip_model)
1570
-
1571
- def clear_comments(self):
1572
- self.comments = []
1573
-
1574
- def get_prompt_lengths(self, text):
1575
- _, token_count = self.clip.process_texts([text])
1576
-
1577
- return token_count, self.clip.get_target_prompt_token_count(token_count)
1578
-
1579
-
1580
- class EmbeddingsWithFixes(nn.Layer):
1581
- def __init__(self, wrapped, embeddings):
1582
- super().__init__()
1583
- self.wrapped = wrapped
1584
- self.embeddings = embeddings
1585
-
1586
- def forward(self, input_ids):
1587
- batch_fixes = self.embeddings.fixes
1588
- self.embeddings.fixes = None
1589
-
1590
- inputs_embeds = self.wrapped(input_ids)
1591
-
1592
- if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
1593
- return inputs_embeds
1594
-
1595
- vecs = []
1596
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
1597
- for offset, embedding in fixes:
1598
- emb = embedding.vec.cast(self.wrapped.dtype)
1599
- emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
1600
- tensor = paddle.concat([tensor[0 : offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len :]])
1601
-
1602
- vecs.append(tensor)
1603
-
1604
- return paddle.stack(vecs)
1605
-
1606
-
1607
- # textual_inversion.py
1608
-
1609
- import os
1610
- import sys
1611
- import traceback
1612
-
1613
-
1614
- class Embedding:
1615
- def __init__(self, vec, name, step=None):
1616
- self.vec = vec
1617
- self.name = name
1618
- self.step = step
1619
- self.shape = None
1620
- self.vectors = 0
1621
- self.cached_checksum = None
1622
- self.sd_checkpoint = None
1623
- self.sd_checkpoint_name = None
1624
- self.optimizer_state_dict = None
1625
- self.filename = None
1626
-
1627
- def save(self, filename):
1628
- embedding_data = {
1629
- "string_to_token": {"*": 265},
1630
- "string_to_param": {"*": self.vec},
1631
- "name": self.name,
1632
- "step": self.step,
1633
- "sd_checkpoint": self.sd_checkpoint,
1634
- "sd_checkpoint_name": self.sd_checkpoint_name,
1635
- }
1636
-
1637
- paddle.save(embedding_data, filename)
1638
-
1639
- def checksum(self):
1640
- if self.cached_checksum is not None:
1641
- return self.cached_checksum
1642
-
1643
- def const_hash(a):
1644
- r = 0
1645
- for v in a:
1646
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
1647
- return r
1648
-
1649
- self.cached_checksum = f"{const_hash(self.vec.flatten() * 100) & 0xffff:04x}"
1650
- return self.cached_checksum
1651
-
1652
-
1653
- class DirWithTextualInversionEmbeddings:
1654
- def __init__(self, path):
1655
- self.path = path
1656
- self.mtime = None
1657
-
1658
- def has_changed(self):
1659
- if not os.path.isdir(self.path):
1660
- return False
1661
-
1662
- mt = os.path.getmtime(self.path)
1663
- if self.mtime is None or mt > self.mtime:
1664
- return True
1665
-
1666
- def update(self):
1667
- if not os.path.isdir(self.path):
1668
- return
1669
-
1670
- self.mtime = os.path.getmtime(self.path)
1671
-
1672
-
1673
- class EmbeddingDatabase:
1674
- def __init__(self, clip):
1675
- self.clip = clip
1676
- self.ids_lookup = {}
1677
- self.word_embeddings = {}
1678
- self.skipped_embeddings = {}
1679
- self.expected_shape = -1
1680
- self.embedding_dirs = {}
1681
- self.previously_displayed_embeddings = ()
1682
-
1683
- def add_embedding_dir(self, path):
1684
- if path is not None:
1685
- self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1686
-
1687
- def clear_embedding_dirs(self):
1688
- self.embedding_dirs.clear()
1689
-
1690
- def register_embedding(self, embedding, model):
1691
- self.word_embeddings[embedding.name] = embedding
1692
-
1693
- ids = model.tokenize([embedding.name])[0]
1694
-
1695
- first_id = ids[0]
1696
- if first_id not in self.ids_lookup:
1697
- self.ids_lookup[first_id] = []
1698
-
1699
- self.ids_lookup[first_id] = sorted(
1700
- self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True
1701
- )
1702
-
1703
- return embedding
1704
-
1705
- def get_expected_shape(self):
1706
- vec = self.clip.encode_embedding_init_text(",", 1)
1707
- return vec.shape[1]
1708
-
1709
- def load_from_file(self, path, filename):
1710
- name, ext = os.path.splitext(filename)
1711
- ext = ext.upper()
1712
-
1713
- if ext in [".PNG", ".WEBP", ".JXL", ".AVIF"]:
1714
- _, second_ext = os.path.splitext(name)
1715
- if second_ext.upper() == ".PREVIEW":
1716
- return
1717
-
1718
- embed_image = Image.open(path)
1719
- if hasattr(embed_image, "text") and "sd-ti-embedding" in embed_image.text:
1720
- data = embedding_from_b64(embed_image.text["sd-ti-embedding"])
1721
- name = data.get("name", name)
1722
- else:
1723
- data = extract_image_data_embed(embed_image)
1724
- if data:
1725
- name = data.get("name", name)
1726
- else:
1727
- # if data is None, means this is not an embeding, just a preview image
1728
- return
1729
- elif ext in [".BIN", ".PT"]:
1730
- data = torch_load(path)
1731
- elif ext in [".SAFETENSORS"]:
1732
- data = safetensors_load(path)
1733
- else:
1734
- return
1735
-
1736
- # textual inversion embeddings
1737
- if "string_to_param" in data:
1738
- param_dict = data["string_to_param"]
1739
- if hasattr(param_dict, "_parameters"):
1740
- param_dict = getattr(param_dict, "_parameters")
1741
- assert len(param_dict) == 1, "embedding file has multiple terms in it"
1742
- emb = next(iter(param_dict.items()))[1]
1743
- # diffuser concepts
1744
- elif type(data) == dict and type(next(iter(data.values()))) == paddle.Tensor:
1745
- assert len(data.keys()) == 1, "embedding file has multiple terms in it"
1746
-
1747
- emb = next(iter(data.values()))
1748
- if len(emb.shape) == 1:
1749
- emb = emb.unsqueeze(0)
1750
- else:
1751
- raise Exception(
1752
- f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept."
1753
- )
1754
-
1755
- with paddle.no_grad():
1756
- if hasattr(emb, "detach"):
1757
- emb = emb.detach()
1758
- if hasattr(emb, "cpu"):
1759
- emb = emb.cpu()
1760
- if hasattr(emb, "numpy"):
1761
- emb = emb.numpy()
1762
- emb = paddle.to_tensor(emb)
1763
- vec = emb.detach().cast(paddle.float32)
1764
- embedding = Embedding(vec, name)
1765
- embedding.step = data.get("step", None)
1766
- embedding.sd_checkpoint = data.get("sd_checkpoint", None)
1767
- embedding.sd_checkpoint_name = data.get("sd_checkpoint_name", None)
1768
- embedding.vectors = vec.shape[0]
1769
- embedding.shape = vec.shape[-1]
1770
- embedding.filename = path
1771
-
1772
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
1773
- self.register_embedding(embedding, self.clip)
1774
- else:
1775
- self.skipped_embeddings[name] = embedding
1776
-
1777
- def load_from_dir(self, embdir):
1778
- if not os.path.isdir(embdir.path):
1779
- return
1780
-
1781
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
1782
- for fn in fns:
1783
- try:
1784
- fullfn = os.path.join(root, fn)
1785
-
1786
- if os.stat(fullfn).st_size == 0:
1787
- continue
1788
-
1789
- self.load_from_file(fullfn, fn)
1790
- except Exception:
1791
- print(f"Error loading embedding {fn}:", file=sys.stderr)
1792
- print(traceback.format_exc(), file=sys.stderr)
1793
- continue
1794
-
1795
- def load_textual_inversion_embeddings(self, force_reload=False):
1796
- if not force_reload:
1797
- need_reload = False
1798
- for path, embdir in self.embedding_dirs.items():
1799
- if embdir.has_changed():
1800
- need_reload = True
1801
- break
1802
-
1803
- if not need_reload:
1804
- return
1805
-
1806
- self.ids_lookup.clear()
1807
- self.word_embeddings.clear()
1808
- self.skipped_embeddings.clear()
1809
- self.expected_shape = self.get_expected_shape()
1810
-
1811
- for path, embdir in self.embedding_dirs.items():
1812
- self.load_from_dir(embdir)
1813
- embdir.update()
1814
-
1815
- displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
1816
- if self.previously_displayed_embeddings != displayed_embeddings:
1817
- self.previously_displayed_embeddings = displayed_embeddings
1818
- print(
1819
- f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}"
1820
- )
1821
- if len(self.skipped_embeddings) > 0:
1822
- print(
1823
- f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}"
1824
- )
1825
-
1826
- def find_embedding_at_position(self, tokens, offset):
1827
- token = tokens[offset]
1828
- possible_matches = self.ids_lookup.get(token, None)
1829
-
1830
- if possible_matches is None:
1831
- return None, None
1832
-
1833
- for ids, embedding in possible_matches:
1834
- if tokens[offset : offset + len(ids)] == ids:
1835
- return embedding, len(ids)
1836
-
1837
- return None, None