QY-H00 commited on
Commit
0320907
1 Parent(s): c824a93
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: AID V2
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: PAID
3
+ emoji: 🏢
4
+ colorFrom: pink
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from pipeline_interpolated_sd import InterpolationStableDiffusionPipeline
11
+ from pipeline_interpolated_sdxl import InterpolationStableDiffusionXLPipeline
12
+ from prior import BetaPriorPipeline
13
+
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ title = r"""
18
+ <h1 align="center">PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</h1>
19
+ """
20
+
21
+ description = r"""
22
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</b></a>.<br>
23
+ How to use:<br>
24
+ 1. Input prompt 1, prompt 2 and negative prompt.
25
+ 2. For <b> Compositional Generation </b> Input the guidance prompt and choose the one you are satisfied!
26
+ 3. For <b> Image morphing </b> Input the image prompt 1 and image prompt 2, and choose IP-Adapter.
27
+ 4. For <b> Scale Control </b> Input the same text for prompt 1 and prompt 2, leave image prompt 1 blank and upload image prompt 2. Then choose IP-Adapter or IP-Composition-Adapter.
28
+ 5. <b> Note that the time required for the SD-series with an exploration size of 10 is around 120 seconds. XL-series with an exploration size 5 is around 5 minutes 30 seconds. </b>
29
+ 6. Click the <b>Generate</b> button to begin generating images.
30
+ 7. Enjoy! 😊"""
31
+
32
+ article = r"""
33
+ ---
34
+ ✒️ **Citation**
35
+ <br>
36
+ If you found this demo/our paper useful, please consider citing:
37
+ ```bibtex
38
+ @article{he2024aid,
39
+ title={AID: Attention Interpolation of Text-to-Image Diffusion},
40
+ author={He, Qiyuan and Wang, Jinghao and Liu, Ziwei and Yao, Angela},
41
+ journal={arXiv preprint arXiv:2403.17924},
42
+ year={2024}
43
+ }
44
+ ```
45
+ 📧 **Contact**
46
+ <br>
47
+ If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>[email protected]</b>.
48
+ """
49
+
50
+ MAX_SEED = np.iinfo(np.int32).max
51
+ CACHE_EXAMPLES = False
52
+ USE_TORCH_COMPILE = False
53
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
54
+ PREVIEW_IMAGES = False
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ pipeline = InterpolationStableDiffusionPipeline.from_pretrained(
58
+ "SG161222/Realistic_Vision_V4.0_noVAE",
59
+ torch_dtype=torch.float16
60
+ )
61
+ pipeline.to(device, dtype=torch.float16)
62
+
63
+
64
+ def change_model_fn(model_name: str) -> None:
65
+ global device
66
+ name_mapping = {
67
+ "AOM3": "hogiahien/aom3",
68
+ "SD1.5-512": "stable-diffusion-v1-5/stable-diffusion-v1-5",
69
+ "SD2.1-768": "stabilityai/stable-diffusion-2-1",
70
+ "RealVis-v4.0": "SG161222/Realistic_Vision_V4.0_noVAE",
71
+ "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0",
72
+ "Playground-XL-v2": "playgroundai/playground-v2.5-1024px-aesthetic",
73
+ "Juggernaut-XL-v9": "RunDiffusion/Juggernaut-XL-v9"
74
+ }
75
+ if device == torch.device("cpu"):
76
+ dtype = torch.float16
77
+ else:
78
+ dtype = torch.float16
79
+ if "XL" not in model_name:
80
+ globals()["pipeline"] = InterpolationStableDiffusionPipeline.from_pretrained(
81
+ name_mapping[model_name], torch_dtype=dtype
82
+ )
83
+ globals()["pipeline"].to(device, dtype=torch.float16)
84
+ else:
85
+ globals()["pipeline"] = InterpolationStableDiffusionXLPipeline.from_pretrained(
86
+ name_mapping[model_name], torch_dtype=dtype
87
+ )
88
+ globals()["pipeline"].to(device)
89
+
90
+
91
+ def change_adapter_fn(adapter_name: str) -> None:
92
+ global pipeline
93
+ if adapter_name == "IP-Adapter":
94
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
95
+ pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
96
+ else:
97
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
98
+ elif adapter_name == "IP-Composition-Adapter":
99
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
100
+ pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors")
101
+ else:
102
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors")
103
+ else:
104
+ pipeline.load_aid()
105
+
106
+
107
+ def save_image(img, index):
108
+ unique_name = f"{index}.png"
109
+ img = Image.fromarray(img)
110
+ img.save(unique_name)
111
+ return unique_name
112
+
113
+
114
+ def get_example() -> list[list[str | float | int ]]:
115
+ case = [
116
+ [
117
+ "A statue",
118
+ "A dragon",
119
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
120
+ "",
121
+ None,
122
+ None,
123
+ 50,
124
+ 10,
125
+ 5,
126
+ 5.0,
127
+ 0.5,
128
+ "RealVis-v4.0",
129
+ "None",
130
+ 0,
131
+ True,
132
+ ],
133
+ [
134
+ "A photo of a statue",
135
+ "Het meisje met de parel, by Vermeer",
136
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
137
+ "",
138
+ Image.open("asset/statue.jpg"),
139
+ Image.open("asset/vermeer.jpg"),
140
+ 50,
141
+ 10,
142
+ 5,
143
+ 5.0,
144
+ 0.5,
145
+ "RealVis-v4.0",
146
+ "IP-Adapter",
147
+ 0,
148
+ True,
149
+ ],
150
+ [
151
+ "A boy is smiling",
152
+ "A boy is smiling",
153
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
154
+ "",
155
+ None,
156
+ Image.open("asset/vermeer.jpg"),
157
+ 50,
158
+ 10,
159
+ 5,
160
+ 5.0,
161
+ 0.5,
162
+ "RealVis-v4.0",
163
+ "IP-Composition-Adapter",
164
+ 0,
165
+ True,
166
+ ],
167
+ [
168
+ "masterpiece, best quality, very aesthetic, absurdres, A dog",
169
+ "masterpiece, best quality, very aesthetic, absurdres, A car",
170
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
171
+ "masterpiece, best quality, very aesthetic, absurdres, the toy, named 'Dog-Car', is designed as a dog figure with car wheels instead of feet",
172
+ None,
173
+ None,
174
+ 50,
175
+ 5,
176
+ 5,
177
+ 5.0,
178
+ 0.5,
179
+ "RealVis-v4.0",
180
+ "None",
181
+ 1002,
182
+ True
183
+ ],
184
+ [
185
+ "masterpiece, best quality, very aesthetic, absurdres, A dog",
186
+ "masterpiece, best quality, very aesthetic, absurdres, A car",
187
+ "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
188
+ "masterpiece, best quality, very aesthetic, absurdres, a dog is driving a car",
189
+ None,
190
+ None,
191
+ 28,
192
+ 5,
193
+ 5,
194
+ 5.0,
195
+ 0.5,
196
+ "Playground-XL-v2",
197
+ "None",
198
+ 1002,
199
+ True
200
+ ]
201
+ # [
202
+ # "masterpiece, best quality, very aesthetic, absurdres, A cat is smiling, face portrait",
203
+ # "masterpiece, best quality, very aesthetic, absurdres, A beautiful lady, face portrait",
204
+ # "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
205
+ # None,
206
+ # None,
207
+ # None,
208
+ # 28,
209
+ # 7,
210
+ # 5,
211
+ # 5.0,
212
+ # 1.0,
213
+ # "Playground-XL-v2"
214
+ # ],
215
+ # [
216
+ # "masterpiece, best quality, very aesthetic, absurdres, A dog",
217
+ # "masterpiece, best quality, very aesthetic, absurdres, A car",
218
+ # "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
219
+ # "masterpiece, best quality, very aesthetic, absurdres, the toy, named 'Dog-Car', is designed as a dog figure with car wheels instead of feet",
220
+ # None,
221
+ # None,
222
+ # 28,
223
+ # 5,
224
+ # 5,
225
+ # 5.0,
226
+ # 0.5,
227
+ # "Playground-XL-v2"
228
+ # ],
229
+
230
+ ]
231
+ return case
232
+
233
+
234
+ def change_generate_button_fn(enable: int) -> gr.Button:
235
+ if enable == 0:
236
+ return gr.Button(interactive=False, value="Switching Model...")
237
+ else:
238
+ return gr.Button(interactive=True, value="Generate")
239
+
240
+
241
+ def dynamic_gallery_fn(interpolation_size: int):
242
+ return gr.Gallery(
243
+ label="Result", show_label=False, rows=1, columns=interpolation_size
244
+ )
245
+
246
+
247
+ @torch.no_grad()
248
+ def generate(
249
+ prompt1,
250
+ prompt2,
251
+ negative_prompt,
252
+ guide_prompt=None,
253
+ image_prompt1=None,
254
+ image_prompt2=None,
255
+ num_inference_steps=28,
256
+ exploration_size=16,
257
+ interpolation_size=7,
258
+ guidance_scale=5.0,
259
+ warmup_ratio=0.5,
260
+ seed=0,
261
+ same_latent=True,
262
+ ) -> np.ndarray:
263
+ global pipeline
264
+ global adapter_choice
265
+ beta_pipe = BetaPriorPipeline(pipeline)
266
+ if guide_prompt == "":
267
+ guide_prompt = None
268
+ generator = (
269
+ torch.cuda.manual_seed(seed)
270
+ if torch.cuda.is_available()
271
+ else torch.manual_seed(seed)
272
+ )
273
+ size = pipeline.unet.config.sample_size
274
+ latent1 = torch.randn((1, 4, size, size,), device="cuda", dtype=pipeline.unet.dtype, generator=generator)
275
+ if same_latent:
276
+ latent2 = latent1.clone()
277
+ else:
278
+ latent2 = torch.randn((1, 4, size, size,), device="cuda", dtype=pipeline.unet.dtype, generator=generator)
279
+
280
+ if image_prompt1 is None and image_prompt2 is None:
281
+ pipeline.load_aid()
282
+ elif (image_prompt1 is None and image_prompt2 is not None):
283
+ if adapter_choice.value == "IP-Adapter":
284
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
285
+ pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
286
+ else:
287
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
288
+ elif adapter_choice.value == "IP-Composition-Adapter":
289
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
290
+ pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors")
291
+ else:
292
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors")
293
+ elif (image_prompt1 is None and image_prompt2 is not None):
294
+ if adapter_choice.value == "IP-Adapter":
295
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
296
+ pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", early="scale_control")
297
+ else:
298
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors", early="scale_control")
299
+ elif adapter_choice.value == "IP-Composition-Adapter":
300
+ if isinstance(pipeline, InterpolationStableDiffusionPipeline):
301
+ pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors", early="scale_control")
302
+ else:
303
+ pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors", early="scale_control")
304
+ else:
305
+ raise ValueError("To use scale control, please provide only the right image; To use image morphing, please provide images from both side.")
306
+ images = beta_pipe.generate_interpolation(
307
+ gr.Progress(),
308
+ prompt1,
309
+ prompt2,
310
+ negative_prompt,
311
+ latent1,
312
+ latent2,
313
+ num_inference_steps,
314
+ image_start=image_prompt1,
315
+ image_end=image_prompt2,
316
+ exploration_size=exploration_size,
317
+ interpolation_size=interpolation_size,
318
+ output_type="np",
319
+ guide_prompt=guide_prompt,
320
+ guidance_scale=guidance_scale,
321
+ warmup_ratio=warmup_ratio
322
+ )
323
+ return images
324
+
325
+
326
+ interpolation_size = None
327
+
328
+ with gr.Blocks(css="style.css") as demo:
329
+ gr.Markdown(title)
330
+ gr.Markdown(description)
331
+ with gr.Row(elem_classes="grid-container"):
332
+ with gr.Group():
333
+ with gr.Column(elem_classes="grid-item"): # 左侧列
334
+ prompt1 = gr.Text(
335
+ label="Prompt 1",
336
+ max_lines=3,
337
+ placeholder="Enter the First Prompt",
338
+ interactive=True,
339
+ value="A photo of a cat",
340
+ )
341
+ prompt2 = gr.Text(
342
+ label="Prompt 2",
343
+ max_lines=3,
344
+ placeholder="Enter the Second Prompt",
345
+ interactive=True,
346
+ value="A photo of a beautiful lady",
347
+ )
348
+ negative_prompt = gr.Text(
349
+ label="Negative prompt",
350
+ max_lines=3,
351
+ placeholder="Enter a Negative Prompt",
352
+ interactive=True,
353
+ value="nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
354
+ )
355
+ guidance_prompt = gr.Text(
356
+ label="Guidance prompt (Optional)",
357
+ max_lines=3,
358
+ placeholder="Enter a Guidance Prompt",
359
+ interactive=True,
360
+ value="",
361
+ )
362
+
363
+ with gr.Group():
364
+ with gr.Column(elem_classes="grid-item"): # 右侧列
365
+ with gr.Row(elem_classes="flex-grow"):
366
+ image_prompt1 = gr.Image(label="Image Prompt 1 (Optional)", interactive=True, height=236, width=235)
367
+ image_prompt2 = gr.Image(label="Image Prompt 2 (Optional)", interactive=True, height=236, width=235)
368
+ with gr.Row(elem_classes="flex-grow"):
369
+ model_choice = gr.Dropdown(
370
+ ["RealVis-v4.0", "SD1.4-512", "SD1.5-512", "SD2.1-768", "AOM3", "SDXL-1024", "Playground-XL-v2", "Juggernaut-XL-v9"],
371
+ label="Model",
372
+ value="RealVis-v4.0",
373
+ interactive=True,
374
+ info="All series are running on float16; SD2.1 does not support IP-Adapter; XL-Series takes longer time",
375
+ )
376
+ adapter_choice = gr.Dropdown(
377
+ ["None", "IP-Adapter", "IP-Composition-Adapter"],
378
+ label="IP-Adapter",
379
+ value="None",
380
+ interactive=True,
381
+ info="Only set to IP-Adapter or IP-Composition-Adapter when using image prompt",
382
+ )
383
+
384
+ with gr.Group():
385
+ result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3)
386
+ generate_button = gr.Button(value="Generate", variant="primary")
387
+
388
+ with gr.Accordion("Advanced options", open=True):
389
+ with gr.Group():
390
+ with gr.Row():
391
+ with gr.Column():
392
+ interpolation_size = gr.Slider(
393
+ label="Interpolation Size",
394
+ minimum=3,
395
+ maximum=7,
396
+ step=1,
397
+ value=5,
398
+ info="Interpolation size includes the start and end images",
399
+ )
400
+ exploration_size = gr.Slider(
401
+ label="Exploration Size",
402
+ minimum=7,
403
+ maximum=16,
404
+ step=1,
405
+ value=10,
406
+ info="Exploration size has to be larger than interpolation size",
407
+ )
408
+ with gr.Row():
409
+ with gr.Column():
410
+ warmup_ratio = gr.Slider(
411
+ label="Warmup Ratio",
412
+ minimum=0.02,
413
+ maximum=1,
414
+ step=0.01,
415
+ value=0.5,
416
+ interactive=True,
417
+ )
418
+ guidance_scale = gr.Slider(
419
+ label="Guidance Scale",
420
+ minimum=0,
421
+ maximum=20,
422
+ step=0.1,
423
+ value=5.0,
424
+ interactive=True,
425
+ )
426
+ num_inference_steps = gr.Slider(
427
+ label="Inference Steps",
428
+ minimum=25,
429
+ maximum=50,
430
+ step=1,
431
+ value=50,
432
+ interactive=True,
433
+ )
434
+ with gr.Column():
435
+ seed = gr.Slider(
436
+ label="Seed",
437
+ minimum=0,
438
+ maximum=MAX_SEED,
439
+ step=1,
440
+ value=0,
441
+ )
442
+ same_latent = gr.Checkbox(
443
+ label="Same latent",
444
+ value=False,
445
+ info="Use the same latent for start and end images",
446
+ show_label=True,
447
+ )
448
+
449
+ gr.Examples(
450
+ examples=get_example(),
451
+ inputs=[
452
+ prompt1,
453
+ prompt2,
454
+ negative_prompt,
455
+ guidance_prompt,
456
+ image_prompt1,
457
+ image_prompt2,
458
+ num_inference_steps,
459
+ exploration_size,
460
+ interpolation_size,
461
+ guidance_scale,
462
+ warmup_ratio,
463
+ model_choice,
464
+ adapter_choice,
465
+ seed,
466
+ same_latent,
467
+ ],
468
+ cache_examples=CACHE_EXAMPLES,
469
+ )
470
+
471
+ model_choice.change(
472
+ fn=change_generate_button_fn,
473
+ inputs=gr.Number(0, visible=False),
474
+ outputs=generate_button,
475
+ ).then(fn=change_model_fn, inputs=model_choice).then(
476
+ fn=change_generate_button_fn,
477
+ inputs=gr.Number(1, visible=False),
478
+ outputs=generate_button,
479
+ )
480
+
481
+ adapter_choice.change(
482
+ fn=change_generate_button_fn,
483
+ inputs=gr.Number(0, visible=False),
484
+ outputs=generate_button,
485
+ ).then(fn=change_adapter_fn, inputs=[adapter_choice]).then(
486
+ fn=change_generate_button_fn,
487
+ inputs=gr.Number(1, visible=False),
488
+ outputs=generate_button,
489
+ )
490
+
491
+ inputs = [
492
+ prompt1,
493
+ prompt2,
494
+ negative_prompt,
495
+ guidance_prompt,
496
+ image_prompt1,
497
+ image_prompt2,
498
+ num_inference_steps,
499
+ exploration_size,
500
+ interpolation_size,
501
+ guidance_scale,
502
+ warmup_ratio,
503
+ seed,
504
+ same_latent,
505
+ ]
506
+ generate_button.click(
507
+ fn=dynamic_gallery_fn,
508
+ inputs=interpolation_size,
509
+ outputs=result,
510
+ ).then(
511
+ fn=generate,
512
+ inputs=inputs,
513
+ outputs=result,
514
+ )
515
+ gr.Markdown(article)
516
+
517
+ demo.launch()
asset/statue.jpg ADDED
asset/vermeer.jpg ADDED
interpolation.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import FloatTensor, LongTensor, Size, Tensor
5
+ from torch import nn as nn
6
+
7
+ from prior import generate_beta_tensor
8
+
9
+
10
+ class InterpolatedAttnProcessor(nn.Module):
11
+ def __init__(
12
+ self,
13
+ t: Optional[float] = None,
14
+ size: int = 7,
15
+ is_fused: bool = False,
16
+ alpha: float = 1,
17
+ beta: float = 1,
18
+ ):
19
+ super().__init__()
20
+ if t is None:
21
+ ts = generate_beta_tensor(size, alpha=alpha, beta=beta)
22
+ ts[0], ts[-1] = 0, 1
23
+ else:
24
+ assert t > 0 and t < 1, "t must be between 0 and 1"
25
+ ts = [0, t, 1]
26
+ ts = torch.tensor(ts)
27
+ size = 3
28
+
29
+ self.size = size
30
+ self.coef = ts
31
+ self.is_fused = is_fused
32
+ self.activated = True
33
+
34
+ def deactivate(self):
35
+ self.activated = False
36
+
37
+ def activate(self, t):
38
+ self.activated = True
39
+ assert t > 0 and t < 1, "t must be between 0 and 1"
40
+ ts = [0, t, 1]
41
+ ts = torch.tensor(ts)
42
+ self.coef = ts
43
+
44
+ def load_end_point(self, key_begin, value_begin, key_end, value_end):
45
+ self.key_begin = key_begin
46
+ self.value_begin = value_begin
47
+ self.key_end = key_end
48
+ self.value_end = value_end
49
+
50
+
51
+ class ScaleControlIPAttnProcessor(InterpolatedAttnProcessor):
52
+ r"""
53
+ Personalized processor for control the impact of image prompt via attention interpolation.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ t: Optional[float] = None,
59
+ size: int = 7,
60
+ is_fused: bool = False,
61
+ alpha: float = 1,
62
+ beta: float = 1,
63
+ ip_attn: Optional[nn.Module] = None,
64
+ ):
65
+ """
66
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
67
+ """
68
+ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta)
69
+
70
+ self.num_tokens = (
71
+ ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,)
72
+ )
73
+ self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None
74
+ self.ip_attn = ip_attn
75
+
76
+ def __call__(
77
+ self,
78
+ attn,
79
+ hidden_states: torch.FloatTensor,
80
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
81
+ attention_mask: Optional[torch.FloatTensor] = None,
82
+ temb: Optional[torch.FloatTensor] = None,
83
+ ) -> torch.Tensor:
84
+ residual = hidden_states
85
+
86
+ if encoder_hidden_states is None:
87
+ encoder_hidden_states = hidden_states
88
+ ip_hidden_states = None
89
+ else:
90
+ if isinstance(encoder_hidden_states, tuple):
91
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
92
+ else:
93
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
94
+ encoder_hidden_states, ip_hidden_states = (
95
+ encoder_hidden_states[:, :end_pos, :],
96
+ [encoder_hidden_states[:, end_pos:, :]],
97
+ )
98
+
99
+ if attn.spatial_norm is not None:
100
+ hidden_states = attn.spatial_norm(hidden_states, temb)
101
+
102
+ input_ndim = hidden_states.ndim
103
+
104
+ if input_ndim == 4:
105
+ batch_size, channel, height, width = hidden_states.shape
106
+ hidden_states = hidden_states.view(
107
+ batch_size, channel, height * width
108
+ ).transpose(1, 2)
109
+
110
+ batch_size, sequence_length, _ = (
111
+ hidden_states.shape
112
+ if encoder_hidden_states is None
113
+ else encoder_hidden_states.shape
114
+ )
115
+ attention_mask = attn.prepare_attention_mask(
116
+ attention_mask, sequence_length, batch_size
117
+ )
118
+
119
+ if attn.group_norm is not None:
120
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
121
+ 1, 2
122
+ )
123
+
124
+ query = attn.to_q(hidden_states)
125
+ query = attn.head_to_batch_dim(query)
126
+
127
+ key = attn.to_k(encoder_hidden_states)
128
+ value = attn.to_v(encoder_hidden_states)
129
+
130
+ if not self.activated:
131
+ key = attn.head_to_batch_dim(key)
132
+ value = attn.head_to_batch_dim(value)
133
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
134
+ hidden_states = torch.bmm(attention_probs, value)
135
+ hidden_states = attn.batch_to_head_dim(hidden_states)
136
+ if ip_hidden_states is not None:
137
+ key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][6:9])
138
+ value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][6:9])
139
+ key = attn.head_to_batch_dim(key)
140
+ value = attn.head_to_batch_dim(value)
141
+ ip_attention_probs = attn.get_attention_scores(
142
+ query, key, attention_mask
143
+ )
144
+ ip_hidden_states = torch.bmm(ip_attention_probs, value)
145
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
146
+ hidden_states = (
147
+ hidden_states
148
+ + self.coef.reshape(-1, 1, 1).to(key.device, key.dtype)
149
+ * ip_hidden_states
150
+ )
151
+ else:
152
+ key_begin = key[0:1].expand(3, *key.shape[1:])
153
+ key_end = key[-1:].expand(3, *key.shape[1:])
154
+ value_begin = value[0:1].expand(3, *value.shape[1:])
155
+ value_end = value[-1:].expand(3, *value.shape[1:])
156
+ key_begin = attn.head_to_batch_dim(key_begin)
157
+ value_begin = attn.head_to_batch_dim(value_begin)
158
+ key_end = attn.head_to_batch_dim(key_end)
159
+ value_end = attn.head_to_batch_dim(value_end)
160
+
161
+ if self.is_fused:
162
+ key = attn.head_to_batch_dim(key)
163
+ value = attn.head_to_batch_dim(value)
164
+ key_end = torch.cat([key, key_end], dim=-2)
165
+ value_end = torch.cat([value, value_end], dim=-2)
166
+ key_begin = torch.cat([key, key_begin], dim=-2)
167
+ value_begin = torch.cat([value, value_begin], dim=-2)
168
+
169
+ attention_probs_end = attn.get_attention_scores(
170
+ query, key_end, attention_mask
171
+ )
172
+ hidden_states_end = torch.bmm(attention_probs_end, value_end)
173
+ hidden_states_end = attn.batch_to_head_dim(hidden_states_end)
174
+ attention_probs_begin = attn.get_attention_scores(
175
+ query, key_begin, attention_mask
176
+ )
177
+ hidden_states_begin = torch.bmm(attention_probs_begin, value_begin)
178
+ hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin)
179
+
180
+ # Apply outer interpolation on attention
181
+ coef = self.coef.reshape(-1, 1, 1)
182
+ coef = coef.to(key.device, key.dtype)
183
+ hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end
184
+
185
+ # for ip-adapter
186
+ if ip_hidden_states is not None:
187
+ key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][6:9])
188
+ value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][6:9])
189
+ key = attn.head_to_batch_dim(key)
190
+ value = attn.head_to_batch_dim(value)
191
+ ip_attention_probs = attn.get_attention_scores(
192
+ query, key, attention_mask
193
+ )
194
+ ip_hidden_states = torch.bmm(ip_attention_probs, value)
195
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
196
+ hidden_states = hidden_states + coef * ip_hidden_states
197
+
198
+ hidden_states = attn.to_out[0](hidden_states)
199
+ hidden_states = attn.to_out[1](hidden_states)
200
+
201
+ if input_ndim == 4:
202
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
203
+ batch_size, channel, height, width
204
+ )
205
+
206
+ if attn.residual_connection:
207
+ hidden_states = hidden_states + residual
208
+
209
+ hidden_states = hidden_states / attn.rescale_output_factor
210
+
211
+ return hidden_states
212
+
213
+
214
+ class OuterInterpolatedIPAttnProcessor(InterpolatedAttnProcessor):
215
+ r"""
216
+ Personalized processor for performing outer attention interpolation.
217
+ Combined with IP-Adapter attention processor.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ t: Optional[float] = None,
223
+ size: int = 7,
224
+ is_fused: bool = False,
225
+ alpha: float = 1,
226
+ beta: float = 1,
227
+ ip_attn: Optional[nn.Module] = None,
228
+ ):
229
+ """
230
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
231
+ """
232
+ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta)
233
+
234
+ self.num_tokens = (
235
+ ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,)
236
+ )
237
+ self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None
238
+ self.ip_attn = ip_attn
239
+
240
+ def __call__(
241
+ self,
242
+ attn,
243
+ hidden_states: torch.FloatTensor,
244
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
245
+ attention_mask: Optional[torch.FloatTensor] = None,
246
+ temb: Optional[torch.FloatTensor] = None,
247
+ ) -> torch.Tensor:
248
+ if not self.activated:
249
+ return self.ip_attn(
250
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb
251
+ )
252
+
253
+ residual = hidden_states
254
+
255
+ if encoder_hidden_states is None:
256
+ encoder_hidden_states = hidden_states
257
+ ip_hidden_states = None
258
+ else:
259
+ if isinstance(encoder_hidden_states, tuple):
260
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
261
+ else:
262
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
263
+ encoder_hidden_states, ip_hidden_states = (
264
+ encoder_hidden_states[:, :end_pos, :],
265
+ [encoder_hidden_states[:, end_pos:, :]],
266
+ )
267
+
268
+ if attn.spatial_norm is not None:
269
+ hidden_states = attn.spatial_norm(hidden_states, temb)
270
+
271
+ input_ndim = hidden_states.ndim
272
+
273
+ if input_ndim == 4:
274
+ batch_size, channel, height, width = hidden_states.shape
275
+ hidden_states = hidden_states.view(
276
+ batch_size, channel, height * width
277
+ ).transpose(1, 2)
278
+
279
+ batch_size, sequence_length, _ = (
280
+ hidden_states.shape
281
+ if encoder_hidden_states is None
282
+ else encoder_hidden_states.shape
283
+ )
284
+ attention_mask = attn.prepare_attention_mask(
285
+ attention_mask, sequence_length, batch_size
286
+ )
287
+
288
+ if attn.group_norm is not None:
289
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
290
+ 1, 2
291
+ )
292
+
293
+ query = attn.to_q(hidden_states)
294
+ query = attn.head_to_batch_dim(query)
295
+
296
+ key = attn.to_k(encoder_hidden_states)
297
+ value = attn.to_v(encoder_hidden_states)
298
+
299
+ # Specify the first and last key and value
300
+ key_begin = key[0:1].expand(3, *key.shape[1:])
301
+ key_end = key[-1:].expand(3, *key.shape[1:])
302
+ value_begin = value[0:1].expand(3, *value.shape[1:])
303
+ value_end = value[-1:].expand(3, *value.shape[1:])
304
+ key_begin = attn.head_to_batch_dim(key_begin)
305
+ value_begin = attn.head_to_batch_dim(value_begin)
306
+ key_end = attn.head_to_batch_dim(key_end)
307
+ value_end = attn.head_to_batch_dim(value_end)
308
+
309
+ # Fused with self-attention
310
+ if self.is_fused:
311
+ key = attn.head_to_batch_dim(key)
312
+ value = attn.head_to_batch_dim(value)
313
+ key_end = torch.cat([key, key_end], dim=-2)
314
+ value_end = torch.cat([value, value_end], dim=-2)
315
+ key_begin = torch.cat([key, key_begin], dim=-2)
316
+ value_begin = torch.cat([value, value_begin], dim=-2)
317
+
318
+ attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask)
319
+ hidden_states_end = torch.bmm(attention_probs_end, value_end)
320
+ hidden_states_end = attn.batch_to_head_dim(hidden_states_end)
321
+
322
+ attention_probs_begin = attn.get_attention_scores(
323
+ query, key_begin, attention_mask
324
+ )
325
+ hidden_states_begin = torch.bmm(attention_probs_begin, value_begin)
326
+ hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin)
327
+
328
+ # for ip-adapter
329
+ if ip_hidden_states is not None:
330
+ key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][::3])
331
+ value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][::3])
332
+
333
+ # Specify the first and last key and value
334
+ key_begin = key[0:1].expand(3, *key.shape[1:])
335
+ key_end = key[-1:].expand(3, *key.shape[1:])
336
+ value_begin = value[0:1].expand(3, *value.shape[1:])
337
+ value_end = value[-1:].expand(3, *value.shape[1:])
338
+ key_begin = attn.head_to_batch_dim(key_begin)
339
+ value_begin = attn.head_to_batch_dim(value_begin)
340
+ key_end = attn.head_to_batch_dim(key_end)
341
+ value_end = attn.head_to_batch_dim(value_end)
342
+
343
+ # Fused with self-attention
344
+ if self.is_fused:
345
+ key = attn.head_to_batch_dim(key)
346
+ value = attn.head_to_batch_dim(value)
347
+ key_end = torch.cat([key, key_end], dim=-2)
348
+ value_end = torch.cat([value, value_end], dim=-2)
349
+ key_begin = torch.cat([key, key_begin], dim=-2)
350
+ value_begin = torch.cat([value, value_begin], dim=-2)
351
+
352
+ ip_attention_probs_end = attn.get_attention_scores(
353
+ query, key_end, attention_mask
354
+ )
355
+ ip_hidden_states_end = torch.bmm(ip_attention_probs_end, value_end)
356
+ ip_hidden_states_end = attn.batch_to_head_dim(ip_hidden_states_end)
357
+
358
+ ip_attention_probs_begin = attn.get_attention_scores(
359
+ query, key_begin, attention_mask
360
+ )
361
+ ip_hidden_states_begin = torch.bmm(ip_attention_probs_begin, value_begin)
362
+ ip_hidden_states_begin = attn.batch_to_head_dim(ip_hidden_states_begin)
363
+
364
+ hidden_states_begin = (
365
+ hidden_states_begin + self.scale[0] * ip_hidden_states_begin
366
+ )
367
+ hidden_states_end = hidden_states_end + self.scale[0] * ip_hidden_states_end
368
+
369
+ # Apply outer interpolation on attention
370
+ coef = self.coef.reshape(-1, 1, 1)
371
+ coef = coef.to(key.device, key.dtype)
372
+ hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end
373
+
374
+ hidden_states = attn.to_out[0](hidden_states)
375
+ hidden_states = attn.to_out[1](hidden_states)
376
+
377
+ if input_ndim == 4:
378
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
379
+ batch_size, channel, height, width
380
+ )
381
+
382
+ if attn.residual_connection:
383
+ hidden_states = hidden_states + residual
384
+
385
+ hidden_states = hidden_states / attn.rescale_output_factor
386
+
387
+ return hidden_states
388
+
389
+
390
+ class InnerInterpolatedIPAttnProcessor(InterpolatedAttnProcessor):
391
+ r"""
392
+ Personalized processor for performing inner attention interpolation.
393
+
394
+ With IP-Adapter.
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ t: Optional[float] = None,
400
+ size: int = 7,
401
+ is_fused: bool = False,
402
+ alpha: float = 1,
403
+ beta: float = 1,
404
+ ip_attn: Optional[nn.Module] = None,
405
+ ):
406
+ """
407
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
408
+ """
409
+ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta)
410
+
411
+ self.num_tokens = (
412
+ ip_attn.num_tokens if hasattr(ip_attn, "num_tokens") else (16,)
413
+ )
414
+ self.scale = ip_attn.scale if hasattr(ip_attn, "scale") else None
415
+ self.ip_attn = ip_attn
416
+
417
+ def __call__(
418
+ self,
419
+ attn,
420
+ hidden_states: torch.FloatTensor,
421
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
422
+ attention_mask: Optional[torch.FloatTensor] = None,
423
+ temb: Optional[torch.FloatTensor] = None,
424
+ ) -> torch.Tensor:
425
+ if not self.activated:
426
+ return self.ip_attn(
427
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb
428
+ )
429
+
430
+ residual = hidden_states
431
+
432
+ if encoder_hidden_states is None:
433
+ encoder_hidden_states = hidden_states
434
+ ip_hidden_states = None
435
+ else:
436
+ if isinstance(encoder_hidden_states, tuple):
437
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
438
+ else:
439
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
440
+ encoder_hidden_states, ip_hidden_states = (
441
+ encoder_hidden_states[:, :end_pos, :],
442
+ [encoder_hidden_states[:, end_pos:, :]],
443
+ )
444
+
445
+ if attn.spatial_norm is not None:
446
+ hidden_states = attn.spatial_norm(hidden_states, temb)
447
+
448
+ input_ndim = hidden_states.ndim
449
+
450
+ if input_ndim == 4:
451
+ batch_size, channel, height, width = hidden_states.shape
452
+ hidden_states = hidden_states.view(
453
+ batch_size, channel, height * width
454
+ ).transpose(1, 2)
455
+
456
+ batch_size, sequence_length, _ = (
457
+ hidden_states.shape
458
+ if encoder_hidden_states is None
459
+ else encoder_hidden_states.shape
460
+ )
461
+ attention_mask = attn.prepare_attention_mask(
462
+ attention_mask, sequence_length, batch_size
463
+ )
464
+
465
+ if attn.group_norm is not None:
466
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
467
+ 1, 2
468
+ )
469
+
470
+ query = attn.to_q(hidden_states)
471
+ query = attn.head_to_batch_dim(query)
472
+
473
+ key = attn.to_k(encoder_hidden_states)
474
+ value = attn.to_v(encoder_hidden_states)
475
+
476
+ # Specify the first and last key and value
477
+ key_begin = key[0:1].expand(3, *key.shape[1:])
478
+ key_end = key[-1:].expand(3, *key.shape[1:])
479
+ value_begin = value[0:1].expand(3, *value.shape[1:])
480
+ value_end = value[-1:].expand(3, *value.shape[1:])
481
+
482
+ coef = self.coef.reshape(-1, 1, 1)
483
+ coef = coef.to(key.device, key.dtype)
484
+ key_cross = (1 - coef) * key_begin + coef * key_end
485
+ value_cross = (1 - coef) * value_begin + coef * value_end
486
+ key_cross = attn.head_to_batch_dim(key_cross)
487
+ value_cross = attn.head_to_batch_dim(value_cross)
488
+
489
+ # Fused with self-attention
490
+ if self.is_fused:
491
+ key = attn.head_to_batch_dim(key)
492
+ value = attn.head_to_batch_dim(value)
493
+ key_cross = torch.cat([key, key_cross], dim=-2)
494
+ value_cross = torch.cat([value, value_cross], dim=-2)
495
+
496
+ attention_probs = attn.get_attention_scores(query, key_cross, attention_mask)
497
+ hidden_states = torch.bmm(attention_probs, value_cross)
498
+ hidden_states = attn.batch_to_head_dim(hidden_states)
499
+
500
+ # for ip-adapter
501
+ if ip_hidden_states is not None:
502
+ key = self.ip_attn.to_k_ip[0](ip_hidden_states[0][::3])
503
+ value = self.ip_attn.to_v_ip[0](ip_hidden_states[0][::3])
504
+ key = key.squeeze()
505
+ value = value.squeeze()
506
+
507
+ # Specify the first and last key and value
508
+ key_begin = key[0:1].expand(3, *key.shape[1:])
509
+ key_end = key[-1:].expand(3, *key.shape[1:])
510
+ value_begin = value[0:1].expand(3, *value.shape[1:])
511
+ value_end = value[-1:].expand(3, *value.shape[1:])
512
+ key_cross = (1 - coef) * key_begin + coef * key_end
513
+ value_cross = (1 - coef) * value_begin + coef * value_end
514
+
515
+ key_cross = attn.head_to_batch_dim(key_cross)
516
+ value_cross = attn.head_to_batch_dim(value_cross)
517
+
518
+ # Fused with self-attention
519
+ if self.is_fused:
520
+ key = attn.head_to_batch_dim(key)
521
+ value = attn.head_to_batch_dim(value)
522
+ key_cross = torch.cat([key, key_cross], dim=-2)
523
+ value_cross = torch.cat([value, value_cross], dim=-2)
524
+
525
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
526
+
527
+ ip_hidden_states = torch.bmm(attention_probs, value)
528
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
529
+
530
+ hidden_states = hidden_states + self.scale[0] * ip_hidden_states
531
+
532
+ hidden_states = attn.to_out[0](hidden_states)
533
+ hidden_states = attn.to_out[1](hidden_states)
534
+
535
+ if input_ndim == 4:
536
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
537
+ batch_size, channel, height, width
538
+ )
539
+
540
+ if attn.residual_connection:
541
+ hidden_states = hidden_states + residual
542
+
543
+ hidden_states = hidden_states / attn.rescale_output_factor
544
+
545
+ return hidden_states
546
+
547
+
548
+ class OuterInterpolatedAttnProcessor(InterpolatedAttnProcessor):
549
+ r"""
550
+ Personalized processor for performing outer attention interpolation.
551
+
552
+ The attention output of interpolated image is obtained by:
553
+ (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
554
+ If fused with self-attention:
555
+ (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
556
+ """
557
+
558
+ def __init__(
559
+ self,
560
+ t: Optional[float] = None,
561
+ size: int = 7,
562
+ is_fused: bool = False,
563
+ alpha: float = 1,
564
+ beta: float = 1,
565
+ original_attn: Optional[nn.Module] = None,
566
+ ):
567
+ """
568
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
569
+ """
570
+ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta)
571
+ self.original_attn = original_attn
572
+
573
+ def __call__(
574
+ self,
575
+ attn,
576
+ hidden_states: torch.FloatTensor,
577
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
578
+ attention_mask: Optional[torch.FloatTensor] = None,
579
+ temb: Optional[torch.FloatTensor] = None,
580
+ ) -> torch.Tensor:
581
+ if not self.activated:
582
+ return self.original_attn(
583
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb
584
+ )
585
+
586
+ residual = hidden_states
587
+
588
+ if attn.spatial_norm is not None:
589
+ hidden_states = attn.spatial_norm(hidden_states, temb)
590
+
591
+ input_ndim = hidden_states.ndim
592
+
593
+ if input_ndim == 4:
594
+ batch_size, channel, height, width = hidden_states.shape
595
+ hidden_states = hidden_states.view(
596
+ batch_size, channel, height * width
597
+ ).transpose(1, 2)
598
+
599
+ batch_size, sequence_length, _ = (
600
+ hidden_states.shape
601
+ if encoder_hidden_states is None
602
+ else encoder_hidden_states.shape
603
+ )
604
+ attention_mask = attn.prepare_attention_mask(
605
+ attention_mask, sequence_length, batch_size
606
+ )
607
+
608
+ if attn.group_norm is not None:
609
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
610
+ 1, 2
611
+ )
612
+
613
+ query = attn.to_q(hidden_states)
614
+ query = attn.head_to_batch_dim(query)
615
+
616
+ if encoder_hidden_states is None:
617
+ encoder_hidden_states = hidden_states
618
+ elif attn.norm_cross:
619
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
620
+ encoder_hidden_states
621
+ )
622
+
623
+ key = attn.to_k(encoder_hidden_states)
624
+ value = attn.to_v(encoder_hidden_states)
625
+
626
+ # Specify the first and last key and value
627
+ key_begin = key[0:1]
628
+ key_end = key[-1:]
629
+ value_begin = value[0:1]
630
+ value_end = value[-1:]
631
+
632
+ key_begin = torch.cat([key_begin] * (self.size))
633
+ key_end = torch.cat([key_end] * (self.size))
634
+ value_begin = torch.cat([value_begin] * (self.size))
635
+ value_end = torch.cat([value_end] * (self.size))
636
+
637
+ key_begin = attn.head_to_batch_dim(key_begin)
638
+ value_begin = attn.head_to_batch_dim(value_begin)
639
+ key_end = attn.head_to_batch_dim(key_end)
640
+ value_end = attn.head_to_batch_dim(value_end)
641
+
642
+ # Fused with self-attention
643
+ if self.is_fused:
644
+ key = attn.head_to_batch_dim(key)
645
+ value = attn.head_to_batch_dim(value)
646
+ key_end = torch.cat([key, key_end], dim=-2)
647
+ value_end = torch.cat([value, value_end], dim=-2)
648
+ key_begin = torch.cat([key, key_begin], dim=-2)
649
+ value_begin = torch.cat([value, value_begin], dim=-2)
650
+
651
+ attention_probs_end = attn.get_attention_scores(query, key_end, attention_mask)
652
+ hidden_states_end = torch.bmm(attention_probs_end, value_end)
653
+ hidden_states_end = attn.batch_to_head_dim(hidden_states_end)
654
+
655
+ attention_probs_begin = attn.get_attention_scores(
656
+ query, key_begin, attention_mask
657
+ )
658
+ hidden_states_begin = torch.bmm(attention_probs_begin, value_begin)
659
+ hidden_states_begin = attn.batch_to_head_dim(hidden_states_begin)
660
+
661
+ # Apply outer interpolation on attention
662
+ coef = self.coef.reshape(-1, 1, 1)
663
+ coef = coef.to(key.device, key.dtype)
664
+ hidden_states = (1 - coef) * hidden_states_begin + coef * hidden_states_end
665
+
666
+ hidden_states = attn.to_out[0](hidden_states)
667
+ hidden_states = attn.to_out[1](hidden_states)
668
+
669
+ if input_ndim == 4:
670
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
671
+ batch_size, channel, height, width
672
+ )
673
+
674
+ if attn.residual_connection:
675
+ hidden_states = hidden_states + residual
676
+
677
+ hidden_states = hidden_states / attn.rescale_output_factor
678
+
679
+ return hidden_states
680
+
681
+
682
+ class InnerInterpolatedAttnProcessor(InterpolatedAttnProcessor):
683
+ r"""
684
+ Personalized processor for performing inner attention interpolation.
685
+
686
+ The attention output of interpolated image is obtained by:
687
+ (1 - t) * Q_t * K_1 * V_1 + t * Q_t * K_m * V_m;
688
+ If fused with self-attention:
689
+ (1 - t) * Q_t * [K_1, K_t] * [V_1, V_t] + t * Q_t * [K_m, K_t] * [V_m, V_t];
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ t: Optional[float] = None,
695
+ size: int = 7,
696
+ is_fused: bool = False,
697
+ alpha: float = 1,
698
+ beta: float = 1,
699
+ original_attn: Optional[nn.Module] = None,
700
+ ):
701
+ """
702
+ t: float, interpolation point between 0 and 1, if specified, size is set to 3
703
+ """
704
+ super().__init__(t=t, size=size, is_fused=is_fused, alpha=alpha, beta=beta)
705
+ self.original_attn = original_attn
706
+
707
+ def __call__(
708
+ self,
709
+ attn,
710
+ hidden_states: torch.FloatTensor,
711
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
712
+ attention_mask: Optional[torch.FloatTensor] = None,
713
+ temb: Optional[torch.FloatTensor] = None,
714
+ ) -> torch.Tensor:
715
+ if not self.activated:
716
+ return self.original_attn(
717
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb
718
+ )
719
+
720
+ residual = hidden_states
721
+
722
+ if attn.spatial_norm is not None:
723
+ hidden_states = attn.spatial_norm(hidden_states, temb)
724
+
725
+ input_ndim = hidden_states.ndim
726
+
727
+ if input_ndim == 4:
728
+ batch_size, channel, height, width = hidden_states.shape
729
+ hidden_states = hidden_states.view(
730
+ batch_size, channel, height * width
731
+ ).transpose(1, 2)
732
+
733
+ batch_size, sequence_length, _ = (
734
+ hidden_states.shape
735
+ if encoder_hidden_states is None
736
+ else encoder_hidden_states.shape
737
+ )
738
+ attention_mask = attn.prepare_attention_mask(
739
+ attention_mask, sequence_length, batch_size
740
+ )
741
+
742
+ if attn.group_norm is not None:
743
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
744
+ 1, 2
745
+ )
746
+
747
+ query = attn.to_q(hidden_states)
748
+ query = attn.head_to_batch_dim(query)
749
+
750
+ if encoder_hidden_states is None:
751
+ encoder_hidden_states = hidden_states
752
+ elif attn.norm_cross:
753
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
754
+ encoder_hidden_states
755
+ )
756
+
757
+ key = attn.to_k(encoder_hidden_states)
758
+ value = attn.to_v(encoder_hidden_states)
759
+
760
+ # Specify the first and last key and value
761
+ key_start = key[0:1]
762
+ key_end = key[-1:]
763
+ value_start = value[0:1]
764
+ value_end = value[-1:]
765
+
766
+ key_start = torch.cat([key_start] * (self.size))
767
+ key_end = torch.cat([key_end] * (self.size))
768
+ value_start = torch.cat([value_start] * (self.size))
769
+ value_end = torch.cat([value_end] * (self.size))
770
+
771
+ # Apply inner interpolation on attention
772
+ coef = self.coef.reshape(-1, 1, 1)
773
+ coef = coef.to(key.device, key.dtype)
774
+ key_cross = (1 - coef) * key_start + coef * key_end
775
+ value_cross = (1 - coef) * value_start + coef * value_end
776
+
777
+ key_cross = attn.head_to_batch_dim(key_cross)
778
+ value_cross = attn.head_to_batch_dim(value_cross)
779
+
780
+ # Fused with self-attention
781
+ if self.is_fused:
782
+ key = attn.head_to_batch_dim(key)
783
+ value = attn.head_to_batch_dim(value)
784
+ key_cross = torch.cat([key, key_cross], dim=-2)
785
+ value_cross = torch.cat([value, value_cross], dim=-2)
786
+
787
+ attention_probs = attn.get_attention_scores(query, key_cross, attention_mask)
788
+
789
+ hidden_states = torch.bmm(attention_probs, value_cross)
790
+ hidden_states = attn.batch_to_head_dim(hidden_states)
791
+ hidden_states = attn.to_out[0](hidden_states)
792
+ hidden_states = attn.to_out[1](hidden_states)
793
+
794
+ if input_ndim == 4:
795
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
796
+ batch_size, channel, height, width
797
+ )
798
+
799
+ if attn.residual_connection:
800
+ hidden_states = hidden_states + residual
801
+
802
+ hidden_states = hidden_states / attn.rescale_output_factor
803
+
804
+ return hidden_states
805
+
806
+
807
+ def linear_interpolation(
808
+ l1: FloatTensor, l2: FloatTensor, ts: Optional[FloatTensor] = None, size: int = 5
809
+ ) -> FloatTensor:
810
+ """
811
+ Linear interpolation
812
+
813
+ Args:
814
+ l1: Starting vector: (1, *)
815
+ l2: Final vector: (1, *)
816
+ ts: FloatTensor, interpolation points between 0 and 1
817
+ size: int, number of interpolation points including l1 and l2
818
+
819
+ Returns:
820
+ Interpolated vectors: (size, *)
821
+ """
822
+ assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
823
+
824
+ res = []
825
+ if ts is not None:
826
+ for t in ts:
827
+ li = torch.lerp(l1, l2, t)
828
+ res.append(li)
829
+ else:
830
+ for i in range(size):
831
+ t = i / (size - 1)
832
+ li = torch.lerp(l1, l2, t)
833
+ res.append(li)
834
+ res = torch.cat(res, dim=0)
835
+ return res
836
+
837
+
838
+ def spherical_interpolation(l1: FloatTensor, l2: FloatTensor, size=5) -> FloatTensor:
839
+ """
840
+ Spherical interpolation
841
+
842
+ Args:
843
+ l1: Starting vector: (1, *)
844
+ l2: Final vector: (1, *)
845
+ size: int, number of interpolation points including l1 and l2
846
+
847
+ Returns:
848
+ Interpolated vectors: (size, *)
849
+ """
850
+ assert l1.shape == l2.shape, "shapes of l1 and l2 must match"
851
+
852
+ res = []
853
+ for i in range(size):
854
+ t = i / (size - 1)
855
+ li = slerp(l1, l2, t)
856
+ res.append(li)
857
+ res = torch.cat(res, dim=0)
858
+ return res
859
+
860
+
861
+ def slerp(v0: FloatTensor, v1: FloatTensor, t, threshold=0.9995):
862
+ """
863
+ Spherical linear interpolation
864
+ Args:
865
+ v0: Starting vector
866
+ v1: Final vector
867
+ t: Float value between 0.0 and 1.0
868
+ threshold: Threshold for considering the two vectors as
869
+ colinear. Not recommended to alter this.
870
+ Returns:
871
+ Interpolation vector between v0 and v1
872
+ """
873
+ assert v0.shape == v1.shape, "shapes of v0 and v1 must match"
874
+
875
+ # Normalize the vectors to get the directions and angles
876
+ v0_norm: FloatTensor = torch.norm(v0, dim=-1)
877
+ v1_norm: FloatTensor = torch.norm(v1, dim=-1)
878
+
879
+ v0_normed: FloatTensor = v0 / v0_norm.unsqueeze(-1)
880
+ v1_normed: FloatTensor = v1 / v1_norm.unsqueeze(-1)
881
+
882
+ # Dot product with the normalized vectors
883
+ dot: FloatTensor = (v0_normed * v1_normed).sum(-1)
884
+ dot_mag: FloatTensor = dot.abs()
885
+
886
+ # if dp is NaN, it's because the v0 or v1 row was filled with 0s
887
+ # If absolute value of dot product is almost 1, vectors are ~colinear, so use torch.lerp
888
+ gotta_lerp: LongTensor = dot_mag.isnan() | (dot_mag > threshold)
889
+ can_slerp: LongTensor = ~gotta_lerp
890
+
891
+ t_batch_dim_count: int = max(0, t.dim() - v0.dim()) if isinstance(t, Tensor) else 0
892
+ t_batch_dims: Size = (
893
+ t.shape[:t_batch_dim_count] if isinstance(t, Tensor) else Size([])
894
+ )
895
+ out: FloatTensor = torch.zeros_like(v0.expand(*t_batch_dims, *[-1] * v0.dim()))
896
+
897
+ # if no elements are lerpable, our vectors become 0-dimensional, preventing broadcasting
898
+ if gotta_lerp.any():
899
+ lerped: FloatTensor = torch.lerp(v0, v1, t)
900
+
901
+ out: FloatTensor = lerped.where(gotta_lerp.unsqueeze(-1), out)
902
+
903
+ # if no elements are slerpable, our vectors become 0-dimensional, preventing broadcasting
904
+ if can_slerp.any():
905
+ # Calculate initial angle between v0 and v1
906
+ theta_0: FloatTensor = dot.arccos().unsqueeze(-1)
907
+ sin_theta_0: FloatTensor = theta_0.sin()
908
+ # Angle at timestep t
909
+ theta_t: FloatTensor = theta_0 * t
910
+ sin_theta_t: FloatTensor = theta_t.sin()
911
+ # Finish the slerp algorithm
912
+ s0: FloatTensor = (theta_0 - theta_t).sin() / sin_theta_0
913
+ s1: FloatTensor = sin_theta_t / sin_theta_0
914
+ slerped: FloatTensor = s0 * v0 + s1 * v1
915
+
916
+ out: FloatTensor = slerped.where(can_slerp.unsqueeze(-1), out)
917
+
918
+ return out
pipeline_interpolated_sd.py ADDED
@@ -0,0 +1,1963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
20
+ from diffusers.configuration_utils import FrozenDict
21
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
22
+ from diffusers.loaders import (
23
+ FromSingleFileMixin,
24
+ IPAdapterMixin,
25
+ TextualInversionLoaderMixin,
26
+ )
27
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ FusedAttnProcessor2_0,
30
+ )
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
33
+ StableDiffusionPipelineOutput,
34
+ )
35
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
36
+ StableDiffusionSafetyChecker,
37
+ )
38
+ from diffusers.schedulers import KarrasDiffusionSchedulers
39
+ from diffusers.utils import (
40
+ deprecate,
41
+ is_torch_xla_available,
42
+ logging,
43
+ replace_example_docstring,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from packaging import version
47
+
48
+ from interpolation import (
49
+ InnerInterpolatedAttnProcessor,
50
+ InnerInterpolatedIPAttnProcessor,
51
+ OuterInterpolatedAttnProcessor,
52
+ OuterInterpolatedIPAttnProcessor,
53
+ ScaleControlIPAttnProcessor,
54
+ slerp,
55
+ )
56
+ from transformers import (
57
+ CLIPImageProcessor,
58
+ CLIPTextModel,
59
+ CLIPTokenizer,
60
+ CLIPVisionModelWithProjection,
61
+ )
62
+
63
+
64
+ if is_torch_xla_available():
65
+ import torch_xla.core.xla_model as xm # type: ignore
66
+
67
+ XLA_AVAILABLE = True
68
+ else:
69
+ XLA_AVAILABLE = False
70
+
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+ EXAMPLE_DOC_STRING = """
75
+ Examples:
76
+ ```py
77
+ >>> import torch
78
+ >>> from diffusers import StableDiffusionXLPipeline
79
+
80
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
81
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
82
+ ... )
83
+ >>> pipe = pipe.to("cuda")
84
+
85
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
86
+ >>> image = pipe(prompt).images[0]
87
+ ```
88
+ """
89
+
90
+
91
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
92
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
93
+ """
94
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
95
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
96
+ """
97
+ std_text = noise_pred_text.std(
98
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
99
+ )
100
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
101
+ # rescale the results from guidance (fixes overexposure)
102
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
103
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
104
+ noise_cfg = (
105
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
106
+ )
107
+ return noise_cfg
108
+
109
+
110
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
111
+ def retrieve_timesteps(
112
+ scheduler,
113
+ num_inference_steps: Optional[int] = None,
114
+ device: Optional[Union[str, torch.device]] = None,
115
+ timesteps: Optional[List[int]] = None,
116
+ **kwargs,
117
+ ):
118
+ """
119
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
120
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
121
+
122
+ Args:
123
+ scheduler (`SchedulerMixin`):
124
+ The scheduler to get timesteps from.
125
+ num_inference_steps (`int`):
126
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
127
+ `timesteps` must be `None`.
128
+ device (`str` or `torch.device`, *optional*):
129
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
130
+ timesteps (`List[int]`, *optional*):
131
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
132
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
133
+ must be `None`.
134
+
135
+ Returns:
136
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
137
+ second element is the number of inference steps.
138
+ """
139
+ if timesteps is not None:
140
+ accepts_timesteps = "timesteps" in set(
141
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
142
+ )
143
+ if not accepts_timesteps:
144
+ raise ValueError(
145
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
146
+ f" timestep schedules. Please check whether you are using the correct scheduler."
147
+ )
148
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ num_inference_steps = len(timesteps)
151
+ else:
152
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
153
+ timesteps = scheduler.timesteps
154
+ return timesteps, num_inference_steps
155
+
156
+
157
+ class StableDiffusionMixin:
158
+ r"""
159
+ Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
160
+ """
161
+
162
+ def enable_vae_slicing(self):
163
+ r"""
164
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
+ """
167
+ self.vae.enable_slicing()
168
+
169
+ def disable_vae_slicing(self):
170
+ r"""
171
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
172
+ computing decoding in one step.
173
+ """
174
+ self.vae.disable_slicing()
175
+
176
+ def enable_vae_tiling(self):
177
+ r"""
178
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
179
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
180
+ processing larger images.
181
+ """
182
+ self.vae.enable_tiling()
183
+
184
+ def disable_vae_tiling(self):
185
+ r"""
186
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
187
+ computing decoding in one step.
188
+ """
189
+ self.vae.disable_tiling()
190
+
191
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
192
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
193
+
194
+ The suffixes after the scaling factors represent the stages where they are being applied.
195
+
196
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
197
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
198
+
199
+ Args:
200
+ s1 (`float`):
201
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
202
+ mitigate "oversmoothing effect" in the enhanced denoising process.
203
+ s2 (`float`):
204
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
205
+ mitigate "oversmoothing effect" in the enhanced denoising process.
206
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
207
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
208
+ """
209
+ if not hasattr(self, "unet"):
210
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
211
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
212
+
213
+ def disable_freeu(self):
214
+ """Disables the FreeU mechanism if enabled."""
215
+ self.unet.disable_freeu()
216
+
217
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
218
+ """
219
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
220
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
221
+
222
+ <Tip warning={true}>
223
+
224
+ This API is 🧪 experimental.
225
+
226
+ </Tip>
227
+
228
+ Args:
229
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
230
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
231
+ """
232
+ self.fusing_unet = False
233
+ self.fusing_vae = False
234
+
235
+ if unet:
236
+ self.fusing_unet = True
237
+ self.unet.fuse_qkv_projections()
238
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
239
+
240
+ if vae:
241
+ if not isinstance(self.vae, AutoencoderKL):
242
+ raise ValueError(
243
+ "`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`."
244
+ )
245
+
246
+ self.fusing_vae = True
247
+ self.vae.fuse_qkv_projections()
248
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
249
+
250
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
251
+ """Disable QKV projection fusion if enabled.
252
+
253
+ <Tip warning={true}>
254
+
255
+ This API is 🧪 experimental.
256
+
257
+ </Tip>
258
+
259
+ Args:
260
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
261
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
262
+
263
+ """
264
+ if unet:
265
+ if not self.fusing_unet:
266
+ logger.warning(
267
+ "The UNet was not initially fused for QKV projections. Doing nothing."
268
+ )
269
+ else:
270
+ self.unet.unfuse_qkv_projections()
271
+ self.fusing_unet = False
272
+
273
+ if vae:
274
+ if not self.fusing_vae:
275
+ logger.warning(
276
+ "The VAE was not initially fused for QKV projections. Doing nothing."
277
+ )
278
+ else:
279
+ self.vae.unfuse_qkv_projections()
280
+ self.fusing_vae = False
281
+
282
+
283
+ class InterpolationStableDiffusionPipeline(
284
+ DiffusionPipeline,
285
+ StableDiffusionMixin,
286
+ TextualInversionLoaderMixin,
287
+ IPAdapterMixin,
288
+ FromSingleFileMixin,
289
+ ):
290
+ r"""
291
+ Pipeline for text-to-image generation using Stable Diffusion.
292
+
293
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
294
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
295
+
296
+ The pipeline also inherits the following loading methods:
297
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
298
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
299
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
300
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
301
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
302
+
303
+ Args:
304
+ vae ([`AutoencoderKL`]):
305
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
306
+ text_encoder ([`~transformers.CLIPTextModel`]):
307
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
308
+ tokenizer ([`~transformers.CLIPTokenizer`]):
309
+ A `CLIPTokenizer` to tokenize text.
310
+ unet ([`UNet2DConditionModel`]):
311
+ A `UNet2DConditionModel` to denoise the encoded image latents.
312
+ scheduler ([`SchedulerMixin`]):
313
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
314
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
315
+ safety_checker ([`StableDiffusionSafetyChecker`]):
316
+ Classification module that estimates whether generated images could be considered offensive or harmful.
317
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
318
+ about a model's potential harms.
319
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
320
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
321
+ """
322
+
323
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
324
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
325
+ _exclude_from_cpu_offload = ["safety_checker"]
326
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
327
+
328
+ def __init__(
329
+ self,
330
+ vae: AutoencoderKL,
331
+ text_encoder: CLIPTextModel,
332
+ tokenizer: CLIPTokenizer,
333
+ unet: UNet2DConditionModel,
334
+ scheduler: KarrasDiffusionSchedulers,
335
+ safety_checker: StableDiffusionSafetyChecker,
336
+ feature_extractor: CLIPImageProcessor,
337
+ image_encoder: CLIPVisionModelWithProjection = None,
338
+ requires_safety_checker: bool = True,
339
+ ):
340
+ super().__init__()
341
+
342
+ if (
343
+ hasattr(scheduler.config, "steps_offset")
344
+ and scheduler.config.steps_offset != 1
345
+ ):
346
+ deprecation_message = (
347
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
348
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
349
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
350
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
351
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
352
+ " file"
353
+ )
354
+ deprecate(
355
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
356
+ )
357
+ new_config = dict(scheduler.config)
358
+ new_config["steps_offset"] = 1
359
+ scheduler._internal_dict = FrozenDict(new_config)
360
+
361
+ if (
362
+ hasattr(scheduler.config, "clip_sample")
363
+ and scheduler.config.clip_sample is True
364
+ ):
365
+ deprecation_message = (
366
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
367
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
368
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
369
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
370
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
371
+ )
372
+ deprecate(
373
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
374
+ )
375
+ new_config = dict(scheduler.config)
376
+ new_config["clip_sample"] = False
377
+ scheduler._internal_dict = FrozenDict(new_config)
378
+
379
+ if safety_checker is None and requires_safety_checker:
380
+ logger.warning(
381
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
382
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
383
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
384
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
385
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
386
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
387
+ )
388
+
389
+ if safety_checker is not None and feature_extractor is None:
390
+ raise ValueError(
391
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
392
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
393
+ )
394
+
395
+ is_unet_version_less_0_9_0 = hasattr(
396
+ unet.config, "_diffusers_version"
397
+ ) and version.parse(
398
+ version.parse(unet.config._diffusers_version).base_version
399
+ ) < version.parse(
400
+ "0.9.0.dev0"
401
+ )
402
+ is_unet_sample_size_less_64 = (
403
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
404
+ )
405
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
406
+ deprecation_message = (
407
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
408
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
409
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
410
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
411
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
412
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
413
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
414
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
415
+ " the `unet/config.json` file"
416
+ )
417
+ deprecate(
418
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
419
+ )
420
+ new_config = dict(unet.config)
421
+ new_config["sample_size"] = 64
422
+ unet._internal_dict = FrozenDict(new_config)
423
+
424
+ self.register_modules(
425
+ vae=vae,
426
+ text_encoder=text_encoder,
427
+ tokenizer=tokenizer,
428
+ unet=unet,
429
+ scheduler=scheduler,
430
+ safety_checker=safety_checker,
431
+ feature_extractor=feature_extractor,
432
+ image_encoder=image_encoder,
433
+ )
434
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
435
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
436
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
437
+
438
+ self.load_aid()
439
+
440
+ def _encode_prompt(
441
+ self,
442
+ prompt,
443
+ device,
444
+ num_images_per_prompt,
445
+ do_classifier_free_guidance,
446
+ negative_prompt=None,
447
+ prompt_embeds: Optional[torch.Tensor] = None,
448
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
449
+ lora_scale: Optional[float] = None,
450
+ **kwargs,
451
+ ):
452
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
453
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ prompt_embeds_tuple = self.encode_prompt(
456
+ prompt=prompt,
457
+ device=device,
458
+ num_images_per_prompt=num_images_per_prompt,
459
+ do_classifier_free_guidance=do_classifier_free_guidance,
460
+ negative_prompt=negative_prompt,
461
+ prompt_embeds=prompt_embeds,
462
+ negative_prompt_embeds=negative_prompt_embeds,
463
+ lora_scale=lora_scale,
464
+ **kwargs,
465
+ )
466
+
467
+ # concatenate for backwards comp
468
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
469
+
470
+ return prompt_embeds
471
+
472
+ def encode_prompt(
473
+ self,
474
+ prompt,
475
+ device,
476
+ num_images_per_prompt,
477
+ do_classifier_free_guidance,
478
+ negative_prompt=None,
479
+ prompt_embeds: Optional[torch.Tensor] = None,
480
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
481
+ lora_scale: Optional[float] = None,
482
+ clip_skip: Optional[int] = None,
483
+ ):
484
+ r"""
485
+ Encodes the prompt into text encoder hidden states.
486
+
487
+ Args:
488
+ prompt (`str` or `List[str]`, *optional*):
489
+ prompt to be encoded
490
+ device: (`torch.device`):
491
+ torch device
492
+ num_images_per_prompt (`int`):
493
+ number of images that should be generated per prompt
494
+ do_classifier_free_guidance (`bool`):
495
+ whether to use classifier free guidance or not
496
+ negative_prompt (`str` or `List[str]`, *optional*):
497
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
498
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
499
+ less than `1`).
500
+ prompt_embeds (`torch.Tensor`, *optional*):
501
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
502
+ provided, text embeddings will be generated from `prompt` input argument.
503
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
504
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
505
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
506
+ argument.
507
+ lora_scale (`float`, *optional*):
508
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
509
+ clip_skip (`int`, *optional*):
510
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
511
+ the output of the pre-final layer will be used for computing the prompt embeddings.
512
+ """
513
+
514
+ if prompt is not None and isinstance(prompt, str):
515
+ batch_size = 1
516
+ elif prompt is not None and isinstance(prompt, list):
517
+ batch_size = len(prompt)
518
+ else:
519
+ batch_size = prompt_embeds.shape[0]
520
+
521
+ if prompt_embeds is None:
522
+ # textual inversion: process multi-vector tokens if necessary
523
+ if isinstance(self, TextualInversionLoaderMixin):
524
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
525
+
526
+ text_inputs = self.tokenizer(
527
+ prompt,
528
+ padding="max_length",
529
+ max_length=self.tokenizer.model_max_length,
530
+ truncation=True,
531
+ return_tensors="pt",
532
+ )
533
+ text_input_ids = text_inputs.input_ids
534
+ untruncated_ids = self.tokenizer(
535
+ prompt, padding="longest", return_tensors="pt"
536
+ ).input_ids
537
+
538
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
539
+ -1
540
+ ] and not torch.equal(text_input_ids, untruncated_ids):
541
+ removed_text = self.tokenizer.batch_decode(
542
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
543
+ )
544
+ logger.warning(
545
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
546
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
547
+ )
548
+
549
+ if (
550
+ hasattr(self.text_encoder.config, "use_attention_mask")
551
+ and self.text_encoder.config.use_attention_mask
552
+ ):
553
+ attention_mask = text_inputs.attention_mask.to(device)
554
+ else:
555
+ attention_mask = None
556
+
557
+ if clip_skip is None:
558
+ prompt_embeds = self.text_encoder(
559
+ text_input_ids.to(device), attention_mask=attention_mask
560
+ )
561
+ prompt_embeds = prompt_embeds[0]
562
+ else:
563
+ prompt_embeds = self.text_encoder(
564
+ text_input_ids.to(device),
565
+ attention_mask=attention_mask,
566
+ output_hidden_states=True,
567
+ )
568
+ # Access the `hidden_states` first, that contains a tuple of
569
+ # all the hidden states from the encoder layers. Then index into
570
+ # the tuple to access the hidden states from the desired layer.
571
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
572
+ # We also need to apply the final LayerNorm here to not mess with the
573
+ # representations. The `last_hidden_states` that we typically use for
574
+ # obtaining the final prompt representations passes through the LayerNorm
575
+ # layer.
576
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(
577
+ prompt_embeds
578
+ )
579
+
580
+ if self.text_encoder is not None:
581
+ prompt_embeds_dtype = self.text_encoder.dtype
582
+ elif self.unet is not None:
583
+ prompt_embeds_dtype = self.unet.dtype
584
+ else:
585
+ prompt_embeds_dtype = prompt_embeds.dtype
586
+
587
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
588
+
589
+ bs_embed, seq_len, _ = prompt_embeds.shape
590
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
591
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
592
+ prompt_embeds = prompt_embeds.view(
593
+ bs_embed * num_images_per_prompt, seq_len, -1
594
+ )
595
+
596
+ # get unconditional embeddings for classifier free guidance
597
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
598
+ uncond_tokens: List[str]
599
+ if negative_prompt is None:
600
+ uncond_tokens = [""] * batch_size
601
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
602
+ raise TypeError(
603
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
604
+ f" {type(prompt)}."
605
+ )
606
+ elif isinstance(negative_prompt, str):
607
+ uncond_tokens = [negative_prompt]
608
+ elif batch_size != len(negative_prompt):
609
+ raise ValueError(
610
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
611
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
612
+ " the batch size of `prompt`."
613
+ )
614
+ else:
615
+ uncond_tokens = negative_prompt
616
+
617
+ # textual inversion: process multi-vector tokens if necessary
618
+ if isinstance(self, TextualInversionLoaderMixin):
619
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
620
+
621
+ max_length = prompt_embeds.shape[1]
622
+ uncond_input = self.tokenizer(
623
+ uncond_tokens,
624
+ padding="max_length",
625
+ max_length=max_length,
626
+ truncation=True,
627
+ return_tensors="pt",
628
+ )
629
+
630
+ if (
631
+ hasattr(self.text_encoder.config, "use_attention_mask")
632
+ and self.text_encoder.config.use_attention_mask
633
+ ):
634
+ attention_mask = uncond_input.attention_mask.to(device)
635
+ else:
636
+ attention_mask = None
637
+
638
+ negative_prompt_embeds = self.text_encoder(
639
+ uncond_input.input_ids.to(device),
640
+ attention_mask=attention_mask,
641
+ )
642
+ negative_prompt_embeds = negative_prompt_embeds[0]
643
+
644
+ if do_classifier_free_guidance:
645
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
646
+ seq_len = negative_prompt_embeds.shape[1]
647
+
648
+ negative_prompt_embeds = negative_prompt_embeds.to(
649
+ dtype=prompt_embeds_dtype, device=device
650
+ )
651
+
652
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
653
+ 1, num_images_per_prompt, 1
654
+ )
655
+ negative_prompt_embeds = negative_prompt_embeds.view(
656
+ batch_size * num_images_per_prompt, seq_len, -1
657
+ )
658
+
659
+ return prompt_embeds, negative_prompt_embeds
660
+
661
+ def encode_image(
662
+ self, image, device, num_images_per_prompt, output_hidden_states=None
663
+ ):
664
+ dtype = next(self.image_encoder.parameters()).dtype
665
+
666
+ if not isinstance(image, torch.Tensor):
667
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
668
+
669
+ image = image.to(device=device, dtype=dtype)
670
+ if output_hidden_states:
671
+ image_enc_hidden_states = self.image_encoder(
672
+ image, output_hidden_states=True
673
+ ).hidden_states[-2]
674
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
675
+ num_images_per_prompt, dim=0
676
+ )
677
+ uncond_image_enc_hidden_states = self.image_encoder(
678
+ torch.zeros_like(image), output_hidden_states=True
679
+ ).hidden_states[-2]
680
+ uncond_image_enc_hidden_states = (
681
+ uncond_image_enc_hidden_states.repeat_interleave(
682
+ num_images_per_prompt, dim=0
683
+ )
684
+ )
685
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
686
+ else:
687
+ image_embeds = self.image_encoder(image).image_embeds
688
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
689
+ uncond_image_embeds = torch.zeros_like(image_embeds)
690
+
691
+ return image_embeds, uncond_image_embeds
692
+
693
+ def prepare_ip_adapter_image_embeds(
694
+ self,
695
+ ip_adapter_image,
696
+ ip_adapter_image_embeds,
697
+ device,
698
+ num_images_per_prompt,
699
+ do_classifier_free_guidance,
700
+ ):
701
+ image_embeds = []
702
+ if do_classifier_free_guidance:
703
+ negative_image_embeds = []
704
+ if ip_adapter_image_embeds is None:
705
+ if not isinstance(ip_adapter_image, list):
706
+ ip_adapter_image = [ip_adapter_image]
707
+
708
+ if len(ip_adapter_image) != len(
709
+ self.unet.encoder_hid_proj.image_projection_layers
710
+ ):
711
+ raise ValueError(
712
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
713
+ )
714
+
715
+ for single_ip_adapter_image, image_proj_layer in zip(
716
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
717
+ ):
718
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
719
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
720
+ single_ip_adapter_image, device, 1, output_hidden_state
721
+ )
722
+
723
+ image_embeds.append(single_image_embeds[None, :])
724
+ if do_classifier_free_guidance:
725
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
726
+ else:
727
+ for single_image_embeds in ip_adapter_image_embeds:
728
+ if do_classifier_free_guidance:
729
+ single_negative_image_embeds, single_image_embeds = (
730
+ single_image_embeds.chunk(2)
731
+ )
732
+ negative_image_embeds.append(single_negative_image_embeds)
733
+ image_embeds.append(single_image_embeds)
734
+
735
+ ip_adapter_image_embeds = []
736
+ for i, single_image_embeds in enumerate(image_embeds):
737
+ single_image_embeds = torch.cat(
738
+ [single_image_embeds] * num_images_per_prompt, dim=0
739
+ )
740
+ if do_classifier_free_guidance:
741
+ single_negative_image_embeds = torch.cat(
742
+ [negative_image_embeds[i]] * num_images_per_prompt, dim=0
743
+ )
744
+ single_image_embeds = torch.cat(
745
+ [single_negative_image_embeds, single_image_embeds], dim=0
746
+ )
747
+
748
+ single_image_embeds = single_image_embeds.to(device=device)
749
+ ip_adapter_image_embeds.append(single_image_embeds)
750
+
751
+ return ip_adapter_image_embeds
752
+
753
+ def run_safety_checker(self, image, device, dtype):
754
+ if self.safety_checker is None:
755
+ has_nsfw_concept = None
756
+ else:
757
+ if torch.is_tensor(image):
758
+ feature_extractor_input = self.image_processor.postprocess(
759
+ image, output_type="pil"
760
+ )
761
+ else:
762
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
763
+ safety_checker_input = self.feature_extractor(
764
+ feature_extractor_input, return_tensors="pt"
765
+ ).to(device)
766
+ image, has_nsfw_concept = self.safety_checker(
767
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
768
+ )
769
+ return image, has_nsfw_concept
770
+
771
+ def decode_latents(self, latents):
772
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
773
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
774
+
775
+ latents = 1 / self.vae.config.scaling_factor * latents
776
+ image = self.vae.decode(latents, return_dict=False)[0]
777
+ image = (image / 2 + 0.5).clamp(0, 1)
778
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
779
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
780
+ return image
781
+
782
+ def prepare_extra_step_kwargs(self, generator, eta):
783
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
784
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
785
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
786
+ # and should be between [0, 1]
787
+
788
+ accepts_eta = "eta" in set(
789
+ inspect.signature(self.scheduler.step).parameters.keys()
790
+ )
791
+ extra_step_kwargs = {}
792
+ if accepts_eta:
793
+ extra_step_kwargs["eta"] = eta
794
+
795
+ # check if the scheduler accepts generator
796
+ accepts_generator = "generator" in set(
797
+ inspect.signature(self.scheduler.step).parameters.keys()
798
+ )
799
+ if accepts_generator:
800
+ extra_step_kwargs["generator"] = generator
801
+ return extra_step_kwargs
802
+
803
+ def check_inputs(
804
+ self,
805
+ prompt,
806
+ height,
807
+ width,
808
+ callback_steps,
809
+ negative_prompt=None,
810
+ prompt_embeds=None,
811
+ negative_prompt_embeds=None,
812
+ ip_adapter_image=None,
813
+ ip_adapter_image_embeds=None,
814
+ callback_on_step_end_tensor_inputs=None,
815
+ ):
816
+ if height % 8 != 0 or width % 8 != 0:
817
+ raise ValueError(
818
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
819
+ )
820
+
821
+ if callback_steps is not None and (
822
+ not isinstance(callback_steps, int) or callback_steps <= 0
823
+ ):
824
+ raise ValueError(
825
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
826
+ f" {type(callback_steps)}."
827
+ )
828
+ if callback_on_step_end_tensor_inputs is not None and not all(
829
+ k in self._callback_tensor_inputs
830
+ for k in callback_on_step_end_tensor_inputs
831
+ ):
832
+ raise ValueError(
833
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
834
+ )
835
+
836
+ if prompt is not None and prompt_embeds is not None:
837
+ raise ValueError(
838
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
839
+ " only forward one of the two."
840
+ )
841
+ elif prompt is None and prompt_embeds is None:
842
+ raise ValueError(
843
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
844
+ )
845
+ elif prompt is not None and (
846
+ not isinstance(prompt, str) and not isinstance(prompt, list)
847
+ ):
848
+ raise ValueError(
849
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
850
+ )
851
+
852
+ if negative_prompt is not None and negative_prompt_embeds is not None:
853
+ raise ValueError(
854
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
855
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
856
+ )
857
+
858
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
859
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
860
+ raise ValueError(
861
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
862
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
863
+ f" {negative_prompt_embeds.shape}."
864
+ )
865
+
866
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
867
+ raise ValueError(
868
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
869
+ )
870
+
871
+ if ip_adapter_image_embeds is not None:
872
+ if not isinstance(ip_adapter_image_embeds, list):
873
+ raise ValueError(
874
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
875
+ )
876
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
877
+ raise ValueError(
878
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
879
+ )
880
+
881
+ def prepare_latents(
882
+ self,
883
+ batch_size,
884
+ num_channels_latents,
885
+ height,
886
+ width,
887
+ dtype,
888
+ device,
889
+ generator,
890
+ latents=None,
891
+ ):
892
+ shape = (
893
+ batch_size,
894
+ num_channels_latents,
895
+ int(height) // self.vae_scale_factor,
896
+ int(width) // self.vae_scale_factor,
897
+ )
898
+ if isinstance(generator, list) and len(generator) != batch_size:
899
+ raise ValueError(
900
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
901
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
902
+ )
903
+
904
+ if latents is None:
905
+ latents = randn_tensor(
906
+ shape, generator=generator, device=device, dtype=dtype
907
+ )
908
+ else:
909
+ latents = latents.to(device)
910
+
911
+ # scale the initial noise by the standard deviation required by the scheduler
912
+ latents = latents * self.scheduler.init_noise_sigma
913
+ return latents
914
+
915
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
916
+ def get_guidance_scale_embedding(
917
+ self,
918
+ w: torch.Tensor,
919
+ embedding_dim: int = 512,
920
+ dtype: torch.dtype = torch.float32,
921
+ ) -> torch.Tensor:
922
+ """
923
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
924
+
925
+ Args:
926
+ w (`torch.Tensor`):
927
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
928
+ embedding_dim (`int`, *optional*, defaults to 512):
929
+ Dimension of the embeddings to generate.
930
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
931
+ Data type of the generated embeddings.
932
+
933
+ Returns:
934
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
935
+ """
936
+ assert len(w.shape) == 1
937
+ w = w * 1000.0
938
+
939
+ half_dim = embedding_dim // 2
940
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
941
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
942
+ emb = w.to(dtype)[:, None] * emb[None, :]
943
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
944
+ if embedding_dim % 2 == 1: # zero pad
945
+ emb = torch.nn.functional.pad(emb, (0, 1))
946
+ assert emb.shape == (w.shape[0], embedding_dim)
947
+ return emb
948
+
949
+ # load interpolated attention processor
950
+ def load_aid(
951
+ self, t: Optional[float] = 0.5, is_fused: bool = True, atype="fused_outer"
952
+ ):
953
+ attn_procs = {}
954
+ for name in self.unet.attn_processors.keys():
955
+ if not name.startswith("encoder"):
956
+ if atype == "fused_outer":
957
+ attn_procs[name] = OuterInterpolatedAttnProcessor(
958
+ t=t,
959
+ is_fused=is_fused,
960
+ original_attn=self.unet.attn_processors[name],
961
+ )
962
+ elif atype == "fused_inner":
963
+ attn_procs[name] = InnerInterpolatedAttnProcessor(
964
+ t=t,
965
+ is_fused=is_fused,
966
+ original_attn=self.unet.attn_processors[name],
967
+ )
968
+ else:
969
+ attn_procs[name] = self.unet.attn_processors[name]
970
+ self.unet.set_attn_processor(attn_procs)
971
+
972
+ # load customized ip_adapter
973
+ def load_aid_ip_adapter(
974
+ self,
975
+ pretrained_model_name_or_path_or_dict: Union[
976
+ str, List[str], Dict[str, torch.Tensor]
977
+ ],
978
+ subfolder: Union[str, List[str]],
979
+ weight_name: Union[str, List[str]],
980
+ t: Optional[float] = 0.5,
981
+ is_fused: bool = True,
982
+ image_encoder_folder: Optional[str] = "image_encoder",
983
+ early="fused_outer",
984
+ **kwargs,
985
+ ):
986
+ self.load_ip_adapter(
987
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
988
+ subfolder=subfolder,
989
+ weight_name=weight_name,
990
+ image_encoder_folder=image_encoder_folder,
991
+ **kwargs,
992
+ )
993
+ attn_procs = {}
994
+ for name in self.unet.attn_processors.keys():
995
+ if not name.startswith("encoder"):
996
+ if early == "fused_outer":
997
+ attn_procs[name] = OuterInterpolatedIPAttnProcessor(
998
+ t=t, is_fused=is_fused, ip_attn=self.unet.attn_processors[name]
999
+ )
1000
+ elif early == "fused_inner":
1001
+ attn_procs[name] = InnerInterpolatedIPAttnProcessor(
1002
+ t=t, is_fused=is_fused, ip_attn=self.unet.attn_processors[name]
1003
+ )
1004
+ elif early == "scale_control":
1005
+ attn_procs[name] = ScaleControlIPAttnProcessor(
1006
+ t=t, is_fused=is_fused, ip_attn=self.unet.attn_processors[name]
1007
+ )
1008
+ else:
1009
+ attn_procs[name] = self.unet.attn_processors[name]
1010
+ self.unet.set_attn_processor(attn_procs)
1011
+
1012
+ def activate_aid(self, it: float):
1013
+ for name in self.unet.attn_processors.keys():
1014
+ if not name.startswith("encoder"):
1015
+ self.unet.attn_processors[name].activate(it)
1016
+
1017
+ def deactivate_aid(self):
1018
+ for name in self.unet.attn_processors.keys():
1019
+ if not name.startswith("encoder"):
1020
+ self.unet.attn_processors[name].deactivate()
1021
+
1022
+ @property
1023
+ def guidance_scale(self):
1024
+ return self._guidance_scale
1025
+
1026
+ @property
1027
+ def guidance_rescale(self):
1028
+ return self._guidance_rescale
1029
+
1030
+ @property
1031
+ def clip_skip(self):
1032
+ return self._clip_skip
1033
+
1034
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1035
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1036
+ # corresponds to doing no classifier free guidance.
1037
+ @property
1038
+ def do_classifier_free_guidance(self):
1039
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1040
+
1041
+ @property
1042
+ def cross_attention_kwargs(self):
1043
+ return self._cross_attention_kwargs
1044
+
1045
+ @property
1046
+ def num_timesteps(self):
1047
+ return self._num_timesteps
1048
+
1049
+ @property
1050
+ def interrupt(self):
1051
+ return self._interrupt
1052
+
1053
+ @torch.no_grad()
1054
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1055
+ def __call__(
1056
+ self,
1057
+ prompt: Union[str, List[str]] = None,
1058
+ height: Optional[int] = None,
1059
+ width: Optional[int] = None,
1060
+ num_inference_steps: int = 50,
1061
+ timesteps: List[int] = None,
1062
+ sigmas: List[float] = None,
1063
+ guidance_scale: float = 7.5,
1064
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1065
+ num_images_per_prompt: Optional[int] = 1,
1066
+ eta: float = 0.0,
1067
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1068
+ latents: Optional[torch.Tensor] = None,
1069
+ prompt_embeds: Optional[torch.Tensor] = None,
1070
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1071
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1072
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1073
+ output_type: Optional[str] = "pil",
1074
+ return_dict: bool = True,
1075
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1076
+ guidance_rescale: float = 0.0,
1077
+ clip_skip: Optional[int] = None,
1078
+ callback_on_step_end: Optional[
1079
+ Union[
1080
+ Callable[[int, int, Dict], None],
1081
+ PipelineCallback,
1082
+ MultiPipelineCallbacks,
1083
+ ]
1084
+ ] = None,
1085
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1086
+ **kwargs,
1087
+ ):
1088
+ r"""
1089
+ The call function to the pipeline for generation.
1090
+
1091
+ Args:
1092
+ prompt (`str` or `List[str]`, *optional*):
1093
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1094
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1095
+ The height in pixels of the generated image.
1096
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1097
+ The width in pixels of the generated image.
1098
+ num_inference_steps (`int`, *optional*, defaults to 50):
1099
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1100
+ expense of slower inference.
1101
+ timesteps (`List[int]`, *optional*):
1102
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1103
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1104
+ passed will be used. Must be in descending order.
1105
+ sigmas (`List[float]`, *optional*):
1106
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1107
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1108
+ will be used.
1109
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1110
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1111
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1112
+ negative_prompt (`str` or `List[str]`, *optional*):
1113
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1114
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1115
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1116
+ The number of images to generate per prompt.
1117
+ eta (`float`, *optional*, defaults to 0.0):
1118
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1119
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1120
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1121
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1122
+ generation deterministic.
1123
+ latents (`torch.Tensor`, *optional*):
1124
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1125
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1126
+ tensor is generated by sampling using the supplied random `generator`.
1127
+ prompt_embeds (`torch.Tensor`, *optional*):
1128
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1129
+ provided, text embeddings are generated from the `prompt` input argument.
1130
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1131
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1132
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1133
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1134
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1135
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1136
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1137
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1138
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1139
+ output_type (`str`, *optional*, defaults to `"pil"`):
1140
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1141
+ return_dict (`bool`, *optional*, defaults to `True`):
1142
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1143
+ plain tuple.
1144
+ cross_attention_kwargs (`dict`, *optional*):
1145
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1146
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1147
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1148
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
1149
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
1150
+ using zero terminal SNR.
1151
+ clip_skip (`int`, *optional*):
1152
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1153
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1154
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1155
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1156
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1157
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1158
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1159
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1160
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1161
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1162
+ `._callback_tensor_inputs` attribute of your pipeline class.
1163
+
1164
+ Examples:
1165
+
1166
+ Returns:
1167
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1168
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1169
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1170
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1171
+ "not-safe-for-work" (nsfw) content.
1172
+ """
1173
+
1174
+ callback = kwargs.pop("callback", None)
1175
+ callback_steps = kwargs.pop("callback_steps", None)
1176
+
1177
+ if callback is not None:
1178
+ deprecate(
1179
+ "callback",
1180
+ "1.0.0",
1181
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1182
+ )
1183
+ if callback_steps is not None:
1184
+ deprecate(
1185
+ "callback_steps",
1186
+ "1.0.0",
1187
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1188
+ )
1189
+
1190
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1191
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1192
+
1193
+ # 0. Default height and width to unet
1194
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1195
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1196
+ # to deal with lora scaling and other possible forward hooks
1197
+
1198
+ # 1. Check inputs. Raise error if not correct
1199
+ self.check_inputs(
1200
+ prompt,
1201
+ height,
1202
+ width,
1203
+ callback_steps,
1204
+ negative_prompt,
1205
+ prompt_embeds,
1206
+ negative_prompt_embeds,
1207
+ ip_adapter_image,
1208
+ ip_adapter_image_embeds,
1209
+ callback_on_step_end_tensor_inputs,
1210
+ )
1211
+
1212
+ self._guidance_scale = guidance_scale
1213
+ self._guidance_rescale = guidance_rescale
1214
+ self._clip_skip = clip_skip
1215
+ self._cross_attention_kwargs = cross_attention_kwargs
1216
+ self._interrupt = False
1217
+
1218
+ # 2. Define call parameters
1219
+ if prompt is not None and isinstance(prompt, str):
1220
+ batch_size = 1
1221
+ elif prompt is not None and isinstance(prompt, list):
1222
+ batch_size = len(prompt)
1223
+ else:
1224
+ batch_size = prompt_embeds.shape[0]
1225
+
1226
+ device = self._execution_device
1227
+
1228
+ # 3. Encode input prompt
1229
+ lora_scale = (
1230
+ self.cross_attention_kwargs.get("scale", None)
1231
+ if self.cross_attention_kwargs is not None
1232
+ else None
1233
+ )
1234
+
1235
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1236
+ prompt,
1237
+ device,
1238
+ num_images_per_prompt,
1239
+ self.do_classifier_free_guidance,
1240
+ negative_prompt,
1241
+ prompt_embeds=prompt_embeds,
1242
+ negative_prompt_embeds=negative_prompt_embeds,
1243
+ lora_scale=lora_scale,
1244
+ clip_skip=self.clip_skip,
1245
+ )
1246
+
1247
+ # For classifier free guidance, we need to do two forward passes.
1248
+ # Here we concatenate the unconditional and text embeddings into a single batch
1249
+ # to avoid doing two forward passes
1250
+ if self.do_classifier_free_guidance:
1251
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1252
+
1253
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1254
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1255
+ ip_adapter_image,
1256
+ ip_adapter_image_embeds,
1257
+ device,
1258
+ batch_size * num_images_per_prompt,
1259
+ self.do_classifier_free_guidance,
1260
+ )
1261
+
1262
+ # 4. Prepare timesteps
1263
+ timesteps, num_inference_steps = retrieve_timesteps(
1264
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1265
+ )
1266
+
1267
+ # 5. Prepare latent variables
1268
+ num_channels_latents = self.unet.config.in_channels
1269
+ latents = self.prepare_latents(
1270
+ batch_size * num_images_per_prompt,
1271
+ num_channels_latents,
1272
+ height,
1273
+ width,
1274
+ prompt_embeds.dtype,
1275
+ device,
1276
+ generator,
1277
+ latents,
1278
+ )
1279
+
1280
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1281
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1282
+
1283
+ # 6.1 Add image embeds for IP-Adapter
1284
+ added_cond_kwargs = (
1285
+ {"image_embeds": image_embeds}
1286
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1287
+ else None
1288
+ )
1289
+
1290
+ # 6.2 Optionally get Guidance Scale Embedding
1291
+ timestep_cond = None
1292
+ if self.unet.config.time_cond_proj_dim is not None:
1293
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
1294
+ batch_size * num_images_per_prompt
1295
+ )
1296
+ timestep_cond = self.get_guidance_scale_embedding(
1297
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1298
+ ).to(device=device, dtype=latents.dtype)
1299
+
1300
+ # 7. Denoising loop
1301
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1302
+ self._num_timesteps = len(timesteps)
1303
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1304
+ for i, t in enumerate(timesteps):
1305
+ if self.interrupt:
1306
+ continue
1307
+
1308
+ # expand the latents if we are doing classifier free guidance
1309
+ latent_model_input = (
1310
+ torch.cat([latents] * 2)
1311
+ if self.do_classifier_free_guidance
1312
+ else latents
1313
+ )
1314
+ latent_model_input = self.scheduler.scale_model_input(
1315
+ latent_model_input, t
1316
+ )
1317
+
1318
+ # predict the noise residual
1319
+ noise_pred = self.unet(
1320
+ latent_model_input,
1321
+ t,
1322
+ encoder_hidden_states=prompt_embeds,
1323
+ timestep_cond=timestep_cond,
1324
+ cross_attention_kwargs=self.cross_attention_kwargs,
1325
+ added_cond_kwargs=added_cond_kwargs,
1326
+ return_dict=False,
1327
+ )[0]
1328
+
1329
+ # perform guidance
1330
+ if self.do_classifier_free_guidance:
1331
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1332
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1333
+ noise_pred_text - noise_pred_uncond
1334
+ )
1335
+
1336
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1337
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1338
+ noise_pred = rescale_noise_cfg(
1339
+ noise_pred,
1340
+ noise_pred_text,
1341
+ guidance_rescale=self.guidance_rescale,
1342
+ )
1343
+
1344
+ # compute the previous noisy sample x_t -> x_t-1
1345
+ latents = self.scheduler.step(
1346
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1347
+ )[0]
1348
+
1349
+ if callback_on_step_end is not None:
1350
+ callback_kwargs = {}
1351
+ for k in callback_on_step_end_tensor_inputs:
1352
+ callback_kwargs[k] = locals()[k]
1353
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1354
+
1355
+ latents = callback_outputs.pop("latents", latents)
1356
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1357
+ negative_prompt_embeds = callback_outputs.pop(
1358
+ "negative_prompt_embeds", negative_prompt_embeds
1359
+ )
1360
+
1361
+ # call the callback, if provided
1362
+ if i == len(timesteps) - 1 or (
1363
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1364
+ ):
1365
+ progress_bar.update()
1366
+ if callback is not None and i % callback_steps == 0:
1367
+ step_idx = i // getattr(self.scheduler, "order", 1)
1368
+ callback(step_idx, t, latents)
1369
+
1370
+ if XLA_AVAILABLE:
1371
+ xm.mark_step()
1372
+
1373
+ if not output_type == "latent":
1374
+ image = self.vae.decode(
1375
+ latents / self.vae.config.scaling_factor,
1376
+ return_dict=False,
1377
+ generator=generator,
1378
+ )[0]
1379
+ image, has_nsfw_concept = self.run_safety_checker(
1380
+ image, device, prompt_embeds.dtype
1381
+ )
1382
+ else:
1383
+ image = latents
1384
+ has_nsfw_concept = None
1385
+
1386
+ if has_nsfw_concept is None:
1387
+ do_denormalize = [True] * image.shape[0]
1388
+ else:
1389
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1390
+
1391
+ image = self.image_processor.postprocess(
1392
+ image, output_type=output_type, do_denormalize=do_denormalize
1393
+ )
1394
+
1395
+ # Offload all models
1396
+ self.maybe_free_model_hooks()
1397
+
1398
+ if not return_dict:
1399
+ return (image, has_nsfw_concept)
1400
+
1401
+ return StableDiffusionPipelineOutput(
1402
+ images=image, nsfw_content_detected=has_nsfw_concept
1403
+ )
1404
+
1405
+ @torch.no_grad()
1406
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1407
+ def interpolate_single(
1408
+ self,
1409
+ it: int = 0.5,
1410
+ prompt_start: Optional[str] = None,
1411
+ prompt_end: Optional[str] = None,
1412
+ latent_start: Optional[torch.FloatTensor] = None,
1413
+ latent_end: Optional[torch.FloatTensor] = None,
1414
+ image_start: Optional[PipelineImageInput] = None,
1415
+ image_end: Optional[PipelineImageInput] = None,
1416
+ guide_prompt: Optional[str] = None,
1417
+ warmup_ratio: float = 0.5,
1418
+ is_fused: bool = True,
1419
+ atype: str = "outer",
1420
+ init: str = "linear",
1421
+ height: Optional[int] = None,
1422
+ width: Optional[int] = None,
1423
+ num_inference_steps: int = 50,
1424
+ timesteps: List[int] = None,
1425
+ sigmas: List[float] = None,
1426
+ guidance_scale: float = 7.5,
1427
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1428
+ num_images_per_prompt: Optional[int] = 1,
1429
+ eta: float = 0.0,
1430
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1431
+ latents: Optional[torch.FloatTensor] = None,
1432
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1433
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1434
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1435
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1436
+ output_type: Optional[str] = "pil",
1437
+ return_dict: bool = True,
1438
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1439
+ guidance_rescale: float = 0.0,
1440
+ clip_skip: Optional[int] = None,
1441
+ callback_on_step_end: Optional[
1442
+ Union[
1443
+ Callable[[int, int, Dict], None],
1444
+ PipelineCallback,
1445
+ MultiPipelineCallbacks,
1446
+ ]
1447
+ ] = None,
1448
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1449
+ **kwargs,
1450
+ ):
1451
+ r"""
1452
+ Function invoked when calling the pipeline for generation.
1453
+
1454
+ Args:
1455
+ prompt (`str` or `List[str]`, *optional*):
1456
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1457
+ instead.
1458
+ prompt_2 (`str` or `List[str]`, *optional*):
1459
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1460
+ used in both text-encoders
1461
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1462
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1463
+ Anything below 512 pixels won't work well for
1464
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1465
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1466
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1467
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1468
+ Anything below 512 pixels won't work well for
1469
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1470
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1471
+ num_inference_steps (`int`, *optional*, defaults to 50):
1472
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1473
+ expense of slower inference.
1474
+ timesteps (`List[int]`, *optional*):
1475
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1476
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1477
+ passed will be used. Must be in descending order.
1478
+ denoising_end (`float`, *optional*):
1479
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1480
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1481
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1482
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1483
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1484
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1485
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1486
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1487
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1488
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1489
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1490
+ usually at the expense of lower image quality.
1491
+ negative_prompt (`str` or `List[str]`, *optional*):
1492
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1493
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1494
+ less than `1`).
1495
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1496
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1497
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1498
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1499
+ The number of images to generate per prompt.
1500
+ eta (`float`, *optional*, defaults to 0.0):
1501
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1502
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1503
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1504
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1505
+ to make generation deterministic.
1506
+ latents (`torch.FloatTensor`, *optional*):
1507
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1508
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1509
+ tensor will ge generated by sampling using the supplied random `generator`.
1510
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1511
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1512
+ provided, text embeddings will be generated from `prompt` input argument.
1513
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1514
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1515
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1516
+ argument.
1517
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1518
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1519
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1520
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1521
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1522
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1523
+ input argument.
1524
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1525
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1526
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
1527
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
1528
+ if `do_classifier_free_guidance` is set to `True`.
1529
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1530
+ output_type (`str`, *optional*, defaults to `"pil"`):
1531
+ The output format of the generate image. Choose between
1532
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1533
+ return_dict (`bool`, *optional*, defaults to `True`):
1534
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1535
+ of a plain tuple.
1536
+ cross_attention_kwargs (`dict`, *optional*):
1537
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1538
+ `self.processor` in
1539
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1540
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1541
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1542
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1543
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1544
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1545
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1546
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1547
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1548
+ explained in section 2.2 of
1549
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1550
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1551
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1552
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1553
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1554
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1555
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1556
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1557
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1558
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1559
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1560
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1561
+ micro-conditioning as explained in section 2.2 of
1562
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1563
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1564
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1565
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1566
+ micro-conditioning as explained in section 2.2 of
1567
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1568
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1569
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1570
+ To negatively condition the generation process based on a target image resolution. It should be as same
1571
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1572
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1573
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1574
+ callback_on_step_end (`Callable`, *optional*):
1575
+ A function that calls at the end of each denoising steps during the inference. The function is called
1576
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1577
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1578
+ `callback_on_step_end_tensor_inputs`.
1579
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1580
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1581
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1582
+ `._callback_tensor_inputs` attribute of your pipeline class.
1583
+
1584
+ Examples:
1585
+
1586
+ Returns:
1587
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1588
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1589
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1590
+ """
1591
+
1592
+ callback = kwargs.pop("callback", None)
1593
+ callback_steps = kwargs.pop("callback_steps", None)
1594
+
1595
+ if callback is not None:
1596
+ deprecate(
1597
+ "callback",
1598
+ "1.0.0",
1599
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1600
+ )
1601
+ if callback_steps is not None:
1602
+ deprecate(
1603
+ "callback_steps",
1604
+ "1.0.0",
1605
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1606
+ )
1607
+
1608
+ if image_start is not None and image_end is None:
1609
+ # throw error
1610
+ raise ValueError(
1611
+ "Please provide both `image_start` and `image_end` to interpolate, or only `image_end` to control the scale."
1612
+ )
1613
+
1614
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1615
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1616
+
1617
+ # 0. Default height and width to unet
1618
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1619
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1620
+
1621
+ # 1. Check inputs. Raise error if not correct
1622
+ self.check_inputs(
1623
+ prompt_start,
1624
+ height,
1625
+ width,
1626
+ callback_steps,
1627
+ negative_prompt,
1628
+ prompt_embeds,
1629
+ negative_prompt_embeds,
1630
+ ip_adapter_image,
1631
+ ip_adapter_image_embeds,
1632
+ callback_on_step_end_tensor_inputs,
1633
+ )
1634
+
1635
+ self._guidance_scale = guidance_scale
1636
+ self._guidance_rescale = guidance_rescale
1637
+ self._clip_skip = clip_skip
1638
+ self._cross_attention_kwargs = cross_attention_kwargs
1639
+ self._interrupt = False
1640
+
1641
+ # 2. Define call parameters
1642
+ batch_size = 3 # [Source A, Interpolated, Source B]
1643
+
1644
+ device = self._execution_device
1645
+
1646
+ # 3. Encode input prompt
1647
+ lora_scale = (
1648
+ self.cross_attention_kwargs.get("scale", None)
1649
+ if self.cross_attention_kwargs is not None
1650
+ else None
1651
+ )
1652
+
1653
+ (prompt_embeds_start, negative_prompt_embeds_start) = self.encode_prompt(
1654
+ prompt=prompt_start,
1655
+ device=device,
1656
+ num_images_per_prompt=num_images_per_prompt,
1657
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1658
+ negative_prompt=negative_prompt,
1659
+ prompt_embeds=prompt_embeds,
1660
+ negative_prompt_embeds=negative_prompt_embeds,
1661
+ lora_scale=lora_scale,
1662
+ clip_skip=self.clip_skip,
1663
+ )
1664
+
1665
+ (prompt_embeds_end, negative_prompt_embeds_end) = self.encode_prompt(
1666
+ prompt=prompt_end,
1667
+ device=device,
1668
+ num_images_per_prompt=num_images_per_prompt,
1669
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1670
+ negative_prompt=negative_prompt,
1671
+ prompt_embeds=prompt_embeds,
1672
+ negative_prompt_embeds=negative_prompt_embeds,
1673
+ lora_scale=lora_scale,
1674
+ clip_skip=self.clip_skip,
1675
+ )
1676
+
1677
+ if guide_prompt is not None:
1678
+ (prompt_embeds_target, negative_prompt_embeds_target) = self.encode_prompt(
1679
+ prompt=guide_prompt,
1680
+ device=device,
1681
+ num_images_per_prompt=num_images_per_prompt,
1682
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1683
+ negative_prompt=negative_prompt,
1684
+ prompt_embeds=prompt_embeds,
1685
+ negative_prompt_embeds=negative_prompt_embeds,
1686
+ lora_scale=lora_scale,
1687
+ clip_skip=self.clip_skip,
1688
+ )
1689
+ else:
1690
+ if init == "linear":
1691
+ prompt_embeds_target = torch.lerp(
1692
+ prompt_embeds_start, prompt_embeds_end, it
1693
+ )
1694
+ negative_prompt_embeds_target = torch.lerp(
1695
+ negative_prompt_embeds_start, negative_prompt_embeds_end, it
1696
+ )
1697
+ else:
1698
+ prompt_embeds_target = slerp(prompt_embeds_start, prompt_embeds_end, it)
1699
+ negative_prompt_embeds_target = slerp(
1700
+ negative_prompt_embeds_start, negative_prompt_embeds_end, it
1701
+ )
1702
+
1703
+ prompt_embeds = torch.cat(
1704
+ [prompt_embeds_start, prompt_embeds_target, prompt_embeds_end], dim=0
1705
+ ).to(device=device)
1706
+ negative_prompt_embeds = torch.cat(
1707
+ [
1708
+ negative_prompt_embeds_start,
1709
+ negative_prompt_embeds_target,
1710
+ negative_prompt_embeds_end,
1711
+ ],
1712
+ dim=0,
1713
+ ).to(device=device)
1714
+
1715
+ # 4. Prepare timesteps
1716
+ timesteps, num_inference_steps = retrieve_timesteps(
1717
+ self.scheduler, num_inference_steps, device, timesteps
1718
+ )
1719
+
1720
+ # 5. Prepare latent variables
1721
+ num_channels_latents = self.unet.config.in_channels
1722
+ latent_start = self.prepare_latents(
1723
+ 1,
1724
+ num_channels_latents,
1725
+ height,
1726
+ width,
1727
+ prompt_embeds.dtype,
1728
+ device,
1729
+ generator,
1730
+ latent_start,
1731
+ )
1732
+
1733
+ latent_end = self.prepare_latents(
1734
+ 1,
1735
+ num_channels_latents,
1736
+ height,
1737
+ width,
1738
+ prompt_embeds.dtype,
1739
+ device,
1740
+ generator,
1741
+ latent_end,
1742
+ )
1743
+
1744
+ latent_target = slerp(latent_start, latent_end, it)
1745
+ latents = torch.cat([latent_start, latent_target, latent_end], dim=0).to(
1746
+ device=device
1747
+ )
1748
+
1749
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1750
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1751
+
1752
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1753
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1754
+ ip_adapter_image,
1755
+ ip_adapter_image_embeds,
1756
+ device,
1757
+ 3,
1758
+ self.do_classifier_free_guidance,
1759
+ )
1760
+
1761
+ # 6.1 Prepare image embeddings for interpolation
1762
+ if image_end is not None:
1763
+ image_embeds_end = self.prepare_ip_adapter_image_embeds(
1764
+ image_end,
1765
+ None,
1766
+ device,
1767
+ 3,
1768
+ self.do_classifier_free_guidance,
1769
+ )
1770
+ negative_image_embeds_end, image_embeds_end = image_embeds_end[0].chunk(2)
1771
+
1772
+ if image_start is None:
1773
+ image_embeds_start = negative_image_embeds_end
1774
+ negative_image_embeds_start = negative_image_embeds_end
1775
+ else:
1776
+ image_embeds_start = self.prepare_ip_adapter_image_embeds(
1777
+ image_start,
1778
+ None,
1779
+ device,
1780
+ 3,
1781
+ self.do_classifier_free_guidance,
1782
+ )
1783
+ negative_image_embeds_start, image_embeds_start = image_embeds_start[
1784
+ 0
1785
+ ].chunk(2)
1786
+
1787
+ if init == "linear":
1788
+ image_embeds_target = torch.lerp(
1789
+ image_embeds_start, image_embeds_end, it
1790
+ )
1791
+ negative_image_embeds_target = torch.lerp(
1792
+ negative_image_embeds_start, negative_image_embeds_end, it
1793
+ )
1794
+ else:
1795
+ image_embeds_target = slerp(image_embeds_start, image_embeds_end, it)
1796
+ negative_image_embeds_target = slerp(
1797
+ negative_image_embeds_start, negative_image_embeds_end, it
1798
+ )
1799
+
1800
+ image_embeds = torch.cat(
1801
+ [image_embeds_start, image_embeds_target, image_embeds_end], dim=0
1802
+ ).to(device=device)
1803
+
1804
+ negative_image_embeds = torch.cat(
1805
+ [
1806
+ negative_image_embeds_start,
1807
+ negative_image_embeds_target,
1808
+ negative_image_embeds_end,
1809
+ ],
1810
+ dim=0,
1811
+ ).to(device=device)
1812
+
1813
+ image_embeds = [image_embeds]
1814
+ negative_image_embeds = [negative_image_embeds]
1815
+
1816
+ # 6.2 Optionally get Guidance Scale Embedding
1817
+ timestep_cond = None
1818
+ if self.unet.config.time_cond_proj_dim is not None:
1819
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
1820
+ batch_size * num_images_per_prompt
1821
+ )
1822
+ timestep_cond = self.get_guidance_scale_embedding(
1823
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1824
+ ).to(device=device, dtype=latents.dtype)
1825
+
1826
+ # 7. Denoising loop
1827
+ num_warmup_steps = max(
1828
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
1829
+ )
1830
+
1831
+ warmup_steps = int(num_inference_steps * warmup_ratio)
1832
+ self._num_timesteps = len(timesteps)
1833
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1834
+ for i, t in enumerate(timesteps):
1835
+ if self.interrupt:
1836
+ continue
1837
+
1838
+ # expand the latents if we are doing classifier free guidance
1839
+ latent_model_input = latents
1840
+ latent_model_input = self.scheduler.scale_model_input(
1841
+ latent_model_input, t
1842
+ )
1843
+
1844
+ # Set the interpolated attention processor
1845
+ if i < warmup_steps:
1846
+ self.activate_aid(it)
1847
+ else:
1848
+ self.deactivate_aid()
1849
+
1850
+ # predict the noise residual for conditional noise
1851
+ if (
1852
+ (image_start is not None or image_end is not None)
1853
+ or ip_adapter_image is not None
1854
+ or ip_adapter_image_embeds is not None
1855
+ ):
1856
+ added_cond_kwargs = {"image_embeds": image_embeds}
1857
+ else:
1858
+ added_cond_kwargs = None
1859
+ noise_pred_text = self.unet(
1860
+ latent_model_input,
1861
+ t,
1862
+ encoder_hidden_states=prompt_embeds,
1863
+ timestep_cond=timestep_cond,
1864
+ cross_attention_kwargs=self.cross_attention_kwargs,
1865
+ added_cond_kwargs=added_cond_kwargs,
1866
+ return_dict=False,
1867
+ )[0]
1868
+
1869
+ # Set back to usual attention processor, if using image_embed, dont do this
1870
+ self.deactivate_aid()
1871
+
1872
+ # predict the noise residual for negative noise
1873
+ if (
1874
+ (image_start is not None or image_end is not None)
1875
+ or ip_adapter_image is not None
1876
+ or ip_adapter_image_embeds is not None
1877
+ ):
1878
+ added_cond_kwargs = {"image_embeds": negative_image_embeds}
1879
+ else:
1880
+ None
1881
+ noise_pred_uncond = self.unet(
1882
+ latent_model_input,
1883
+ t,
1884
+ encoder_hidden_states=negative_prompt_embeds,
1885
+ timestep_cond=timestep_cond,
1886
+ cross_attention_kwargs=self.cross_attention_kwargs,
1887
+ added_cond_kwargs=added_cond_kwargs,
1888
+ return_dict=False,
1889
+ )[0]
1890
+
1891
+ # perform guidance
1892
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1893
+ noise_pred_text - noise_pred_uncond
1894
+ )
1895
+
1896
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1897
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1898
+ noise_pred = rescale_noise_cfg(
1899
+ noise_pred,
1900
+ noise_pred_text,
1901
+ guidance_rescale=self.guidance_rescale,
1902
+ )
1903
+
1904
+ # compute the previous noisy sample x_t -> x_t-1
1905
+ latents = self.scheduler.step(
1906
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1907
+ )[0]
1908
+
1909
+ if callback_on_step_end is not None:
1910
+ callback_kwargs = {}
1911
+ for k in callback_on_step_end_tensor_inputs:
1912
+ callback_kwargs[k] = locals()[k]
1913
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1914
+
1915
+ latents = callback_outputs.pop("latents", latents)
1916
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1917
+ negative_prompt_embeds = callback_outputs.pop(
1918
+ "negative_prompt_embeds", negative_prompt_embeds
1919
+ )
1920
+
1921
+ # call the callback, if provided
1922
+ if i == len(timesteps) - 1 or (
1923
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1924
+ ):
1925
+ progress_bar.update()
1926
+ if callback is not None and i % callback_steps == 0:
1927
+ step_idx = i // getattr(self.scheduler, "order", 1)
1928
+ callback(step_idx, t, latents)
1929
+
1930
+ if XLA_AVAILABLE:
1931
+ xm.mark_step()
1932
+
1933
+ if not output_type == "latent":
1934
+ image = self.vae.decode(
1935
+ latents / self.vae.config.scaling_factor,
1936
+ return_dict=False,
1937
+ generator=generator,
1938
+ )[0]
1939
+ image, has_nsfw_concept = self.run_safety_checker(
1940
+ image, device, prompt_embeds.dtype
1941
+ )
1942
+ else:
1943
+ image = latents
1944
+ has_nsfw_concept = None
1945
+
1946
+ if has_nsfw_concept is None:
1947
+ do_denormalize = [True] * image.shape[0]
1948
+ else:
1949
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1950
+
1951
+ image = self.image_processor.postprocess(
1952
+ image, output_type=output_type, do_denormalize=do_denormalize
1953
+ )
1954
+
1955
+ # Offload all models
1956
+ self.maybe_free_model_hooks()
1957
+
1958
+ if not return_dict:
1959
+ return (image, has_nsfw_concept)
1960
+
1961
+ return StableDiffusionPipelineOutput(
1962
+ images=image, nsfw_content_detected=has_nsfw_concept
1963
+ )
pipeline_interpolated_sdxl.py ADDED
The diff for this file is too large to render. See raw diff
 
prior.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from bayes_opt import BayesianOptimization, SequentialDomainReductionTransformer
4
+ from lpips import LPIPS
5
+ from scipy.optimize import curve_fit
6
+ from scipy.stats import beta as beta_distribution
7
+
8
+ from transformers import CLIPImageProcessor, CLIPModel
9
+ from utils import compute_lpips, compute_smoothness_and_consistency
10
+
11
+
12
+ class BetaPriorPipeline:
13
+ def __init__(self, pipe, model_ID="openai/clip-vit-base-patch32"):
14
+ self.model = CLIPModel.from_pretrained(model_ID)
15
+ self.preprocess = CLIPImageProcessor.from_pretrained(model_ID)
16
+ self.pipe = pipe
17
+
18
+ def _compute_clip(self, embedding_a, embedding_b):
19
+ similarity_score = torch.nn.functional.cosine_similarity(
20
+ embedding_a, embedding_b
21
+ )
22
+ return 1 - similarity_score[0]
23
+
24
+ def _get_feature(self, image):
25
+ with torch.no_grad():
26
+ if isinstance(image, np.ndarray):
27
+ image = self.preprocess(
28
+ image, return_tensors="pt", do_rescale=False
29
+ ).pixel_values
30
+ else:
31
+ image = self.preprocess(image, return_tensors="pt").pixel_values
32
+ embedding = self.model.get_image_features(image)
33
+ return embedding
34
+
35
+ def _update_alpha_beta(self, xs, ds):
36
+ uniform_point = []
37
+ ds_sum = sum(ds)
38
+ for i in range(len(ds)):
39
+ uniform_point.append(ds[i] / ds_sum)
40
+ uniform_point = [0] + uniform_point
41
+ uniform_points = np.cumsum(uniform_point)
42
+
43
+ xs = np.asarray(xs)
44
+ uniform_points = np.asarray(uniform_points)
45
+
46
+ def beta_cdf(x, alpha, beta_param):
47
+ return beta_distribution.cdf(x, alpha, beta_param)
48
+
49
+ initial_guess = [1.0, 1.0]
50
+ bounds = ([1e-6, 1e-6], [np.inf, np.inf])
51
+ params, covariance = curve_fit(
52
+ beta_cdf, xs, uniform_points, p0=initial_guess, bounds=bounds
53
+ )
54
+
55
+ fitted_alpha, fitted_beta = params
56
+ return fitted_alpha, fitted_beta
57
+
58
+ def _add_next_point(
59
+ self,
60
+ ds,
61
+ xs,
62
+ images,
63
+ features,
64
+ alpha,
65
+ beta_param,
66
+ prompt_start,
67
+ prompt_end,
68
+ negative_prompt,
69
+ latent_start,
70
+ latent_end,
71
+ num_inference_steps,
72
+ uniform=False,
73
+ **kwargs,
74
+ ):
75
+ idx = np.argmax(ds)
76
+ A = xs[idx]
77
+ B = xs[idx + 1]
78
+ F_A = beta_distribution.cdf(A, alpha, beta_param)
79
+ F_B = beta_distribution.cdf(B, alpha, beta_param)
80
+
81
+ # Compute the target CDF for t
82
+ F_t = (F_A + F_B) / 2
83
+
84
+ # Compute the value of t using the inverse CDF (percent point function)
85
+ t = beta_distribution.ppf(F_t, alpha, beta_param)
86
+
87
+ if uniform:
88
+ idx = np.argmax(np.array(xs) - np.array([0] + xs[:-1])) - 1
89
+ t = (xs[idx] + xs[idx + 1]) / 2
90
+
91
+ if t < 0 or t > 1:
92
+ return xs, False
93
+
94
+ ims = self.pipe.interpolate_single(
95
+ t,
96
+ prompt_start=prompt_start,
97
+ prompt_end=prompt_end,
98
+ negative_prompt=negative_prompt,
99
+ latent_start=latent_start,
100
+ latent_end=latent_end,
101
+ early="fused_outer",
102
+ num_inference_steps=num_inference_steps,
103
+ **kwargs,
104
+ )
105
+
106
+ added_image = ims.images[1]
107
+ added_feature = self._get_feature(added_image)
108
+ d1 = self._compute_clip(features[idx], added_feature)
109
+ d2 = self._compute_clip(features[idx + 1], added_feature)
110
+
111
+ images.insert(idx + 1, ims.images[1])
112
+ features.insert(idx + 1, added_feature)
113
+ xs.insert(idx + 1, t)
114
+ del ds[idx]
115
+ ds.insert(idx, d1)
116
+ ds.insert(idx + 1, d2)
117
+ return xs, True
118
+
119
+ def explore_with_beta(
120
+ self,
121
+ progress,
122
+ prompt_start,
123
+ prompt_end,
124
+ negative_prompt,
125
+ latent_start,
126
+ latent_end,
127
+ num_inference_steps=28,
128
+ exploration_size=16,
129
+ init_alpha=3,
130
+ init_beta=3,
131
+ uniform=False,
132
+ **kwargs,
133
+ ):
134
+ xs = [0.0, 0.5, 1.0]
135
+ images = self.pipe.interpolate_single(
136
+ 0.5,
137
+ prompt_start=prompt_start,
138
+ prompt_end=prompt_end,
139
+ negative_prompt=negative_prompt,
140
+ latent_start=latent_start,
141
+ latent_end=latent_end,
142
+ early="fused_outer",
143
+ num_inference_steps=num_inference_steps,
144
+ **kwargs,
145
+ )
146
+ images = images.images
147
+ images = [images[0], images[1], images[2]]
148
+ features = [self._get_feature(image) for image in images]
149
+ ds = [
150
+ self._compute_clip(features[0], features[1]),
151
+ self._compute_clip(features[1], features[2]),
152
+ ]
153
+ alpha = init_alpha
154
+ beta_param = init_beta
155
+ print(
156
+ "Alpha:",
157
+ alpha,
158
+ "| Beta:",
159
+ beta_param,
160
+ "| Current Coefs:",
161
+ xs,
162
+ "| Current Distances:",
163
+ ds,
164
+ )
165
+ progress(3, desc="Exploration")
166
+ for i in progress.tqdm(range(3, exploration_size)):
167
+ xs, flag = self._add_next_point(
168
+ ds,
169
+ xs,
170
+ images,
171
+ features,
172
+ alpha,
173
+ beta_param,
174
+ prompt_start,
175
+ prompt_end,
176
+ negative_prompt,
177
+ latent_start,
178
+ latent_end,
179
+ num_inference_steps,
180
+ uniform=uniform,
181
+ **kwargs,
182
+ )
183
+ if not flag:
184
+ break
185
+ alpha, beta_param = self._update_alpha_beta(xs, ds)
186
+ if uniform:
187
+ alpha = 1
188
+ beta_param = 1
189
+ print(f"--------Exploration: {len(xs)} / {exploration_size}--------")
190
+ print(
191
+ "Alpha:",
192
+ alpha,
193
+ "| Beta:",
194
+ beta_param,
195
+ "| Current Coefs:",
196
+ xs,
197
+ "| Current Distances:",
198
+ ds,
199
+ )
200
+
201
+ return images, features, ds, xs, alpha, beta_param
202
+
203
+ def extract_uniform_points(self, ds, interpolation_size):
204
+ expected_dis = sum(ds) / (interpolation_size - 1)
205
+ current_sum = 0
206
+ output_idxs = [0]
207
+ for idx, d in enumerate(ds):
208
+ current_sum += d
209
+ if current_sum >= expected_dis:
210
+ output_idxs.append(idx)
211
+ current_sum = 0
212
+ return output_idxs
213
+
214
+ def extract_uniform_points_plus(self, features, interpolation_size):
215
+ weights = -1 * np.ones((len(features), len(features)))
216
+ for i in range(len(features)):
217
+ for j in range(i + 1, len(features)):
218
+ weights[i][j] = self._compute_clip(features[i], features[j])
219
+ m = len(features)
220
+ n = interpolation_size
221
+ _, best_path = self.find_minimal_spread_and_path(n, m, weights)
222
+ print("Optimal smooth path:", best_path)
223
+ return best_path
224
+
225
+ def find_minimal_spread_and_path(self, n, m, weights):
226
+ # Collect all unique edge weights, excluding non-existent edges (-1)
227
+ W = sorted(
228
+ {
229
+ weights[i][j]
230
+ for i in range(m - 1)
231
+ for j in range(i + 1, m)
232
+ if weights[i][j] != -1
233
+ }
234
+ )
235
+ min_weight = W[0]
236
+ max_weight = W[-1]
237
+
238
+ low = 0.0
239
+ high = max_weight - min_weight
240
+ epsilon = 1e-6 # Desired precision
241
+
242
+ best_D = None
243
+ best_path = None
244
+
245
+ while high - low > epsilon:
246
+ D = (low + high) / 2
247
+ result = self.is_path_possible(D, n, m, weights, W)
248
+ if result is not None:
249
+ # A valid path is found
250
+ high = D
251
+ best_D = D
252
+ best_path = result
253
+ else:
254
+ low = D
255
+
256
+ return best_D, best_path
257
+
258
+ def is_path_possible(self, D, n, m, weights, W):
259
+ for w_min in W:
260
+ w_max = w_min + D
261
+ if w_max > W[-1]:
262
+ break
263
+
264
+ # Dynamic Programming to check for a valid path
265
+ dp = [[None] * (n + 1) for _ in range(m)]
266
+ dp[0][1] = (
267
+ float("-inf"),
268
+ float("inf"),
269
+ [0],
270
+ ) # Start from x1 with path length 1
271
+
272
+ for l in range(1, n):
273
+ for i in range(m):
274
+ if dp[i][l] is not None:
275
+ max_w, min_w, path = dp[i][l]
276
+ for j in range(i + 1, m):
277
+ w = weights[i][j]
278
+ if w != -1 and w_min <= w <= w_max:
279
+ # Update max and min weights along the path
280
+ new_max_w = max(max_w, w)
281
+ new_min_w = min(min_w, w)
282
+ new_diff = new_max_w - new_min_w
283
+ if new_diff <= D:
284
+ dp_j_l_plus_1 = dp[j][l + 1]
285
+ if dp_j_l_plus_1 is None or new_diff < (
286
+ dp_j_l_plus_1[0] - dp_j_l_plus_1[1]
287
+ ):
288
+ dp[j][l + 1] = (
289
+ new_max_w,
290
+ new_min_w,
291
+ path + [j],
292
+ )
293
+
294
+ if dp[m - 1][n] is not None:
295
+ # Reconstruct the path
296
+ _, _, path = dp[m - 1][n]
297
+ return path # Return the path if found
298
+
299
+ return None # Return None if no valid path is found
300
+
301
+ def generate_interpolation(
302
+ self,
303
+ progress,
304
+ prompt_start,
305
+ prompt_end,
306
+ negative_prompt,
307
+ latent_start,
308
+ latent_end,
309
+ num_inference_steps=28,
310
+ exploration_size=16,
311
+ init_alpha=3,
312
+ init_beta=3,
313
+ interpolation_size=7,
314
+ uniform=False,
315
+ **kwargs,
316
+ ):
317
+ images, features, ds, xs, alpha, beta_param = self.explore_with_beta(
318
+ progress,
319
+ prompt_start,
320
+ prompt_end,
321
+ negative_prompt,
322
+ latent_start,
323
+ latent_end,
324
+ num_inference_steps,
325
+ exploration_size,
326
+ init_alpha,
327
+ init_beta,
328
+ uniform=uniform,
329
+ **kwargs,
330
+ )
331
+ # output_idx = self.extract_uniform_points(ds, interpolation_size)
332
+ output_idx = self.extract_uniform_points_plus(features, interpolation_size)
333
+ output_images = []
334
+ for idx in output_idx:
335
+ output_images.append(images[idx])
336
+
337
+ # for call_back
338
+ self.images = images
339
+ self.ds = ds
340
+ self.xs = xs
341
+ self.alpha = alpha
342
+ self.beta_param = beta_param
343
+
344
+ return output_images
345
+
346
+
347
+ def bayesian_prior_selection(
348
+ interpolation_pipe,
349
+ latent1: torch.FloatTensor,
350
+ latent2: torch.FloatTensor,
351
+ prompt1: str,
352
+ prompt2: str,
353
+ lpips_model: LPIPS,
354
+ guide_prompt: str | None = None,
355
+ negative_prompt: str = "",
356
+ size: int = 3,
357
+ num_inference_steps: int = 25,
358
+ warmup_ratio: float = 1,
359
+ early: str = "vfused",
360
+ late: str = "self",
361
+ target_score: float = 0.9,
362
+ n_iter: int = 15,
363
+ p_min: float | None = None,
364
+ p_max: float | None = None,
365
+ ) -> tuple:
366
+ """
367
+ Select the alpha and beta parameters for the interpolation using Bayesian optimization.
368
+
369
+ Args:
370
+ interpolation_pipe (any): The interpolation pipeline.
371
+ latent1 (torch.FloatTensor): The first source latent vector.
372
+ latent2 (torch.FloatTensor): The second source latent vector.
373
+ prompt1 (str): The first source prompt.
374
+ prompt2 (str): The second source prompt.
375
+ lpips_model (any): The LPIPS model used to compute perceptual distances.
376
+ guide_prompt (str | None, optional): The guide prompt for the interpolation, if any. Defaults to None.
377
+ negative_prompt (str, optional): The negative prompt for the interpolation, default to empty string. Defaults to "".
378
+ size (int, optional): The size of the interpolation sequence. Defaults to 3.
379
+ num_inference_steps (int, optional): The number of inference steps. Defaults to 25.
380
+ warmup_ratio (float, optional): The warmup ratio. Defaults to 1.
381
+ early (str, optional): The early fusion method. Defaults to "vfused".
382
+ late (str, optional): The late fusion method. Defaults to "self".
383
+ target_score (float, optional): The target score. Defaults to 0.9.
384
+ n_iter (int, optional): The maximum number of iterations. Defaults to 15.
385
+ p_min (float, optional): The minimum value of alpha and beta. Defaults to None.
386
+ p_max (float, optional): The maximum value of alpha and beta. Defaults to None.
387
+ Returns:
388
+ tuple: A tuple containing the selected alpha and beta parameters.
389
+ """
390
+
391
+ def get_smoothness(alpha, beta):
392
+ """
393
+ Black-box objective function of Bayesian Optimization.
394
+ Get the smoothness of the interpolated sequence with the given alpha and beta.
395
+ """
396
+ if alpha < beta and large_alpha_prior:
397
+ return 0
398
+ if alpha > beta and not large_alpha_prior:
399
+ return 0
400
+ if alpha == beta:
401
+ return init_smoothness
402
+ interpolation_sequence = interpolation_pipe.interpolate_save_gpu(
403
+ latent1,
404
+ latent2,
405
+ prompt1,
406
+ prompt2,
407
+ guide_prompt=guide_prompt,
408
+ negative_prompt=negative_prompt,
409
+ size=size,
410
+ num_inference_steps=num_inference_steps,
411
+ warmup_ratio=warmup_ratio,
412
+ early=early,
413
+ late=late,
414
+ alpha=alpha,
415
+ beta=beta,
416
+ )
417
+ smoothness, _, _ = compute_smoothness_and_consistency(
418
+ interpolation_sequence, lpips_model
419
+ )
420
+ return smoothness
421
+
422
+ # Add prior into selection of alpha and beta
423
+ # We firstly compute the interpolated images with t=0.5
424
+ images = interpolation_pipe.interpolate_single(
425
+ 0.5,
426
+ latent1,
427
+ latent2,
428
+ prompt1,
429
+ prompt2,
430
+ guide_prompt=guide_prompt,
431
+ negative_prompt=negative_prompt,
432
+ num_inference_steps=num_inference_steps,
433
+ warmup_ratio=warmup_ratio,
434
+ early=early,
435
+ late=late,
436
+ )
437
+ # We compute the perceptual distances of the interpolated images (t=0.5) to the source image
438
+ distances = compute_lpips(images, lpips_model)
439
+ # We compute the init_smoothness as the smoothness when alpha=beta to avoid recomputation
440
+ init_smoothness, _, _ = compute_smoothness_and_consistency(images, lpips_model)
441
+ # If perceptual distance to the first source image is smaller, alpha should be larger than beta
442
+ large_alpha_prior = distances[0] < distances[1]
443
+
444
+ # Bayesian optimization configuration
445
+ num_warmup_steps = warmup_ratio * num_inference_steps
446
+ if p_min is None:
447
+ p_min = 1
448
+ if p_max is None:
449
+ p_max = num_warmup_steps
450
+ pbounds = {"alpha": (p_min, p_max), "beta": (p_min, p_max)}
451
+ bounds_transformer = SequentialDomainReductionTransformer(minimum_window=0.1)
452
+ optimizer = BayesianOptimization(
453
+ f=get_smoothness,
454
+ pbounds=pbounds,
455
+ random_state=1,
456
+ bounds_transformer=bounds_transformer,
457
+ allow_duplicate_points=True,
458
+ )
459
+ alpha_init = [p_min, (p_min + p_max) / 2, p_max]
460
+ beta_init = [p_min, (p_min + p_max) / 2, p_max]
461
+
462
+ # Initial probing
463
+ for alpha in alpha_init:
464
+ for beta in beta_init:
465
+ optimizer.probe(params={"alpha": alpha, "beta": beta}, lazy=False)
466
+ latest_result = optimizer.res[-1] # Get the last result
467
+ latest_score = latest_result["target"]
468
+ if latest_score >= target_score:
469
+ return alpha, beta
470
+
471
+ # Start optimization
472
+ for _ in range(n_iter): # Max iterations
473
+ optimizer.maximize(init_points=0, n_iter=1) # One iteration at a time
474
+ max_score = optimizer.max["target"] # Get the highest score so far
475
+ if max_score >= target_score:
476
+ print(f"Stopping early, target of {target_score} reached.")
477
+ break # Exit the loop if target is reached or exceeded
478
+
479
+ results = optimizer.max
480
+ alpha = results["params"]["alpha"]
481
+ beta = results["params"]["beta"]
482
+ return alpha, beta
483
+
484
+
485
+ def generate_beta_tensor(
486
+ size: int, alpha: float = 3, beta: float = 3
487
+ ) -> torch.FloatTensor:
488
+ """
489
+ Assume size as n
490
+ Generates a PyTorch tensor of values [x0, x1, ..., xn-1] for the Beta distribution
491
+ where each xi satisfies F(xi) = i/(n-1) for the CDF F of the Beta distribution.
492
+
493
+ Args:
494
+ size (int): The number of values to generate.
495
+ alpha (float): The alpha parameter of the Beta distribution.
496
+ beta (float): The beta parameter of the Beta distribution.
497
+
498
+ Returns:
499
+ torch.Tensor: A tensor of the inverse CDF values of the Beta distribution.
500
+ """
501
+ # Generating the inverse CDF values
502
+ prob_values = [i / (size - 1) for i in range(size)]
503
+ inverse_cdf_values = beta_distribution.ppf(prob_values, alpha, beta)
504
+
505
+ # Converting to a PyTorch tensor
506
+ return torch.tensor(inverse_cdf_values, dtype=torch.float32)
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.27.2
3
+ addict==2.4.0
4
+ antlr4-python3-runtime==4.9.3
5
+ bayesian-optimization==1.4.3
6
+ clean-fid==0.1.35
7
+ clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
8
+ colorama==0.4.6
9
+ contourpy==1.2.0
10
+ cycler==0.12.1
11
+ diffusers==0.27.1
12
+ einops==0.7.0
13
+ facexlib==0.3.0
14
+ filterpy==1.4.5
15
+ fonttools==4.49.0
16
+ fsspec==2024.2.0
17
+ ftfy==6.1.3
18
+ future==1.0.0
19
+ grpcio==1.62.0
20
+ huggingface-hub==0.20.3
21
+ imageio==2.34.0
22
+ imgaug==0.4.0
23
+ joblib==1.3.2
24
+ kiwisolver==1.4.5
25
+ lazy_loader==0.3
26
+ llvmlite==0.42.0
27
+ lmdb==1.4.1
28
+ lpips==0.1.4
29
+ Markdown==3.5.2
30
+ matplotlib==3.8.3
31
+ mkl-service==2.4.0
32
+ numba==0.59.0
33
+ numpy==1.24.4
34
+ omegaconf==2.3.0
35
+ openai-clip==1.0.1
36
+ opencv-python==4.9.0.80
37
+ pandas==2.2.0
38
+ protobuf==4.25.3
39
+ pyiqa==0.1.10
40
+ pyparsing==3.1.1
41
+ python-dateutil==2.8.2
42
+ pytorch-fid==0.3.0
43
+ pytz==2024.1
44
+ regex==2023.12.25
45
+ safetensors==0.4.2
46
+ scikit-image==0.22.0
47
+ scikit-learn==1.4.1.post1
48
+ scipy==1.9.1
49
+ shapely==2.0.3
50
+ tensorboard==2.16.2
51
+ tensorboard-data-server==0.7.2
52
+ threadpoolctl==3.3.0
53
+ tifffile==2024.2.12
54
+ timm==0.9.16
55
+ tokenizers==0.15.2
56
+ tomli==2.0.1
57
+ torch==2.1.0
58
+ torchmetrics
59
+ torchaudio==2.1.0
60
+ torchvision==0.16.0
61
+ tqdm==4.66.2
62
+ transformers==4.38.2
63
+ triton==2.1.0
64
+ tzdata==2024.1
65
+ Werkzeug==3.0.1
66
+ yapf==0.40.2
style.css ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ justify-content: center;
4
+ }
5
+
6
+ [role="tabpanel"] {
7
+ border: 0
8
+ }
9
+
10
+ #duplicate-button {
11
+ margin: auto;
12
+ color: #fff;
13
+ background: #1565c0;
14
+ border-radius: 100vh;
15
+ }
16
+
17
+ .gradio-container {
18
+ max-width: 690px ! important;
19
+ }
20
+
21
+ .equal-height {
22
+ display: flex;
23
+ flex: 1;
24
+ }
25
+
26
+ .grid-container {
27
+ display: grid;
28
+ grid-template-columns: 1fr 1fr; /* 两列宽度相等 */
29
+ gap: 20px;
30
+ height: 100%; /* 确保容器高度为100% */
31
+ }
32
+
33
+ .grid-item {
34
+ display: flex;
35
+ flex-direction: column;
36
+ height: 100%;
37
+ }
38
+
39
+ .flex-grow {
40
+ flex-grow: 1; /* 使该元素占据剩余的高度 */
41
+ display: flex;
42
+ flex-direction: column;
43
+ }
44
+
45
+ #share-btn-container {
46
+ padding-left: 0.5rem !important;
47
+ padding-right: 0.5rem !important;
48
+ background-color: #000000;
49
+ justify-content: center;
50
+ align-items: center;
51
+ border-radius: 9999px !important;
52
+ max-width: 13rem;
53
+ margin-left: auto;
54
+ margin-top: 0.35em;
55
+ }
56
+
57
+ div#share-btn-container>div {
58
+ flex-direction: row;
59
+ background: black;
60
+ align-items: center
61
+ }
62
+
63
+ #share-btn-container:hover {
64
+ background-color: #060606
65
+ }
66
+
67
+ #share-btn {
68
+ all: initial;
69
+ color: #ffffff;
70
+ font-weight: 600;
71
+ cursor: pointer;
72
+ font-family: 'IBM Plex Sans', sans-serif;
73
+ margin-left: 0.5rem !important;
74
+ padding-top: 0.5rem !important;
75
+ padding-bottom: 0.5rem !important;
76
+ right: 0;
77
+ font-size: 15px;
78
+ }
79
+
80
+ #share-btn * {
81
+ all: unset
82
+ }
83
+
84
+ #share-btn-container div:nth-child(-n+2) {
85
+ width: auto !important;
86
+ min-height: 0px !important;
87
+ }
88
+
89
+ #share-btn-container .wrap {
90
+ display: none !important
91
+ }
92
+
93
+ #share-btn-container.hidden {
94
+ display: none !important
95
+ }
utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from lpips import LPIPS
8
+ from PIL import Image
9
+ from torchvision.transforms import Normalize
10
+
11
+
12
+ def show_images_horizontally(
13
+ list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
14
+ ) -> None:
15
+ """
16
+ Visualize the list of images horizontally and save the figure as PNG.
17
+
18
+ Args:
19
+ list_of_files: The list of images as numpy array with shape (N, H, W, C).
20
+ output_file: The output file path to save the figure as PNG.
21
+ interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
22
+ """
23
+ number_of_files = len(list_of_files)
24
+
25
+ heights = [a[0].shape[0] for a in list_of_files]
26
+ widths = [a.shape[1] for a in list_of_files[0]]
27
+
28
+ fig_width = 8.0 # inches
29
+ fig_height = fig_width * sum(heights) / sum(widths)
30
+
31
+ # Create a figure with subplots
32
+ _, axs = plt.subplots(
33
+ 1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
34
+ )
35
+ plt.tight_layout()
36
+ for i in range(number_of_files):
37
+ _image = list_of_files[i]
38
+ axs[i].imshow(_image)
39
+ axs[i].axis("off")
40
+
41
+ # Save the figure as PNG
42
+ if interact:
43
+ plt.show()
44
+ else:
45
+ plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
46
+
47
+
48
+ def image_grids(images, rows=None, cols=None):
49
+ if not images:
50
+ raise ValueError("The image list is empty.")
51
+
52
+ n_images = len(images)
53
+ if cols is None:
54
+ cols = int(n_images**0.5)
55
+ if rows is None:
56
+ rows = (n_images + cols - 1) // cols
57
+
58
+ width, height = images[0].size
59
+ grid_width = cols * width
60
+ grid_height = rows * height
61
+
62
+ grid_image = Image.new("RGB", (grid_width, grid_height))
63
+
64
+ for i, image in enumerate(images):
65
+ row, col = divmod(i, cols)
66
+ grid_image.paste(image, (col * width, row * height))
67
+
68
+ return grid_image
69
+
70
+
71
+ def save_image(image: np.array, file_name: str) -> None:
72
+ """
73
+ Save the image as JPG.
74
+
75
+ Args:
76
+ image: The input image as numpy array with shape (H, W, C).
77
+ file_name: The file name to save the image.
78
+ """
79
+ image = Image.fromarray(image)
80
+ image.save(file_name)
81
+
82
+
83
+ def load_and_process_images(load_dir: str) -> np.array:
84
+ """
85
+ Load and process the images into numpy array from the directory.
86
+
87
+ Args:
88
+ load_dir: The directory to load the images.
89
+
90
+ Returns:
91
+ images: The images as numpy array with shape (N, H, W, C).
92
+ """
93
+ images = []
94
+ print(load_dir)
95
+ filenames = sorted(
96
+ os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
97
+ ) # Ensure the files are sorted numerically
98
+ for filename in filenames:
99
+ if filename.endswith(".jpg"):
100
+ img = Image.open(os.path.join(load_dir, filename))
101
+ img_array = (
102
+ np.asarray(img) / 255.0
103
+ ) # Convert to numpy array and scale pixel values to [0, 1]
104
+ images.append(img_array)
105
+ return images
106
+
107
+
108
+ def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
109
+ """
110
+ Compute the LPIPS of the input images.
111
+
112
+ Args:
113
+ images: The input images as numpy array with shape (N, H, W, C).
114
+ lpips_model: The LPIPS model used to compute perceptual distances.
115
+
116
+ Returns:
117
+ distances: The LPIPS of the input images.
118
+ """
119
+ # Get device of lpips_model
120
+ device = next(lpips_model.parameters()).device
121
+ device = str(device)
122
+
123
+ # Change the input images into tensor
124
+ images = torch.tensor(images).to(device).float()
125
+ images = torch.permute(images, (0, 3, 1, 2))
126
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127
+ images = normalize(images)
128
+
129
+ # Compute the LPIPS between each adjacent input images
130
+ distances = []
131
+ for i in range(images.shape[0]):
132
+ if i == images.shape[0] - 1:
133
+ break
134
+ img1 = images[i].unsqueeze(0)
135
+ img2 = images[i + 1].unsqueeze(0)
136
+ loss = lpips_model(img1, img2)
137
+ distances.append(loss.item())
138
+ distances = np.array(distances)
139
+ return distances
140
+
141
+
142
+ def compute_gini(distances: np.array) -> float:
143
+ """
144
+ Compute the Gini index of the input distances.
145
+
146
+ Args:
147
+ distances: The input distances as numpy array.
148
+
149
+ Returns:
150
+ gini: The Gini index of the input distances.
151
+ """
152
+ if len(distances) < 2:
153
+ return 0.0 # Gini index is 0 for less than two elements
154
+
155
+ # Sort the list of distances
156
+ sorted_distances = sorted(distances)
157
+ n = len(sorted_distances)
158
+ mean_distance = sum(sorted_distances) / n
159
+
160
+ # Compute the sum of absolute differences
161
+ sum_of_differences = 0
162
+ for di in sorted_distances:
163
+ for dj in sorted_distances:
164
+ sum_of_differences += abs(di - dj)
165
+
166
+ # Normalize the sum of differences by the mean and the number of elements
167
+ gini = sum_of_differences / (2 * n * n * mean_distance)
168
+ return gini
169
+
170
+
171
+ def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
172
+ """
173
+ Compute the smoothness and efficiency of the input images.
174
+
175
+ Args:
176
+ images: The input images as numpy array with shape (N, H, W, C).
177
+ lpips_model: The LPIPS model used to compute perceptual distances.
178
+
179
+ Returns:
180
+ smoothness: One minus gini index of LPIPS of consecutive images.
181
+ consistency: The mean LPIPS of consecutive images.
182
+ max_inception_distance: The maximum LPIPS of consecutive images.
183
+ """
184
+ distances = compute_lpips(images, lpips_model)
185
+ smoothness = 1 - compute_gini(distances)
186
+ consistency = np.mean(distances)
187
+ max_inception_distance = np.max(distances)
188
+ return smoothness, consistency, max_inception_distance
189
+
190
+
191
+ def separate_source_and_interpolated_images(images: np.array) -> tuple:
192
+ """
193
+ Separate the input images into source and interpolated images.
194
+ The input source is the start and end of the images, while the interpolated images are the rest.
195
+
196
+ Args:
197
+ images: The input images as numpy array with shape (N, H, W, C).
198
+
199
+ Returns:
200
+ source: The source images as numpy array with shape (2, H, W, C).
201
+ interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
202
+ """
203
+ # Check if the array has at least two elements
204
+ if len(images) < 2:
205
+ raise ValueError("The input array should have at least two elements.")
206
+
207
+ # Separate the array into two parts
208
+ # First part takes the first and last element
209
+ source = np.array([images[0], images[-1]])
210
+ # Second part takes the rest of the elements
211
+ interpolation = images[1:-1]
212
+ return source, interpolation