ironjr commited on
Commit
1e1c50f
·
1 Parent(s): 35a5a5e

first commit

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: SemanticPalette3
3
- emoji: 🚀
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: Semantic Palette with Stable Diffusion 3
3
+ emoji: 🧠🎨3️
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import sys
22
+
23
+ sys.path.append('../../src')
24
+
25
+ import argparse
26
+ import random
27
+ import time
28
+ import json
29
+ import os
30
+ import glob
31
+ import pathlib
32
+ from functools import partial
33
+ from pprint import pprint
34
+
35
+ import numpy as np
36
+ from PIL import Image
37
+ import torch
38
+
39
+ import gradio as gr
40
+ from huggingface_hub import snapshot_download
41
+
42
+ from model import StableMultiDiffusion3Pipeline
43
+ from util import seed_everything
44
+ from prompt_util import preprocess_prompts, _quality_dict, _style_dict
45
+
46
+
47
+ ### Utils
48
+
49
+
50
+
51
+
52
+ def log_state(state):
53
+ pprint(vars(opt))
54
+ if isinstance(state, gr.State):
55
+ state = state.value
56
+ pprint(vars(state))
57
+
58
+
59
+ def is_empty_image(im: Image.Image) -> bool:
60
+ if im is None:
61
+ return True
62
+ im = np.array(im)
63
+ has_alpha = (im.shape[2] == 4)
64
+ if not has_alpha:
65
+ return False
66
+ elif im.sum() == 0:
67
+ return True
68
+ else:
69
+ return False
70
+
71
+
72
+ ### Argument passing
73
+
74
+ parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SD3 support.')
75
+ parser.add_argument('-H', '--height', type=int, default=1024)
76
+ parser.add_argument('-W', '--width', type=int, default=2560)
77
+ parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
78
+ parser.add_argument('--bootstrap_steps', type=int, default=2)
79
+ parser.add_argument('--seed', type=int, default=-1)
80
+ parser.add_argument('--device', type=int, default=0)
81
+ parser.add_argument('--port', type=int, default=8000)
82
+ opt = parser.parse_args()
83
+
84
+
85
+ ### Global variables and data structures
86
+
87
+ device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
88
+
89
+
90
+ if opt.model is None:
91
+ model_dict = {
92
+ 'Stable Diffusion 3': 'stabilityai/stable-diffusion-3-medium-diffusers',
93
+ }
94
+ else:
95
+ if opt.model.endswith('.safetensors'):
96
+ opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
97
+ model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
98
+
99
+ dtype = torch.float32 if device == 'cpu' else torch.float16
100
+ models = {
101
+ k: StableMultiDiffusion3Pipeline(device, dtype=dtype, hf_key=v, has_i2t=False)
102
+ for k, v in model_dict.items()
103
+ }
104
+
105
+
106
+ prompt_suggestions = [
107
+ '1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
108
+ '1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
109
+ '1girl, arima kana, oshi no ko, solo, upper body, from behind',
110
+ ]
111
+
112
+ opt.max_palettes = 4
113
+ opt.default_prompt_strength = 1.0
114
+ opt.default_mask_strength = 1.0
115
+ opt.default_mask_std = 0.0
116
+ opt.default_negative_prompt = (
117
+ 'nsfw, worst quality, bad quality, normal quality, cropped, framed'
118
+ )
119
+ opt.verbose = True
120
+ opt.colors = [
121
+ '#000000',
122
+ '#2692F3',
123
+ '#F89E12',
124
+ '#16C232',
125
+ '#F92F6C',
126
+ # '#AC6AEB',
127
+ # '#92C62C',
128
+ # '#92C6EC',
129
+ # '#FECAC0',
130
+ ]
131
+
132
+
133
+ ### Event handlers
134
+
135
+ def add_palette(state):
136
+ old_actives = state.active_palettes
137
+ state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
138
+
139
+ if opt.verbose:
140
+ log_state(state)
141
+
142
+ if state.active_palettes != old_actives:
143
+ return [state] + [
144
+ gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
145
+ ] + [
146
+ gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
147
+ for i in range(opt.max_palettes)
148
+ ]
149
+ else:
150
+ return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
151
+
152
+
153
+ def select_palette(state, button, idx):
154
+ if idx < 0 or idx > opt.max_palettes:
155
+ idx = 0
156
+ old_idx = state.current_palette
157
+ if old_idx == idx:
158
+ return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
159
+
160
+ state.current_palette = idx
161
+
162
+ if opt.verbose:
163
+ log_state(state)
164
+
165
+ updates = [state] + [
166
+ gr.update() if i not in (idx, old_idx) else
167
+ gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
168
+ for i in range(opt.max_palettes + 1)
169
+ ]
170
+ label = 'Background' if idx == 0 else f'Palette {idx}'
171
+ updates.extend([
172
+ gr.update(value=button, interactive=(idx > 0)),
173
+ gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
174
+ gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
175
+ (
176
+ gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
177
+ gr.update(value=opt.default_mask_strength, interactive=False)
178
+ ),
179
+ (
180
+ gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
181
+ gr.update(value=opt.default_prompt_strength, interactive=False)
182
+ ),
183
+ (
184
+ gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
185
+ gr.update(value=opt.default_mask_std, interactive=False)
186
+ ),
187
+ ])
188
+ return updates
189
+
190
+
191
+ def change_prompt_strength(state, strength):
192
+ if state.current_palette == 0:
193
+ return state
194
+
195
+ state.prompt_strengths[state.current_palette - 1] = strength
196
+ if opt.verbose:
197
+ log_state(state)
198
+
199
+ return state
200
+
201
+
202
+ def change_std(state, std):
203
+ if state.current_palette == 0:
204
+ return state
205
+
206
+ state.mask_stds[state.current_palette - 1] = std
207
+ if opt.verbose:
208
+ log_state(state)
209
+
210
+ return state
211
+
212
+
213
+ def change_mask_strength(state, strength):
214
+ if state.current_palette == 0:
215
+ return state
216
+
217
+ state.mask_strengths[state.current_palette - 1] = strength
218
+ if opt.verbose:
219
+ log_state(state)
220
+
221
+ return state
222
+
223
+
224
+ def reset_seed(state, seed):
225
+ state.seed = seed
226
+ if opt.verbose:
227
+ log_state(state)
228
+
229
+ return state
230
+
231
+ def rename_prompt(state, name):
232
+ state.prompt_names[state.current_palette] = name
233
+ if opt.verbose:
234
+ log_state(state)
235
+
236
+ return [state] + [
237
+ gr.update() if i != state.current_palette else gr.update(value=name)
238
+ for i in range(opt.max_palettes + 1)
239
+ ]
240
+
241
+
242
+ def change_prompt(state, prompt):
243
+ state.prompts[state.current_palette] = prompt
244
+ if opt.verbose:
245
+ log_state(state)
246
+
247
+ return state
248
+
249
+
250
+ def change_neg_prompt(state, neg_prompt):
251
+ state.neg_prompts[state.current_palette] = neg_prompt
252
+ if opt.verbose:
253
+ log_state(state)
254
+
255
+ return state
256
+
257
+
258
+ def select_model(state, model_id):
259
+ state.model_id = model_id
260
+ if opt.verbose:
261
+ log_state(state)
262
+
263
+ return state
264
+
265
+
266
+ def select_style(state, style_name):
267
+ state.style_name = style_name
268
+ if opt.verbose:
269
+ log_state(state)
270
+
271
+ return state
272
+
273
+
274
+ def select_quality(state, quality_name):
275
+ state.quality_name = quality_name
276
+ if opt.verbose:
277
+ log_state(state)
278
+
279
+ return state
280
+
281
+
282
+ def import_state(state, json_text):
283
+ current_palette = state.current_palette
284
+ # active_palettes = state.active_palettes
285
+ state = argparse.Namespace(**json.loads(json_text))
286
+ state.active_palettes = opt.max_palettes
287
+ return [state] + [
288
+ gr.update(value=v, visible=True) for v in state.prompt_names
289
+ ] + [
290
+ # state.model_id,
291
+ # state.style_name,
292
+ # state.quality_name,
293
+ state.prompts[current_palette],
294
+ state.prompt_names[current_palette],
295
+ state.neg_prompts[current_palette],
296
+ state.prompt_strengths[current_palette - 1],
297
+ state.mask_strengths[current_palette - 1],
298
+ state.mask_stds[current_palette - 1],
299
+ state.seed,
300
+ ]
301
+
302
+
303
+ ### Main worker
304
+
305
+ def generate(state, *args, **kwargs):
306
+ return models[state.model_id](*args, **kwargs)
307
+
308
+
309
+
310
+ def run(state, drawpad):
311
+ seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
312
+ print('Generate!')
313
+
314
+ background = drawpad['background'].convert('RGBA')
315
+ inpainting_mode = np.asarray(background).sum() != 0
316
+ print('Inpainting mode: ', inpainting_mode)
317
+
318
+ user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
319
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
320
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
321
+
322
+ palette = torch.tensor([
323
+ tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
324
+ for s in opt.colors[1:]
325
+ ]) # (N, 3)
326
+ masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
327
+ has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
328
+ print('Has mask: ', has_masks)
329
+ masks = masks * foreground_mask
330
+ masks = masks[has_masks]
331
+
332
+ if inpainting_mode:
333
+ prompts = [state.prompts[v + 1] for v in has_masks]
334
+ negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
335
+ mask_strengths = [state.mask_strengths[v] for v in has_masks]
336
+ mask_stds = [state.mask_stds[v] for v in has_masks]
337
+ prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
338
+ else:
339
+ masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
340
+ prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
341
+ negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
342
+ mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
343
+ mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
344
+ prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
345
+
346
+ prompts, negative_prompts = preprocess_prompts(
347
+ prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
348
+
349
+ return generate(
350
+ state,
351
+ prompts,
352
+ negative_prompts,
353
+ masks=masks,
354
+ mask_strengths=mask_strengths,
355
+ mask_stds=mask_stds,
356
+ prompt_strengths=prompt_strengths,
357
+ background=background.convert('RGB'),
358
+ background_prompt=state.prompts[0],
359
+ background_negative_prompt=state.neg_prompts[0],
360
+ height=opt.height,
361
+ width=opt.width,
362
+ bootstrap_steps=2,
363
+ guidance_scale=0,
364
+ )
365
+
366
+
367
+
368
+ ### Load examples
369
+
370
+
371
+ root = pathlib.Path(__file__).parent
372
+ print(root)
373
+ example_root = os.path.join(root, 'examples')
374
+ example_images = glob.glob(os.path.join(example_root, '*.webp'))
375
+ example_images = [Image.open(i) for i in example_images]
376
+
377
+ with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
378
+ prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
379
+
380
+ with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
381
+ prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
382
+
383
+ with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
384
+ prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
385
+
386
+ with open(os.path.join(example_root, 'prompt_props.txt')) as f:
387
+ prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
388
+ prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
389
+
390
+ prompt_background = lambda: random.choice(prompts_background)
391
+ prompt_girl = lambda: random.choice(prompts_girl)
392
+ prompt_boy = lambda: random.choice(prompts_boy)
393
+ prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
394
+
395
+
396
+ ### Main application
397
+
398
+ css = f"""
399
+ #run-button {{
400
+ font-size: 30pt;
401
+ background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
402
+ margin: 0;
403
+ padding: 15px 45px;
404
+ text-align: center;
405
+ text-transform: uppercase;
406
+ transition: 0.5s;
407
+ background-size: 200% auto;
408
+ color: white;
409
+ box-shadow: 0 0 20px #eee;
410
+ border-radius: 10px;
411
+ display: block;
412
+ background-position: right center;
413
+ }}
414
+
415
+ #run-button:hover {{
416
+ background-position: left center;
417
+ color: #fff;
418
+ text-decoration: none;
419
+ }}
420
+
421
+ #semantic-palette {{
422
+ border-style: solid;
423
+ border-width: 0.2em;
424
+ border-color: #eee;
425
+ }}
426
+
427
+ #semantic-palette:hover {{
428
+ box-shadow: 0 0 20px #eee;
429
+ }}
430
+
431
+ #output-screen {{
432
+ width: 100%;
433
+ aspect-ratio: {opt.width} / {opt.height};
434
+ }}
435
+
436
+ .layer-wrap {{
437
+ display: none;
438
+ }}
439
+
440
+ .rainbow {{
441
+ text-align: center;
442
+ text-decoration: underline;
443
+ font-size: 32px;
444
+ font-family: monospace;
445
+ letter-spacing: 5px;
446
+ }}
447
+ .rainbow_text_animated {{
448
+ background: linear-gradient(to right, #6666ff, #0099ff , #00ff00, #ff3399, #6666ff);
449
+ -webkit-background-clip: text;
450
+ background-clip: text;
451
+ color: transparent;
452
+ animation: rainbow_animation 6s ease-in-out infinite;
453
+ background-size: 400% 100%;
454
+ }}
455
+
456
+ @keyframes rainbow_animation {{
457
+ 0%,100% {{
458
+ background-position: 0 0;
459
+ }}
460
+
461
+ 50% {{
462
+ background-position: 100% 0;
463
+ }}
464
+ }}
465
+
466
+ .gallery {{
467
+ --z: 16px; /* control the zig-zag */
468
+ --s: 144px; /* control the size */
469
+ --g: 4px; /* control the gap */
470
+
471
+ display: grid;
472
+ gap: var(--g);
473
+ width: calc(2*var(--s) + var(--g));
474
+ grid-auto-flow: column;
475
+ }}
476
+ .gallery > a {{
477
+ width: 0;
478
+ min-width: calc(100% + var(--z)/2);
479
+ height: var(--s);
480
+ object-fit: cover;
481
+ -webkit-mask: var(--mask);
482
+ mask: var(--mask);
483
+ cursor: pointer;
484
+ transition: .5s;
485
+ }}
486
+ .gallery > a:hover {{
487
+ width: calc(var(--s)/2);
488
+ }}
489
+ .gallery > a:first-child {{
490
+ place-self: start;
491
+ clip-path: polygon(calc(2*var(--z)) 0,100% 0,100% 100%,0 100%);
492
+ --mask:
493
+ conic-gradient(from -135deg at right,#0000,#000 1deg 89deg,#0000 90deg)
494
+ 50%/100% calc(2*var(--z)) repeat-y;
495
+ }}
496
+ .gallery > a:last-child {{
497
+ place-self: end;
498
+ clip-path: polygon(0 0,100% 0,calc(100% - 2*var(--z)) 100%,0 100%);
499
+ --mask:
500
+ conic-gradient(from 45deg at left ,#0000,#000 1deg 89deg,#0000 90deg)
501
+ 50% calc(50% - var(--z))/100% calc(2*var(--z)) repeat-y;
502
+ }}
503
+ """
504
+
505
+ for i in range(opt.max_palettes + 1):
506
+ css = css + f"""
507
+ .secondary#semantic-palette-{i} {{
508
+ background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
509
+ color: white;
510
+ }}
511
+
512
+ .primary#semantic-palette-{i} {{
513
+ background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
514
+ color: white;
515
+ }}
516
+ """
517
+
518
+
519
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
520
+
521
+ iface = argparse.Namespace()
522
+
523
+ def _define_state():
524
+ state = argparse.Namespace()
525
+
526
+ # Cursor.
527
+ state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
528
+ state.model_id = list(model_dict.keys())[0]
529
+ state.style_name = '(None)'
530
+ state.quality_name = '(None)' # 'Standard v3.1'
531
+
532
+ # State variables (one-hot).
533
+ state.active_palettes = 1
534
+
535
+ # Front-end initialized to the default values.
536
+ prompt_props_ = prompt_props()
537
+ state.prompt_names = [
538
+ '🌄 Background',
539
+ '👧 Girl',
540
+ '👦 Boy',
541
+ ] + prompt_props_ + ['🎨 New Palette' for _ in range(opt.max_palettes - 5)]
542
+ state.prompts = [
543
+ prompt_background(),
544
+ prompt_girl(),
545
+ prompt_boy(),
546
+ ] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
547
+ state.neg_prompts = [
548
+ opt.default_negative_prompt
549
+ + (', humans, humans, humans' if i == 0 else '')
550
+ for i in range(opt.max_palettes + 1)
551
+ ]
552
+ state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
553
+ state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
554
+ state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
555
+ state.seed = opt.seed
556
+ return state
557
+
558
+ state = gr.State(value=_define_state)
559
+
560
+
561
+ ### Demo user interface
562
+
563
+ gr.HTML(
564
+ """
565
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
566
+ <div>
567
+ <h1>🧠 Semantic Palette with <font class="rainbow rainbow_text_animated">Stable Diffusion 3</font> 🎨</h1>
568
+ <h5 style="margin: 0;">powered by</h5>
569
+ <h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
570
+ <h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
571
+ </br>
572
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
573
+ <a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
574
+ <img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
575
+ </a>
576
+ &nbsp;
577
+ <a href='https://arxiv.org/abs/2403.09055'>
578
+ <img src="https://img.shields.io/badge/arXiv-2403.09055-red">
579
+ </a>
580
+ &nbsp;
581
+ <a href='https://github.com/ironjr/StreamMultiDiffusion'>
582
+ <img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
583
+ </a>
584
+ &nbsp;
585
+ <a href='https://twitter.com/_ironjr_'>
586
+ <img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
587
+ </a>
588
+ &nbsp;
589
+ <a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
590
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
591
+ </a>
592
+ &nbsp;
593
+ <a href='https://huggingface.co/spaces/ironjr/StreamMultiDiffusion'>
594
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-StreamMultiDiffusion-yellow'>
595
+ </a>
596
+ &nbsp;
597
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
598
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD1.5-yellow'>
599
+ </a>
600
+ &nbsp;
601
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPaletteXL'>
602
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SDXL-yellow'>
603
+ </a>
604
+ &nbsp;
605
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPalette3'>
606
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SD3-yellow'>
607
+ </a>
608
+ </div>
609
+ </div>
610
+ </div>
611
+ <div>
612
+ </br>
613
+ </div>
614
+ """
615
+ )
616
+
617
+ with gr.Row():
618
+
619
+ iface.image_slot = gr.Image(
620
+ interactive=False,
621
+ show_label=False,
622
+ show_download_button=True,
623
+ type='pil',
624
+ label='Generated Result',
625
+ elem_id='output-screen',
626
+ value=lambda: random.choice(example_images),
627
+ )
628
+
629
+ with gr.Row():
630
+
631
+ with gr.Column(scale=1):
632
+
633
+ with gr.Group(elem_id='semantic-palette'):
634
+
635
+ gr.HTML(
636
+ """
637
+ <div style="justify-content: center; align-items: center;">
638
+ <br/>
639
+ <h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
640
+ <br/>
641
+ </div>
642
+ """
643
+ )
644
+
645
+ iface.btn_semantics = [gr.Button(
646
+ value=state.value.prompt_names[0],
647
+ variant='primary',
648
+ elem_id='semantic-palette-0',
649
+ )]
650
+ for i in range(opt.max_palettes):
651
+ iface.btn_semantics.append(gr.Button(
652
+ value=state.value.prompt_names[i + 1],
653
+ variant='secondary',
654
+ visible=(i < state.value.active_palettes),
655
+ elem_id=f'semantic-palette-{i + 1}'
656
+ ))
657
+
658
+ iface.btn_add_palette = gr.Button(
659
+ value='Create New Semantic Brush',
660
+ variant='primary',
661
+ )
662
+
663
+ with gr.Accordion(label='Import/Export Semantic Palette', open=False):
664
+ iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
665
+ iface.json_state_export = gr.JSON(label='Exported Palette')
666
+ iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
667
+ iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
668
+
669
+ gr.HTML(
670
+ """
671
+ <div>
672
+ </br>
673
+ </div>
674
+ <div style="justify-content: center; align-items: center;">
675
+ <h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
676
+ </br>
677
+ <div style="justify-content: center; align-items: left; text-align: left;">
678
+ <p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
679
+ <p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
680
+ <p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
681
+ <p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
682
+ <p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
683
+ <p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
684
+ </div>
685
+ </div>
686
+ """
687
+ )
688
+
689
+ gr.HTML(
690
+ """
691
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
692
+ <h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
693
+ </div>
694
+ """
695
+ )
696
+
697
+ gr.DuplicateButton()
698
+
699
+ with gr.Column(scale=4):
700
+
701
+ with gr.Row():
702
+
703
+ with gr.Column(scale=3):
704
+
705
+ iface.ctrl_semantic = gr.ImageEditor(
706
+ image_mode='RGBA',
707
+ sources=['upload', 'clipboard', 'webcam'],
708
+ transforms=['crop'],
709
+ crop_size=(opt.width, opt.height),
710
+ brush=gr.Brush(
711
+ colors=opt.colors[1:],
712
+ color_mode="fixed",
713
+ ),
714
+ layers=False,
715
+ canvas_size=(opt.width, opt.height),
716
+ type='pil',
717
+ label='Semantic Drawpad',
718
+ elem_id='drawpad',
719
+ )
720
+
721
+ with gr.Column(scale=1):
722
+
723
+ iface.btn_generate = gr.Button(
724
+ value='Generate!',
725
+ variant='primary',
726
+ # scale=1,
727
+ elem_id='run-button'
728
+ )
729
+
730
+
731
+ gr.HTML(
732
+ """
733
+ <h3 style="text-align: center;">Try other demos in HF 🤗 Space!</h3>
734
+ <div style="display: flex; justify-content: center; text-align: center;">
735
+ <div><b style="color: #2692F3">Semantic Palette<br>Animagine XL 3.1</b></div>
736
+ <div style="margin-left: 10px; margin-right: 10px; margin-top: 8px">or</div>
737
+ <div><b style="color: #F89E12">Official Demo of<br>StreamMultiDiffusion</b></div>
738
+ </div>
739
+ <div style="display: inline-block; margin-top: 10px">
740
+ <div class="gallery">
741
+ <a href="https://huggingface.co/spaces/ironjr/SemanticPaletteXL" target="_blank">
742
+ <img alt="AnimagineXL3.1 Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/sdxl.webp?raw=true">
743
+ </a>
744
+ <a href="https://huggingface.co/spaces/ironjr/StreamMultiDiffusion" target="_blank">
745
+ <img alt="StreamMultiDiffusion Demo" src="https://github.com/ironjr/StreamMultiDiffusion/blob/main/demo/semantic_palette_sd3/examples/icons/smd.gif?raw=true">
746
+ </a>
747
+ </div>
748
+ </div>
749
+ """
750
+ )
751
+
752
+ # iface.model_select = gr.Radio(
753
+ # list(model_dict.keys()),
754
+ # label='Stable Diffusion Checkpoint',
755
+ # info='Choose your favorite style.',
756
+ # value=state.value.model_id,
757
+ # )
758
+
759
+ # with gr.Accordion(label='Prompt Engineering', open=True):
760
+ # iface.quality_select = gr.Dropdown(
761
+ # label='Quality Presets',
762
+ # interactive=True,
763
+ # choices=list(_quality_dict.keys()),
764
+ # value='Standard v3.1',
765
+ # )
766
+ # iface.style_select = gr.Radio(
767
+ # label='Style Preset',
768
+ # container=True,
769
+ # interactive=True,
770
+ # choices=list(_style_dict.keys()),
771
+ # value='(None)',
772
+ # )
773
+
774
+ with gr.Group(elem_id='control-panel'):
775
+
776
+ with gr.Row():
777
+ iface.tbox_prompt = gr.Textbox(
778
+ label='Edit Prompt for Background',
779
+ info='What do you want to draw?',
780
+ value=state.value.prompts[0],
781
+ placeholder=lambda: random.choice(prompt_suggestions),
782
+ scale=2,
783
+ )
784
+
785
+ iface.tbox_name = gr.Textbox(
786
+ label='Edit Brush Name',
787
+ info='Just for your convenience.',
788
+ value=state.value.prompt_names[0],
789
+ placeholder='🌄 Background',
790
+ scale=1,
791
+ )
792
+
793
+ with gr.Row():
794
+ iface.tbox_neg_prompt = gr.Textbox(
795
+ label='Edit Negative Prompt for Background',
796
+ info='Add unwanted objects for this semantic brush.',
797
+ value=opt.default_negative_prompt,
798
+ scale=2,
799
+ )
800
+
801
+ iface.slider_strength = gr.Slider(
802
+ label='Prompt Strength',
803
+ info='Blends fg & bg in the prompt level, >0.8 Preferred.',
804
+ minimum=0.5,
805
+ maximum=1.0,
806
+ value=opt.default_prompt_strength,
807
+ scale=1,
808
+ )
809
+
810
+ with gr.Row():
811
+ iface.slider_alpha = gr.Slider(
812
+ label='Mask Alpha',
813
+ info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
814
+ minimum=0.5,
815
+ maximum=1.0,
816
+ value=opt.default_mask_strength,
817
+ )
818
+
819
+ iface.slider_std = gr.Slider(
820
+ label='Mask Blur STD',
821
+ info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
822
+ minimum=0.0001,
823
+ maximum=100.0,
824
+ value=opt.default_mask_std,
825
+ )
826
+
827
+ iface.slider_seed = gr.Slider(
828
+ label='Seed',
829
+ info='The global seed.',
830
+ minimum=-1,
831
+ maximum=2147483647,
832
+ step=1,
833
+ value=opt.seed,
834
+ )
835
+
836
+ ### Attach event handlers
837
+
838
+ for idx, btn in enumerate(iface.btn_semantics):
839
+ btn.click(
840
+ fn=partial(select_palette, idx=idx),
841
+ inputs=[state, btn],
842
+ outputs=[state] + iface.btn_semantics + [
843
+ iface.tbox_name,
844
+ iface.tbox_prompt,
845
+ iface.tbox_neg_prompt,
846
+ iface.slider_alpha,
847
+ iface.slider_strength,
848
+ iface.slider_std,
849
+ ],
850
+ api_name=f'select_palette_{idx}',
851
+ )
852
+
853
+ iface.btn_add_palette.click(
854
+ fn=add_palette,
855
+ inputs=state,
856
+ outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
857
+ api_name='create_new',
858
+ )
859
+
860
+ iface.btn_generate.click(
861
+ fn=run,
862
+ inputs=[state, iface.ctrl_semantic],
863
+ outputs=iface.image_slot,
864
+ api_name='run',
865
+ )
866
+
867
+ iface.slider_alpha.input(
868
+ fn=change_mask_strength,
869
+ inputs=[state, iface.slider_alpha],
870
+ outputs=state,
871
+ api_name='change_alpha',
872
+ )
873
+ iface.slider_std.input(
874
+ fn=change_std,
875
+ inputs=[state, iface.slider_std],
876
+ outputs=state,
877
+ api_name='change_std',
878
+ )
879
+ iface.slider_strength.input(
880
+ fn=change_prompt_strength,
881
+ inputs=[state, iface.slider_strength],
882
+ outputs=state,
883
+ api_name='change_strength',
884
+ )
885
+ iface.slider_seed.input(
886
+ fn=reset_seed,
887
+ inputs=[state, iface.slider_seed],
888
+ outputs=state,
889
+ api_name='reset_seed',
890
+ )
891
+
892
+ iface.tbox_name.input(
893
+ fn=rename_prompt,
894
+ inputs=[state, iface.tbox_name],
895
+ outputs=[state] + iface.btn_semantics,
896
+ api_name='prompt_rename',
897
+ )
898
+ iface.tbox_prompt.input(
899
+ fn=change_prompt,
900
+ inputs=[state, iface.tbox_prompt],
901
+ outputs=state,
902
+ api_name='prompt_edit',
903
+ )
904
+ iface.tbox_neg_prompt.input(
905
+ fn=change_neg_prompt,
906
+ inputs=[state, iface.tbox_neg_prompt],
907
+ outputs=state,
908
+ api_name='neg_prompt_edit',
909
+ )
910
+
911
+ # iface.model_select.change(
912
+ # fn=select_model,
913
+ # inputs=[state, iface.model_select],
914
+ # outputs=state,
915
+ # api_name='model_select',
916
+ # )
917
+ # iface.style_select.change(
918
+ # fn=select_style,
919
+ # inputs=[state, iface.style_select],
920
+ # outputs=state,
921
+ # api_name='style_select',
922
+ # )
923
+ # iface.quality_select.change(
924
+ # fn=select_quality,
925
+ # inputs=[state, iface.quality_select],
926
+ # outputs=state,
927
+ # api_name='quality_select',
928
+ # )
929
+
930
+ iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
931
+ iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
932
+ state,
933
+ *iface.btn_semantics,
934
+ # iface.model_select,
935
+ # iface.style_select,
936
+ # iface.quality_select,
937
+ iface.tbox_prompt,
938
+ iface.tbox_name,
939
+ iface.tbox_neg_prompt,
940
+ iface.slider_strength,
941
+ iface.slider_alpha,
942
+ iface.slider_std,
943
+ iface.slider_seed,
944
+ ])
945
+
946
+
947
+ if __name__ == '__main__':
948
+ demo.launch(server_port=opt.port)
examples/prompt_background.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
2
+ Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
3
+ Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
4
+ Maximalism, best quality, high quality, no humans, background, galaxy
5
+ Maximalism, best quality, high quality, no humans, background, sky, daylight
6
+ Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
7
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
8
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
examples/prompt_background_advanced.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/prompt_boy.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1boy, looking at viewer, brown hair, blue shirt
2
+ 1boy, looking at viewer, brown hair, red shirt
3
+ 1boy, looking at viewer, brown hair, purple shirt
4
+ 1boy, looking at viewer, brown hair, orange shirt
5
+ 1boy, looking at viewer, brown hair, yellow shirt
6
+ 1boy, looking at viewer, brown hair, green shirt
7
+ 1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
8
+ 1boy, looking back, short hair, renaissance cloths, noble boy
9
+ 1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
10
+ 1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
11
+ 1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
12
+ 1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
13
+ 1boy, looking at viewer, black haired, old eastern cloth
14
+ 1boy, looking back, messy hair, suit, short beard, noir
15
+ 1boy, looking at viewer, cute face, light smile, starry eyes, jeans
examples/prompt_girl.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
2
+ 1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
3
+ 1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
4
+ 1girl, looking at viewer, fantasy adventurer, backpack
5
+ 1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
6
+ 1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
7
+ 1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
8
+ 1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
9
+ 1girl, looking at viewer, evil smile, very short hair, suit, evil genius
10
+ 1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
11
+ 1girl, looking at viewer, purple hair, happy face, black leather jacket
12
+ 1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
13
+ 1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
14
+ 1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
15
+ 1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
16
+ 1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
examples/prompt_props.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🏯 Palace, Gyeongbokgung palace
2
+ 🌳 Garden, Chinese garden
3
+ 🏛️ Rome, Ancient city of Rome
4
+ 🧱 Wall, Castle wall
5
+ 🔴 Mars, Martian desert, Red rocky desert
6
+ 🌻 Grassland, Grasslands
7
+ 🏡 Village, A fantasy village
8
+ 🐉 Dragon, a flying chinese dragon
9
+ 🌏 Earth, Earth seen from ISS
10
+ 🚀 Space Station, the international space station
11
+ 🪻 Grassland, Rusty grassland with flowers
12
+ 🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
13
+ 🏙️ City Ruin, city, ruins, ruins, ruins, deserted
14
+ 🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
15
+ 🌷 Flowers, Flower garden
16
+ 🌼 Flowers, Flower garden, spring garden
17
+ 🌹 Flowers, Flowers flowers, flowers
18
+ ⛰️ Dolomites Mountains, Dolomites
19
+ ⛰️ Himalayas Mountains, Himalayas
20
+ ⛰️ Alps Mountains, Alps
21
+ ⛰️ Mountains, Mountains
22
+ ❄️⛰️ Mountains, Winter mountains
23
+ 🌷⛰️ Mountains, Spring mountains
24
+ 🌞⛰️ Mountains, Summer mountains
25
+ 🌵 Desert, A sandy desert, dunes
26
+ 🪨🌵 Desert, A rocky desert
27
+ 💦 Waterfall, A giant waterfall
28
+ 🌊 Ocean, Ocean
29
+ ⛱️ Seashore, Seashore
30
+ 🌅 Sea Horizon, Sea horizon
31
+ 🌊 Lake, Clear blue lake
32
+ 💻 Computer, A giant supecomputer
33
+ 🌳 Tree, A giant tree
34
+ 🌳 Forest, A forest
35
+ 🌳🌳 Forest, A dense forest
36
+ 🌲 Forest, Winter forest
37
+ 🌴 Forest, Summer forest, tropical forest
38
+ 👒 Hat, A hat
39
+ 🐶 Dog, Doggy body parts
40
+ 😻 Cat, A cat
41
+ 🦉 Owl, A small sitting owl
42
+ 🦅 Eagle, A small sitting eagle
43
+ 🚀 Rocket, A flying rocket
model.py ADDED
@@ -0,0 +1,1095 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import inspect
22
+ from typing import Any, Callable, Dict, List, Literal, Tuple, Optional, Union
23
+ from tqdm import tqdm
24
+ from PIL import Image
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torchvision.transforms as T
30
+ from einops import rearrange
31
+
32
+ from transformers import (
33
+ CLIPTextModelWithProjection,
34
+ CLIPTokenizer,
35
+ T5EncoderModel,
36
+ T5TokenizerFast,
37
+ )
38
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
39
+
40
+ from diffusers.image_processor import VaeImageProcessor
41
+ from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
42
+ from diffusers.models.attention_processor import (
43
+ AttnProcessor2_0,
44
+ FusedAttnProcessor2_0,
45
+ LoRAAttnProcessor2_0,
46
+ LoRAXFormersAttnProcessor,
47
+ XFormersAttnProcessor,
48
+ )
49
+ from diffusers.models.autoencoders import AutoencoderKL
50
+ from diffusers.models.transformers import SD3Transformer2DModel
51
+ from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3PipelineOutput
52
+ from diffusers.schedulers import (
53
+ FlowMatchEulerDiscreteScheduler,
54
+ FlashFlowMatchEulerDiscreteScheduler,
55
+ )
56
+ from diffusers.utils import (
57
+ is_torch_xla_available,
58
+ logging,
59
+ replace_example_docstring,
60
+ )
61
+ from diffusers.utils.torch_utils import randn_tensor
62
+ from diffusers import (
63
+ DiffusionPipeline,
64
+ StableDiffusion3Pipeline,
65
+ )
66
+
67
+ from peft import PeftModel
68
+
69
+ from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
70
+
71
+
72
+ if is_torch_xla_available():
73
+ import torch_xla.core.xla_model as xm
74
+
75
+ XLA_AVAILABLE = True
76
+ else:
77
+ XLA_AVAILABLE = False
78
+
79
+
80
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
81
+
82
+ EXAMPLE_DOC_STRING = """
83
+ Examples:
84
+ ```py
85
+ >>> import torch
86
+ >>> from diffusers import StableDiffusion3Pipeline
87
+
88
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
89
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
90
+ ... )
91
+ >>> pipe.to("cuda")
92
+ >>> prompt = "A cat holding a sign that says hello world"
93
+ >>> image = pipe(prompt).images[0]
94
+ >>> image.save("sd3.png")
95
+ ```
96
+ """
97
+
98
+
99
+ class StableMultiDiffusion3Pipeline(nn.Module):
100
+ def __init__(
101
+ self,
102
+ device: torch.device,
103
+ dtype: torch.dtype = torch.float16,
104
+ hf_key: Optional[str] = None,
105
+ lora_key: Optional[str] = None,
106
+ load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
107
+ default_mask_std: float = 1.0, # 8.0
108
+ default_mask_strength: float = 1.0,
109
+ default_prompt_strength: float = 1.0, # 8.0
110
+ default_bootstrap_steps: int = 1,
111
+ default_boostrap_mix_steps: float = 1.0,
112
+ default_bootstrap_leak_sensitivity: float = 0.2,
113
+ default_preprocess_mask_cover_alpha: float = 0.3,
114
+ t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number.
115
+ mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
116
+ has_i2t: bool = True,
117
+ lora_weight: float = 1.0,
118
+ ) -> None:
119
+ r"""Stabilized MultiDiffusion for fast sampling.
120
+
121
+ Accelrated region-based text-to-image synthesis with Latent Consistency
122
+ Model while preserving mask fidelity and quality.
123
+
124
+ Args:
125
+ device (torch.device): Specify CUDA device.
126
+ hf_key (Optional[str]): Custom StableDiffusion checkpoint for
127
+ stylized generation.
128
+ lora_key (Optional[str]): Custom Lightning LoRA for acceleration.
129
+ load_from_local (bool): Turn on if you have already downloaed LoRA
130
+ & Hugging Face hub is down.
131
+ default_mask_std (float): Preprocess mask with Gaussian blur with
132
+ specified standard deviation.
133
+ default_mask_strength (float): Preprocess mask by multiplying it
134
+ globally with the specified variable. Caution: extremely
135
+ sensitive. Recommended range: 0.98-1.
136
+ default_prompt_strength (float): Preprocess foreground prompts
137
+ globally by linearly interpolating its embedding with the
138
+ background prompt embeddint with specified mix ratio. Useful
139
+ control handle for foreground blending. Recommended range:
140
+ 0.5-1.
141
+ default_bootstrap_steps (int): Bootstrapping stage steps to
142
+ encourage region separation. Recommended range: 1-3.
143
+ default_boostrap_mix_steps (float): Bootstrapping background is a
144
+ linear interpolation between background latent and the white
145
+ image latent. This handle controls the mix ratio. Available
146
+ range: 0-(number of bootstrapping inference steps). For
147
+ example, 2.3 means that for the first two steps, white image
148
+ is used as a bootstrapping background and in the third step,
149
+ mixture of white (0.3) and registered background (0.7) is used
150
+ as a bootstrapping background.
151
+ default_bootstrap_leak_sensitivity (float): Postprocessing at each
152
+ inference step by masking away the remaining bootstrap
153
+ backgrounds t Recommended range: 0-1.
154
+ default_preprocess_mask_cover_alpha (float): Optional preprocessing
155
+ where each mask covered by other masks is reduced in its alpha
156
+ value by this specified factor.
157
+ t_index_list (List[int]): The default scheduling for the scheduler.
158
+ mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
159
+ defines the mask quantization modes. Details in the codes of
160
+ `self.process_mask`. Basically, this (subtly) controls the
161
+ smoothness of foreground-background blending. More continuous
162
+ means more blending, but smaller generated patch depending on
163
+ the mask standard deviation.
164
+ has_i2t (bool): Automatic background image to text prompt con-
165
+ version with BLIP-2 model. May not be necessary for the non-
166
+ streaming application.
167
+ lora_weight (float): Adjusts weight of the LCM/Lightning LoRA.
168
+ Heavily affects the overall quality!
169
+ """
170
+ super().__init__()
171
+
172
+ self.device = device
173
+ self.dtype = dtype
174
+
175
+ self.default_mask_std = default_mask_std
176
+ self.default_mask_strength = default_mask_strength
177
+ self.default_prompt_strength = default_prompt_strength
178
+ self.default_t_list = t_index_list
179
+ self.default_bootstrap_steps = default_bootstrap_steps
180
+ self.default_boostrap_mix_steps = default_boostrap_mix_steps
181
+ self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
182
+ self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
183
+ self.mask_type = mask_type
184
+
185
+ # Create model.
186
+ print(f'[INFO] Loading Stable Diffusion...')
187
+ if hf_key is not None:
188
+ print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
189
+ else:
190
+ hf_key = "stabilityai/stable-diffusion-3-medium-diffusers"
191
+
192
+ transformer = SD3Transformer2DModel.from_pretrained(
193
+ hf_key,
194
+ subfolder="transformer",
195
+ torch_dtype=torch.float16,
196
+ ).to(self.device)
197
+
198
+ transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-sd3").to(self.device)
199
+
200
+ self.pipe = StableDiffusion3Pipeline.from_pretrained(
201
+ "stabilityai/stable-diffusion-3-medium-diffusers",
202
+ transformer=transformer,
203
+ torch_dtype=torch.float16,
204
+ text_encoder_3=None,
205
+ tokenizer_3=None
206
+ ).to(self.device)
207
+
208
+ # Create model
209
+ if has_i2t:
210
+ self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
211
+ self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
212
+
213
+ # Use SDXL-Lightning LoRA by default.
214
+ self.pipe.scheduler = FlashFlowMatchEulerDiscreteScheduler.from_pretrained(
215
+ "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="scheduler")
216
+ self.pipe = self.pipe.to(self.device)
217
+
218
+ self.scheduler = self.pipe.scheduler
219
+ self.default_num_inference_steps = 4
220
+ self.default_guidance_scale = 0.0
221
+
222
+ if t_index_list is None:
223
+ self.prepare_flashflowmatch_schedule(
224
+ list(range(self.default_num_inference_steps)),
225
+ self.default_num_inference_steps,
226
+ )
227
+ else:
228
+ self.prepare_flashflowmatch_schedule(t_index_list, 50)
229
+
230
+ self.vae = self.pipe.vae
231
+ self.tokenizer = self.pipe.tokenizer
232
+ self.tokenizer_2 = self.pipe.tokenizer_2
233
+ self.tokenizer_3 = self.pipe.tokenizer_3
234
+ self.text_encoder = self.pipe.text_encoder
235
+ self.text_encoder_2 = self.pipe.text_encoder_2
236
+ self.text_encoder_3 = self.pipe.text_encoder_3
237
+ self.transformer = self.pipe.transformer
238
+ self.vae_scale_factor = self.pipe.vae_scale_factor
239
+
240
+ # Prepare white background for bootstrapping.
241
+ self.get_white_background(1024, 1024)
242
+
243
+ print(f'[INFO] Model is loaded!')
244
+
245
+ def prepare_flashflowmatch_schedule(
246
+ self,
247
+ t_index_list: Optional[List[int]] = None,
248
+ num_inference_steps: Optional[int] = None,
249
+ ) -> None:
250
+ r"""Set up different inference schedule for the diffusion model.
251
+
252
+ You do not have to run this explicitly if you want to use the default
253
+ setting, but if you want other time schedules, run this function
254
+ between the module initialization and the main call.
255
+
256
+ Note:
257
+ - Recommended t_index_lists for LCMs:
258
+ - [0, 12, 25, 37]: Default schedule for 4 steps. Best for
259
+ panorama. Not recommended if you want to use bootstrapping.
260
+ Because bootstrapping stage affects the initial structuring
261
+ of the generated image & in this four step LCM, this is done
262
+ with only at the first step, the structure may be distorted.
263
+ - [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
264
+ strapping. Default initialization in this implementation.
265
+ - [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
266
+ bootstrapping.
267
+ - Due to the characteristic of SD1.5 LCM LoRA, setting
268
+ `num_inference_steps` larger than 20 may results in overly blurry
269
+ and unrealistic images. Beware!
270
+
271
+ Args:
272
+ t_index_list (Optional[List[int]]): The specified scheduling step
273
+ regarding the maximum timestep as `num_inference_steps`, which
274
+ is by default, 50. That means that
275
+ `t_index_list=[0, 12, 25, 37]` is a relative time indices basd
276
+ on the full scale of 50. If None, reinitialize the module with
277
+ the default value.
278
+ num_inference_steps (Optional[int]): The maximum timestep of the
279
+ sampler. Defines relative scale of the `t_index_list`. Rarely
280
+ used in practice. If None, reinitialize the module with the
281
+ default value.
282
+ """
283
+ if t_index_list is None:
284
+ t_index_list = self.default_t_list
285
+ if num_inference_steps is None:
286
+ num_inference_steps = self.default_num_inference_steps
287
+
288
+ self.scheduler.set_timesteps(num_inference_steps)
289
+ self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)].to(self.device)
290
+
291
+ # FlashFlowMatchEulerDiscreteScheduler
292
+ # https://github.com/initml/diffusers/blob/clement/feature/flash_sd3/src/diffusers/schedulers/scheduling_flash_flow_match_euler_discrete.py
293
+
294
+ self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)].to(self.device)
295
+ self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:].to(self.device)
296
+
297
+ noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
298
+ self.noise_lvs = noise_lvs[None, :, None, None, None]
299
+ self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
300
+
301
+ @torch.no_grad()
302
+ def get_text_prompts(self, image: Image.Image) -> str:
303
+ r"""A convenient method to extract text prompt from an image.
304
+
305
+ This is called if the user does not provide background prompt but only
306
+ the background image. We use BLIP-2 to automatically generate prompts.
307
+
308
+ Args:
309
+ image (Image.Image): A PIL image.
310
+
311
+ Returns:
312
+ A single string of text prompt.
313
+ """
314
+ if hasattr(self, 'i2t_model'):
315
+ question = 'Question: What are in the image? Answer:'
316
+ inputs = self.i2t_processor(image, question, return_tensors='pt')
317
+ out = self.i2t_model.generate(**inputs, max_new_tokens=77)
318
+ prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
319
+ return prompt
320
+ else:
321
+ return ''
322
+
323
+ @torch.no_grad()
324
+ def encode_imgs(
325
+ self,
326
+ imgs: torch.Tensor,
327
+ generator: Optional[torch.Generator] = None,
328
+ vae: Optional[nn.Module] = None,
329
+ ) -> torch.Tensor:
330
+ r"""A wrapper function for VAE encoder of the latent diffusion model.
331
+
332
+ Args:
333
+ imgs (torch.Tensor): An image to get StableDiffusion latents.
334
+ Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
335
+ generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
336
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
337
+ the demo application with TinyVAE).
338
+
339
+ Returns:
340
+ An image latent embedding with 1/8 size (depending on the auto-
341
+ encoder. Shape: (B, 4, H//8, W//8).
342
+ """
343
+ def _retrieve_latents(
344
+ encoder_output: torch.Tensor,
345
+ generator: Optional[torch.Generator] = None,
346
+ sample_mode: str = 'sample',
347
+ ):
348
+ if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
349
+ return encoder_output.latent_dist.sample(generator)
350
+ elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
351
+ return encoder_output.latent_dist.mode()
352
+ elif hasattr(encoder_output, 'latents'):
353
+ return encoder_output.latents
354
+ else:
355
+ raise AttributeError('Could not access latents of provided encoder_output')
356
+
357
+ vae = self.vae if vae is None else vae
358
+ imgs = 2 * imgs - 1
359
+ latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
360
+ return latents
361
+
362
+ @torch.no_grad()
363
+ def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
364
+ r"""A wrapper function for VAE decoder of the latent diffusion model.
365
+
366
+ Args:
367
+ latents (torch.Tensor): An image latent to get associated images.
368
+ Expected shape: (B, 4, H//8, W//8).
369
+ vae (Optional[nn.Module]): Explicitly specify VAE (used for
370
+ the demo application with TinyVAE).
371
+
372
+ Returns:
373
+ An image latent embedding with 1/8 size (depending on the auto-
374
+ encoder. Shape: (B, 3, H, W).
375
+ """
376
+ vae = self.vae if vae is None else vae
377
+ latents = 1 / vae.config.scaling_factor * latents
378
+ imgs = vae.decode(latents).sample
379
+ imgs = (imgs / 2 + 0.5).clip_(0, 1)
380
+ return imgs
381
+
382
+ @torch.no_grad()
383
+ def get_white_background(self, height: int, width: int) -> torch.Tensor:
384
+ r"""White background image latent for bootstrapping or in case of
385
+ absent background.
386
+
387
+ Additionally stores the maximally-sized white latent for fast retrieval
388
+ in the future. By default, we initially call this with 1024x1024 sized
389
+ white image, so the function is rarely visited twice.
390
+
391
+ Args:
392
+ height (int): The height of the white *image*, not its latent.
393
+ width (int): The width of the white *image*, not its latent.
394
+
395
+ Returns:
396
+ A white image latent of size (1, 4, height//8, width//8). A cropped
397
+ version of the stored white latent is returned if the requested
398
+ size is smaller than what we already have created.
399
+ """
400
+ if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
401
+ white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
402
+ self.white = self.encode_imgs(white)
403
+ return self.white
404
+ return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
405
+
406
+ @torch.no_grad()
407
+ def process_mask(
408
+ self,
409
+ masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
410
+ strength: Optional[Union[torch.Tensor, float]] = None,
411
+ std: Optional[Union[torch.Tensor, float]] = None,
412
+ height: int = 1024,
413
+ width: int = 1024,
414
+ use_boolean_mask: bool = True,
415
+ timesteps: Optional[torch.Tensor] = None,
416
+ preprocess_mask_cover_alpha: Optional[float] = None,
417
+ ) -> Tuple[torch.Tensor]:
418
+ r"""Fast preprocess of masks for region-based generation with fine-
419
+ grained controls.
420
+
421
+ Mask preprocessing is done in four steps:
422
+ 1. Resizing: Resize the masks into the specified width and height by
423
+ nearest neighbor interpolation.
424
+ 2. (Optional) Ordering: Masks with higher indices are considered to
425
+ cover the masks with smaller indices. Covered masks are decayed
426
+ in its alpha value by the specified factor of
427
+ `preprocess_mask_cover_alpha`.
428
+ 3. Blurring: Gaussian blur is applied to the mask with the specified
429
+ standard deviation (isotropic). This results in gradual increase of
430
+ masked region as the timesteps evolve, naturally blending fore-
431
+ ground and the predesignated background. Not strictly required if
432
+ you want to produce images from scratch withoout background.
433
+ 4. Quantization: Split the real-numbered masks of value between [0, 1]
434
+ into predefined noise levels for each quantized scheduling step of
435
+ the diffusion sampler. For example, if the diffusion model sampler
436
+ has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
437
+ is the default noise level of this module with schedule [0, 4, 12,
438
+ 25, 37], the masks are split into binary masks whose values are
439
+ greater than these levels. This results in tradual increase of mask
440
+ region as the timesteps increase. Details are described in our
441
+ paper at https://arxiv.org/pdf/2403.09055.pdf.
442
+
443
+ On the Three Modes of `mask_type`:
444
+ `self.mask_type` is predefined at the initialization stage of this
445
+ pipeline. Three possible modes are available: 'discrete', 'semi-
446
+ continuous', and 'continuous'. These define the mask quantization
447
+ modes we use. Basically, this (subtly) controls the smoothness of
448
+ foreground-background blending. Continuous modes produces nonbinary
449
+ masks to further blend foreground and background latents by linear-
450
+ ly interpolating between them. Semi-continuous masks only applies
451
+ continuous mask at the last step of the LCM sampler. Due to the
452
+ large step size of the LCM scheduler, we find that our continuous
453
+ blending helps generating seamless inpainting and editing results.
454
+
455
+ Args:
456
+ masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
457
+ strength (Optional[Union[torch.Tensor, float]]): Mask strength that
458
+ overrides the default value. A globally multiplied factor to
459
+ the mask at the initial stage of processing. Can be applied
460
+ seperately for each mask.
461
+ std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
462
+ kernel's standard deviation. Overrides the default value. Can
463
+ be applied seperately for each mask.
464
+ height (int): The height of the expected generation. Mask is
465
+ resized to (height//8, width//8) with nearest neighbor inter-
466
+ polation.
467
+ width (int): The width of the expected generation. Mask is resized
468
+ to (height//8, width//8) with nearest neighbor interpolation.
469
+ use_boolean_mask (bool): Specify this to treat the mask image as
470
+ a boolean tensor. The retion with dark part darker than 0.5 of
471
+ the maximal pixel value (that is, 127.5) is considered as the
472
+ designated mask.
473
+ timesteps (Optional[torch.Tensor]): Defines the scheduler noise
474
+ levels that acts as bins of mask quantization.
475
+ preprocess_mask_cover_alpha (Optional[float]): Optional pre-
476
+ processing where each mask covered by other masks is reduced in
477
+ its alpha value by this specified factor. Overrides the default
478
+ value.
479
+
480
+ Returns: A tuple of tensors.
481
+ - masks: Preprocessed (ordered, blurred, and quantized) binary/non-
482
+ binary masks (see the explanation on `mask_type` above) for
483
+ region-based image synthesis.
484
+ - masks_blurred: Gaussian blurred masks. Used for optionally
485
+ specified foreground-background blending after image
486
+ generation.
487
+ - std: Mask blur standard deviation. Used for optionally specified
488
+ foreground-background blending after image generation.
489
+ """
490
+ if isinstance(masks, Image.Image):
491
+ masks = [masks]
492
+ if isinstance(masks, (tuple, list)):
493
+ # Assumes white background for Image.Image;
494
+ # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
495
+ if use_boolean_mask:
496
+ proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
497
+ else:
498
+ proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
499
+ masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
500
+ masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
501
+ masks = masks.to(self.device)
502
+
503
+ # Background mask alpha is decayed by the specified factor where foreground masks covers it.
504
+ if preprocess_mask_cover_alpha is None:
505
+ preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
506
+ if preprocess_mask_cover_alpha > 0:
507
+ masks = torch.stack([
508
+ torch.where(
509
+ masks[i + 1:].sum(dim=0) > 0,
510
+ mask * preprocess_mask_cover_alpha,
511
+ mask,
512
+ ) if i < len(masks) - 1 else mask
513
+ for i, mask in enumerate(masks)
514
+ ], dim=0)
515
+
516
+ # Scheduler noise levels for mask quantization.
517
+ if timesteps is None:
518
+ noise_lvs = self.noise_lvs
519
+ next_noise_lvs = self.next_noise_lvs
520
+ else:
521
+ noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
522
+ # noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
523
+ noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device)
524
+ next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
525
+
526
+ # Mask preprocessing parameters are fetched from the default settings.
527
+ if std is None:
528
+ std = self.default_mask_std
529
+ if isinstance(std, (int, float)):
530
+ std = [std] * len(masks)
531
+ if isinstance(std, (list, tuple)):
532
+ std = torch.as_tensor(std, dtype=torch.float, device=self.device)
533
+
534
+ if strength is None:
535
+ strength = self.default_mask_strength
536
+ if isinstance(strength, (int, float)):
537
+ strength = [strength] * len(masks)
538
+ if isinstance(strength, (list, tuple)):
539
+ strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
540
+
541
+ if (std > 0).any():
542
+ std = torch.where(std > 0, std, 1e-5)
543
+ masks = gaussian_lowpass(masks, std)
544
+ masks_blurred = masks
545
+
546
+ # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
547
+ # gives unpleasant results.
548
+ masks = masks * strength[:, None, None, None]
549
+ masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
550
+
551
+ # Mask is quantized according to the current noise levels specified by the scheduler.
552
+ if self.mask_type == 'discrete':
553
+ # Discrete mode.
554
+ masks = masks > noise_lvs
555
+ elif self.mask_type == 'semi-continuous':
556
+ # Semi-continuous mode (continuous at the last step only).
557
+ masks = torch.cat((
558
+ masks[:, :-1] > noise_lvs[:, :-1],
559
+ (
560
+ (masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
561
+ ).clip_(0, 1),
562
+ ), dim=1)
563
+ elif self.mask_type == 'continuous':
564
+ # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
565
+ # decreases continuously after the discrete mode boundary to become `0` at the
566
+ # next lower threshold.
567
+ masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
568
+
569
+ # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
570
+ # fine-grained mask alpha channel tuning is available with this form.
571
+ # masks = masks * strength[None, :, None, None, None]
572
+
573
+ h = height // self.vae_scale_factor
574
+ w = width // self.vae_scale_factor
575
+ masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
576
+ masks = F.interpolate(masks, size=(h, w), mode='nearest')
577
+ masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
578
+ return masks, masks_blurred, std
579
+
580
+ def scheduler_step(
581
+ self,
582
+ noise_pred: torch.Tensor,
583
+ idx: int,
584
+ latent: torch.Tensor,
585
+ ) -> torch.Tensor:
586
+ r"""Denoise-only step for reverse diffusion scheduler.
587
+
588
+ Designed to match the interface of the original `pipe.scheduler.step`,
589
+ which is a combination of this method and the following
590
+ `scheduler_add_noise`.
591
+
592
+ Args:
593
+ noise_pred (torch.Tensor): Noise prediction results from the U-Net.
594
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
595
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
596
+ latent (torch.Tensor): Noisy latent.
597
+
598
+ Returns:
599
+ A denoised tensor with the same size as latent.
600
+ """
601
+ # Upcast to avoid precision issues when computing prev_sample.
602
+ latent = latent.to(torch.float32)
603
+ prev_sample = latent - noise_pred * self.sigmas[idx]
604
+ return prev_sample.to(self.dtype)
605
+
606
+ def scheduler_add_noise(
607
+ self,
608
+ latent: torch.Tensor,
609
+ noise: Optional[torch.Tensor],
610
+ idx: int,
611
+ ) -> torch.Tensor:
612
+ r"""Separated noise-add step for the reverse diffusion scheduler.
613
+
614
+ Designed to match the interface of the original
615
+ `pipe.scheduler.add_noise`.
616
+
617
+ Args:
618
+ latent (torch.Tensor): Denoised latent.
619
+ noise (torch.Tensor): Added noise. Can be None. If None, a random
620
+ noise is newly sampled for addition.
621
+ idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
622
+ for the timesteps tensor (ranged in [0, len(timesteps)-1]).
623
+
624
+ Returns:
625
+ A noisy tensor with the same size as latent.
626
+ """
627
+ if idx < len(self.sigmas) and idx >= 0:
628
+ noise = torch.randn_like(latent) if noise is None else noise
629
+ return (1.0 - self.sigmas[idx]) * latent + self.sigmas[idx] * noise
630
+ else:
631
+ return latent
632
+
633
+ @torch.no_grad()
634
+ def __call__(
635
+ self,
636
+ prompts: Optional[Union[str, List[str]]] = None,
637
+ negative_prompts: Union[str, List[str]] = '',
638
+ suffix: Optional[str] = None, #', background is ',
639
+ background: Optional[Union[torch.Tensor, Image.Image]] = None,
640
+ background_prompt: Optional[str] = None,
641
+ background_negative_prompt: str = '',
642
+ height: int = 1024,
643
+ width: int = 1024,
644
+ num_inference_steps: Optional[int] = None,
645
+ guidance_scale: Optional[float] = None,
646
+ prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
647
+ masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
648
+ mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
649
+ mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
650
+ use_boolean_mask: bool = True,
651
+ do_blend: bool = True,
652
+ tile_size: int = 1024,
653
+ bootstrap_steps: Optional[int] = None,
654
+ boostrap_mix_steps: Optional[float] = None,
655
+ bootstrap_leak_sensitivity: Optional[float] = None,
656
+ preprocess_mask_cover_alpha: Optional[float] = None,
657
+ # SDXL Pipeline setting.
658
+ guidance_rescale: float = 0.7,
659
+ output_type = 'pil',
660
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
661
+ clip_skip: Optional[int] = None,
662
+ ) -> Image.Image:
663
+ r"""Arbitrary-size image generation from multiple pairs of (regional)
664
+ text prompt-mask pairs.
665
+
666
+ This is a main routine for this pipeline.
667
+
668
+ Example:
669
+ >>> device = torch.device('cuda:0')
670
+ >>> smd = StableMultiDiffusionPipeline(device)
671
+ >>> prompts = {... specify prompts}
672
+ >>> masks = {... specify mask tensors}
673
+ >>> height, width = masks.shape[-2:]
674
+ >>> image = smd(
675
+ >>> prompts, masks=masks.float(), height=height, width=width)
676
+ >>> image.save('my_beautiful_creation.png')
677
+
678
+ Args:
679
+ prompts (Union[str, List[str]]): A text prompt.
680
+ negative_prompts (Union[str, List[str]]): A negative text prompt.
681
+ suffix (Optional[str]): One option for blending foreground prompts
682
+ with background prompts by simply appending background prompt
683
+ to the end of each foreground prompt with this `middle word` in
684
+ between. For example, if you set this as `, background is`,
685
+ then the foreground prompt will be changed into
686
+ `(fg), background is (bg)` before conditional generation.
687
+ background (Optional[Union[torch.Tensor, Image.Image]]): a
688
+ background image, if the user wants to draw in front of the
689
+ specified image. Background prompt will automatically generated
690
+ with a BLIP-2 model.
691
+ background_prompt (Optional[str]): The background prompt is used
692
+ for preprocessing foreground prompt embeddings to blend
693
+ foreground and background.
694
+ background_negative_prompt (Optional[str]): The negative background
695
+ prompt.
696
+ height (int): Height of a generated image. It is tiled if larger
697
+ than `tile_size`.
698
+ width (int): Width of a generated image. It is tiled if larger
699
+ than `tile_size`.
700
+ num_inference_steps (Optional[int]): Number of inference steps.
701
+ Default inference scheduling is used if none is specified.
702
+ guidance_scale (Optional[float]): Classifier guidance scale.
703
+ Default value is used if none is specified.
704
+ prompt_strength (float): Overrides default value. Preprocess
705
+ foreground prompts globally by linearly interpolating its
706
+ embedding with the background prompt embeddint with specified
707
+ mix ratio. Useful control handle for foreground blending.
708
+ Recommended range: 0.5-1.
709
+ masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
710
+ mask images. Each mask associates with each of the text prompts
711
+ and each of the negative prompts. If specified as an image, it
712
+ regards the image as a boolean mask. Also accepts torch.Tensor
713
+ masks, which can have nonbinary values for fine-grained
714
+ controls in mixing regional generations.
715
+ mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
716
+ Overrides the default value. an be assigned for each mask
717
+ separately. Preprocess mask by multiplying it globally with the
718
+ specified variable. Caution: extremely sensitive. Recommended
719
+ range: 0.98-1.
720
+ mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
721
+ Overrides the default value. Can be assigned for each mask
722
+ separately. Preprocess mask with Gaussian blur with specified
723
+ standard deviation. Recommended range: 0-64.
724
+ use_boolean_mask (bool): Turn this off if you want to treat the
725
+ mask image as nonbinary one. The module will use the last
726
+ channel of the given image in `masks` as the mask value.
727
+ do_blend (bool): Blend the generated foreground and the optionally
728
+ predefined background by smooth boundary obtained from Gaussian
729
+ blurs of the foreground `masks` with the given `mask_stds`.
730
+ tile_size (Optional[int]): Tile size of the panorama generation.
731
+ Works best with the default training size of the Stable-
732
+ Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL.
733
+ bootstrap_steps (int): Overrides the default value. Bootstrapping
734
+ stage steps to encourage region separation. Recommended range:
735
+ 1-3.
736
+ boostrap_mix_steps (float): Overrides the default value.
737
+ Bootstrapping background is a linear interpolation between
738
+ background latent and the white image latent. This handle
739
+ controls the mix ratio. Available range: 0-(number of
740
+ bootstrapping inference steps). For example, 2.3 means that for
741
+ the first two steps, white image is used as a bootstrapping
742
+ background and in the third step, mixture of white (0.3) and
743
+ registered background (0.7) is used as a bootstrapping
744
+ background.
745
+ bootstrap_leak_sensitivity (float): Overrides the default value.
746
+ Postprocessing at each inference step by masking away the
747
+ remaining bootstrap backgrounds t Recommended range: 0-1.
748
+ preprocess_mask_cover_alpha (float): Overrides the default value.
749
+ Optional preprocessing where each mask covered by other masks
750
+ is reduced in its alpha value by this specified factor.
751
+
752
+ Returns: A PIL.Image image of a panorama (large-size) image.
753
+ """
754
+
755
+ ### Simplest cases
756
+
757
+ # prompts is None: return background.
758
+ # masks is None but prompts is not None: return prompts
759
+ # masks is not None and prompts is not None: Do StableMultiDiffusion.
760
+
761
+ if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
762
+ if background is None and background_prompt is not None:
763
+ return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
764
+ return background
765
+ elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
766
+ return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
767
+
768
+
769
+ ### Prepare generation
770
+
771
+ if num_inference_steps is not None:
772
+ self.prepare_flashflowmatch_schedule(list(range(num_inference_steps)), num_inference_steps)
773
+
774
+ if guidance_scale is None:
775
+ guidance_scale = self.default_guidance_scale
776
+ self.pipe._guidance_scale = guidance_scale
777
+ self.pipe._clip_skip = clip_skip
778
+ self.pipe._joint_attention_kwargs = joint_attention_kwargs
779
+ self.pipe._interrupt = False
780
+ do_classifier_free_guidance = guidance_scale > 1.0
781
+
782
+
783
+ ### Prompts & Masks
784
+
785
+ # asserts #m > 0 and #p > 0.
786
+ # #m == #p == #n > 0: We happily generate according to the prompts & masks.
787
+ # #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
788
+ # #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
789
+
790
+ if isinstance(masks, Image.Image):
791
+ masks = [masks]
792
+ if isinstance(prompts, str):
793
+ prompts = [prompts]
794
+ if isinstance(negative_prompts, str):
795
+ negative_prompts = [negative_prompts]
796
+ num_masks = len(masks)
797
+ num_prompts = len(prompts)
798
+ num_nprompts = len(negative_prompts)
799
+ assert num_prompts in (num_masks, 1), \
800
+ f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
801
+ assert num_nprompts in (num_prompts, 1), \
802
+ f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
803
+
804
+ fg_masks, masks_g, std = self.process_mask(
805
+ masks,
806
+ mask_strengths,
807
+ mask_stds,
808
+ height=height,
809
+ width=width,
810
+ use_boolean_mask=use_boolean_mask,
811
+ timesteps=self.timesteps,
812
+ preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
813
+ ) # (p, t, 1, H, W)
814
+ bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
815
+ has_background = bg_masks.sum() > 0
816
+
817
+ h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
818
+ w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
819
+
820
+
821
+ ### Background
822
+
823
+ # background == None && background_prompt == None: Initialize with white background.
824
+ # background == None && background_prompt != None: Generate background *along with other prompts*.
825
+ # background != None && background_prompt == None: Retrieve text prompt using BLIP.
826
+ # background != None && background_prompt != None: Use the given arguments.
827
+
828
+ # not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
829
+ # has_background && prompt_strength != 1: mix only for this case.
830
+
831
+ bg_latent = None
832
+ if has_background:
833
+ if background is None and background_prompt is not None:
834
+ fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
835
+ if suffix is not None:
836
+ prompts = [p + suffix + background_prompt for p in prompts]
837
+ prompts = [background_prompt] + prompts
838
+ negative_prompts = [background_negative_prompt] + negative_prompts
839
+ has_background = False # Regard that background does not exist.
840
+ else:
841
+ if background is None and background_prompt is None:
842
+ background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
843
+ background_prompt = 'simple white background image'
844
+ elif background is not None and background_prompt is None:
845
+ background_prompt = self.get_text_prompts(background)
846
+ if suffix is not None:
847
+ prompts = [p + suffix + background_prompt for p in prompts]
848
+ prompts = [background_prompt] + prompts
849
+ negative_prompts = [background_negative_prompt] + negative_prompts
850
+ if isinstance(background, Image.Image):
851
+ background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
852
+ background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
853
+ bg_latent = self.encode_imgs(background)
854
+
855
+ # Bootstrapping stage preparation.
856
+
857
+ if bootstrap_steps is None:
858
+ bootstrap_steps = self.default_bootstrap_steps
859
+ if boostrap_mix_steps is None:
860
+ boostrap_mix_steps = self.default_boostrap_mix_steps
861
+ if bootstrap_leak_sensitivity is None:
862
+ bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
863
+ if bootstrap_steps > 0:
864
+ height_ = min(height, tile_size)
865
+ width_ = min(width, tile_size)
866
+ white = self.get_white_background(height, width) # (1, 4, h, w)
867
+
868
+
869
+ ### Prepare text embeddings (optimized for the minimal encoder batch size)
870
+
871
+ # SD3 pipeline settings.
872
+ batch_size = 1
873
+ num_images_per_prompt = 1
874
+
875
+ original_size = (height, width)
876
+ target_size = (height, width)
877
+ crops_coords_top_left = (0, 0)
878
+ negative_original_size = None
879
+ negative_target_size = None
880
+ negative_crops_coords_top_left = (0, 0)
881
+
882
+ prompt_2 = None
883
+ prompt_3 = None
884
+ negative_prompt_2 = None
885
+ negative_prompt_3 = None
886
+ prompt_embeds = None
887
+ negative_prompt_embeds = None
888
+ pooled_prompt_embeds = None
889
+ negative_pooled_prompt_embeds = None
890
+ text_encoder_lora_scale = None
891
+
892
+ (
893
+ prompt_embeds,
894
+ negative_prompt_embeds,
895
+ pooled_prompt_embeds,
896
+ negative_pooled_prompt_embeds,
897
+ ) = self.pipe.encode_prompt(
898
+ prompt=prompts,
899
+ prompt_2=prompt_2,
900
+ prompt_3=prompt_3,
901
+ negative_prompt=negative_prompts,
902
+ negative_prompt_2=negative_prompt_2,
903
+ negative_prompt_3=negative_prompt_3,
904
+ do_classifier_free_guidance=do_classifier_free_guidance,
905
+ prompt_embeds=prompt_embeds,
906
+ negative_prompt_embeds=negative_prompt_embeds,
907
+ pooled_prompt_embeds=pooled_prompt_embeds,
908
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
909
+ device=self.device,
910
+ clip_skip=self.pipe.clip_skip,
911
+ num_images_per_prompt=num_images_per_prompt,
912
+ )
913
+
914
+ if has_background:
915
+ # First channel is background prompt text embeds. Background prompt itself is not used for generation.
916
+ s = prompt_strengths
917
+ if prompt_strengths is None:
918
+ s = self.default_prompt_strength
919
+ if isinstance(s, (int, float)):
920
+ s = [s] * num_prompts
921
+ if isinstance(s, (list, tuple)):
922
+ assert len(s) == num_prompts, \
923
+ f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
924
+ s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
925
+ s = s[:, None, None]
926
+
927
+ be = prompt_embeds[:1]
928
+ fe = prompt_embeds[1:]
929
+ prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024)
930
+
931
+ if negative_prompt_embeds is not None:
932
+ bu = negative_prompt_embeds[:1]
933
+ fu = negative_prompt_embeds[1:]
934
+ if num_prompts > num_nprompts:
935
+ # # negative prompts = 1; # prompts > 1.
936
+ assert fu.shape[0] == 1 and fe.shape == num_prompts
937
+ fu = fu.repeat(num_prompts, 1, 1)
938
+ negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
939
+
940
+ be = pooled_prompt_embeds[:1]
941
+ fe = pooled_prompt_embeds[1:]
942
+ pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280)
943
+
944
+ if negative_pooled_prompt_embeds is not None:
945
+ bu = negative_pooled_prompt_embeds[:1]
946
+ fu = negative_pooled_prompt_embeds[1:]
947
+ if num_prompts > num_nprompts:
948
+ # # negative prompts = 1; # prompts > 1.
949
+ assert fu.shape[0] == 1 and fe.shape == num_prompts
950
+ fu = fu.repeat(num_prompts, 1)
951
+ negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280)
952
+ elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
953
+ # # negative prompts = 1; # prompts > 1.
954
+ assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
955
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
956
+
957
+ assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
958
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
959
+ # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
960
+ if num_masks > num_prompts:
961
+ assert masks.shape[0] == num_masks and num_prompts == 1
962
+ prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1)
963
+ if negative_prompt_embeds is not None:
964
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
965
+
966
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
967
+ if negative_pooled_prompt_embeds is not None:
968
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
969
+
970
+ # SD3 pipeline settings.
971
+ if do_classifier_free_guidance:
972
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
973
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
974
+ del negative_prompt_embeds, negative_pooled_prompt_embeds
975
+
976
+ prompt_embeds = prompt_embeds.to(self.device)
977
+ pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
978
+
979
+
980
+ ### Run
981
+
982
+ # Latent initialization.
983
+ num_channels_latents = self.transformer.config.in_channels
984
+ noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
985
+ if self.timesteps[0] < 999 and has_background:
986
+ latent = self.scheduler_add_noise(bg_latent, noise, 0)
987
+ else:
988
+ noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device)
989
+ latent = noise
990
+
991
+ if has_background:
992
+ noise_bg_latents = [
993
+ self.scheduler_add_noise(bg_latent, noise, i) for i in range(len(self.timesteps))
994
+ ] + [bg_latent]
995
+
996
+ # Tiling (if needed).
997
+ if height > tile_size or width > tile_size:
998
+ t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
999
+ views, tile_masks = get_panorama_views(h, w, t)
1000
+ tile_masks = tile_masks.to(self.device)
1001
+ else:
1002
+ views = [(0, h, 0, w)]
1003
+ tile_masks = latent.new_ones((1, 1, h, w))
1004
+ value = torch.zeros_like(latent)
1005
+ count_all = torch.zeros_like(latent)
1006
+
1007
+ with torch.autocast('cuda'):
1008
+ for i, t in enumerate(tqdm(self.timesteps)):
1009
+ if self.pipe.interrupt:
1010
+ continue
1011
+
1012
+ fg_mask = fg_masks[:, i]
1013
+ bg_mask = bg_masks[i:i + 1]
1014
+
1015
+ value.zero_()
1016
+ count_all.zero_()
1017
+ for j, (h_start, h_end, w_start, w_end) in enumerate(views):
1018
+ fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
1019
+ latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
1020
+
1021
+ # Bootstrap for tight background.
1022
+ if i < bootstrap_steps:
1023
+ mix_ratio = min(1, max(0, boostrap_mix_steps - i))
1024
+ # Treat the first foreground latent as the background latent if one does not exist.
1025
+ bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
1026
+ white_ = white[..., h_start:h_end, w_start:w_end]
1027
+ white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i)
1028
+ bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
1029
+ latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
1030
+
1031
+ # Centering.
1032
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
1033
+
1034
+ # expand the latents if we are doing classifier free guidance
1035
+ latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
1036
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1037
+ timestep = t.expand(latent_model_input.shape[0])
1038
+
1039
+ # Perform one step of the reverse diffusion.
1040
+ noise_pred = self.transformer(
1041
+ hidden_states=latent_model_input,
1042
+ timestep=timestep,
1043
+ encoder_hidden_states=prompt_embeds,
1044
+ pooled_projections=pooled_prompt_embeds,
1045
+ joint_attention_kwargs=joint_attention_kwargs,
1046
+ return_dict=False,
1047
+ )[0]
1048
+
1049
+ if do_classifier_free_guidance:
1050
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1051
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1052
+
1053
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1054
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1055
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
1056
+
1057
+ latent_ = self.scheduler_step(noise_pred, i, latent_)
1058
+
1059
+ if i < bootstrap_steps:
1060
+ # Uncentering.
1061
+ latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
1062
+
1063
+ # Remove leakage (optional).
1064
+ leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
1065
+ leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
1066
+ fg_mask_ = fg_mask_ * leak_sigmoid
1067
+
1068
+ # Mix the latents.
1069
+ fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
1070
+ value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
1071
+ count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
1072
+
1073
+ latent = torch.where(count_all > 0, value / count_all, value)
1074
+ bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
1075
+ if has_background:
1076
+ latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
1077
+
1078
+ # Noise is added after mixing.
1079
+ if i < len(self.timesteps) - 1:
1080
+ latent = self.scheduler_add_noise(latent, None, i + 1)
1081
+
1082
+ if not output_type == "latent":
1083
+ latent = (latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1084
+ image = self.vae.decode(latent, return_dict=False)[0]
1085
+ else:
1086
+ image = latent
1087
+
1088
+ # Return PIL Image.
1089
+ image = image[0].clip_(-1, 1) * 0.5 + 0.5
1090
+ if has_background and do_blend:
1091
+ fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
1092
+ image = blend(image, background[0], fg_mask)
1093
+ else:
1094
+ image = T.ToPILImage()(image)
1095
+ return image
prompt_util.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+
4
+ quality_prompt_list = [
5
+ {
6
+ "name": "(None)",
7
+ "prompt": "{prompt}",
8
+ "negative_prompt": "nsfw, lowres",
9
+ },
10
+ {
11
+ "name": "Standard v3.0",
12
+ "prompt": "{prompt}, masterpiece, best quality",
13
+ "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
14
+ },
15
+ {
16
+ "name": "Standard v3.1",
17
+ "prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
18
+ "negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
19
+ },
20
+ {
21
+ "name": "Light v3.1",
22
+ "prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
23
+ "negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
24
+ },
25
+ {
26
+ "name": "Heavy v3.1",
27
+ "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
28
+ "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
29
+ },
30
+ ]
31
+
32
+ style_list = [
33
+ {
34
+ "name": "(None)",
35
+ "prompt": "{prompt}",
36
+ "negative_prompt": "",
37
+ },
38
+ {
39
+ "name": "Cinematic",
40
+ "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
41
+ "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
42
+ },
43
+ {
44
+ "name": "Photographic",
45
+ "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
46
+ "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
47
+ },
48
+ {
49
+ "name": "Anime",
50
+ "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
51
+ "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
52
+ },
53
+ {
54
+ "name": "Manga",
55
+ "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
56
+ "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
57
+ },
58
+ {
59
+ "name": "Digital Art",
60
+ "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
61
+ "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
62
+ },
63
+ {
64
+ "name": "Pixel art",
65
+ "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
66
+ "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
67
+ },
68
+ {
69
+ "name": "Fantasy art",
70
+ "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
71
+ "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
72
+ },
73
+ {
74
+ "name": "Neonpunk",
75
+ "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
76
+ "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
77
+ },
78
+ {
79
+ "name": "3D Model",
80
+ "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
81
+ "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
82
+ },
83
+ ]
84
+
85
+
86
+ _style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
87
+ _quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
88
+
89
+
90
+ def preprocess_prompt(
91
+ positive: str,
92
+ negative: str = "",
93
+ style_dict: Dict[str, dict] = _quality_dict,
94
+ style_name: str = "Standard v3.1", # "Heavy v3.1"
95
+ add_style: bool = True,
96
+ ) -> Tuple[str, str]:
97
+ p, n = style_dict.get(style_name, style_dict["(None)"])
98
+
99
+ if add_style and positive.strip():
100
+ formatted_positive = p.format(prompt=positive)
101
+ else:
102
+ formatted_positive = positive
103
+
104
+ combined_negative = n
105
+ if negative.strip():
106
+ if combined_negative:
107
+ combined_negative += ", " + negative
108
+ else:
109
+ combined_negative = negative
110
+
111
+ return formatted_positive, combined_negative
112
+
113
+
114
+ def preprocess_prompts(
115
+ positives: List[str],
116
+ negatives: List[str] = None,
117
+ style_dict = _style_dict,
118
+ style_name: str = "Manga", # "(None)"
119
+ quality_dict = _quality_dict,
120
+ quality_name: str = "Standard v3.1", # "Heavy v3.1"
121
+ add_style: bool = True,
122
+ add_quality_tags = True,
123
+ ) -> Tuple[List[str], List[str]]:
124
+ if negatives is None:
125
+ negatives = ['' for _ in positives]
126
+
127
+ positives_ = []
128
+ negatives_ = []
129
+ for pos, neg in zip(positives, negatives):
130
+ pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
131
+ pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
132
+ positives_.append(pos)
133
+ negatives_.append(neg)
134
+ return positives_, negatives_
135
+
136
+
137
+ def print_prompts(
138
+ positives: Union[str, List[str]],
139
+ negatives: Union[str, List[str]],
140
+ has_background: bool = False,
141
+ ) -> None:
142
+ if isinstance(positives, str):
143
+ positives = [positives]
144
+ if isinstance(negatives, str):
145
+ negatives = [negatives]
146
+
147
+ for i, prompt in enumerate(positives):
148
+ prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
149
+ if has_background else f'Prompt{i + 1}')
150
+ print(prefix + ': ' + prompt)
151
+ for i, prompt in enumerate(negatives):
152
+ prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
153
+ if has_background else f'Negative Prompt{i + 1}')
154
+ print(prefix + ': ' + prompt)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ xformers==0.0.22
4
+ einops
5
+ diffusers @ git+https://github.com/initml/diffusers.git@clement/feature/flash_sd3
6
+ transformers
7
+ huggingface_hub[torch]
8
+ gradio==4.39.0
9
+ Pillow
10
+ emoji
11
+ numpy
12
+ tqdm
13
+ jupyterlab
14
+ peft>=0.10.0
15
+ sentencepiece
16
+ protobuf
share_btn.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ share_js = """async () => {
2
+ async function uploadFile(file) {
3
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
4
+ const response = await fetch(UPLOAD_URL, {
5
+ method: 'POST',
6
+ headers: {
7
+ 'Content-Type': file.type,
8
+ 'X-Requested-With': 'XMLHttpRequest',
9
+ },
10
+ body: file, /// <- File inherits from Blob
11
+ });
12
+ const url = await response.text();
13
+ return url;
14
+ }
15
+ async function getBase64(file) {
16
+ var reader = new FileReader();
17
+ reader.readAsDataURL(file);
18
+ reader.onload = function () {
19
+ console.log(reader.result);
20
+ };
21
+ reader.onerror = function (error) {
22
+ console.log('Error: ', error);
23
+ };
24
+ }
25
+ const toDataURL = url => fetch(url)
26
+ .then(response => response.blob())
27
+ .then(blob => new Promise((resolve, reject) => {
28
+ const reader = new FileReader()
29
+ reader.onloadend = () => resolve(reader.result)
30
+ reader.onerror = reject
31
+ reader.readAsDataURL(blob)
32
+ }));
33
+ async function dataURLtoFile(dataurl, filename) {
34
+ var arr = dataurl.split(','), mime = arr[0].match(/:(.*?);/)[1],
35
+ bstr = atob(arr[1]), n = bstr.length, u8arr = new Uint8Array(n);
36
+ while (n--) {
37
+ u8arr[n] = bstr.charCodeAt(n);
38
+ }
39
+ return new File([u8arr], filename, {type:mime});
40
+ };
41
+
42
+ const gradioEl = document.querySelector('body > gradio-app');
43
+ const imgEls = gradioEl.querySelectorAll('#output-screen img');
44
+ if(!imgEls.length){
45
+ return;
46
+ };
47
+
48
+ const urls = await Promise.all([...imgEls].map((imgEl) => {
49
+ const origURL = imgEl.src;
50
+ const imgId = Date.now() % 200;
51
+ const fileName = 'semantic-palette-xl-' + imgId + '.png';
52
+ return toDataURL(origURL)
53
+ .then(dataUrl => {
54
+ return dataURLtoFile(dataUrl, fileName);
55
+ })
56
+ })).then(fileData => {return Promise.all([...fileData].map((file) => {
57
+ return uploadFile(file);
58
+ }))});
59
+
60
+ const htmlImgs = urls.map(url => `<img src='${url}' width='2560' height='1024'>`);
61
+ const descriptionMd = `<div style='display: flex; flex-wrap: wrap; column-gap: 0.75rem;'>
62
+ ${htmlImgs.join(`\n`)}
63
+ </div>`;
64
+ const params = new URLSearchParams({
65
+ title: `My creation`,
66
+ description: descriptionMd,
67
+ });
68
+ const paramsStr = params.toString();
69
+ window.open(`https://huggingface.co/spaces/ironjr/SemanticPaletteXL/discussions/new?${paramsStr}`, '_blank');
70
+ }"""
util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import concurrent.futures
22
+ import time
23
+ from typing import Any, Callable, List, Literal, Tuple, Union
24
+
25
+ from PIL import Image
26
+ import numpy as np
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.cuda.amp as amp
31
+ import torchvision.transforms as T
32
+ import torchvision.transforms.functional as TF
33
+
34
+ from diffusers import (
35
+ DiffusionPipeline,
36
+ StableDiffusionPipeline,
37
+ StableDiffusionXLPipeline,
38
+ )
39
+
40
+
41
+ def seed_everything(seed: int) -> None:
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+
48
+ def load_model(
49
+ model_key: str,
50
+ sd_version: Literal['1.5', 'xl'],
51
+ device: torch.device,
52
+ dtype: torch.dtype,
53
+ ) -> torch.nn.Module:
54
+ if model_key.endswith('.safetensors'):
55
+ if sd_version == '1.5':
56
+ pipeline = StableDiffusionPipeline
57
+ elif sd_version == 'xl':
58
+ pipeline = StableDiffusionXLPipeline
59
+ else:
60
+ raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
61
+ return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
62
+ try:
63
+ return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
64
+ except:
65
+ return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
66
+
67
+
68
+ def get_cutoff(cutoff: float = None, scale: float = None) -> float:
69
+ if cutoff is not None:
70
+ return cutoff
71
+
72
+ if scale is not None and cutoff is None:
73
+ return 0.5 / scale
74
+
75
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
76
+
77
+
78
+ def get_scale(cutoff: float = None, scale: float = None) -> float:
79
+ if scale is not None:
80
+ return scale
81
+
82
+ if cutoff is not None and scale is None:
83
+ return 0.5 / cutoff
84
+
85
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
86
+
87
+
88
+ def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
89
+ assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
90
+ # assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
91
+
92
+ b, c, h, w = x.shape
93
+ ks = k.shape[-1]
94
+ k = k.view(1, 1, -1).repeat(c, 1, 1)
95
+
96
+ x = x.permute(0, 2, 1, 3)
97
+ x = x.reshape(b * h, c, w)
98
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
99
+ x = F.conv1d(x, k, groups=c)
100
+ x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
101
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
102
+ x = F.conv1d(x, k, groups=c)
103
+ x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
104
+ return x
105
+
106
+
107
+ def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
108
+ assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
109
+
110
+ x = F.pad(x, (
111
+ k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
112
+ k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
113
+ ), mode='replicate')
114
+
115
+ b, c, _, _ = x.shape
116
+ if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
117
+ k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
118
+ x = F.conv2d(x, k, groups=c)
119
+ elif len(k.shape) == 3:
120
+ assert k.shape[0] == b, \
121
+ 'The number of kernels should match the batch size.'
122
+
123
+ k = k.unsqueeze(1)
124
+ x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
125
+ return x
126
+
127
+
128
+ @amp.autocast(False)
129
+ def filter_by_kernel(
130
+ x: torch.Tensor,
131
+ k: torch.Tensor,
132
+ is_batch: bool = False,
133
+ ) -> torch.Tensor:
134
+ k_dim = len(k.shape)
135
+ if k_dim == 1 or k_dim == 2 and is_batch:
136
+ return filter_2d_by_kernel_1d(x, k)
137
+ elif k_dim == 2 or k_dim == 3 and is_batch:
138
+ return filter_2d_by_kernel_2d(x, k)
139
+ else:
140
+ raise ValueError('Kernel size should be one of (1, 2, 3).')
141
+
142
+
143
+ def gen_gauss_lowpass_filter_2d(
144
+ std: torch.Tensor,
145
+ window_size: int = None,
146
+ ) -> torch.Tensor:
147
+ # Gaussian kernel size is odd in order to preserve the center.
148
+ if window_size is None:
149
+ window_size = (
150
+ 2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
151
+
152
+ y = torch.arange(
153
+ window_size, dtype=std.dtype, device=std.device
154
+ ).view(-1, 1).repeat(1, window_size)
155
+ grid = torch.stack((y.t(), y), dim=-1)
156
+ grid -= 0.5 * (window_size - 1) # (W, W)
157
+ var = (std * std).unsqueeze(-1).unsqueeze(-1)
158
+ distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
159
+ k = torch.exp(-0.5 * distsq / var)
160
+ k /= k.sum(dim=(-2, -1), keepdim=True)
161
+ return k
162
+
163
+
164
+ def gaussian_lowpass(
165
+ x: torch.Tensor,
166
+ std: Union[float, Tuple[float], torch.Tensor] = None,
167
+ cutoff: Union[float, torch.Tensor] = None,
168
+ scale: Union[float, torch.Tensor] = None,
169
+ ) -> torch.Tensor:
170
+ if std is None:
171
+ cutoff = get_cutoff(cutoff, scale)
172
+ std = 0.5 / (np.pi * cutoff)
173
+ if isinstance(std, (float, int)):
174
+ std = (std, std)
175
+ if isinstance(std, torch.Tensor):
176
+ """Using nn.functional.conv2d with Gaussian kernels built in runtime is
177
+ 80% faster than transforms.functional.gaussian_blur for individual
178
+ items.
179
+
180
+ (in GPU); However, in CPU, the result is exactly opposite. But you
181
+ won't gonna run this on CPU, right?
182
+ """
183
+ if len(list(s for s in std.shape if s != 1)) >= 2:
184
+ raise NotImplementedError(
185
+ 'Anisotropic Gaussian filter is not currently available.')
186
+
187
+ # k.shape == (B, W, W).
188
+ k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
189
+ if k.shape[0] == 1:
190
+ return filter_by_kernel(x, k[0], False)
191
+ else:
192
+ return filter_by_kernel(x, k, True)
193
+ else:
194
+ # Gaussian kernel size is odd in order to preserve the center.
195
+ window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
196
+ return TF.gaussian_blur(x, window_size, std)
197
+
198
+
199
+ def blend(
200
+ fg: Union[torch.Tensor, Image.Image],
201
+ bg: Union[torch.Tensor, Image.Image],
202
+ mask: Union[torch.Tensor, Image.Image],
203
+ std: float = 0.0,
204
+ ) -> Image.Image:
205
+ if not isinstance(fg, torch.Tensor):
206
+ fg = T.ToTensor()(fg)
207
+ if not isinstance(bg, torch.Tensor):
208
+ bg = T.ToTensor()(bg)
209
+ if not isinstance(mask, torch.Tensor):
210
+ mask = (T.ToTensor()(mask) < 0.5).float()[:1]
211
+ if std > 0:
212
+ mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
213
+ return T.ToPILImage()(fg * mask + bg * (1 - mask))
214
+
215
+
216
+ def get_panorama_views(
217
+ panorama_height: int,
218
+ panorama_width: int,
219
+ window_size: int = 64,
220
+ ) -> tuple[List[Tuple[int]], torch.Tensor]:
221
+ stride = window_size // 2
222
+ is_horizontal = panorama_width > panorama_height
223
+ num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
224
+ num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
225
+ total_num_blocks = num_blocks_height * num_blocks_width
226
+
227
+ half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
228
+ half_rev = half_fwd.flip(0)
229
+ if window_size % 2 == 1:
230
+ half_rev = half_rev[1:]
231
+ c = torch.cat((half_fwd, half_rev))
232
+ one = torch.ones_like(c)
233
+ f = c.clone()
234
+ f[:window_size // 2] = 1
235
+ b = c.clone()
236
+ b[-(window_size // 2):] = 1
237
+
238
+ h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
239
+ w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
240
+
241
+ views = []
242
+ masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
243
+ for i in range(total_num_blocks):
244
+ hi, wi = i // num_blocks_width, i % num_blocks_width
245
+ h_start = hi * stride
246
+ h_end = min(h_start + window_size, panorama_height)
247
+ w_start = wi * stride
248
+ w_end = min(w_start + window_size, panorama_width)
249
+ views.append((h_start, h_end, w_start, w_end))
250
+
251
+ h_width = h_end - h_start
252
+ w_width = w_end - w_start
253
+ masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
254
+
255
+ # Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
256
+ return views, masks[None] # (1, n, h, w)
257
+
258
+
259
+ def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
260
+ h, w = mask.shape[-2:]
261
+ device = mask.device
262
+ mask = mask.reshape(-1, h, w)
263
+ # assert mask.shape[0] == im.shape[0]
264
+ h_occupied = mask.sum(dim=-2) > 0
265
+ w_occupied = mask.sum(dim=-1) > 0
266
+ l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
267
+ r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
268
+ t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
269
+ b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
270
+ tb = (t + b + 1) // 2
271
+ lr = (l + r + 1) // 2
272
+ shifts = (tb - (h // 2), lr - (w // 2))
273
+ shifts = torch.cat(shifts, dim=1) # (p, 2)
274
+ if reverse:
275
+ shifts = shifts * -1
276
+ return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
277
+
278
+
279
+ class Streamer:
280
+ def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
281
+ self.fn = fn
282
+ self.ema_alpha = ema_alpha
283
+
284
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
285
+ self.future = self.executor.submit(fn)
286
+ self.image = None
287
+
288
+ self.prev_exec_time = 0
289
+ self.ema_exec_time = 0
290
+
291
+ @property
292
+ def throughput(self) -> float:
293
+ return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
294
+
295
+ def timed_fn(self) -> Any:
296
+ start = time.time()
297
+ res = self.fn()
298
+ end = time.time()
299
+ self.prev_exec_time = end - start
300
+ self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
301
+ return res
302
+
303
+ def __call__(self) -> Any:
304
+ if self.future.done() or self.image is None:
305
+ # get the result (the new image) and start a new task
306
+ image = self.future.result()
307
+ self.future = self.executor.submit(self.timed_fn)
308
+ self.image = image
309
+ return image
310
+ else:
311
+ # if self.fn() is not ready yet, use the previous image
312
+ # NOTE: This assumes that we have access to a previously generated image here.
313
+ # If there's no previous image (i.e., this is the first invocation), you could fall
314
+ # back to some default image or handle it differently based on your requirements.
315
+ return self.image