Lightxr nyanko7 commited on
Commit
6fd31c7
·
0 Parent(s):

Duplicate from nyanko7/sd-diffusers-webui

Browse files

Co-authored-by: Nyanko <[email protected]>

Files changed (8) hide show
  1. .gitattributes +34 -0
  2. Dockerfile +22 -0
  3. README.md +14 -0
  4. app.py +878 -0
  5. modules/lora.py +183 -0
  6. modules/model.py +897 -0
  7. modules/prompt_parser.py +391 -0
  8. modules/safe.py +188 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile Public T4
2
+
3
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
4
+ ENV DEBIAN_FRONTEND noninteractive
5
+
6
+ WORKDIR /content
7
+
8
+ RUN apt-get update -y && apt-get upgrade -y && apt-get install -y libgl1 libglib2.0-0 wget git git-lfs python3-pip python-is-python3 && pip3 install --upgrade pip
9
+
10
+ RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchsde --extra-index-url https://download.pytorch.org/whl/cu113
11
+ RUN pip install https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.16/xformers-0.0.16+814314d.d20230118-cp310-cp310-linux_x86_64.whl
12
+ RUN pip install --pre triton
13
+ RUN pip install numexpr einops transformers k_diffusion safetensors gradio diffusers==0.12.1
14
+
15
+ ADD . .
16
+ RUN adduser --disabled-password --gecos '' user
17
+ RUN chown -R user:user /content
18
+ RUN chmod -R 777 /content
19
+ USER user
20
+
21
+ EXPOSE 7860
22
+ CMD python /content/app.py
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sd Diffusers Webui
3
+ emoji: 🐳
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ sdk_version: 3.9
8
+ pinned: false
9
+ license: openrail
10
+ app_port: 7860
11
+ duplicated_from: nyanko7/sd-diffusers-webui
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import math
8
+ import re
9
+
10
+ from gradio import inputs
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ DDIMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from modules.model import (
17
+ CrossAttnProcessor,
18
+ StableDiffusionPipeline,
19
+ )
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+ from PIL import Image
23
+ from pathlib import Path
24
+ from safetensors.torch import load_file
25
+ import modules.safe as _
26
+ from modules.lora import LoRANetwork
27
+
28
+ models = [
29
+ ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
30
+ ("Pastal Mix", "andite/pastel-mix", 2),
31
+ ("Basil Mix", "nuigurumi/basil_mix", 2)
32
+ ]
33
+
34
+ keep_vram = ["Korakoe/AbyssOrangeMix2-HF", "andite/pastel-mix"]
35
+ base_name, base_model, clip_skip = models[0]
36
+
37
+ samplers_k_diffusion = [
38
+ ("Euler a", "sample_euler_ancestral", {}),
39
+ ("Euler", "sample_euler", {}),
40
+ ("LMS", "sample_lms", {}),
41
+ ("Heun", "sample_heun", {}),
42
+ ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
43
+ ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
44
+ ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
45
+ ("DPM++ 2M", "sample_dpmpp_2m", {}),
46
+ ("DPM++ SDE", "sample_dpmpp_sde", {}),
47
+ ("LMS Karras", "sample_lms", {"scheduler": "karras"}),
48
+ ("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
49
+ ("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
50
+ ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
51
+ ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
52
+ ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
53
+ ]
54
+
55
+ # samplers_diffusers = [
56
+ # ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {})
57
+ # ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {})
58
+ # ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {})
59
+ # ]
60
+
61
+ start_time = time.time()
62
+ timeout = 90
63
+
64
+ scheduler = DDIMScheduler.from_pretrained(
65
+ base_model,
66
+ subfolder="scheduler",
67
+ )
68
+ vae = AutoencoderKL.from_pretrained(
69
+ "stabilityai/sd-vae-ft-ema",
70
+ torch_dtype=torch.float16
71
+ )
72
+ text_encoder = CLIPTextModel.from_pretrained(
73
+ base_model,
74
+ subfolder="text_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ tokenizer = CLIPTokenizer.from_pretrained(
78
+ base_model,
79
+ subfolder="tokenizer",
80
+ torch_dtype=torch.float16,
81
+ )
82
+ unet = UNet2DConditionModel.from_pretrained(
83
+ base_model,
84
+ subfolder="unet",
85
+ torch_dtype=torch.float16,
86
+ )
87
+ pipe = StableDiffusionPipeline(
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ unet=unet,
91
+ vae=vae,
92
+ scheduler=scheduler,
93
+ )
94
+
95
+ unet.set_attn_processor(CrossAttnProcessor)
96
+ pipe.setup_text_encoder(clip_skip, text_encoder)
97
+ if torch.cuda.is_available():
98
+ pipe = pipe.to("cuda")
99
+
100
+ def get_model_list():
101
+ return models
102
+
103
+ te_cache = {
104
+ base_model: text_encoder
105
+ }
106
+
107
+ unet_cache = {
108
+ base_model: unet
109
+ }
110
+
111
+ lora_cache = {
112
+ base_model: LoRANetwork(text_encoder, unet)
113
+ }
114
+
115
+ te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
116
+ original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
117
+ current_model = base_model
118
+
119
+ def setup_model(name, lora_state=None, lora_scale=1.0):
120
+ global pipe, current_model
121
+
122
+ keys = [k[0] for k in models]
123
+ model = models[keys.index(name)][1]
124
+ if model not in unet_cache:
125
+ unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
126
+ text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
127
+
128
+ unet_cache[model] = unet
129
+ te_cache[model] = text_encoder
130
+ lora_cache[model] = LoRANetwork(text_encoder, unet)
131
+
132
+ if current_model != model:
133
+ if current_model not in keep_vram:
134
+ # offload current model
135
+ unet_cache[current_model].to("cpu")
136
+ te_cache[current_model].to("cpu")
137
+ lora_cache[current_model].to("cpu")
138
+ current_model = model
139
+
140
+ local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
141
+ local_unet.set_attn_processor(CrossAttnProcessor())
142
+ local_lora.reset()
143
+ clip_skip = models[keys.index(name)][2]
144
+
145
+ if torch.cuda.is_available():
146
+ local_unet.to("cuda")
147
+ local_te.to("cuda")
148
+
149
+ if lora_state is not None and lora_state != "":
150
+ local_lora.load(lora_state, lora_scale)
151
+ local_lora.to(local_unet.device, dtype=local_unet.dtype)
152
+
153
+ pipe.text_encoder, pipe.unet = local_te, local_unet
154
+ pipe.setup_unet(local_unet)
155
+ pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
156
+ pipe.tokenizer.added_tokens_encoder = {}
157
+ pipe.tokenizer.added_tokens_decoder = {}
158
+ pipe.setup_text_encoder(clip_skip, local_te)
159
+ return pipe
160
+
161
+
162
+ def error_str(error, title="Error"):
163
+ return (
164
+ f"""#### {title}
165
+ {error}"""
166
+ if error
167
+ else ""
168
+ )
169
+
170
+ def make_token_names(embs):
171
+ all_tokens = []
172
+ for name, vec in embs.items():
173
+ tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
174
+ all_tokens.append(tokens)
175
+ return all_tokens
176
+
177
+ def setup_tokenizer(tokenizer, embs):
178
+ reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
179
+ clip_keywords = [' '.join(s) for s in make_token_names(embs)]
180
+
181
+ def parse_prompt(prompt: str):
182
+ for m, v in zip(reg_match, clip_keywords):
183
+ prompt = m.sub(v, prompt)
184
+ return prompt
185
+
186
+ def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs):
187
+ text = parse_prompt(text)
188
+ r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs)
189
+ return r
190
+ tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer)
191
+ return [t for sublist in make_token_names(embs) for t in sublist]
192
+
193
+
194
+ def convert_size(size_bytes):
195
+ if size_bytes == 0:
196
+ return "0B"
197
+ size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
198
+ i = int(math.floor(math.log(size_bytes, 1024)))
199
+ p = math.pow(1024, i)
200
+ s = round(size_bytes / p, 2)
201
+ return "%s %s" % (s, size_name[i])
202
+
203
+ def inference(
204
+ prompt,
205
+ guidance,
206
+ steps,
207
+ width=512,
208
+ height=512,
209
+ seed=0,
210
+ neg_prompt="",
211
+ state=None,
212
+ g_strength=0.4,
213
+ img_input=None,
214
+ i2i_scale=0.5,
215
+ hr_enabled=False,
216
+ hr_method="Latent",
217
+ hr_scale=1.5,
218
+ hr_denoise=0.8,
219
+ sampler="DPM++ 2M Karras",
220
+ embs=None,
221
+ model=None,
222
+ lora_state=None,
223
+ lora_scale=None,
224
+ ):
225
+ if seed is None or seed == 0:
226
+ seed = random.randint(0, 2147483647)
227
+
228
+ pipe = setup_model(model, lora_state, lora_scale)
229
+ generator = torch.Generator("cuda").manual_seed(int(seed))
230
+ start_time = time.time()
231
+
232
+ sampler_name, sampler_opt = None, None
233
+ for label, funcname, options in samplers_k_diffusion:
234
+ if label == sampler:
235
+ sampler_name, sampler_opt = funcname, options
236
+
237
+ tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder
238
+ if embs is not None and len(embs) > 0:
239
+ ti_embs = {}
240
+ for name, file in embs.items():
241
+ if str(file).endswith(".pt"):
242
+ loaded_learned_embeds = torch.load(file, map_location="cpu")
243
+ else:
244
+ loaded_learned_embeds = load_file(file, device="cpu")
245
+ loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed
246
+ ti_embs[name] = loaded_learned_embeds
247
+
248
+ if len(ti_embs) > 0:
249
+ tokens = setup_tokenizer(tokenizer, ti_embs)
250
+ added_tokens = tokenizer.add_tokens(tokens)
251
+ delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
252
+
253
+ assert added_tokens == delta_weight.shape[0]
254
+ text_encoder.resize_token_embeddings(len(tokenizer))
255
+ token_embeds = text_encoder.get_input_embeddings().weight.data
256
+ token_embeds[-delta_weight.shape[0]:] = delta_weight
257
+
258
+ config = {
259
+ "negative_prompt": neg_prompt,
260
+ "num_inference_steps": int(steps),
261
+ "guidance_scale": guidance,
262
+ "generator": generator,
263
+ "sampler_name": sampler_name,
264
+ "sampler_opt": sampler_opt,
265
+ "pww_state": state,
266
+ "pww_attn_weight": g_strength,
267
+ "start_time": start_time,
268
+ "timeout": timeout,
269
+ }
270
+
271
+ if img_input is not None:
272
+ ratio = min(height / img_input.height, width / img_input.width)
273
+ img_input = img_input.resize(
274
+ (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
275
+ )
276
+ result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
277
+ elif hr_enabled:
278
+ result = pipe.txt2img(
279
+ prompt,
280
+ width=width,
281
+ height=height,
282
+ upscale=True,
283
+ upscale_x=hr_scale,
284
+ upscale_denoising_strength=hr_denoise,
285
+ **config,
286
+ **latent_upscale_modes[hr_method],
287
+ )
288
+ else:
289
+ result = pipe.txt2img(prompt, width=width, height=height, **config)
290
+
291
+ end_time = time.time()
292
+ vram_free, vram_total = torch.cuda.mem_get_info()
293
+ print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
294
+ return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
295
+
296
+
297
+ color_list = []
298
+
299
+
300
+ def get_color(n):
301
+ for _ in range(n - len(color_list)):
302
+ color_list.append(tuple(np.random.random(size=3) * 256))
303
+ return color_list
304
+
305
+
306
+ def create_mixed_img(current, state, w=512, h=512):
307
+ w, h = int(w), int(h)
308
+ image_np = np.full([h, w, 4], 255)
309
+ if state is None:
310
+ state = {}
311
+
312
+ colors = get_color(len(state))
313
+ idx = 0
314
+
315
+ for key, item in state.items():
316
+ if item["map"] is not None:
317
+ m = item["map"] < 255
318
+ alpha = 150
319
+ if current == key:
320
+ alpha = 200
321
+ image_np[m] = colors[idx] + (alpha,)
322
+ idx += 1
323
+
324
+ return image_np
325
+
326
+
327
+ # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
328
+ def apply_new_res(w, h, state):
329
+ w, h = int(w), int(h)
330
+
331
+ for key, item in state.items():
332
+ if item["map"] is not None:
333
+ item["map"] = resize(item["map"], w, h)
334
+
335
+ update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
336
+ return state, update_img
337
+
338
+
339
+ def detect_text(text, state, width, height):
340
+
341
+ if text is None or text == "":
342
+ return None, None, gr.Radio.update(value=None), None
343
+
344
+ t = text.split(",")
345
+ new_state = {}
346
+
347
+ for item in t:
348
+ item = item.strip()
349
+ if item == "":
350
+ continue
351
+ if state is not None and item in state:
352
+ new_state[item] = {
353
+ "map": state[item]["map"],
354
+ "weight": state[item]["weight"],
355
+ "mask_outsides": state[item]["mask_outsides"],
356
+ }
357
+ else:
358
+ new_state[item] = {
359
+ "map": None,
360
+ "weight": 0.5,
361
+ "mask_outsides": False
362
+ }
363
+ update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
364
+ update_img = gr.update(value=create_mixed_img("", new_state, width, height))
365
+ update_sketch = gr.update(value=None, interactive=False)
366
+ return new_state, update_sketch, update, update_img
367
+
368
+
369
+ def resize(img, w, h):
370
+ trs = transforms.Compose(
371
+ [
372
+ transforms.ToPILImage(),
373
+ transforms.Resize(min(h, w)),
374
+ transforms.CenterCrop((h, w)),
375
+ ]
376
+ )
377
+ result = np.array(trs(img), dtype=np.uint8)
378
+ return result
379
+
380
+
381
+ def switch_canvas(entry, state, width, height):
382
+ if entry == None:
383
+ return None, 0.5, False, create_mixed_img("", state, width, height)
384
+
385
+ return (
386
+ gr.update(value=None, interactive=True),
387
+ gr.update(value=state[entry]["weight"] if entry in state else 0.5),
388
+ gr.update(value=state[entry]["mask_outsides"] if entry in state else False),
389
+ create_mixed_img(entry, state, width, height),
390
+ )
391
+
392
+
393
+ def apply_canvas(selected, draw, state, w, h):
394
+ if selected in state:
395
+ w, h = int(w), int(h)
396
+ state[selected]["map"] = resize(draw, w, h)
397
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
398
+
399
+
400
+ def apply_weight(selected, weight, state):
401
+ if selected in state:
402
+ state[selected]["weight"] = weight
403
+ return state
404
+
405
+
406
+ def apply_option(selected, mask, state):
407
+ if selected in state:
408
+ state[selected]["mask_outsides"] = mask
409
+ return state
410
+
411
+
412
+ # sp2, radio, width, height, global_stats
413
+ def apply_image(image, selected, w, h, strgength, mask, state):
414
+ if selected in state:
415
+ state[selected] = {
416
+ "map": resize(image, w, h),
417
+ "weight": strgength,
418
+ "mask_outsides": mask
419
+ }
420
+
421
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
422
+
423
+
424
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
425
+ def add_net(files, ti_state, lora_state):
426
+ if files is None:
427
+ return ti_state, "", lora_state, None
428
+
429
+ for file in files:
430
+ item = Path(file.name)
431
+ stripedname = str(item.stem).strip()
432
+ if item.suffix == ".pt":
433
+ state_dict = torch.load(file.name, map_location="cpu")
434
+ else:
435
+ state_dict = load_file(file.name, device="cpu")
436
+ if any("lora" in k for k in state_dict.keys()):
437
+ lora_state = file.name
438
+ else:
439
+ ti_state[stripedname] = file.name
440
+
441
+ return (
442
+ ti_state,
443
+ lora_state,
444
+ gr.Text.update(f"{[key for key in ti_state.keys()]}"),
445
+ gr.Text.update(f"{lora_state}"),
446
+ gr.Files.update(value=None),
447
+ )
448
+
449
+
450
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
451
+ def clean_states(ti_state, lora_state):
452
+ return (
453
+ dict(),
454
+ None,
455
+ gr.Text.update(f""),
456
+ gr.Text.update(f""),
457
+ gr.File.update(value=None),
458
+ )
459
+
460
+
461
+ latent_upscale_modes = {
462
+ "Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
463
+ "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
464
+ "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
465
+ "Latent (bicubic antialiased)": {
466
+ "upscale_method": "bicubic",
467
+ "upscale_antialias": True,
468
+ },
469
+ "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
470
+ "Latent (nearest-exact)": {
471
+ "upscale_method": "nearest-exact",
472
+ "upscale_antialias": False,
473
+ },
474
+ }
475
+
476
+ css = """
477
+ .finetuned-diffusion-div div{
478
+ display:inline-flex;
479
+ align-items:center;
480
+ gap:.8rem;
481
+ font-size:1.75rem;
482
+ padding-top:2rem;
483
+ }
484
+ .finetuned-diffusion-div div h1{
485
+ font-weight:900;
486
+ margin-bottom:7px
487
+ }
488
+ .finetuned-diffusion-div p{
489
+ margin-bottom:10px;
490
+ font-size:94%
491
+ }
492
+ .box {
493
+ float: left;
494
+ height: 20px;
495
+ width: 20px;
496
+ margin-bottom: 15px;
497
+ border: 1px solid black;
498
+ clear: both;
499
+ }
500
+ a{
501
+ text-decoration:underline
502
+ }
503
+ .tabs{
504
+ margin-top:0;
505
+ margin-bottom:0
506
+ }
507
+ #gallery{
508
+ min-height:20rem
509
+ }
510
+ .no-border {
511
+ border: none !important;
512
+ }
513
+ """
514
+ with gr.Blocks(css=css) as demo:
515
+ gr.HTML(
516
+ f"""
517
+ <div class="finetuned-diffusion-div">
518
+ <div>
519
+ <h1>Demo for diffusion models</h1>
520
+ </div>
521
+ <p>Hso @ nyanko.sketch2img.gradio</p>
522
+ </div>
523
+ """
524
+ )
525
+ global_stats = gr.State(value={})
526
+
527
+ with gr.Row():
528
+
529
+ with gr.Column(scale=55):
530
+ model = gr.Dropdown(
531
+ choices=[k[0] for k in get_model_list()],
532
+ label="Model",
533
+ value=base_name,
534
+ )
535
+ image_out = gr.Image(height=512)
536
+ # gallery = gr.Gallery(
537
+ # label="Generated images", show_label=False, elem_id="gallery"
538
+ # ).style(grid=[1], height="auto")
539
+
540
+ with gr.Column(scale=45):
541
+
542
+ with gr.Group():
543
+
544
+ with gr.Row():
545
+ with gr.Column(scale=70):
546
+
547
+ prompt = gr.Textbox(
548
+ label="Prompt",
549
+ value="loli cat girl, blue eyes, flat chest, solo, long messy silver hair, blue capelet, cat ears, cat tail, upper body",
550
+ show_label=True,
551
+ max_lines=4,
552
+ placeholder="Enter prompt.",
553
+ )
554
+ neg_prompt = gr.Textbox(
555
+ label="Negative Prompt",
556
+ value="bad quality, low quality, jpeg artifact, cropped",
557
+ show_label=True,
558
+ max_lines=4,
559
+ placeholder="Enter negative prompt.",
560
+ )
561
+
562
+ generate = gr.Button(value="Generate").style(
563
+ rounded=(False, True, True, False)
564
+ )
565
+
566
+ with gr.Tab("Options"):
567
+
568
+ with gr.Group():
569
+
570
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
571
+ with gr.Row():
572
+ guidance = gr.Slider(
573
+ label="Guidance scale", value=7.5, maximum=15
574
+ )
575
+ steps = gr.Slider(
576
+ label="Steps", value=25, minimum=2, maximum=50, step=1
577
+ )
578
+
579
+ with gr.Row():
580
+ width = gr.Slider(
581
+ label="Width", value=512, minimum=64, maximum=768, step=64
582
+ )
583
+ height = gr.Slider(
584
+ label="Height", value=512, minimum=64, maximum=768, step=64
585
+ )
586
+
587
+ sampler = gr.Dropdown(
588
+ value="DPM++ 2M Karras",
589
+ label="Sampler",
590
+ choices=[s[0] for s in samplers_k_diffusion],
591
+ )
592
+ seed = gr.Number(label="Seed (0 = random)", value=0)
593
+
594
+ with gr.Tab("Image to image"):
595
+ with gr.Group():
596
+
597
+ inf_image = gr.Image(
598
+ label="Image", height=256, tool="editor", type="pil"
599
+ )
600
+ inf_strength = gr.Slider(
601
+ label="Transformation strength",
602
+ minimum=0,
603
+ maximum=1,
604
+ step=0.01,
605
+ value=0.5,
606
+ )
607
+
608
+ def res_cap(g, w, h, x):
609
+ if g:
610
+ return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
611
+ else:
612
+ return "Enable upscaler"
613
+
614
+ with gr.Tab("Hires fix"):
615
+ with gr.Group():
616
+
617
+ hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
618
+ hr_method = gr.Dropdown(
619
+ [key for key in latent_upscale_modes.keys()],
620
+ value="Latent",
621
+ label="Upscale method",
622
+ )
623
+ hr_scale = gr.Slider(
624
+ label="Upscale factor",
625
+ minimum=1.0,
626
+ maximum=1.5,
627
+ step=0.1,
628
+ value=1.2,
629
+ )
630
+ hr_denoise = gr.Slider(
631
+ label="Denoising strength",
632
+ minimum=0.0,
633
+ maximum=1.0,
634
+ step=0.1,
635
+ value=0.8,
636
+ )
637
+
638
+ hr_scale.change(
639
+ lambda g, x, w, h: gr.Checkbox.update(
640
+ label=res_cap(g, w, h, x)
641
+ ),
642
+ inputs=[hr_enabled, hr_scale, width, height],
643
+ outputs=hr_enabled,
644
+ queue=False,
645
+ )
646
+ hr_enabled.change(
647
+ lambda g, x, w, h: gr.Checkbox.update(
648
+ label=res_cap(g, w, h, x)
649
+ ),
650
+ inputs=[hr_enabled, hr_scale, width, height],
651
+ outputs=hr_enabled,
652
+ queue=False,
653
+ )
654
+
655
+ with gr.Tab("Embeddings/Loras"):
656
+
657
+ ti_state = gr.State(dict())
658
+ lora_state = gr.State()
659
+
660
+ with gr.Group():
661
+ with gr.Row():
662
+ with gr.Column(scale=90):
663
+ ti_vals = gr.Text(label="Loaded embeddings")
664
+
665
+ with gr.Row():
666
+ with gr.Column(scale=90):
667
+ lora_vals = gr.Text(label="Loaded loras")
668
+
669
+ with gr.Row():
670
+
671
+ uploads = gr.Files(label="Upload new embeddings/lora")
672
+
673
+ with gr.Column():
674
+ lora_scale = gr.Slider(
675
+ label="Lora scale",
676
+ minimum=0,
677
+ maximum=2,
678
+ step=0.01,
679
+ value=1.0,
680
+ )
681
+ btn = gr.Button(value="Upload")
682
+ btn_del = gr.Button(value="Reset")
683
+
684
+ btn.click(
685
+ add_net,
686
+ inputs=[uploads, ti_state, lora_state],
687
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
688
+ queue=False,
689
+ )
690
+ btn_del.click(
691
+ clean_states,
692
+ inputs=[ti_state, lora_state],
693
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
694
+ queue=False,
695
+ )
696
+
697
+ # error_output = gr.Markdown()
698
+
699
+ gr.HTML(
700
+ f"""
701
+ <div class="finetuned-diffusion-div">
702
+ <div>
703
+ <h1>Paint with words</h1>
704
+ </div>
705
+ <p>
706
+ Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
707
+ </p>
708
+ </div>
709
+ """
710
+ )
711
+
712
+ with gr.Row():
713
+
714
+ with gr.Column(scale=55):
715
+
716
+ rendered = gr.Image(
717
+ invert_colors=True,
718
+ source="canvas",
719
+ interactive=False,
720
+ image_mode="RGBA",
721
+ )
722
+
723
+ with gr.Column(scale=45):
724
+
725
+ with gr.Group():
726
+ with gr.Row():
727
+ with gr.Column(scale=70):
728
+ g_strength = gr.Slider(
729
+ label="Weight scaling",
730
+ minimum=0,
731
+ maximum=0.8,
732
+ step=0.01,
733
+ value=0.4,
734
+ )
735
+
736
+ text = gr.Textbox(
737
+ lines=2,
738
+ interactive=True,
739
+ label="Token to Draw: (Separate by comma)",
740
+ )
741
+
742
+ radio = gr.Radio([], label="Tokens")
743
+
744
+ sk_update = gr.Button(value="Update").style(
745
+ rounded=(False, True, True, False)
746
+ )
747
+
748
+ # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
749
+
750
+ with gr.Tab("SketchPad"):
751
+
752
+ sp = gr.Image(
753
+ image_mode="L",
754
+ tool="sketch",
755
+ source="canvas",
756
+ interactive=False,
757
+ )
758
+
759
+ mask_outsides = gr.Checkbox(
760
+ label="Mask other areas",
761
+ value=False
762
+ )
763
+
764
+ strength = gr.Slider(
765
+ label="Token strength",
766
+ minimum=0,
767
+ maximum=0.8,
768
+ step=0.01,
769
+ value=0.5,
770
+ )
771
+
772
+
773
+ sk_update.click(
774
+ detect_text,
775
+ inputs=[text, global_stats, width, height],
776
+ outputs=[global_stats, sp, radio, rendered],
777
+ queue=False,
778
+ )
779
+ radio.change(
780
+ switch_canvas,
781
+ inputs=[radio, global_stats, width, height],
782
+ outputs=[sp, strength, mask_outsides, rendered],
783
+ queue=False,
784
+ )
785
+ sp.edit(
786
+ apply_canvas,
787
+ inputs=[radio, sp, global_stats, width, height],
788
+ outputs=[global_stats, rendered],
789
+ queue=False,
790
+ )
791
+ strength.change(
792
+ apply_weight,
793
+ inputs=[radio, strength, global_stats],
794
+ outputs=[global_stats],
795
+ queue=False,
796
+ )
797
+ mask_outsides.change(
798
+ apply_option,
799
+ inputs=[radio, mask_outsides, global_stats],
800
+ outputs=[global_stats],
801
+ queue=False,
802
+ )
803
+
804
+ with gr.Tab("UploadFile"):
805
+
806
+ sp2 = gr.Image(
807
+ image_mode="L",
808
+ source="upload",
809
+ shape=(512, 512),
810
+ )
811
+
812
+ mask_outsides2 = gr.Checkbox(
813
+ label="Mask other areas",
814
+ value=False,
815
+ )
816
+
817
+ strength2 = gr.Slider(
818
+ label="Token strength",
819
+ minimum=0,
820
+ maximum=0.8,
821
+ step=0.01,
822
+ value=0.5,
823
+ )
824
+
825
+ apply_style = gr.Button(value="Apply")
826
+ apply_style.click(
827
+ apply_image,
828
+ inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats],
829
+ outputs=[global_stats, rendered],
830
+ queue=False,
831
+ )
832
+
833
+ width.change(
834
+ apply_new_res,
835
+ inputs=[width, height, global_stats],
836
+ outputs=[global_stats, rendered],
837
+ queue=False,
838
+ )
839
+ height.change(
840
+ apply_new_res,
841
+ inputs=[width, height, global_stats],
842
+ outputs=[global_stats, rendered],
843
+ queue=False,
844
+ )
845
+
846
+ # color_stats = gr.State(value={})
847
+ # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
848
+ # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
849
+
850
+ inputs = [
851
+ prompt,
852
+ guidance,
853
+ steps,
854
+ width,
855
+ height,
856
+ seed,
857
+ neg_prompt,
858
+ global_stats,
859
+ g_strength,
860
+ inf_image,
861
+ inf_strength,
862
+ hr_enabled,
863
+ hr_method,
864
+ hr_scale,
865
+ hr_denoise,
866
+ sampler,
867
+ ti_state,
868
+ model,
869
+ lora_state,
870
+ lora_scale,
871
+ ]
872
+ outputs = [image_out]
873
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
874
+ generate.click(inference, inputs=inputs, outputs=outputs)
875
+
876
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
877
+ # demo.launch(share=True)
878
+ demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
modules/lora.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ import modules.safe as _
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ lora_name,
22
+ org_module: torch.nn.Module,
23
+ multiplier=1.0,
24
+ lora_dim=4,
25
+ alpha=1,
26
+ ):
27
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
28
+ super().__init__()
29
+ self.lora_name = lora_name
30
+ self.lora_dim = lora_dim
31
+
32
+ if org_module.__class__.__name__ == "Conv2d":
33
+ in_dim = org_module.in_channels
34
+ out_dim = org_module.out_channels
35
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
36
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
37
+ else:
38
+ in_dim = org_module.in_features
39
+ out_dim = org_module.out_features
40
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
41
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
42
+
43
+ if type(alpha) == torch.Tensor:
44
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
45
+
46
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
47
+ self.scale = alpha / self.lora_dim
48
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
49
+
50
+ # same as microsoft's
51
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
52
+ torch.nn.init.zeros_(self.lora_up.weight)
53
+
54
+ self.multiplier = multiplier
55
+ self.org_module = org_module # remove in applying
56
+ self.enable = False
57
+
58
+ def resize(self, rank, alpha, multiplier):
59
+ self.alpha = torch.tensor(alpha)
60
+ self.multiplier = multiplier
61
+ self.scale = alpha / rank
62
+ if self.lora_down.__class__.__name__ == "Conv2d":
63
+ in_dim = self.lora_down.in_channels
64
+ out_dim = self.lora_up.out_channels
65
+ self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
66
+ self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
67
+ else:
68
+ in_dim = self.lora_down.in_features
69
+ out_dim = self.lora_up.out_features
70
+ self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
71
+ self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
72
+
73
+ def apply(self):
74
+ if hasattr(self, "org_module"):
75
+ self.org_forward = self.org_module.forward
76
+ self.org_module.forward = self.forward
77
+ del self.org_module
78
+
79
+ def forward(self, x):
80
+ if self.enable:
81
+ return (
82
+ self.org_forward(x)
83
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
84
+ )
85
+ return self.org_forward(x)
86
+
87
+
88
+ class LoRANetwork(torch.nn.Module):
89
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
90
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
91
+ LORA_PREFIX_UNET = "lora_unet"
92
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
93
+
94
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
95
+ super().__init__()
96
+ self.multiplier = multiplier
97
+ self.lora_dim = lora_dim
98
+ self.alpha = alpha
99
+
100
+ # create module instances
101
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
102
+ loras = []
103
+ for name, module in root_module.named_modules():
104
+ if module.__class__.__name__ in target_replace_modules:
105
+ for child_name, child_module in module.named_modules():
106
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
107
+ lora_name = prefix + "." + name + "." + child_name
108
+ lora_name = lora_name.replace(".", "_")
109
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
110
+ loras.append(lora)
111
+ return loras
112
+
113
+ if isinstance(text_encoder, list):
114
+ self.text_encoder_loras = text_encoder
115
+ else:
116
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
117
+ print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
118
+
119
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
120
+ print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
121
+
122
+ self.weights_sd = None
123
+
124
+ # assertion
125
+ names = set()
126
+ for lora in self.text_encoder_loras + self.unet_loras:
127
+ assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
128
+ names.add(lora.lora_name)
129
+
130
+ lora.apply()
131
+ self.add_module(lora.lora_name, lora)
132
+
133
+ def reset(self):
134
+ for lora in self.text_encoder_loras + self.unet_loras:
135
+ lora.enable = False
136
+
137
+ def load(self, file, scale):
138
+
139
+ weights = None
140
+ if os.path.splitext(file)[1] == ".safetensors":
141
+ weights = load_file(file)
142
+ else:
143
+ weights = torch.load(file, map_location="cpu")
144
+
145
+ if not weights:
146
+ return
147
+
148
+ network_alpha = None
149
+ network_dim = None
150
+ for key, value in weights.items():
151
+ if network_alpha is None and "alpha" in key:
152
+ network_alpha = value
153
+ if network_dim is None and "lora_down" in key and len(value.size()) == 2:
154
+ network_dim = value.size()[0]
155
+
156
+ if network_alpha is None:
157
+ network_alpha = network_dim
158
+
159
+ weights_has_text_encoder = weights_has_unet = False
160
+ weights_to_modify = []
161
+
162
+ for key in weights.keys():
163
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
164
+ weights_has_text_encoder = True
165
+
166
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
167
+ weights_has_unet = True
168
+
169
+ if weights_has_text_encoder:
170
+ weights_to_modify += self.text_encoder_loras
171
+
172
+ if weights_has_unet:
173
+ weights_to_modify += self.unet_loras
174
+
175
+ for lora in self.text_encoder_loras + self.unet_loras:
176
+ lora.resize(network_dim, network_alpha, scale)
177
+ if lora in weights_to_modify:
178
+ lora.enable = True
179
+
180
+ info = self.load_state_dict(weights, False)
181
+ if len(info.unexpected_keys) > 0:
182
+ print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
183
+
modules/model.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import math
4
+ from pathlib import Path
5
+ import re
6
+ from collections import defaultdict
7
+ from typing import List, Optional, Union
8
+
9
+ import time
10
+ import k_diffusion
11
+ import numpy as np
12
+ import PIL
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange
17
+ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
18
+ from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords
19
+ from torch import einsum
20
+ from torch.autograd.function import Function
21
+
22
+ from diffusers import DiffusionPipeline
23
+ from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
24
+ from diffusers.utils import logging, randn_tensor
25
+
26
+ import modules.safe as _
27
+ from safetensors.torch import load_file
28
+
29
+ xformers_available = False
30
+ try:
31
+ import xformers
32
+
33
+ xformers_available = True
34
+ except ImportError:
35
+ pass
36
+
37
+ EPSILON = 1e-6
38
+ exists = lambda val: val is not None
39
+ default = lambda val, d: val if exists(val) else d
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def get_attention_scores(attn, query, key, attention_mask=None):
44
+
45
+ if attn.upcast_attention:
46
+ query = query.float()
47
+ key = key.float()
48
+
49
+ attention_scores = torch.baddbmm(
50
+ torch.empty(
51
+ query.shape[0],
52
+ query.shape[1],
53
+ key.shape[1],
54
+ dtype=query.dtype,
55
+ device=query.device,
56
+ ),
57
+ query,
58
+ key.transpose(-1, -2),
59
+ beta=0,
60
+ alpha=attn.scale,
61
+ )
62
+
63
+ if attention_mask is not None:
64
+ attention_scores = attention_scores + attention_mask
65
+
66
+ if attn.upcast_softmax:
67
+ attention_scores = attention_scores.float()
68
+
69
+ return attention_scores
70
+
71
+
72
+ class CrossAttnProcessor(nn.Module):
73
+ def __call__(
74
+ self,
75
+ attn,
76
+ hidden_states,
77
+ encoder_hidden_states=None,
78
+ attention_mask=None,
79
+ ):
80
+ batch_size, sequence_length, _ = hidden_states.shape
81
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
82
+
83
+ encoder_states = hidden_states
84
+ is_xattn = False
85
+ if encoder_hidden_states is not None:
86
+ is_xattn = True
87
+ img_state = encoder_hidden_states["img_state"]
88
+ encoder_states = encoder_hidden_states["states"]
89
+ weight_func = encoder_hidden_states["weight_func"]
90
+ sigma = encoder_hidden_states["sigma"]
91
+
92
+ query = attn.to_q(hidden_states)
93
+ key = attn.to_k(encoder_states)
94
+ value = attn.to_v(encoder_states)
95
+
96
+ query = attn.head_to_batch_dim(query)
97
+ key = attn.head_to_batch_dim(key)
98
+ value = attn.head_to_batch_dim(value)
99
+
100
+ if is_xattn and isinstance(img_state, dict):
101
+ # use torch.baddbmm method (slow)
102
+ attention_scores = get_attention_scores(attn, query, key, attention_mask)
103
+ w = img_state[sequence_length].to(query.device)
104
+ cross_attention_weight = weight_func(w, sigma, attention_scores)
105
+ attention_scores += torch.repeat_interleave(
106
+ cross_attention_weight, repeats=attn.heads, dim=0
107
+ )
108
+
109
+ # calc probs
110
+ attention_probs = attention_scores.softmax(dim=-1)
111
+ attention_probs = attention_probs.to(query.dtype)
112
+ hidden_states = torch.bmm(attention_probs, value)
113
+
114
+ elif xformers_available:
115
+ hidden_states = xformers.ops.memory_efficient_attention(
116
+ query.contiguous(),
117
+ key.contiguous(),
118
+ value.contiguous(),
119
+ attn_bias=attention_mask,
120
+ )
121
+ hidden_states = hidden_states.to(query.dtype)
122
+
123
+ else:
124
+ q_bucket_size = 512
125
+ k_bucket_size = 1024
126
+
127
+ # use flash-attention
128
+ hidden_states = FlashAttentionFunction.apply(
129
+ query.contiguous(),
130
+ key.contiguous(),
131
+ value.contiguous(),
132
+ attention_mask,
133
+ False,
134
+ q_bucket_size,
135
+ k_bucket_size,
136
+ )
137
+ hidden_states = hidden_states.to(query.dtype)
138
+
139
+ hidden_states = attn.batch_to_head_dim(hidden_states)
140
+
141
+ # linear proj
142
+ hidden_states = attn.to_out[0](hidden_states)
143
+
144
+ # dropout
145
+ hidden_states = attn.to_out[1](hidden_states)
146
+
147
+ return hidden_states
148
+
149
+ class ModelWrapper:
150
+ def __init__(self, model, alphas_cumprod):
151
+ self.model = model
152
+ self.alphas_cumprod = alphas_cumprod
153
+
154
+ def apply_model(self, *args, **kwargs):
155
+ if len(args) == 3:
156
+ encoder_hidden_states = args[-1]
157
+ args = args[:2]
158
+ if kwargs.get("cond", None) is not None:
159
+ encoder_hidden_states = kwargs.pop("cond")
160
+ return self.model(
161
+ *args, encoder_hidden_states=encoder_hidden_states, **kwargs
162
+ ).sample
163
+
164
+
165
+ class StableDiffusionPipeline(DiffusionPipeline):
166
+
167
+ _optional_components = ["safety_checker", "feature_extractor"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae,
172
+ text_encoder,
173
+ tokenizer,
174
+ unet,
175
+ scheduler,
176
+ ):
177
+ super().__init__()
178
+
179
+ # get correct sigmas from LMS
180
+ self.register_modules(
181
+ vae=vae,
182
+ text_encoder=text_encoder,
183
+ tokenizer=tokenizer,
184
+ unet=unet,
185
+ scheduler=scheduler,
186
+ )
187
+ self.setup_unet(self.unet)
188
+ self.setup_text_encoder()
189
+
190
+ def setup_text_encoder(self, n=1, new_encoder=None):
191
+ if new_encoder is not None:
192
+ self.text_encoder = new_encoder
193
+
194
+ self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
195
+ self.prompt_parser.CLIP_stop_at_last_layers = n
196
+
197
+ def setup_unet(self, unet):
198
+ unet = unet.to(self.device)
199
+ model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
200
+ if self.scheduler.prediction_type == "v_prediction":
201
+ self.k_diffusion_model = CompVisVDenoiser(model)
202
+ else:
203
+ self.k_diffusion_model = CompVisDenoiser(model)
204
+
205
+ def get_scheduler(self, scheduler_type: str):
206
+ library = importlib.import_module("k_diffusion")
207
+ sampling = getattr(library, "sampling")
208
+ return getattr(sampling, scheduler_type)
209
+
210
+ def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
211
+ uncond, cond = text_ids[0], text_ids[1]
212
+
213
+ img_state = []
214
+ if state is None:
215
+ return torch.FloatTensor(0)
216
+
217
+ for k, v in state.items():
218
+ if v["map"] is None:
219
+ continue
220
+
221
+ v_input = self.tokenizer(
222
+ k,
223
+ max_length=self.tokenizer.model_max_length,
224
+ truncation=True,
225
+ add_special_tokens=False,
226
+ ).input_ids
227
+
228
+ dotmap = v["map"] < 255
229
+ out = dotmap.astype(float)
230
+ if v["mask_outsides"]:
231
+ out[out==0] = -1
232
+
233
+ arr = torch.from_numpy(
234
+ out * float(v["weight"]) * g_strength
235
+ )
236
+ img_state.append((v_input, arr))
237
+
238
+ if len(img_state) == 0:
239
+ return torch.FloatTensor(0)
240
+
241
+ w_tensors = dict()
242
+ cond = cond.tolist()
243
+ uncond = uncond.tolist()
244
+ for layer in self.unet.down_blocks:
245
+ c = int(len(cond))
246
+ w, h = img_state[0][1].shape
247
+ w_r, h_r = w // scale_ratio, h // scale_ratio
248
+
249
+ ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
250
+ ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
251
+
252
+ for v_as_tokens, img_where_color in img_state:
253
+ is_in = 0
254
+
255
+ ret = (
256
+ F.interpolate(
257
+ img_where_color.unsqueeze(0).unsqueeze(1),
258
+ scale_factor=1 / scale_ratio,
259
+ mode="bilinear",
260
+ align_corners=True,
261
+ )
262
+ .squeeze()
263
+ .reshape(-1, 1)
264
+ .repeat(1, len(v_as_tokens))
265
+ )
266
+
267
+ for idx, tok in enumerate(cond):
268
+ if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
269
+ is_in = 1
270
+ ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
271
+
272
+ for idx, tok in enumerate(uncond):
273
+ if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
274
+ is_in = 1
275
+ ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
276
+
277
+ if not is_in == 1:
278
+ print(f"tokens {v_as_tokens} not found in text")
279
+
280
+ w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
281
+ scale_ratio *= 2
282
+
283
+ return w_tensors
284
+
285
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
286
+ r"""
287
+ Enable sliced attention computation.
288
+
289
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
290
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
291
+
292
+ Args:
293
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
294
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
295
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
296
+ `attention_head_dim` must be a multiple of `slice_size`.
297
+ """
298
+ if slice_size == "auto":
299
+ # half the attention head size is usually a good trade-off between
300
+ # speed and memory
301
+ slice_size = self.unet.config.attention_head_dim // 2
302
+ self.unet.set_attention_slice(slice_size)
303
+
304
+ def disable_attention_slicing(self):
305
+ r"""
306
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
307
+ back to computing attention in one step.
308
+ """
309
+ # set slice_size = `None` to disable `attention slicing`
310
+ self.enable_attention_slicing(None)
311
+
312
+ def enable_sequential_cpu_offload(self, gpu_id=0):
313
+ r"""
314
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
315
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
316
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
317
+ """
318
+ if is_accelerate_available():
319
+ from accelerate import cpu_offload
320
+ else:
321
+ raise ImportError("Please install accelerate via `pip install accelerate`")
322
+
323
+ device = torch.device(f"cuda:{gpu_id}")
324
+
325
+ for cpu_offloaded_model in [
326
+ self.unet,
327
+ self.text_encoder,
328
+ self.vae,
329
+ self.safety_checker,
330
+ ]:
331
+ if cpu_offloaded_model is not None:
332
+ cpu_offload(cpu_offloaded_model, device)
333
+
334
+ @property
335
+ def _execution_device(self):
336
+ r"""
337
+ Returns the device on which the pipeline's models will be executed. After calling
338
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
339
+ hooks.
340
+ """
341
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
342
+ return self.device
343
+ for module in self.unet.modules():
344
+ if (
345
+ hasattr(module, "_hf_hook")
346
+ and hasattr(module._hf_hook, "execution_device")
347
+ and module._hf_hook.execution_device is not None
348
+ ):
349
+ return torch.device(module._hf_hook.execution_device)
350
+ return self.device
351
+
352
+ def decode_latents(self, latents):
353
+ latents = latents.to(self.device, dtype=self.vae.dtype)
354
+ latents = 1 / 0.18215 * latents
355
+ image = self.vae.decode(latents).sample
356
+ image = (image / 2 + 0.5).clamp(0, 1)
357
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
358
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
359
+ return image
360
+
361
+ def check_inputs(self, prompt, height, width, callback_steps):
362
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
363
+ raise ValueError(
364
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
365
+ )
366
+
367
+ if height % 8 != 0 or width % 8 != 0:
368
+ raise ValueError(
369
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
370
+ )
371
+
372
+ if (callback_steps is None) or (
373
+ callback_steps is not None
374
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
375
+ ):
376
+ raise ValueError(
377
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
378
+ f" {type(callback_steps)}."
379
+ )
380
+
381
+ def prepare_latents(
382
+ self,
383
+ batch_size,
384
+ num_channels_latents,
385
+ height,
386
+ width,
387
+ dtype,
388
+ device,
389
+ generator,
390
+ latents=None,
391
+ ):
392
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
393
+ if latents is None:
394
+ if device.type == "mps":
395
+ # randn does not work reproducibly on mps
396
+ latents = torch.randn(
397
+ shape, generator=generator, device="cpu", dtype=dtype
398
+ ).to(device)
399
+ else:
400
+ latents = torch.randn(
401
+ shape, generator=generator, device=device, dtype=dtype
402
+ )
403
+ else:
404
+ # if latents.shape != shape:
405
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
406
+ latents = latents.to(device)
407
+
408
+ # scale the initial noise by the standard deviation required by the scheduler
409
+ return latents
410
+
411
+ def preprocess(self, image):
412
+ if isinstance(image, torch.Tensor):
413
+ return image
414
+ elif isinstance(image, PIL.Image.Image):
415
+ image = [image]
416
+
417
+ if isinstance(image[0], PIL.Image.Image):
418
+ w, h = image[0].size
419
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
420
+
421
+ image = [
422
+ np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
423
+ None, :
424
+ ]
425
+ for i in image
426
+ ]
427
+ image = np.concatenate(image, axis=0)
428
+ image = np.array(image).astype(np.float32) / 255.0
429
+ image = image.transpose(0, 3, 1, 2)
430
+ image = 2.0 * image - 1.0
431
+ image = torch.from_numpy(image)
432
+ elif isinstance(image[0], torch.Tensor):
433
+ image = torch.cat(image, dim=0)
434
+ return image
435
+
436
+ @torch.no_grad()
437
+ def img2img(
438
+ self,
439
+ prompt: Union[str, List[str]],
440
+ num_inference_steps: int = 50,
441
+ guidance_scale: float = 7.5,
442
+ negative_prompt: Optional[Union[str, List[str]]] = None,
443
+ generator: Optional[torch.Generator] = None,
444
+ image: Optional[torch.FloatTensor] = None,
445
+ output_type: Optional[str] = "pil",
446
+ latents=None,
447
+ strength=1.0,
448
+ pww_state=None,
449
+ pww_attn_weight=1.0,
450
+ sampler_name="",
451
+ sampler_opt={},
452
+ start_time=-1,
453
+ timeout=180,
454
+ scale_ratio=8.0,
455
+ ):
456
+ sampler = self.get_scheduler(sampler_name)
457
+ if image is not None:
458
+ image = self.preprocess(image)
459
+ image = image.to(self.vae.device, dtype=self.vae.dtype)
460
+
461
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
462
+ latents = 0.18215 * init_latents
463
+
464
+ # 2. Define call parameters
465
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
466
+ device = self._execution_device
467
+ latents = latents.to(device, dtype=self.unet.dtype)
468
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
469
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
470
+ # corresponds to doing no classifier free guidance.
471
+ do_classifier_free_guidance = True
472
+ if guidance_scale <= 1.0:
473
+ raise ValueError("has to use guidance_scale")
474
+
475
+ # 3. Encode input prompt
476
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
477
+ text_embeddings = text_embeddings.to(self.unet.dtype)
478
+
479
+ init_timestep = (
480
+ int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
481
+ )
482
+ sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
483
+ text_embeddings.device, dtype=text_embeddings.dtype
484
+ )
485
+
486
+ t_start = max(init_timestep - num_inference_steps, 0)
487
+ sigma_sched = sigmas[t_start:]
488
+
489
+ noise = randn_tensor(
490
+ latents.shape,
491
+ generator=generator,
492
+ device=device,
493
+ dtype=text_embeddings.dtype,
494
+ )
495
+ latents = latents.to(device)
496
+ latents = latents + noise * sigma_sched[0]
497
+
498
+ # 5. Prepare latent variables
499
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
500
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
501
+ latents.device
502
+ )
503
+
504
+ img_state = self.encode_sketchs(
505
+ pww_state,
506
+ g_strength=pww_attn_weight,
507
+ text_ids=text_ids,
508
+ )
509
+
510
+ def model_fn(x, sigma):
511
+
512
+ if start_time > 0 and timeout > 0:
513
+ assert (time.time() - start_time) < timeout, "inference process timed out"
514
+
515
+ latent_model_input = torch.cat([x] * 2)
516
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
517
+ encoder_state = {
518
+ "img_state": img_state,
519
+ "states": text_embeddings,
520
+ "sigma": sigma[0],
521
+ "weight_func": weight_func,
522
+ }
523
+
524
+ noise_pred = self.k_diffusion_model(
525
+ latent_model_input, sigma, cond=encoder_state
526
+ )
527
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
528
+ noise_pred = noise_pred_uncond + guidance_scale * (
529
+ noise_pred_text - noise_pred_uncond
530
+ )
531
+ return noise_pred
532
+
533
+ sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
534
+ latents = sampler(model_fn, latents, **sampler_args)
535
+
536
+ # 8. Post-processing
537
+ image = self.decode_latents(latents)
538
+
539
+ # 10. Convert to PIL
540
+ if output_type == "pil":
541
+ image = self.numpy_to_pil(image)
542
+
543
+ return (image,)
544
+
545
+ def get_sigmas(self, steps, params):
546
+ discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
547
+ steps += 1 if discard_next_to_last_sigma else 0
548
+
549
+ if params.get("scheduler", None) == "karras":
550
+ sigma_min, sigma_max = (
551
+ self.k_diffusion_model.sigmas[0].item(),
552
+ self.k_diffusion_model.sigmas[-1].item(),
553
+ )
554
+ sigmas = k_diffusion.sampling.get_sigmas_karras(
555
+ n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
556
+ )
557
+ else:
558
+ sigmas = self.k_diffusion_model.get_sigmas(steps)
559
+
560
+ if discard_next_to_last_sigma:
561
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
562
+
563
+ return sigmas
564
+
565
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
566
+ def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func):
567
+ extra_params_kwargs = {}
568
+
569
+ if "eta" in inspect.signature(func).parameters:
570
+ extra_params_kwargs["eta"] = eta
571
+
572
+ if "sigma_min" in inspect.signature(func).parameters:
573
+ extra_params_kwargs["sigma_min"] = sigmas[0].item()
574
+ extra_params_kwargs["sigma_max"] = sigmas[-1].item()
575
+
576
+ if "n" in inspect.signature(func).parameters:
577
+ extra_params_kwargs["n"] = steps
578
+ else:
579
+ extra_params_kwargs["sigmas"] = sigmas
580
+
581
+ return extra_params_kwargs
582
+
583
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
584
+ def get_sampler_extra_args_i2i(self, sigmas, func):
585
+ extra_params_kwargs = {}
586
+
587
+ if "sigma_min" in inspect.signature(func).parameters:
588
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
589
+ extra_params_kwargs["sigma_min"] = sigmas[-2]
590
+
591
+ if "sigma_max" in inspect.signature(func).parameters:
592
+ extra_params_kwargs["sigma_max"] = sigmas[0]
593
+
594
+ if "n" in inspect.signature(func).parameters:
595
+ extra_params_kwargs["n"] = len(sigmas) - 1
596
+
597
+ if "sigma_sched" in inspect.signature(func).parameters:
598
+ extra_params_kwargs["sigma_sched"] = sigmas
599
+
600
+ if "sigmas" in inspect.signature(func).parameters:
601
+ extra_params_kwargs["sigmas"] = sigmas
602
+
603
+ return extra_params_kwargs
604
+
605
+ @torch.no_grad()
606
+ def txt2img(
607
+ self,
608
+ prompt: Union[str, List[str]],
609
+ height: int = 512,
610
+ width: int = 512,
611
+ num_inference_steps: int = 50,
612
+ guidance_scale: float = 7.5,
613
+ negative_prompt: Optional[Union[str, List[str]]] = None,
614
+ eta: float = 0.0,
615
+ generator: Optional[torch.Generator] = None,
616
+ latents: Optional[torch.FloatTensor] = None,
617
+ output_type: Optional[str] = "pil",
618
+ callback_steps: Optional[int] = 1,
619
+ upscale=False,
620
+ upscale_x: float = 2.0,
621
+ upscale_method: str = "bicubic",
622
+ upscale_antialias: bool = False,
623
+ upscale_denoising_strength: int = 0.7,
624
+ pww_state=None,
625
+ pww_attn_weight=1.0,
626
+ sampler_name="",
627
+ sampler_opt={},
628
+ start_time=-1,
629
+ timeout=180,
630
+ ):
631
+ sampler = self.get_scheduler(sampler_name)
632
+ # 1. Check inputs. Raise error if not correct
633
+ self.check_inputs(prompt, height, width, callback_steps)
634
+
635
+ # 2. Define call parameters
636
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
637
+ device = self._execution_device
638
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
639
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
640
+ # corresponds to doing no classifier free guidance.
641
+ do_classifier_free_guidance = True
642
+ if guidance_scale <= 1.0:
643
+ raise ValueError("has to use guidance_scale")
644
+
645
+ # 3. Encode input prompt
646
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
647
+ text_embeddings = text_embeddings.to(self.unet.dtype)
648
+
649
+ # 4. Prepare timesteps
650
+ sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
651
+ text_embeddings.device, dtype=text_embeddings.dtype
652
+ )
653
+
654
+ # 5. Prepare latent variables
655
+ num_channels_latents = self.unet.in_channels
656
+ latents = self.prepare_latents(
657
+ batch_size,
658
+ num_channels_latents,
659
+ height,
660
+ width,
661
+ text_embeddings.dtype,
662
+ device,
663
+ generator,
664
+ latents,
665
+ )
666
+ latents = latents * sigmas[0]
667
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
668
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
669
+ latents.device
670
+ )
671
+
672
+ img_state = self.encode_sketchs(
673
+ pww_state,
674
+ g_strength=pww_attn_weight,
675
+ text_ids=text_ids,
676
+ )
677
+
678
+ def model_fn(x, sigma):
679
+
680
+ if start_time > 0 and timeout > 0:
681
+ assert (time.time() - start_time) < timeout, "inference process timed out"
682
+
683
+ latent_model_input = torch.cat([x] * 2)
684
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
685
+ encoder_state = {
686
+ "img_state": img_state,
687
+ "states": text_embeddings,
688
+ "sigma": sigma[0],
689
+ "weight_func": weight_func,
690
+ }
691
+
692
+ noise_pred = self.k_diffusion_model(
693
+ latent_model_input, sigma, cond=encoder_state
694
+ )
695
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
696
+ noise_pred = noise_pred_uncond + guidance_scale * (
697
+ noise_pred_text - noise_pred_uncond
698
+ )
699
+ return noise_pred
700
+
701
+ extra_args = self.get_sampler_extra_args_t2i(
702
+ sigmas, eta, num_inference_steps, sampler
703
+ )
704
+ latents = sampler(model_fn, latents, **extra_args)
705
+
706
+ if upscale:
707
+ target_height = height * upscale_x
708
+ target_width = width * upscale_x
709
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
710
+ latents = torch.nn.functional.interpolate(
711
+ latents,
712
+ size=(
713
+ int(target_height // vae_scale_factor),
714
+ int(target_width // vae_scale_factor),
715
+ ),
716
+ mode=upscale_method,
717
+ antialias=upscale_antialias,
718
+ )
719
+ return self.img2img(
720
+ prompt=prompt,
721
+ num_inference_steps=num_inference_steps,
722
+ guidance_scale=guidance_scale,
723
+ negative_prompt=negative_prompt,
724
+ generator=generator,
725
+ latents=latents,
726
+ strength=upscale_denoising_strength,
727
+ sampler_name=sampler_name,
728
+ sampler_opt=sampler_opt,
729
+ pww_state=None,
730
+ pww_attn_weight=pww_attn_weight / 2,
731
+ )
732
+
733
+ # 8. Post-processing
734
+ image = self.decode_latents(latents)
735
+
736
+ # 10. Convert to PIL
737
+ if output_type == "pil":
738
+ image = self.numpy_to_pil(image)
739
+
740
+ return (image,)
741
+
742
+
743
+ class FlashAttentionFunction(Function):
744
+ @staticmethod
745
+ @torch.no_grad()
746
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
747
+ """Algorithm 2 in the paper"""
748
+
749
+ device = q.device
750
+ max_neg_value = -torch.finfo(q.dtype).max
751
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
752
+
753
+ o = torch.zeros_like(q)
754
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
755
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
756
+
757
+ scale = q.shape[-1] ** -0.5
758
+
759
+ if not exists(mask):
760
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
761
+ else:
762
+ mask = rearrange(mask, "b n -> b 1 1 n")
763
+ mask = mask.split(q_bucket_size, dim=-1)
764
+
765
+ row_splits = zip(
766
+ q.split(q_bucket_size, dim=-2),
767
+ o.split(q_bucket_size, dim=-2),
768
+ mask,
769
+ all_row_sums.split(q_bucket_size, dim=-2),
770
+ all_row_maxes.split(q_bucket_size, dim=-2),
771
+ )
772
+
773
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
774
+ q_start_index = ind * q_bucket_size - qk_len_diff
775
+
776
+ col_splits = zip(
777
+ k.split(k_bucket_size, dim=-2),
778
+ v.split(k_bucket_size, dim=-2),
779
+ )
780
+
781
+ for k_ind, (kc, vc) in enumerate(col_splits):
782
+ k_start_index = k_ind * k_bucket_size
783
+
784
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
785
+
786
+ if exists(row_mask):
787
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
788
+
789
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
790
+ causal_mask = torch.ones(
791
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
792
+ ).triu(q_start_index - k_start_index + 1)
793
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
794
+
795
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
796
+ attn_weights -= block_row_maxes
797
+ exp_weights = torch.exp(attn_weights)
798
+
799
+ if exists(row_mask):
800
+ exp_weights.masked_fill_(~row_mask, 0.0)
801
+
802
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
803
+ min=EPSILON
804
+ )
805
+
806
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
807
+
808
+ exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
809
+
810
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
811
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
812
+
813
+ new_row_sums = (
814
+ exp_row_max_diff * row_sums
815
+ + exp_block_row_max_diff * block_row_sums
816
+ )
817
+
818
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
819
+ (exp_block_row_max_diff / new_row_sums) * exp_values
820
+ )
821
+
822
+ row_maxes.copy_(new_row_maxes)
823
+ row_sums.copy_(new_row_sums)
824
+
825
+ lse = all_row_sums.log() + all_row_maxes
826
+
827
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
828
+ ctx.save_for_backward(q, k, v, o, lse)
829
+
830
+ return o
831
+
832
+ @staticmethod
833
+ @torch.no_grad()
834
+ def backward(ctx, do):
835
+ """Algorithm 4 in the paper"""
836
+
837
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
838
+ q, k, v, o, lse = ctx.saved_tensors
839
+
840
+ device = q.device
841
+
842
+ max_neg_value = -torch.finfo(q.dtype).max
843
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
844
+
845
+ dq = torch.zeros_like(q)
846
+ dk = torch.zeros_like(k)
847
+ dv = torch.zeros_like(v)
848
+
849
+ row_splits = zip(
850
+ q.split(q_bucket_size, dim=-2),
851
+ o.split(q_bucket_size, dim=-2),
852
+ do.split(q_bucket_size, dim=-2),
853
+ mask,
854
+ lse.split(q_bucket_size, dim=-2),
855
+ dq.split(q_bucket_size, dim=-2),
856
+ )
857
+
858
+ for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
859
+ q_start_index = ind * q_bucket_size - qk_len_diff
860
+
861
+ col_splits = zip(
862
+ k.split(k_bucket_size, dim=-2),
863
+ v.split(k_bucket_size, dim=-2),
864
+ dk.split(k_bucket_size, dim=-2),
865
+ dv.split(k_bucket_size, dim=-2),
866
+ )
867
+
868
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
869
+ k_start_index = k_ind * k_bucket_size
870
+
871
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
872
+
873
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
874
+ causal_mask = torch.ones(
875
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
876
+ ).triu(q_start_index - k_start_index + 1)
877
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
878
+
879
+ p = torch.exp(attn_weights - lsec)
880
+
881
+ if exists(row_mask):
882
+ p.masked_fill_(~row_mask, 0.0)
883
+
884
+ dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
885
+ dp = einsum("... i d, ... j d -> ... i j", doc, vc)
886
+
887
+ D = (doc * oc).sum(dim=-1, keepdims=True)
888
+ ds = p * scale * (dp - D)
889
+
890
+ dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
891
+ dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
892
+
893
+ dqc.add_(dq_chunk)
894
+ dkc.add_(dk_chunk)
895
+ dvc.add_(dv_chunk)
896
+
897
+ return dq, dk, dv, None, None, None, None
modules/prompt_parser.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ # Code from https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/8e2aeee4a127b295bfc880800e4a312e0f049b85, modified.
8
+
9
+ class PromptChunk:
10
+ """
11
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
12
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
13
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
14
+ so just 75 tokens from prompt.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.tokens = []
19
+ self.multipliers = []
20
+ self.fixes = []
21
+
22
+
23
+ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
24
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
25
+ have unlimited prompt length and assign weights to tokens in prompt.
26
+ """
27
+
28
+ def __init__(self, text_encoder, enable_emphasis=True):
29
+ super().__init__()
30
+
31
+ self.device = lambda: text_encoder.device
32
+ self.enable_emphasis = enable_emphasis
33
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
34
+ depending on model."""
35
+
36
+ self.chunk_length = 75
37
+
38
+ def empty_chunk(self):
39
+ """creates an empty PromptChunk and returns it"""
40
+
41
+ chunk = PromptChunk()
42
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
43
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
44
+ return chunk
45
+
46
+ def get_target_prompt_token_count(self, token_count):
47
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
48
+
49
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
50
+
51
+ def tokenize_line(self, line):
52
+ """
53
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
54
+ represent the prompt.
55
+ Returns the list and the total number of tokens in the prompt.
56
+ """
57
+
58
+ if self.enable_emphasis:
59
+ parsed = parse_prompt_attention(line)
60
+ else:
61
+ parsed = [[line, 1.0]]
62
+
63
+ tokenized = self.tokenize([text for text, _ in parsed])
64
+
65
+ chunks = []
66
+ chunk = PromptChunk()
67
+ token_count = 0
68
+ last_comma = -1
69
+
70
+ def next_chunk(is_last=False):
71
+ """puts current chunk into the list of results and produces the next one - empty;
72
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
73
+ nonlocal token_count
74
+ nonlocal last_comma
75
+ nonlocal chunk
76
+
77
+ if is_last:
78
+ token_count += len(chunk.tokens)
79
+ else:
80
+ token_count += self.chunk_length
81
+
82
+ to_add = self.chunk_length - len(chunk.tokens)
83
+ if to_add > 0:
84
+ chunk.tokens += [self.id_end] * to_add
85
+ chunk.multipliers += [1.0] * to_add
86
+
87
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
88
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
89
+
90
+ last_comma = -1
91
+ chunks.append(chunk)
92
+ chunk = PromptChunk()
93
+
94
+ comma_padding_backtrack = 20 # default value in https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/shared.py#L410
95
+ for tokens, (text, weight) in zip(tokenized, parsed):
96
+ if text == "BREAK" and weight == -1:
97
+ next_chunk()
98
+ continue
99
+
100
+ position = 0
101
+ while position < len(tokens):
102
+ token = tokens[position]
103
+
104
+ if token == self.comma_token:
105
+ last_comma = len(chunk.tokens)
106
+
107
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
108
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
109
+ elif (
110
+ comma_padding_backtrack != 0
111
+ and len(chunk.tokens) == self.chunk_length
112
+ and last_comma != -1
113
+ and len(chunk.tokens) - last_comma <= comma_padding_backtrack
114
+ ):
115
+ break_location = last_comma + 1
116
+
117
+ reloc_tokens = chunk.tokens[break_location:]
118
+ reloc_mults = chunk.multipliers[break_location:]
119
+
120
+ chunk.tokens = chunk.tokens[:break_location]
121
+ chunk.multipliers = chunk.multipliers[:break_location]
122
+
123
+ next_chunk()
124
+ chunk.tokens = reloc_tokens
125
+ chunk.multipliers = reloc_mults
126
+
127
+ if len(chunk.tokens) == self.chunk_length:
128
+ next_chunk()
129
+
130
+ chunk.tokens.append(token)
131
+ chunk.multipliers.append(weight)
132
+ position += 1
133
+
134
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
135
+ next_chunk(is_last=True)
136
+
137
+ return chunks, token_count
138
+
139
+ def process_texts(self, texts):
140
+ """
141
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
142
+ length, in tokens, of all texts.
143
+ """
144
+
145
+ token_count = 0
146
+
147
+ cache = {}
148
+ batch_chunks = []
149
+ for line in texts:
150
+ if line in cache:
151
+ chunks = cache[line]
152
+ else:
153
+ chunks, current_token_count = self.tokenize_line(line)
154
+ token_count = max(current_token_count, token_count)
155
+
156
+ cache[line] = chunks
157
+
158
+ batch_chunks.append(chunks)
159
+
160
+ return batch_chunks, token_count
161
+
162
+ def forward(self, texts):
163
+ """
164
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
165
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
166
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
167
+ An example shape returned by this function can be: (2, 77, 768).
168
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
169
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
170
+ """
171
+
172
+ batch_chunks, token_count = self.process_texts(texts)
173
+ chunk_count = max([len(x) for x in batch_chunks])
174
+
175
+ zs = []
176
+ ts = []
177
+ for i in range(chunk_count):
178
+ batch_chunk = [
179
+ chunks[i] if i < len(chunks) else self.empty_chunk()
180
+ for chunks in batch_chunks
181
+ ]
182
+
183
+ tokens = [x.tokens for x in batch_chunk]
184
+ multipliers = [x.multipliers for x in batch_chunk]
185
+ # self.embeddings.fixes = [x.fixes for x in batch_chunk]
186
+
187
+ # for fixes in self.embeddings.fixes:
188
+ # for position, embedding in fixes:
189
+ # used_embeddings[embedding.name] = embedding
190
+
191
+ z = self.process_tokens(tokens, multipliers)
192
+ zs.append(z)
193
+ ts.append(tokens)
194
+
195
+ return np.hstack(ts), torch.hstack(zs)
196
+
197
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
198
+ """
199
+ sends one single prompt chunk to be encoded by transformers neural network.
200
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
201
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
202
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
203
+ corresponds to one token.
204
+ """
205
+ tokens = torch.asarray(remade_batch_tokens).to(self.device())
206
+
207
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
208
+ if self.id_end != self.id_pad:
209
+ for batch_pos in range(len(remade_batch_tokens)):
210
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
211
+ tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
212
+
213
+ z = self.encode_with_transformers(tokens)
214
+
215
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
216
+ batch_multipliers = torch.asarray(batch_multipliers).to(self.device())
217
+ original_mean = z.mean()
218
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
219
+ new_mean = z.mean()
220
+ z = z * (original_mean / new_mean)
221
+
222
+ return z
223
+
224
+
225
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
226
+ def __init__(self, tokenizer, text_encoder):
227
+ super().__init__(text_encoder)
228
+ self.tokenizer = tokenizer
229
+ self.text_encoder = text_encoder
230
+
231
+ vocab = self.tokenizer.get_vocab()
232
+
233
+ self.comma_token = vocab.get(",</w>", None)
234
+
235
+ self.token_mults = {}
236
+ tokens_with_parens = [
237
+ (k, v)
238
+ for k, v in vocab.items()
239
+ if "(" in k or ")" in k or "[" in k or "]" in k
240
+ ]
241
+ for text, ident in tokens_with_parens:
242
+ mult = 1.0
243
+ for c in text:
244
+ if c == "[":
245
+ mult /= 1.1
246
+ if c == "]":
247
+ mult *= 1.1
248
+ if c == "(":
249
+ mult *= 1.1
250
+ if c == ")":
251
+ mult /= 1.1
252
+
253
+ if mult != 1.0:
254
+ self.token_mults[ident] = mult
255
+
256
+ self.id_start = self.tokenizer.bos_token_id
257
+ self.id_end = self.tokenizer.eos_token_id
258
+ self.id_pad = self.id_end
259
+
260
+ def tokenize(self, texts):
261
+ tokenized = self.tokenizer(
262
+ texts, truncation=False, add_special_tokens=False
263
+ )["input_ids"]
264
+
265
+ return tokenized
266
+
267
+ def encode_with_transformers(self, tokens):
268
+ CLIP_stop_at_last_layers = 1
269
+ tokens = tokens.to(self.text_encoder.device)
270
+ outputs = self.text_encoder(tokens, output_hidden_states=True)
271
+
272
+ if CLIP_stop_at_last_layers > 1:
273
+ z = outputs.hidden_states[-CLIP_stop_at_last_layers]
274
+ z = self.text_encoder.text_model.final_layer_norm(z)
275
+ else:
276
+ z = outputs.last_hidden_state
277
+
278
+ return z
279
+
280
+
281
+ re_attention = re.compile(
282
+ r"""
283
+ \\\(|
284
+ \\\)|
285
+ \\\[|
286
+ \\]|
287
+ \\\\|
288
+ \\|
289
+ \(|
290
+ \[|
291
+ :([+-]?[.\d]+)\)|
292
+ \)|
293
+ ]|
294
+ [^\\()\[\]:]+|
295
+ :
296
+ """,
297
+ re.X,
298
+ )
299
+
300
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
301
+
302
+
303
+ def parse_prompt_attention(text):
304
+ """
305
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
306
+ Accepted tokens are:
307
+ (abc) - increases attention to abc by a multiplier of 1.1
308
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
309
+ [abc] - decreases attention to abc by a multiplier of 1.1
310
+ \( - literal character '('
311
+ \[ - literal character '['
312
+ \) - literal character ')'
313
+ \] - literal character ']'
314
+ \\ - literal character '\'
315
+ anything else - just text
316
+
317
+ >>> parse_prompt_attention('normal text')
318
+ [['normal text', 1.0]]
319
+ >>> parse_prompt_attention('an (important) word')
320
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
321
+ >>> parse_prompt_attention('(unbalanced')
322
+ [['unbalanced', 1.1]]
323
+ >>> parse_prompt_attention('\(literal\]')
324
+ [['(literal]', 1.0]]
325
+ >>> parse_prompt_attention('(unnecessary)(parens)')
326
+ [['unnecessaryparens', 1.1]]
327
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
328
+ [['a ', 1.0],
329
+ ['house', 1.5730000000000004],
330
+ [' ', 1.1],
331
+ ['on', 1.0],
332
+ [' a ', 1.1],
333
+ ['hill', 0.55],
334
+ [', sun, ', 1.1],
335
+ ['sky', 1.4641000000000006],
336
+ ['.', 1.1]]
337
+ """
338
+
339
+ res = []
340
+ round_brackets = []
341
+ square_brackets = []
342
+
343
+ round_bracket_multiplier = 1.1
344
+ square_bracket_multiplier = 1 / 1.1
345
+
346
+ def multiply_range(start_position, multiplier):
347
+ for p in range(start_position, len(res)):
348
+ res[p][1] *= multiplier
349
+
350
+ for m in re_attention.finditer(text):
351
+ text = m.group(0)
352
+ weight = m.group(1)
353
+
354
+ if text.startswith("\\"):
355
+ res.append([text[1:], 1.0])
356
+ elif text == "(":
357
+ round_brackets.append(len(res))
358
+ elif text == "[":
359
+ square_brackets.append(len(res))
360
+ elif weight is not None and len(round_brackets) > 0:
361
+ multiply_range(round_brackets.pop(), float(weight))
362
+ elif text == ")" and len(round_brackets) > 0:
363
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
364
+ elif text == "]" and len(square_brackets) > 0:
365
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
366
+ else:
367
+ parts = re.split(re_break, text)
368
+ for i, part in enumerate(parts):
369
+ if i > 0:
370
+ res.append(["BREAK", -1])
371
+ res.append([part, 1.0])
372
+
373
+ for pos in round_brackets:
374
+ multiply_range(pos, round_bracket_multiplier)
375
+
376
+ for pos in square_brackets:
377
+ multiply_range(pos, square_bracket_multiplier)
378
+
379
+ if len(res) == 0:
380
+ res = [["", 1.0]]
381
+
382
+ # merge runs of identical weights
383
+ i = 0
384
+ while i + 1 < len(res):
385
+ if res[i][1] == res[i + 1][1]:
386
+ res[i][0] += res[i + 1][0]
387
+ res.pop(i + 1)
388
+ else:
389
+ i += 1
390
+
391
+ return res
modules/safe.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+ # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
3
+
4
+ import io
5
+ import pickle
6
+ import collections
7
+ import sys
8
+ import traceback
9
+
10
+ import torch
11
+ import numpy
12
+ import _codecs
13
+ import zipfile
14
+ import re
15
+
16
+
17
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
18
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
19
+
20
+
21
+ def encode(*args):
22
+ out = _codecs.encode(*args)
23
+ return out
24
+
25
+
26
+ class RestrictedUnpickler(pickle.Unpickler):
27
+ extra_handler = None
28
+
29
+ def persistent_load(self, saved_id):
30
+ assert saved_id[0] == 'storage'
31
+ return TypedStorage()
32
+
33
+ def find_class(self, module, name):
34
+ if self.extra_handler is not None:
35
+ res = self.extra_handler(module, name)
36
+ if res is not None:
37
+ return res
38
+
39
+ if module == 'collections' and name == 'OrderedDict':
40
+ return getattr(collections, name)
41
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
42
+ return getattr(torch._utils, name)
43
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
44
+ return getattr(torch, name)
45
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
46
+ return getattr(torch.nn.modules.container, name)
47
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
48
+ return getattr(numpy.core.multiarray, name)
49
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
50
+ return getattr(numpy, name)
51
+ if module == '_codecs' and name == 'encode':
52
+ return encode
53
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
54
+ import pytorch_lightning.callbacks
55
+ return pytorch_lightning.callbacks.model_checkpoint
56
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
57
+ import pytorch_lightning.callbacks.model_checkpoint
58
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
59
+ if module == "__builtin__" and name == 'set':
60
+ return set
61
+
62
+ # Forbid everything else.
63
+ raise Exception(f"global '{module}/{name}' is forbidden")
64
+
65
+
66
+ # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
67
+ allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
68
+ data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
69
+
70
+ def check_zip_filenames(filename, names):
71
+ for name in names:
72
+ if allowed_zip_names_re.match(name):
73
+ continue
74
+
75
+ raise Exception(f"bad file inside {filename}: {name}")
76
+
77
+
78
+ def check_pt(filename, extra_handler):
79
+ try:
80
+
81
+ # new pytorch format is a zip file
82
+ with zipfile.ZipFile(filename) as z:
83
+ check_zip_filenames(filename, z.namelist())
84
+
85
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
86
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
87
+ if len(data_pkl_filenames) == 0:
88
+ raise Exception(f"data.pkl not found in {filename}")
89
+ if len(data_pkl_filenames) > 1:
90
+ raise Exception(f"Multiple data.pkl found in {filename}")
91
+ with z.open(data_pkl_filenames[0]) as file:
92
+ unpickler = RestrictedUnpickler(file)
93
+ unpickler.extra_handler = extra_handler
94
+ unpickler.load()
95
+
96
+ except zipfile.BadZipfile:
97
+
98
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
99
+ with open(filename, "rb") as file:
100
+ unpickler = RestrictedUnpickler(file)
101
+ unpickler.extra_handler = extra_handler
102
+ for i in range(5):
103
+ unpickler.load()
104
+
105
+
106
+ def load(filename, *args, **kwargs):
107
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
108
+
109
+
110
+ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
111
+ """
112
+ this function is intended to be used by extensions that want to load models with
113
+ some extra classes in them that the usual unpickler would find suspicious.
114
+
115
+ Use the extra_handler argument to specify a function that takes module and field name as text,
116
+ and returns that field's value:
117
+
118
+ ```python
119
+ def extra(module, name):
120
+ if module == 'collections' and name == 'OrderedDict':
121
+ return collections.OrderedDict
122
+
123
+ return None
124
+
125
+ safe.load_with_extra('model.pt', extra_handler=extra)
126
+ ```
127
+
128
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
129
+ definitely unsafe.
130
+ """
131
+
132
+ try:
133
+ check_pt(filename, extra_handler)
134
+
135
+ except pickle.UnpicklingError:
136
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
137
+ print(traceback.format_exc(), file=sys.stderr)
138
+ print("The file is most likely corrupted.", file=sys.stderr)
139
+ return None
140
+
141
+ except Exception:
142
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
143
+ print(traceback.format_exc(), file=sys.stderr)
144
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
145
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
146
+ return None
147
+
148
+ return unsafe_torch_load(filename, *args, **kwargs)
149
+
150
+
151
+ class Extra:
152
+ """
153
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
154
+ (because it's not your code making the torch.load call). The intended use is like this:
155
+
156
+ ```
157
+ import torch
158
+ from modules import safe
159
+
160
+ def handler(module, name):
161
+ if module == 'torch' and name in ['float64', 'float16']:
162
+ return getattr(torch, name)
163
+
164
+ return None
165
+
166
+ with safe.Extra(handler):
167
+ x = torch.load('model.pt')
168
+ ```
169
+ """
170
+
171
+ def __init__(self, handler):
172
+ self.handler = handler
173
+
174
+ def __enter__(self):
175
+ global global_extra_handler
176
+
177
+ assert global_extra_handler is None, 'already inside an Extra() block'
178
+ global_extra_handler = self.handler
179
+
180
+ def __exit__(self, exc_type, exc_val, exc_tb):
181
+ global global_extra_handler
182
+
183
+ global_extra_handler = None
184
+
185
+
186
+ unsafe_torch_load = torch.load
187
+ torch.load = load
188
+ global_extra_handler = None