junnyu commited on
Commit
9dc30c0
·
1 Parent(s): 41d8155

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1801 -0
pipeline.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # modified from https://github.com/AUTOMATIC1111/stable-diffusion-webui
17
+ # Here is the AGPL-3.0 license https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/LICENSE.txt
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union
21
+
22
+ import paddle
23
+ import paddle.nn as nn
24
+
25
+ from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
26
+ from ppdiffusers.models import AutoencoderKL, UNet2DConditionModel
27
+ from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
29
+ from ppdiffusers.pipelines.stable_diffusion.safety_checker import (
30
+ StableDiffusionSafetyChecker,
31
+ )
32
+ from ppdiffusers.schedulers import KarrasDiffusionSchedulers
33
+ from ppdiffusers.utils import logging, randn_tensor, safetensors_load, torch_load, smart_load
34
+
35
+ from pathlib import Path
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ @paddle.no_grad()
40
+ def load_lora(pipeline,
41
+ state_dict: dict,
42
+ LORA_PREFIX_UNET: str = "lora_unet",
43
+ LORA_PREFIX_TEXT_ENCODER: str = "lora_te",
44
+ ratio: float = 1.0):
45
+ ratio = float(ratio)
46
+ visited = []
47
+ for key in state_dict:
48
+ if ".alpha" in key or ".lora_up" in key or key in visited:
49
+ continue
50
+
51
+ if "text" in key:
52
+ tmp_layer_infos = key.split(".")[0].split(
53
+ LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
54
+ hf_to_ppnlp = {
55
+ "encoder": "transformer",
56
+ "fc1": "linear1",
57
+ "fc2": "linear2",
58
+ }
59
+ layer_infos = []
60
+ for layer_info in tmp_layer_infos:
61
+ if layer_info == "mlp": continue
62
+ layer_infos.append(hf_to_ppnlp.get(layer_info, layer_info))
63
+ curr_layer: paddle.nn.Linear = pipeline.text_encoder
64
+ else:
65
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET +
66
+ "_")[-1].split("_")
67
+ curr_layer: paddle.nn.Linear = pipeline.unet
68
+
69
+ temp_name = layer_infos.pop(0)
70
+ while len(layer_infos) > -1:
71
+ try:
72
+ if temp_name == "to":
73
+ raise ValueError()
74
+ curr_layer = curr_layer.__getattr__(temp_name)
75
+ if len(layer_infos) > 0:
76
+ temp_name = layer_infos.pop(0)
77
+ elif len(layer_infos) == 0:
78
+ break
79
+ except Exception:
80
+ if len(temp_name) > 0:
81
+ temp_name += "_" + layer_infos.pop(0)
82
+ else:
83
+ temp_name = layer_infos.pop(0)
84
+
85
+ triplet_keys = [
86
+ key,
87
+ key.replace("lora_down", "lora_up"),
88
+ key.replace("lora_down.weight", "alpha")
89
+ ]
90
+ dtype: paddle.dtype = curr_layer.weight.dtype
91
+ weight_down: paddle.Tensor = state_dict[triplet_keys[0]].cast(
92
+ dtype)
93
+ weight_up: paddle.Tensor = state_dict[triplet_keys[1]].cast(dtype)
94
+ rank: float = float(weight_down.shape[0])
95
+ if triplet_keys[2] in state_dict:
96
+ alpha: float = state_dict[triplet_keys[2]].cast(dtype).item()
97
+ scale: float = alpha / rank
98
+ else:
99
+ scale = 1.0
100
+
101
+ if not hasattr(curr_layer, "backup_weights"):
102
+ curr_layer.backup_weights = curr_layer.weight.clone()
103
+
104
+ if len(weight_down.shape) == 4:
105
+ if weight_down.shape[2:4] == [1, 1]:
106
+ # conv2d 1x1
107
+ curr_layer.weight.copy_(
108
+ curr_layer.weight +
109
+ ratio * paddle.matmul(weight_up.squeeze(
110
+ [-1, -2]), weight_down.squeeze([-1, -2])).unsqueeze(
111
+ [-1, -2]) * scale, True)
112
+ else:
113
+ # conv2d 3x3
114
+ curr_layer.weight.copy_(
115
+ curr_layer.weight + ratio * paddle.nn.functional.conv2d(
116
+ weight_down.transpose([1, 0, 2, 3]),
117
+ weight_up).transpose([1, 0, 2, 3]) * scale, True)
118
+ else:
119
+ # linear
120
+ curr_layer.weight.copy_(
121
+ curr_layer.weight +
122
+ ratio * paddle.matmul(weight_up, weight_down).T * scale, True)
123
+
124
+ # update visited list
125
+ visited.extend(triplet_keys)
126
+ return pipeline
127
+
128
+ class WebUIStableDiffusionPipeline(DiffusionPipeline):
129
+ r"""
130
+ Pipeline for text-to-image generation using Stable Diffusion.
131
+
132
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
133
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
134
+
135
+ Args:
136
+ vae ([`AutoencoderKL`]):
137
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
138
+ text_encoder ([`CLIPTextModel`]):
139
+ Frozen text-encoder. Stable Diffusion uses the text portion of
140
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
141
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
142
+ tokenizer (`CLIPTokenizer`):
143
+ Tokenizer of class
144
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
145
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
146
+ scheduler ([`SchedulerMixin`]):
147
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
148
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
149
+ or [`DPMSolverMultistepScheduler`].
150
+ safety_checker ([`StableDiffusionSafetyChecker`]):
151
+ Classification module that estimates whether generated images could be considered offensive or harmful.
152
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
153
+ feature_extractor ([`CLIPFeatureExtractor`]):
154
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
155
+ """
156
+ _optional_components = ["safety_checker", "feature_extractor"]
157
+ enable_emphasis = True
158
+ comma_padding_backtrack = 20
159
+
160
+ def __init__(
161
+ self,
162
+ vae: AutoencoderKL,
163
+ text_encoder: CLIPTextModel,
164
+ tokenizer: CLIPTokenizer,
165
+ unet: UNet2DConditionModel,
166
+ scheduler: KarrasDiffusionSchedulers,
167
+ safety_checker: StableDiffusionSafetyChecker,
168
+ feature_extractor: CLIPFeatureExtractor,
169
+ requires_safety_checker: bool = True,
170
+ ):
171
+ super().__init__()
172
+
173
+ if safety_checker is None and requires_safety_checker:
174
+ logger.warning(
175
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
176
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
177
+ " results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
178
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
179
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
180
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
181
+ )
182
+
183
+ if safety_checker is not None and feature_extractor is None:
184
+ raise ValueError(
185
+ f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
186
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
187
+ )
188
+
189
+ self.register_modules(
190
+ vae=vae,
191
+ text_encoder=text_encoder,
192
+ tokenizer=tokenizer,
193
+ unet=unet,
194
+ scheduler=scheduler,
195
+ safety_checker=safety_checker,
196
+ feature_extractor=feature_extractor,
197
+ )
198
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
199
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
200
+
201
+ # custom data
202
+ clip_model = FrozenCLIPEmbedder(text_encoder, tokenizer)
203
+ self.sj = StableDiffusionModelHijack(clip_model)
204
+ self.orginal_scheduler_config = self.scheduler.config
205
+ self.supported_scheduler = [
206
+ "pndm",
207
+ "lms",
208
+ "euler",
209
+ "euler-ancestral",
210
+ "dpm-multi",
211
+ "dpm-single",
212
+ "unipc-multi",
213
+ "ddim",
214
+ "ddpm",
215
+ "deis-multi",
216
+ "heun",
217
+ "kdpm2-ancestral",
218
+ "kdpm2",
219
+ ]
220
+ self.weights_has_changed = False
221
+
222
+ def add_ti_embedding_dir(self, embeddings_dir):
223
+ self.sj.embedding_db.add_embedding_dir(embeddings_dir)
224
+ self.sj.embedding_db.load_textual_inversion_embeddings()
225
+
226
+ def clear_ti_embedding(self):
227
+ self.sj.embedding_db.clear_embedding_dirs()
228
+ self.sj.embedding_db.load_textual_inversion_embeddings(True)
229
+
230
+ def change_scheduler(self, scheduler_type="ddim"):
231
+ self.switch_scheduler(scheduler_type)
232
+
233
+ def switch_scheduler(self, scheduler_type="ddim"):
234
+ scheduler_type = scheduler_type.lower()
235
+ from ppdiffusers import (
236
+ DDIMScheduler,
237
+ DDPMScheduler,
238
+ DEISMultistepScheduler,
239
+ DPMSolverMultistepScheduler,
240
+ DPMSolverSinglestepScheduler,
241
+ EulerAncestralDiscreteScheduler,
242
+ EulerDiscreteScheduler,
243
+ HeunDiscreteScheduler,
244
+ KDPM2AncestralDiscreteScheduler,
245
+ KDPM2DiscreteScheduler,
246
+ LMSDiscreteScheduler,
247
+ PNDMScheduler,
248
+ UniPCMultistepScheduler,
249
+ )
250
+
251
+ if scheduler_type == "pndm":
252
+ scheduler = PNDMScheduler.from_config(self.orginal_scheduler_config, skip_prk_steps=True)
253
+ elif scheduler_type == "lms":
254
+ scheduler = LMSDiscreteScheduler.from_config(self.orginal_scheduler_config)
255
+ elif scheduler_type == "heun":
256
+ scheduler = HeunDiscreteScheduler.from_config(self.orginal_scheduler_config)
257
+ elif scheduler_type == "euler":
258
+ scheduler = EulerDiscreteScheduler.from_config(self.orginal_scheduler_config)
259
+ elif scheduler_type == "euler-ancestral":
260
+ scheduler = EulerAncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
261
+ elif scheduler_type == "dpm-multi":
262
+ scheduler = DPMSolverMultistepScheduler.from_config(self.orginal_scheduler_config)
263
+ elif scheduler_type == "dpm-single":
264
+ scheduler = DPMSolverSinglestepScheduler.from_config(self.orginal_scheduler_config)
265
+ elif scheduler_type == "kdpm2-ancestral":
266
+ scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.orginal_scheduler_config)
267
+ elif scheduler_type == "kdpm2":
268
+ scheduler = KDPM2DiscreteScheduler.from_config(self.orginal_scheduler_config)
269
+ elif scheduler_type == "unipc-multi":
270
+ scheduler = UniPCMultistepScheduler.from_config(self.orginal_scheduler_config)
271
+ elif scheduler_type == "ddim":
272
+ scheduler = DDIMScheduler.from_config(
273
+ self.orginal_scheduler_config,
274
+ steps_offset=1,
275
+ clip_sample=False,
276
+ set_alpha_to_one=False,
277
+ )
278
+ elif scheduler_type == "ddpm":
279
+ scheduler = DDPMScheduler.from_config(
280
+ self.orginal_scheduler_config,
281
+ )
282
+ elif scheduler_type == "deis-multi":
283
+ scheduler = DEISMultistepScheduler.from_config(
284
+ self.orginal_scheduler_config,
285
+ )
286
+ else:
287
+ raise ValueError(
288
+ f"Scheduler of type {scheduler_type} doesn't exist! Please choose in {self.supported_scheduler}!"
289
+ )
290
+ self.scheduler = scheduler
291
+
292
+ @paddle.no_grad()
293
+ def _encode_prompt(
294
+ self,
295
+ prompt: str,
296
+ do_classifier_free_guidance: float = 7.5,
297
+ negative_prompt: str = None,
298
+ num_inference_steps: int = 50,
299
+ ):
300
+ if do_classifier_free_guidance:
301
+ assert isinstance(negative_prompt, str)
302
+ negative_prompt = [negative_prompt]
303
+ uc = get_learned_conditioning(self.sj.clip, negative_prompt, num_inference_steps)
304
+ else:
305
+ uc = None
306
+
307
+ c = get_multicond_learned_conditioning(self.sj.clip, prompt, num_inference_steps)
308
+ return c, uc
309
+
310
+ def run_safety_checker(self, image, dtype):
311
+ if self.safety_checker is not None:
312
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
313
+ image, has_nsfw_concept = self.safety_checker(
314
+ images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
315
+ )
316
+ else:
317
+ has_nsfw_concept = None
318
+ return image, has_nsfw_concept
319
+
320
+ def decode_latents(self, latents):
321
+ latents = 1 / self.vae.config.scaling_factor * latents
322
+ image = self.vae.decode(latents).sample
323
+ image = (image / 2 + 0.5).clip(0, 1)
324
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
325
+ image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
326
+ return image
327
+
328
+ def prepare_extra_step_kwargs(self, generator, eta):
329
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
330
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
331
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
332
+ # and should be between [0, 1]
333
+
334
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
335
+ extra_step_kwargs = {}
336
+ if accepts_eta:
337
+ extra_step_kwargs["eta"] = eta
338
+
339
+ # check if the scheduler accepts generator
340
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
341
+ if accepts_generator:
342
+ extra_step_kwargs["generator"] = generator
343
+ return extra_step_kwargs
344
+
345
+ def check_inputs(
346
+ self,
347
+ prompt,
348
+ height,
349
+ width,
350
+ callback_steps,
351
+ negative_prompt=None,
352
+ ):
353
+ if height % 8 != 0 or width % 8 != 0:
354
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
355
+
356
+ if (callback_steps is None) or (
357
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
358
+ ):
359
+ raise ValueError(
360
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
361
+ f" {type(callback_steps)}."
362
+ )
363
+
364
+ if prompt is not None and not isinstance(prompt, str):
365
+ raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}")
366
+
367
+ if negative_prompt is not None and not isinstance(negative_prompt, str):
368
+ raise ValueError(f"`negative_prompt` has to be of type `str` but is {type(negative_prompt)}")
369
+
370
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
371
+ shape = [batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor]
372
+ if isinstance(generator, list) and len(generator) != batch_size:
373
+ raise ValueError(
374
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
375
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
376
+ )
377
+
378
+ if latents is None:
379
+ latents = randn_tensor(shape, generator=generator, dtype=dtype)
380
+
381
+ # scale the initial noise by the standard deviation required by the scheduler
382
+ latents = latents * self.scheduler.init_noise_sigma
383
+ return latents
384
+
385
+ @paddle.no_grad()
386
+ def __call__(
387
+ self,
388
+ prompt: str = None,
389
+ height: Optional[int] = None,
390
+ width: Optional[int] = None,
391
+ num_inference_steps: int = 50,
392
+ guidance_scale: float = 7.5,
393
+ negative_prompt: str = None,
394
+ eta: float = 0.0,
395
+ generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
396
+ latents: Optional[paddle.Tensor] = None,
397
+ output_type: Optional[str] = "pil",
398
+ return_dict: bool = True,
399
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
400
+ callback_steps: Optional[int] = 1,
401
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
402
+ clip_skip: int = 0,
403
+ lora_dir: str = "./loras",
404
+ ):
405
+ r"""
406
+ Function invoked when calling the pipeline for generation.
407
+
408
+ Args:
409
+ prompt (`str`, *optional*):
410
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
411
+ instead.
412
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
413
+ The height in pixels of the generated image.
414
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
415
+ The width in pixels of the generated image.
416
+ num_inference_steps (`int`, *optional*, defaults to 50):
417
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
418
+ expense of slower inference.
419
+ guidance_scale (`float`, *optional*, defaults to 7.5):
420
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
421
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
422
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
423
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
424
+ usually at the expense of lower image quality.
425
+ negative_prompt (`str`, *optional*):
426
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
427
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
428
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
429
+ eta (`float`, *optional*, defaults to 0.0):
430
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
431
+ [`schedulers.DDIMScheduler`], will be ignored for others.
432
+ generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
433
+ One or a list of paddle generator(s) to make generation deterministic.
434
+ latents (`paddle.Tensor`, *optional*):
435
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
436
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
437
+ tensor will ge generated by sampling using the supplied random `generator`.
438
+ output_type (`str`, *optional*, defaults to `"pil"`):
439
+ The output format of the generate image. Choose between
440
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
441
+ return_dict (`bool`, *optional*, defaults to `True`):
442
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
443
+ plain tuple.
444
+ callback (`Callable`, *optional*):
445
+ A function that will be called every `callback_steps` steps during inference. The function will be
446
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
447
+ callback_steps (`int`, *optional*, defaults to 1):
448
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
449
+ called at every step.
450
+ cross_attention_kwargs (`dict`, *optional*):
451
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
452
+ `self.processor` in
453
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
454
+ clip_skip (`int`, *optional*, defaults to 0):
455
+ CLIP_stop_at_last_layers, if clip_skip < 1, we will use the last_hidden_state from text_encoder.
456
+ Examples:
457
+
458
+ Returns:
459
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
460
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
461
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
462
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
463
+ (nsfw) content, according to the `safety_checker`.
464
+ """
465
+ try:
466
+ # 0. Default height and width to unet
467
+ height = height or max(self.unet.config.sample_size * self.vae_scale_factor, 512)
468
+ width = width or max(self.unet.config.sample_size * self.vae_scale_factor, 512)
469
+
470
+ # 1. Check inputs. Raise error if not correct
471
+ self.check_inputs(
472
+ prompt,
473
+ height,
474
+ width,
475
+ callback_steps,
476
+ negative_prompt,
477
+ )
478
+
479
+ batch_size = 1
480
+
481
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
482
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
483
+ # corresponds to doing no classifier free guidance.
484
+ do_classifier_free_guidance = guidance_scale > 1.0
485
+
486
+ prompts, extra_network_data = parse_prompts([prompt])
487
+
488
+ if lora_dir is not None and os.path.exists(lora_dir):
489
+ lora_mapping = {p.stem: p.absolute() for p in Path(lora_dir).glob("*.safetensors")}
490
+ for params in extra_network_data["lora"]:
491
+ assert len(params.items) > 0
492
+ name = params.items[0]
493
+ if name in lora_mapping:
494
+ ratio = float(params.items[1]) if len(params.items) > 1 else 1.0
495
+ lora_state_dict = smart_load(lora_mapping[name], map_location=paddle.get_device())
496
+ self.weights_has_changed = True
497
+ load_lora(self, state_dict=lora_state_dict, ratio=ratio)
498
+ del lora_state_dict
499
+ else:
500
+ print(f"We can't find lora weight: {name}! Please make sure that exists!")
501
+
502
+ self.sj.clip.CLIP_stop_at_last_layers = clip_skip
503
+ # 3. Encode input prompt
504
+ prompt_embeds, negative_prompt_embeds = self._encode_prompt(
505
+ prompts,
506
+ do_classifier_free_guidance,
507
+ negative_prompt,
508
+ num_inference_steps=num_inference_steps,
509
+ )
510
+
511
+ # 4. Prepare timesteps
512
+ self.scheduler.set_timesteps(num_inference_steps)
513
+ timesteps = self.scheduler.timesteps
514
+
515
+ # 5. Prepare latent variables
516
+ num_channels_latents = self.unet.in_channels
517
+ latents = self.prepare_latents(
518
+ batch_size,
519
+ num_channels_latents,
520
+ height,
521
+ width,
522
+ self.unet.dtype,
523
+ generator,
524
+ latents,
525
+ )
526
+
527
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
528
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
529
+
530
+ # 7. Denoising loop
531
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
532
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
533
+ for i, t in enumerate(timesteps):
534
+ step = i // self.scheduler.order
535
+ do_batch = False
536
+ conds_list, cond_tensor = reconstruct_multicond_batch(prompt_embeds, step)
537
+ try:
538
+ weight = conds_list[0][0][1]
539
+ except Exception:
540
+ weight = 1.0
541
+ if do_classifier_free_guidance:
542
+ uncond_tensor = reconstruct_cond_batch(negative_prompt_embeds, step)
543
+ do_batch = cond_tensor.shape[1] == uncond_tensor.shape[1]
544
+
545
+ # expand the latents if we are doing classifier free guidance
546
+ latent_model_input = paddle.concat([latents] * 2) if do_batch else latents
547
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
548
+
549
+ if do_batch:
550
+ noise_pred = self.unet(
551
+ latent_model_input,
552
+ t,
553
+ encoder_hidden_states=paddle.concat([uncond_tensor, cond_tensor]),
554
+ cross_attention_kwargs=cross_attention_kwargs,
555
+ ).sample
556
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
557
+ noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred_text - noise_pred_uncond)
558
+ else:
559
+ noise_pred = self.unet(
560
+ latent_model_input,
561
+ t,
562
+ encoder_hidden_states=cond_tensor,
563
+ cross_attention_kwargs=cross_attention_kwargs,
564
+ ).sample
565
+
566
+ if do_classifier_free_guidance:
567
+ noise_pred_uncond = self.unet(
568
+ latent_model_input,
569
+ t,
570
+ encoder_hidden_states=uncond_tensor,
571
+ cross_attention_kwargs=cross_attention_kwargs,
572
+ ).sample
573
+ noise_pred = noise_pred_uncond + weight * guidance_scale * (noise_pred - noise_pred_uncond)
574
+
575
+ # compute the previous noisy sample x_t -> x_t-1
576
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
577
+
578
+ # call the callback, if provided
579
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
580
+ progress_bar.update()
581
+ if callback is not None and i % callback_steps == 0:
582
+ callback(i, t, latents)
583
+
584
+ if output_type == "latent":
585
+ image = latents
586
+ has_nsfw_concept = None
587
+ elif output_type == "pil":
588
+ # 8. Post-processing
589
+ image = self.decode_latents(latents)
590
+
591
+ # 9. Run safety checker
592
+ image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
593
+
594
+ # 10. Convert to PIL
595
+ image = self.numpy_to_pil(image)
596
+ else:
597
+ # 8. Post-processing
598
+ image = self.decode_latents(latents)
599
+
600
+ # 9. Run safety checker
601
+ image, has_nsfw_concept = self.run_safety_checker(image, self.unet.dtype)
602
+
603
+ if not return_dict:
604
+ return (image, has_nsfw_concept)
605
+
606
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
607
+ except Exception as e:
608
+ raise ValueError(e)
609
+ finally:
610
+ if self.weights_has_changed:
611
+ for sub_layer in self.text_encoder.sublayers(include_self=True):
612
+ if hasattr(sub_layer, "backup_weights"):
613
+ sub_layer.weight.copy_(sub_layer.backup_weights, True)
614
+ for sub_layer in self.unet.sublayers(include_self=True):
615
+ if hasattr(sub_layer, "backup_weights"):
616
+ sub_layer.weight.copy_(sub_layer.backup_weights, True)
617
+ self.weights_has_changed = False
618
+
619
+ # clip.py
620
+ import math
621
+ from collections import namedtuple
622
+
623
+
624
+ class PromptChunk:
625
+ """
626
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
627
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
628
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
629
+ so just 75 tokens from prompt.
630
+ """
631
+
632
+ def __init__(self):
633
+ self.tokens = []
634
+ self.multipliers = []
635
+ self.fixes = []
636
+
637
+
638
+ PromptChunkFix = namedtuple("PromptChunkFix", ["offset", "embedding"])
639
+ """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
640
+ chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
641
+ are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
642
+
643
+
644
+ class FrozenCLIPEmbedder(nn.Layer):
645
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
646
+
647
+ LAYERS = ["last", "pooled", "hidden"]
648
+
649
+ def __init__(self, text_encoder, tokenizer, freeze=True, layer="last", layer_idx=None):
650
+ super().__init__()
651
+ assert layer in self.LAYERS
652
+ self.tokenizer = tokenizer
653
+ self.text_encoder = text_encoder
654
+ if freeze:
655
+ self.freeze()
656
+ self.layer = layer
657
+ self.layer_idx = layer_idx
658
+ if layer == "hidden":
659
+ assert layer_idx is not None
660
+ assert 0 <= abs(layer_idx) <= 12
661
+
662
+ def freeze(self):
663
+ self.text_encoder.eval()
664
+ for param in self.parameters():
665
+ param.stop_gradient = False
666
+
667
+ def forward(self, text):
668
+ batch_encoding = self.tokenizer(
669
+ text,
670
+ truncation=True,
671
+ max_length=self.tokenizer.model_max_length,
672
+ padding="max_length",
673
+ return_tensors="pd",
674
+ )
675
+ tokens = batch_encoding["input_ids"]
676
+ outputs = self.text_encoder(input_ids=tokens, output_hidden_states=self.layer == "hidden", return_dict=True)
677
+ if self.layer == "last":
678
+ z = outputs.last_hidden_state
679
+ elif self.layer == "pooled":
680
+ z = outputs.pooler_output[:, None, :]
681
+ else:
682
+ z = outputs.hidden_states[self.layer_idx]
683
+ return z
684
+
685
+ def encode(self, text):
686
+ return self(text)
687
+
688
+
689
+ class FrozenCLIPEmbedderWithCustomWordsBase(nn.Layer):
690
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
691
+ have unlimited prompt length and assign weights to tokens in prompt.
692
+ """
693
+
694
+ def __init__(self, wrapped, hijack):
695
+ super().__init__()
696
+
697
+ self.wrapped = wrapped
698
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
699
+ depending on model."""
700
+
701
+ self.hijack = hijack
702
+ self.chunk_length = 75
703
+
704
+ def empty_chunk(self):
705
+ """creates an empty PromptChunk and returns it"""
706
+
707
+ chunk = PromptChunk()
708
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
709
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
710
+ return chunk
711
+
712
+ def get_target_prompt_token_count(self, token_count):
713
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
714
+
715
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
716
+
717
+ def tokenize(self, texts):
718
+ """Converts a batch of texts into a batch of token ids"""
719
+
720
+ raise NotImplementedError
721
+
722
+ def encode_with_text_encoder(self, tokens):
723
+ """
724
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
725
+ All python lists with tokens are assumed to have same length, usually 77.
726
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
727
+ model - can be 768 and 1024.
728
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
729
+ """
730
+
731
+ raise NotImplementedError
732
+
733
+ def encode_embedding_init_text(self, init_text, nvpt):
734
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
735
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
736
+
737
+ raise NotImplementedError
738
+
739
+ def tokenize_line(self, line):
740
+ """
741
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
742
+ represent the prompt.
743
+ Returns the list and the total number of tokens in the prompt.
744
+ """
745
+
746
+ if WebUIStableDiffusionPipeline.enable_emphasis:
747
+ parsed = parse_prompt_attention(line)
748
+ else:
749
+ parsed = [[line, 1.0]]
750
+
751
+ tokenized = self.tokenize([text for text, _ in parsed])
752
+
753
+ chunks = []
754
+ chunk = PromptChunk()
755
+ token_count = 0
756
+ last_comma = -1
757
+
758
+ def next_chunk(is_last=False):
759
+ """puts current chunk into the list of results and produces the next one - empty;
760
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
761
+ nonlocal token_count
762
+ nonlocal last_comma
763
+ nonlocal chunk
764
+
765
+ if is_last:
766
+ token_count += len(chunk.tokens)
767
+ else:
768
+ token_count += self.chunk_length
769
+
770
+ to_add = self.chunk_length - len(chunk.tokens)
771
+ if to_add > 0:
772
+ chunk.tokens += [self.id_end] * to_add
773
+ chunk.multipliers += [1.0] * to_add
774
+
775
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
776
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
777
+
778
+ last_comma = -1
779
+ chunks.append(chunk)
780
+ chunk = PromptChunk()
781
+
782
+ for tokens, (text, weight) in zip(tokenized, parsed):
783
+ if text == "BREAK" and weight == -1:
784
+ next_chunk()
785
+ continue
786
+
787
+ position = 0
788
+ while position < len(tokens):
789
+ token = tokens[position]
790
+
791
+ if token == self.comma_token:
792
+ last_comma = len(chunk.tokens)
793
+
794
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
795
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
796
+ elif (
797
+ WebUIStableDiffusionPipeline.comma_padding_backtrack != 0
798
+ and len(chunk.tokens) == self.chunk_length
799
+ and last_comma != -1
800
+ and len(chunk.tokens) - last_comma <= WebUIStableDiffusionPipeline.comma_padding_backtrack
801
+ ):
802
+ break_location = last_comma + 1
803
+
804
+ reloc_tokens = chunk.tokens[break_location:]
805
+ reloc_mults = chunk.multipliers[break_location:]
806
+
807
+ chunk.tokens = chunk.tokens[:break_location]
808
+ chunk.multipliers = chunk.multipliers[:break_location]
809
+
810
+ next_chunk()
811
+ chunk.tokens = reloc_tokens
812
+ chunk.multipliers = reloc_mults
813
+
814
+ if len(chunk.tokens) == self.chunk_length:
815
+ next_chunk()
816
+
817
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(
818
+ tokens, position
819
+ )
820
+ if embedding is None:
821
+ chunk.tokens.append(token)
822
+ chunk.multipliers.append(weight)
823
+ position += 1
824
+ continue
825
+
826
+ emb_len = int(embedding.vec.shape[0])
827
+ if len(chunk.tokens) + emb_len > self.chunk_length:
828
+ next_chunk()
829
+
830
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
831
+
832
+ chunk.tokens += [0] * emb_len
833
+ chunk.multipliers += [weight] * emb_len
834
+ position += embedding_length_in_tokens
835
+
836
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
837
+ next_chunk(is_last=True)
838
+
839
+ return chunks, token_count
840
+
841
+ def process_texts(self, texts):
842
+ """
843
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
844
+ length, in tokens, of all texts.
845
+ """
846
+
847
+ token_count = 0
848
+
849
+ cache = {}
850
+ batch_chunks = []
851
+ for line in texts:
852
+ if line in cache:
853
+ chunks = cache[line]
854
+ else:
855
+ chunks, current_token_count = self.tokenize_line(line)
856
+ token_count = max(current_token_count, token_count)
857
+
858
+ cache[line] = chunks
859
+
860
+ batch_chunks.append(chunks)
861
+
862
+ return batch_chunks, token_count
863
+
864
+ def forward(self, texts):
865
+ """
866
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
867
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
868
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
869
+ An example shape returned by this function can be: (2, 77, 768).
870
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
871
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
872
+ """
873
+
874
+ batch_chunks, token_count = self.process_texts(texts)
875
+
876
+ used_embeddings = {}
877
+ chunk_count = max([len(x) for x in batch_chunks])
878
+
879
+ zs = []
880
+ for i in range(chunk_count):
881
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
882
+
883
+ tokens = [x.tokens for x in batch_chunk]
884
+ multipliers = [x.multipliers for x in batch_chunk]
885
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
886
+
887
+ for fixes in self.hijack.fixes:
888
+ for position, embedding in fixes:
889
+ used_embeddings[embedding.name] = embedding
890
+
891
+ z = self.process_tokens(tokens, multipliers)
892
+ zs.append(z)
893
+
894
+ if len(used_embeddings) > 0:
895
+ embeddings_list = ", ".join(
896
+ [f"{name} [{embedding.checksum()}]" for name, embedding in used_embeddings.items()]
897
+ )
898
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
899
+
900
+ return paddle.concat(zs, axis=1)
901
+
902
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
903
+ """
904
+ sends one single prompt chunk to be encoded by transformers neural network.
905
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
906
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
907
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
908
+ corresponds to one token.
909
+ """
910
+ tokens = paddle.to_tensor(remade_batch_tokens)
911
+
912
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
913
+ if self.id_end != self.id_pad:
914
+ for batch_pos in range(len(remade_batch_tokens)):
915
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
916
+ tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
917
+
918
+ z = self.encode_with_text_encoder(tokens)
919
+
920
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
921
+ batch_multipliers = paddle.to_tensor(batch_multipliers)
922
+ original_mean = z.mean()
923
+ z = z * batch_multipliers.reshape(
924
+ batch_multipliers.shape
925
+ + [
926
+ 1,
927
+ ]
928
+ ).expand(z.shape)
929
+ new_mean = z.mean()
930
+ z = z * (original_mean / new_mean)
931
+
932
+ return z
933
+
934
+
935
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
936
+ def __init__(self, wrapped, hijack, CLIP_stop_at_last_layers=-1):
937
+ super().__init__(wrapped, hijack)
938
+ self.CLIP_stop_at_last_layers = CLIP_stop_at_last_layers
939
+ self.tokenizer = wrapped.tokenizer
940
+
941
+ vocab = self.tokenizer.get_vocab()
942
+
943
+ self.comma_token = vocab.get(",</w>", None)
944
+
945
+ self.token_mults = {}
946
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if "(" in k or ")" in k or "[" in k or "]" in k]
947
+ for text, ident in tokens_with_parens:
948
+ mult = 1.0
949
+ for c in text:
950
+ if c == "[":
951
+ mult /= 1.1
952
+ if c == "]":
953
+ mult *= 1.1
954
+ if c == "(":
955
+ mult *= 1.1
956
+ if c == ")":
957
+ mult /= 1.1
958
+
959
+ if mult != 1.0:
960
+ self.token_mults[ident] = mult
961
+
962
+ self.id_start = self.wrapped.tokenizer.bos_token_id
963
+ self.id_end = self.wrapped.tokenizer.eos_token_id
964
+ self.id_pad = self.id_end
965
+
966
+ def tokenize(self, texts):
967
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
968
+
969
+ return tokenized
970
+
971
+ def encode_with_text_encoder(self, tokens):
972
+ output_hidden_states = self.CLIP_stop_at_last_layers > 1
973
+ outputs = self.wrapped.text_encoder(
974
+ input_ids=tokens, output_hidden_states=output_hidden_states, return_dict=True
975
+ )
976
+
977
+ if output_hidden_states:
978
+ z = outputs.hidden_states[-self.CLIP_stop_at_last_layers]
979
+ z = self.wrapped.text_encoder.text_model.ln_final(z)
980
+ else:
981
+ z = outputs.last_hidden_state
982
+
983
+ return z
984
+
985
+ def encode_embedding_init_text(self, init_text, nvpt):
986
+ embedding_layer = self.wrapped.text_encoder.text_model
987
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pd", add_special_tokens=False)[
988
+ "input_ids"
989
+ ]
990
+ embedded = embedding_layer.token_embedding.wrapped(ids).squeeze(0)
991
+
992
+ return embedded
993
+
994
+
995
+ # extra_networks.py
996
+ import re
997
+ from collections import defaultdict
998
+
999
+
1000
+ class ExtraNetworkParams:
1001
+ def __init__(self, items=None):
1002
+ self.items = items or []
1003
+
1004
+
1005
+ re_extra_net = re.compile(r"<(\w+):([^>]+)>")
1006
+
1007
+
1008
+ def parse_prompt(prompt):
1009
+ res = defaultdict(list)
1010
+
1011
+ def found(m):
1012
+ name = m.group(1)
1013
+ args = m.group(2)
1014
+
1015
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
1016
+
1017
+ return ""
1018
+
1019
+ prompt = re.sub(re_extra_net, found, prompt)
1020
+
1021
+ return prompt, res
1022
+
1023
+
1024
+ def parse_prompts(prompts):
1025
+ res = []
1026
+ extra_data = None
1027
+
1028
+ for prompt in prompts:
1029
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
1030
+
1031
+ if extra_data is None:
1032
+ extra_data = parsed_extra_data
1033
+
1034
+ res.append(updated_prompt)
1035
+
1036
+ return res, extra_data
1037
+
1038
+
1039
+ # image_embeddings.py
1040
+
1041
+ import base64
1042
+ import json
1043
+ import zlib
1044
+
1045
+ import numpy as np
1046
+ from PIL import Image
1047
+
1048
+
1049
+ class EmbeddingDecoder(json.JSONDecoder):
1050
+ def __init__(self, *args, **kwargs):
1051
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
1052
+
1053
+ def object_hook(self, d):
1054
+ if "TORCHTENSOR" in d:
1055
+ return paddle.to_tensor(np.array(d["TORCHTENSOR"]))
1056
+ return d
1057
+
1058
+
1059
+ def embedding_from_b64(data):
1060
+ d = base64.b64decode(data)
1061
+ return json.loads(d, cls=EmbeddingDecoder)
1062
+
1063
+
1064
+ def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
1065
+ while True:
1066
+ seed = (a * seed + c) % m
1067
+ yield seed % 255
1068
+
1069
+
1070
+ def xor_block(block):
1071
+ g = lcg()
1072
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
1073
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
1074
+
1075
+
1076
+ def crop_black(img, tol=0):
1077
+ mask = (img > tol).all(2)
1078
+ mask0, mask1 = mask.any(0), mask.any(1)
1079
+ col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
1080
+ row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
1081
+ return img[row_start:row_end, col_start:col_end]
1082
+
1083
+
1084
+ def extract_image_data_embed(image):
1085
+ d = 3
1086
+ outarr = (
1087
+ crop_black(np.array(image.convert("RGB").getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8))
1088
+ & 0x0F
1089
+ )
1090
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
1091
+ if black_cols[0].shape[0] < 2:
1092
+ print("No Image data blocks found.")
1093
+ return None
1094
+
1095
+ data_block_lower = outarr[:, : black_cols[0].min(), :].astype(np.uint8)
1096
+ data_block_upper = outarr[:, black_cols[0].max() + 1 :, :].astype(np.uint8)
1097
+
1098
+ data_block_lower = xor_block(data_block_lower)
1099
+ data_block_upper = xor_block(data_block_upper)
1100
+
1101
+ data_block = (data_block_upper << 4) | (data_block_lower)
1102
+ data_block = data_block.flatten().tobytes()
1103
+
1104
+ data = zlib.decompress(data_block)
1105
+ return json.loads(data, cls=EmbeddingDecoder)
1106
+
1107
+
1108
+ # prompt_parser.py
1109
+ import re
1110
+ from collections import namedtuple
1111
+ from typing import List
1112
+
1113
+ import lark
1114
+
1115
+ # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
1116
+ # will be represented with prompt_schedule like this (assuming steps=100):
1117
+ # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
1118
+ # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
1119
+ # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
1120
+ # [75, 'fantasy landscape with a lake and an oak in background masterful']
1121
+ # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
1122
+
1123
+ schedule_parser = lark.Lark(
1124
+ r"""
1125
+ !start: (prompt | /[][():]/+)*
1126
+ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
1127
+ !emphasized: "(" prompt ")"
1128
+ | "(" prompt ":" prompt ")"
1129
+ | "[" prompt "]"
1130
+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
1131
+ alternate: "[" prompt ("|" prompt)+ "]"
1132
+ WHITESPACE: /\s+/
1133
+ plain: /([^\\\[\]():|]|\\.)+/
1134
+ %import common.SIGNED_NUMBER -> NUMBER
1135
+ """
1136
+ )
1137
+
1138
+
1139
+ def get_learned_conditioning_prompt_schedules(prompts, steps):
1140
+ """
1141
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
1142
+ >>> g("test")
1143
+ [[10, 'test']]
1144
+ >>> g("a [b:3]")
1145
+ [[3, 'a '], [10, 'a b']]
1146
+ >>> g("a [b: 3]")
1147
+ [[3, 'a '], [10, 'a b']]
1148
+ >>> g("a [[[b]]:2]")
1149
+ [[2, 'a '], [10, 'a [[b]]']]
1150
+ >>> g("[(a:2):3]")
1151
+ [[3, ''], [10, '(a:2)']]
1152
+ >>> g("a [b : c : 1] d")
1153
+ [[1, 'a b d'], [10, 'a c d']]
1154
+ >>> g("a[b:[c:d:2]:1]e")
1155
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
1156
+ >>> g("a [unbalanced")
1157
+ [[10, 'a [unbalanced']]
1158
+ >>> g("a [b:.5] c")
1159
+ [[5, 'a c'], [10, 'a b c']]
1160
+ >>> g("a [{b|d{:.5] c") # not handling this right now
1161
+ [[5, 'a c'], [10, 'a {b|d{ c']]
1162
+ >>> g("((a][:b:c [d:3]")
1163
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
1164
+ >>> g("[a|(b:1.1)]")
1165
+ [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
1166
+ """
1167
+
1168
+ def collect_steps(steps, tree):
1169
+ l = [steps]
1170
+
1171
+ class CollectSteps(lark.Visitor):
1172
+ def scheduled(self, tree):
1173
+ tree.children[-1] = float(tree.children[-1])
1174
+ if tree.children[-1] < 1:
1175
+ tree.children[-1] *= steps
1176
+ tree.children[-1] = min(steps, int(tree.children[-1]))
1177
+ l.append(tree.children[-1])
1178
+
1179
+ def alternate(self, tree):
1180
+ l.extend(range(1, steps + 1))
1181
+
1182
+ CollectSteps().visit(tree)
1183
+ return sorted(set(l))
1184
+
1185
+ def at_step(step, tree):
1186
+ class AtStep(lark.Transformer):
1187
+ def scheduled(self, args):
1188
+ before, after, _, when = args
1189
+ yield before or () if step <= when else after
1190
+
1191
+ def alternate(self, args):
1192
+ yield next(args[(step - 1) % len(args)])
1193
+
1194
+ def start(self, args):
1195
+ def flatten(x):
1196
+ if type(x) == str:
1197
+ yield x
1198
+ else:
1199
+ for gen in x:
1200
+ yield from flatten(gen)
1201
+
1202
+ return "".join(flatten(args))
1203
+
1204
+ def plain(self, args):
1205
+ yield args[0].value
1206
+
1207
+ def __default__(self, data, children, meta):
1208
+ for child in children:
1209
+ yield child
1210
+
1211
+ return AtStep().transform(tree)
1212
+
1213
+ def get_schedule(prompt):
1214
+ try:
1215
+ tree = schedule_parser.parse(prompt)
1216
+ except lark.exceptions.LarkError:
1217
+ if 0:
1218
+ import traceback
1219
+
1220
+ traceback.print_exc()
1221
+ return [[steps, prompt]]
1222
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
1223
+
1224
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
1225
+ return [promptdict[prompt] for prompt in prompts]
1226
+
1227
+
1228
+ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
1229
+
1230
+
1231
+ def get_learned_conditioning(model, prompts, steps):
1232
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
1233
+ and the sampling step at which this condition is to be replaced by the next one.
1234
+
1235
+ Input:
1236
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
1237
+
1238
+ Output:
1239
+ [
1240
+ [
1241
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
1242
+ ],
1243
+ [
1244
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
1245
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
1246
+ ]
1247
+ ]
1248
+ """
1249
+ res = []
1250
+
1251
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
1252
+ cache = {}
1253
+
1254
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
1255
+
1256
+ cached = cache.get(prompt, None)
1257
+ if cached is not None:
1258
+ res.append(cached)
1259
+ continue
1260
+
1261
+ texts = [x[1] for x in prompt_schedule]
1262
+ conds = model(texts)
1263
+
1264
+ cond_schedule = []
1265
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
1266
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
1267
+
1268
+ cache[prompt] = cond_schedule
1269
+ res.append(cond_schedule)
1270
+
1271
+ return res
1272
+
1273
+
1274
+ re_AND = re.compile(r"\bAND\b")
1275
+ re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
1276
+
1277
+
1278
+ def get_multicond_prompt_list(prompts):
1279
+ res_indexes = []
1280
+
1281
+ prompt_flat_list = []
1282
+ prompt_indexes = {}
1283
+
1284
+ for prompt in prompts:
1285
+ subprompts = re_AND.split(prompt)
1286
+
1287
+ indexes = []
1288
+ for subprompt in subprompts:
1289
+ match = re_weight.search(subprompt)
1290
+
1291
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
1292
+
1293
+ weight = float(weight) if weight is not None else 1.0
1294
+
1295
+ index = prompt_indexes.get(text, None)
1296
+ if index is None:
1297
+ index = len(prompt_flat_list)
1298
+ prompt_flat_list.append(text)
1299
+ prompt_indexes[text] = index
1300
+
1301
+ indexes.append((index, weight))
1302
+
1303
+ res_indexes.append(indexes)
1304
+
1305
+ return res_indexes, prompt_flat_list, prompt_indexes
1306
+
1307
+
1308
+ class ComposableScheduledPromptConditioning:
1309
+ def __init__(self, schedules, weight=1.0):
1310
+ self.schedules: List[ScheduledPromptConditioning] = schedules
1311
+ self.weight: float = weight
1312
+
1313
+
1314
+ class MulticondLearnedConditioning:
1315
+ def __init__(self, shape, batch):
1316
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
1317
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
1318
+
1319
+
1320
+ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
1321
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
1322
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
1323
+
1324
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
1325
+ """
1326
+
1327
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
1328
+
1329
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
1330
+
1331
+ res = []
1332
+ for indexes in res_indexes:
1333
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
1334
+
1335
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
1336
+
1337
+
1338
+ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
1339
+ param = c[0][0].cond
1340
+ res = paddle.zeros(
1341
+ [
1342
+ len(c),
1343
+ ]
1344
+ + param.shape,
1345
+ dtype=param.dtype,
1346
+ )
1347
+ for i, cond_schedule in enumerate(c):
1348
+ target_index = 0
1349
+ for current, (end_at, cond) in enumerate(cond_schedule):
1350
+ if current_step <= end_at:
1351
+ target_index = current
1352
+ break
1353
+ res[i] = cond_schedule[target_index].cond
1354
+
1355
+ return res
1356
+
1357
+
1358
+ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
1359
+ param = c.batch[0][0].schedules[0].cond
1360
+
1361
+ tensors = []
1362
+ conds_list = []
1363
+
1364
+ for batch_no, composable_prompts in enumerate(c.batch):
1365
+ conds_for_batch = []
1366
+
1367
+ for cond_index, composable_prompt in enumerate(composable_prompts):
1368
+ target_index = 0
1369
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
1370
+ if current_step <= end_at:
1371
+ target_index = current
1372
+ break
1373
+
1374
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
1375
+ tensors.append(composable_prompt.schedules[target_index].cond)
1376
+
1377
+ conds_list.append(conds_for_batch)
1378
+
1379
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
1380
+ # and won't be able to torch.stack them. So this fixes that.
1381
+ token_count = max([x.shape[0] for x in tensors])
1382
+ for i in range(len(tensors)):
1383
+ if tensors[i].shape[0] != token_count:
1384
+ last_vector = tensors[i][-1:]
1385
+ last_vector_repeated = last_vector.tile([token_count - tensors[i].shape[0], 1])
1386
+ tensors[i] = paddle.concat([tensors[i], last_vector_repeated], axis=0)
1387
+
1388
+ return conds_list, paddle.stack(tensors).cast(dtype=param.dtype)
1389
+
1390
+
1391
+ re_attention = re.compile(
1392
+ r"""
1393
+ \\\(|
1394
+ \\\)|
1395
+ \\\[|
1396
+ \\]|
1397
+ \\\\|
1398
+ \\|
1399
+ \(|
1400
+ \[|
1401
+ :([+-]?[.\d]+)\)|
1402
+ \)|
1403
+ ]|
1404
+ [^\\()\[\]:]+|
1405
+ :
1406
+ """,
1407
+ re.X,
1408
+ )
1409
+
1410
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
1411
+
1412
+
1413
+ def parse_prompt_attention(text):
1414
+ """
1415
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
1416
+ Accepted tokens are:
1417
+ (abc) - increases attention to abc by a multiplier of 1.1
1418
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
1419
+ [abc] - decreases attention to abc by a multiplier of 1.1
1420
+ \( - literal character '('
1421
+ \[ - literal character '['
1422
+ \) - literal character ')'
1423
+ \] - literal character ']'
1424
+ \\ - literal character '\'
1425
+ anything else - just text
1426
+
1427
+ >>> parse_prompt_attention('normal text')
1428
+ [['normal text', 1.0]]
1429
+ >>> parse_prompt_attention('an (important) word')
1430
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
1431
+ >>> parse_prompt_attention('(unbalanced')
1432
+ [['unbalanced', 1.1]]
1433
+ >>> parse_prompt_attention('\(literal\]')
1434
+ [['(literal]', 1.0]]
1435
+ >>> parse_prompt_attention('(unnecessary)(parens)')
1436
+ [['unnecessaryparens', 1.1]]
1437
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
1438
+ [['a ', 1.0],
1439
+ ['house', 1.5730000000000004],
1440
+ [' ', 1.1],
1441
+ ['on', 1.0],
1442
+ [' a ', 1.1],
1443
+ ['hill', 0.55],
1444
+ [', sun, ', 1.1],
1445
+ ['sky', 1.4641000000000006],
1446
+ ['.', 1.1]]
1447
+ """
1448
+
1449
+ res = []
1450
+ round_brackets = []
1451
+ square_brackets = []
1452
+
1453
+ round_bracket_multiplier = 1.1
1454
+ square_bracket_multiplier = 1 / 1.1
1455
+
1456
+ def multiply_range(start_position, multiplier):
1457
+ for p in range(start_position, len(res)):
1458
+ res[p][1] *= multiplier
1459
+
1460
+ for m in re_attention.finditer(text):
1461
+ text = m.group(0)
1462
+ weight = m.group(1)
1463
+
1464
+ if text.startswith("\\"):
1465
+ res.append([text[1:], 1.0])
1466
+ elif text == "(":
1467
+ round_brackets.append(len(res))
1468
+ elif text == "[":
1469
+ square_brackets.append(len(res))
1470
+ elif weight is not None and len(round_brackets) > 0:
1471
+ multiply_range(round_brackets.pop(), float(weight))
1472
+ elif text == ")" and len(round_brackets) > 0:
1473
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
1474
+ elif text == "]" and len(square_brackets) > 0:
1475
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
1476
+ else:
1477
+ parts = re.split(re_break, text)
1478
+ for i, part in enumerate(parts):
1479
+ if i > 0:
1480
+ res.append(["BREAK", -1])
1481
+ res.append([part, 1.0])
1482
+
1483
+ for pos in round_brackets:
1484
+ multiply_range(pos, round_bracket_multiplier)
1485
+
1486
+ for pos in square_brackets:
1487
+ multiply_range(pos, square_bracket_multiplier)
1488
+
1489
+ if len(res) == 0:
1490
+ res = [["", 1.0]]
1491
+
1492
+ # merge runs of identical weights
1493
+ i = 0
1494
+ while i + 1 < len(res):
1495
+ if res[i][1] == res[i + 1][1]:
1496
+ res[i][0] += res[i + 1][0]
1497
+ res.pop(i + 1)
1498
+ else:
1499
+ i += 1
1500
+
1501
+ return res
1502
+
1503
+
1504
+ # sd_hijack.py
1505
+
1506
+
1507
+ class StableDiffusionModelHijack:
1508
+ fixes = None
1509
+ comments = []
1510
+ layers = None
1511
+ circular_enabled = False
1512
+
1513
+ def __init__(self, clip_model, embeddings_dir=None, CLIP_stop_at_last_layers=-1):
1514
+ model_embeddings = clip_model.text_encoder.text_model
1515
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
1516
+ clip_model = FrozenCLIPEmbedderWithCustomWords(
1517
+ clip_model, self, CLIP_stop_at_last_layers=CLIP_stop_at_last_layers
1518
+ )
1519
+
1520
+ self.embedding_db = EmbeddingDatabase(clip_model)
1521
+ self.embedding_db.add_embedding_dir(embeddings_dir)
1522
+
1523
+ # hack this!
1524
+ self.clip = clip_model
1525
+
1526
+ def flatten(el):
1527
+ flattened = [flatten(children) for children in el.children()]
1528
+ res = [el]
1529
+ for c in flattened:
1530
+ res += c
1531
+ return res
1532
+
1533
+ self.layers = flatten(clip_model)
1534
+
1535
+ def clear_comments(self):
1536
+ self.comments = []
1537
+
1538
+ def get_prompt_lengths(self, text):
1539
+ _, token_count = self.clip.process_texts([text])
1540
+
1541
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
1542
+
1543
+
1544
+ class EmbeddingsWithFixes(nn.Layer):
1545
+ def __init__(self, wrapped, embeddings):
1546
+ super().__init__()
1547
+ self.wrapped = wrapped
1548
+ self.embeddings = embeddings
1549
+
1550
+ def forward(self, input_ids):
1551
+ batch_fixes = self.embeddings.fixes
1552
+ self.embeddings.fixes = None
1553
+
1554
+ inputs_embeds = self.wrapped(input_ids)
1555
+
1556
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
1557
+ return inputs_embeds
1558
+
1559
+ vecs = []
1560
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
1561
+ for offset, embedding in fixes:
1562
+ emb = embedding.vec.cast(self.wrapped.dtype)
1563
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
1564
+ tensor = paddle.concat([tensor[0 : offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len :]])
1565
+
1566
+ vecs.append(tensor)
1567
+
1568
+ return paddle.stack(vecs)
1569
+
1570
+
1571
+ # textual_inversion.py
1572
+
1573
+ import os
1574
+ import sys
1575
+ import traceback
1576
+
1577
+
1578
+ class Embedding:
1579
+ def __init__(self, vec, name, step=None):
1580
+ self.vec = vec
1581
+ self.name = name
1582
+ self.step = step
1583
+ self.shape = None
1584
+ self.vectors = 0
1585
+ self.cached_checksum = None
1586
+ self.sd_checkpoint = None
1587
+ self.sd_checkpoint_name = None
1588
+ self.optimizer_state_dict = None
1589
+ self.filename = None
1590
+
1591
+ def save(self, filename):
1592
+ embedding_data = {
1593
+ "string_to_token": {"*": 265},
1594
+ "string_to_param": {"*": self.vec},
1595
+ "name": self.name,
1596
+ "step": self.step,
1597
+ "sd_checkpoint": self.sd_checkpoint,
1598
+ "sd_checkpoint_name": self.sd_checkpoint_name,
1599
+ }
1600
+
1601
+ paddle.save(embedding_data, filename)
1602
+
1603
+ def checksum(self):
1604
+ if self.cached_checksum is not None:
1605
+ return self.cached_checksum
1606
+
1607
+ def const_hash(a):
1608
+ r = 0
1609
+ for v in a:
1610
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
1611
+ return r
1612
+
1613
+ self.cached_checksum = f"{const_hash(self.vec.flatten() * 100) & 0xffff:04x}"
1614
+ return self.cached_checksum
1615
+
1616
+
1617
+ class DirWithTextualInversionEmbeddings:
1618
+ def __init__(self, path):
1619
+ self.path = path
1620
+ self.mtime = None
1621
+
1622
+ def has_changed(self):
1623
+ if not os.path.isdir(self.path):
1624
+ return False
1625
+
1626
+ mt = os.path.getmtime(self.path)
1627
+ if self.mtime is None or mt > self.mtime:
1628
+ return True
1629
+
1630
+ def update(self):
1631
+ if not os.path.isdir(self.path):
1632
+ return
1633
+
1634
+ self.mtime = os.path.getmtime(self.path)
1635
+
1636
+
1637
+ class EmbeddingDatabase:
1638
+ def __init__(self, clip):
1639
+ self.clip = clip
1640
+ self.ids_lookup = {}
1641
+ self.word_embeddings = {}
1642
+ self.skipped_embeddings = {}
1643
+ self.expected_shape = -1
1644
+ self.embedding_dirs = {}
1645
+ self.previously_displayed_embeddings = ()
1646
+
1647
+ def add_embedding_dir(self, path):
1648
+ if path is not None:
1649
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
1650
+
1651
+ def clear_embedding_dirs(self):
1652
+ self.embedding_dirs.clear()
1653
+
1654
+ def register_embedding(self, embedding, model):
1655
+ self.word_embeddings[embedding.name] = embedding
1656
+
1657
+ ids = model.tokenize([embedding.name])[0]
1658
+
1659
+ first_id = ids[0]
1660
+ if first_id not in self.ids_lookup:
1661
+ self.ids_lookup[first_id] = []
1662
+
1663
+ self.ids_lookup[first_id] = sorted(
1664
+ self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True
1665
+ )
1666
+
1667
+ return embedding
1668
+
1669
+ def get_expected_shape(self):
1670
+ vec = self.clip.encode_embedding_init_text(",", 1)
1671
+ return vec.shape[1]
1672
+
1673
+ def load_from_file(self, path, filename):
1674
+ name, ext = os.path.splitext(filename)
1675
+ ext = ext.upper()
1676
+
1677
+ if ext in [".PNG", ".WEBP", ".JXL", ".AVIF"]:
1678
+ _, second_ext = os.path.splitext(name)
1679
+ if second_ext.upper() == ".PREVIEW":
1680
+ return
1681
+
1682
+ embed_image = Image.open(path)
1683
+ if hasattr(embed_image, "text") and "sd-ti-embedding" in embed_image.text:
1684
+ data = embedding_from_b64(embed_image.text["sd-ti-embedding"])
1685
+ name = data.get("name", name)
1686
+ else:
1687
+ data = extract_image_data_embed(embed_image)
1688
+ if data:
1689
+ name = data.get("name", name)
1690
+ else:
1691
+ # if data is None, means this is not an embeding, just a preview image
1692
+ return
1693
+ elif ext in [".BIN", ".PT"]:
1694
+ data = torch_load(path)
1695
+ elif ext in [".SAFETENSORS"]:
1696
+ data = safetensors_load(path)
1697
+ else:
1698
+ return
1699
+
1700
+ # textual inversion embeddings
1701
+ if "string_to_param" in data:
1702
+ param_dict = data["string_to_param"]
1703
+ if hasattr(param_dict, "_parameters"):
1704
+ param_dict = getattr(param_dict, "_parameters")
1705
+ assert len(param_dict) == 1, "embedding file has multiple terms in it"
1706
+ emb = next(iter(param_dict.items()))[1]
1707
+ # diffuser concepts
1708
+ elif type(data) == dict and type(next(iter(data.values()))) == paddle.Tensor:
1709
+ assert len(data.keys()) == 1, "embedding file has multiple terms in it"
1710
+
1711
+ emb = next(iter(data.values()))
1712
+ if len(emb.shape) == 1:
1713
+ emb = emb.unsqueeze(0)
1714
+ else:
1715
+ raise Exception(
1716
+ f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept."
1717
+ )
1718
+
1719
+ with paddle.no_grad():
1720
+ if hasattr(emb, "detach"):
1721
+ emb = emb.detach()
1722
+ if hasattr(emb, "cpu"):
1723
+ emb = emb.cpu()
1724
+ if hasattr(emb, "numpy"):
1725
+ emb = emb.numpy()
1726
+ emb = paddle.to_tensor(emb)
1727
+ vec = emb.detach().cast(paddle.float32)
1728
+ embedding = Embedding(vec, name)
1729
+ embedding.step = data.get("step", None)
1730
+ embedding.sd_checkpoint = data.get("sd_checkpoint", None)
1731
+ embedding.sd_checkpoint_name = data.get("sd_checkpoint_name", None)
1732
+ embedding.vectors = vec.shape[0]
1733
+ embedding.shape = vec.shape[-1]
1734
+ embedding.filename = path
1735
+
1736
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
1737
+ self.register_embedding(embedding, self.clip)
1738
+ else:
1739
+ self.skipped_embeddings[name] = embedding
1740
+
1741
+ def load_from_dir(self, embdir):
1742
+ if not os.path.isdir(embdir.path):
1743
+ return
1744
+
1745
+ for root, dirs, fns in os.walk(embdir.path, followlinks=True):
1746
+ for fn in fns:
1747
+ try:
1748
+ fullfn = os.path.join(root, fn)
1749
+
1750
+ if os.stat(fullfn).st_size == 0:
1751
+ continue
1752
+
1753
+ self.load_from_file(fullfn, fn)
1754
+ except Exception:
1755
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
1756
+ print(traceback.format_exc(), file=sys.stderr)
1757
+ continue
1758
+
1759
+ def load_textual_inversion_embeddings(self, force_reload=False):
1760
+ if not force_reload:
1761
+ need_reload = False
1762
+ for path, embdir in self.embedding_dirs.items():
1763
+ if embdir.has_changed():
1764
+ need_reload = True
1765
+ break
1766
+
1767
+ if not need_reload:
1768
+ return
1769
+
1770
+ self.ids_lookup.clear()
1771
+ self.word_embeddings.clear()
1772
+ self.skipped_embeddings.clear()
1773
+ self.expected_shape = self.get_expected_shape()
1774
+
1775
+ for path, embdir in self.embedding_dirs.items():
1776
+ self.load_from_dir(embdir)
1777
+ embdir.update()
1778
+
1779
+ displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
1780
+ if self.previously_displayed_embeddings != displayed_embeddings:
1781
+ self.previously_displayed_embeddings = displayed_embeddings
1782
+ print(
1783
+ f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}"
1784
+ )
1785
+ if len(self.skipped_embeddings) > 0:
1786
+ print(
1787
+ f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}"
1788
+ )
1789
+
1790
+ def find_embedding_at_position(self, tokens, offset):
1791
+ token = tokens[offset]
1792
+ possible_matches = self.ids_lookup.get(token, None)
1793
+
1794
+ if possible_matches is None:
1795
+ return None, None
1796
+
1797
+ for ids, embedding in possible_matches:
1798
+ if tokens[offset : offset + len(ids)] == ids:
1799
+ return embedding, len(ids)
1800
+
1801
+ return None, None