Linaqruf commited on
Commit
1f66542
·
1 Parent(s): 3eed5e6

migrating to zero gpu

Browse files
Files changed (9) hide show
  1. .gitattributes +35 -0
  2. README.md +1 -3
  3. app.py +181 -625
  4. config.py +105 -0
  5. lora.toml +0 -28
  6. lora_diffusers.py +0 -478
  7. requirements.txt +6 -7
  8. style.css +4 -30
  9. utils.py +173 -1
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,13 +4,11 @@ emoji: 🌍
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.2.0
8
  app_file: app.py
9
  license: mit
10
  pinned: false
11
  suggested_hardware: a10g-small
12
- duplicated_from: hysts/SD-XL
13
- hf_oauth: true
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.20.0
8
  app_file: app.py
9
  license: mit
10
  pinned: false
11
  suggested_hardware: a10g-small
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,244 +1,71 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
  import os
6
- import random
7
  import gc
8
- import toml
9
  import gradio as gr
10
  import numpy as np
11
- import utils
12
  import torch
13
  import json
14
- import PIL.Image
15
- import base64
16
- import safetensors
17
- from io import BytesIO
18
- from typing import Tuple
19
  from datetime import datetime
20
- from PIL import PngImagePlugin
21
- import gradio_user_history as gr_user_history
22
- from huggingface_hub import hf_hub_download
23
- from safetensors.torch import load_file
24
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
25
- from lora_diffusers import LoRANetwork, create_network_from_weights
26
  from diffusers.models import AutoencoderKL
27
- from diffusers import (
28
- StableDiffusionXLPipeline,
29
- StableDiffusionXLImg2ImgPipeline,
30
- DPMSolverMultistepScheduler,
31
- DPMSolverSinglestepScheduler,
32
- KDPM2DiscreteScheduler,
33
- EulerDiscreteScheduler,
34
- EulerAncestralDiscreteScheduler,
35
- HeunDiscreteScheduler,
36
- LMSDiscreteScheduler,
37
- DDIMScheduler,
38
- DEISMultistepScheduler,
39
- UniPCMultistepScheduler,
40
- )
41
 
42
  DESCRIPTION = "Animagine XL 3.0"
43
  if not torch.cuda.is_available():
44
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
45
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
46
- MAX_SEED = np.iinfo(np.int32).max
47
  HF_TOKEN = os.getenv("HF_TOKEN")
48
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
49
  MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
50
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
51
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
52
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
 
53
 
54
- MODEL = os.getenv("MODEL", "https://huggingface.co/cagliostrolab/animagine-xl-3.0/blob/main/animagine-xl-3.0.safetensors")
 
 
 
55
 
56
  torch.backends.cudnn.deterministic = True
57
  torch.backends.cudnn.benchmark = False
58
 
59
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60
 
61
- if torch.cuda.is_available():
 
62
  vae = AutoencoderKL.from_pretrained(
63
  "madebyollin/sdxl-vae-fp16-fix",
64
  torch_dtype=torch.float16,
65
  )
66
- pipeline = StableDiffusionXLPipeline.from_single_file if MODEL.endswith(".safetensors") else StableDiffusionXLPipeline.from_pretrained
67
-
 
 
 
 
68
  pipe = pipeline(
69
- MODEL,
70
  vae=vae,
71
  torch_dtype=torch.float16,
72
  custom_pipeline="lpw_stable_diffusion_xl",
73
  use_safetensors=True,
 
74
  use_auth_token=HF_TOKEN,
75
  variant="fp16",
76
  )
77
 
78
- if ENABLE_CPU_OFFLOAD:
79
- pipe.enable_model_cpu_offload()
80
- else:
81
- pipe.to(device)
82
- if USE_TORCH_COMPILE:
83
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
84
- else:
85
- pipe = None
86
-
87
-
88
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
89
- if randomize_seed:
90
- seed = random.randint(0, MAX_SEED)
91
- return seed
92
-
93
-
94
- def seed_everything(seed):
95
- torch.manual_seed(seed)
96
- torch.cuda.manual_seed_all(seed)
97
- np.random.seed(seed)
98
- generator = torch.Generator()
99
- generator.manual_seed(seed)
100
- return generator
101
-
102
-
103
- def get_image_path(base_path: str):
104
- extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
105
- for ext in extensions:
106
- image_path = base_path + ext
107
- if os.path.exists(image_path):
108
- return image_path
109
- return None
110
-
111
-
112
- def update_selection(selected_state: gr.SelectData):
113
- lora_repo = sdxl_loras[selected_state.index]["repo"]
114
- lora_weight = sdxl_loras[selected_state.index]["multiplier"]
115
- updated_selected_info = f"{lora_repo}"
116
-
117
- return (
118
- updated_selected_info,
119
- selected_state,
120
- lora_weight,
121
- )
122
-
123
-
124
- def parse_aspect_ratio(aspect_ratio):
125
- if aspect_ratio == "Custom":
126
- return None, None
127
- width, height = aspect_ratio.split(" x ")
128
- return int(width), int(height)
129
-
130
-
131
- def aspect_ratio_handler(aspect_ratio, custom_width, custom_height):
132
- if aspect_ratio == "Custom":
133
- return custom_width, custom_height
134
- else:
135
- width, height = parse_aspect_ratio(aspect_ratio)
136
- return width, height
137
-
138
-
139
- def create_network(text_encoders, unet, state_dict, multiplier, device):
140
- network = create_network_from_weights(
141
- text_encoders,
142
- unet,
143
- state_dict,
144
- multiplier,
145
- )
146
- network.load_state_dict(state_dict)
147
- network.to(device, dtype=unet.dtype)
148
- network.apply_to(multiplier=multiplier)
149
-
150
- return network
151
-
152
-
153
- def get_scheduler(scheduler_config, name):
154
- scheduler_map = {
155
- "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
156
- scheduler_config, use_karras_sigmas=True
157
- ),
158
- "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
159
- scheduler_config, use_karras_sigmas=True
160
- ),
161
- "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
162
- scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
163
- ),
164
- "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
165
- "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
166
- scheduler_config
167
- ),
168
- "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
169
- }
170
- return scheduler_map.get(name, lambda: None)()
171
-
172
-
173
- def free_memory():
174
- torch.cuda.empty_cache()
175
- gc.collect()
176
-
177
-
178
- def preprocess_prompt(
179
- style_dict,
180
- style_name: str,
181
- positive: str,
182
- negative: str = "",
183
- add_style: bool = True,
184
- ) -> Tuple[str, str]:
185
- p, n = style_dict.get(style_name, style_dict["(None)"])
186
-
187
- if add_style and positive.strip():
188
- formatted_positive = p.format(prompt=positive)
189
- else:
190
- formatted_positive = positive
191
-
192
- combined_negative = n + negative
193
- return formatted_positive, combined_negative
194
-
195
-
196
- def common_upscale(samples, width, height, upscale_method):
197
- return torch.nn.functional.interpolate(
198
- samples, size=(height, width), mode=upscale_method
199
- )
200
-
201
-
202
- def upscale(samples, upscale_method, scale_by):
203
- width = round(samples.shape[3] * scale_by)
204
- height = round(samples.shape[2] * scale_by)
205
- s = common_upscale(samples, width, height, upscale_method)
206
- return s
207
-
208
 
209
- def load_and_convert_thumbnail(model_path: str):
210
- with safetensors.safe_open(model_path, framework="pt") as f:
211
- metadata = f.metadata()
212
- if "modelspec.thumbnail" in metadata:
213
- base64_data = metadata["modelspec.thumbnail"]
214
- prefix, encoded = base64_data.split(",", 1)
215
- image_data = base64.b64decode(encoded)
216
- image = PIL.Image.open(BytesIO(image_data))
217
- return image
218
- return None
219
-
220
- def load_wildcard_files(wildcard_dir):
221
- wildcard_files = {}
222
- for file in os.listdir(wildcard_dir):
223
- if file.endswith(".txt"):
224
- key = f"__{file.split('.')[0]}__" # Create a key like __character__
225
- wildcard_files[key] = os.path.join(wildcard_dir, file)
226
- return wildcard_files
227
-
228
- def get_random_line_from_file(file_path):
229
- with open(file_path, 'r') as file:
230
- lines = file.readlines()
231
- if not lines:
232
- return ""
233
- return random.choice(lines).strip()
234
-
235
- def add_wildcard(prompt, wildcard_files):
236
- for key, file_path in wildcard_files.items():
237
- if key in prompt:
238
- wildcard_line = get_random_line_from_file(file_path)
239
- prompt = prompt.replace(key, wildcard_line)
240
- return prompt
241
 
 
242
  def generate(
243
  prompt: str,
244
  negative_prompt: str = "",
@@ -247,90 +74,40 @@ def generate(
247
  custom_height: int = 1024,
248
  guidance_scale: float = 7.0,
249
  num_inference_steps: int = 28,
250
- use_lora: bool = False,
251
- lora_weight: float = 1.0,
252
- selected_state: str = "",
253
  sampler: str = "Euler a",
254
  aspect_ratio_selector: str = "896 x 1152",
255
  style_selector: str = "(None)",
256
  quality_selector: str = "Standard",
257
  use_upscaler: bool = False,
258
- upscaler_strength: float = 0.5,
259
  upscale_by: float = 1.5,
260
  add_quality_tags: bool = True,
261
- profile: gr.OAuthProfile | None = None,
262
  progress=gr.Progress(track_tqdm=True),
263
- ) -> PIL.Image.Image:
264
- generator = seed_everything(seed)
265
 
266
- network = None
267
- network_state = {"current_lora": None, "multiplier": None}
268
-
269
- width, height = aspect_ratio_handler(
270
  aspect_ratio_selector,
271
  custom_width,
272
  custom_height,
273
  )
274
 
275
- prompt = add_wildcard(prompt, wildcard_files)
276
 
277
-
278
- prompt, negative_prompt = preprocess_prompt(
279
  quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
280
  )
281
- prompt, negative_prompt = preprocess_prompt(
282
  styles, style_selector, prompt, negative_prompt
283
  )
284
 
285
- if width % 8 != 0:
286
- width = width - (width % 8)
287
- if height % 8 != 0:
288
- height = height - (height % 8)
289
-
290
- if use_lora:
291
- if not selected_state:
292
- raise Exception("You must Select a LoRA")
293
- repo_name = sdxl_loras[selected_state.index]["repo"]
294
- full_path_lora = saved_names[selected_state.index]
295
- weight_name = sdxl_loras[selected_state.index]["weights"]
296
-
297
- lora_sd = load_file(full_path_lora)
298
- text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
299
-
300
- if network_state["current_lora"] != repo_name:
301
- network = create_network(
302
- text_encoders,
303
- pipe.unet,
304
- lora_sd,
305
- lora_weight,
306
- device,
307
- )
308
- network_state["current_lora"] = repo_name
309
- network_state["multiplier"] = lora_weight
310
- elif network_state["multiplier"] != lora_weight:
311
- network = create_network(
312
- text_encoders,
313
- pipe.unet,
314
- lora_sd,
315
- lora_weight,
316
- device,
317
- )
318
- network_state["multiplier"] = lora_weight
319
- else:
320
- if network:
321
- network.unapply_to()
322
- network = None
323
- network_state = {
324
- "current_lora": None,
325
- "multiplier": None,
326
- }
327
 
328
  backup_scheduler = pipe.scheduler
329
- pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
330
 
331
  if use_upscaler:
332
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
333
-
334
  metadata = {
335
  "prompt": prompt,
336
  "negative_prompt": negative_prompt,
@@ -344,11 +121,6 @@ def generate(
344
  "quality_tags": quality_selector,
345
  }
346
 
347
- if use_lora:
348
- metadata["use_lora"] = {"selected_lora": repo_name, "multiplier": lora_weight}
349
- else:
350
- metadata["use_lora"] = None
351
-
352
  if use_upscaler:
353
  new_width = int(width * upscale_by)
354
  new_height = int(height * upscale_by)
@@ -360,8 +132,7 @@ def generate(
360
  }
361
  else:
362
  metadata["use_upscaler"] = None
363
-
364
- print(json.dumps(metadata, indent=4))
365
 
366
  try:
367
  if use_upscaler:
@@ -375,8 +146,8 @@ def generate(
375
  generator=generator,
376
  output_type="latent",
377
  ).images
378
- upscaled_latents = upscale(latents, "nearest-exact", upscale_by)
379
- image = upscaler_pipe(
380
  prompt=prompt,
381
  negative_prompt=negative_prompt,
382
  image=upscaled_latents,
@@ -385,9 +156,9 @@ def generate(
385
  strength=upscaler_strength,
386
  generator=generator,
387
  output_type="pil",
388
- ).images[0]
389
  else:
390
- image = pipe(
391
  prompt=prompt,
392
  negative_prompt=negative_prompt,
393
  width=width,
@@ -396,194 +167,38 @@ def generate(
396
  num_inference_steps=num_inference_steps,
397
  generator=generator,
398
  output_type="pil",
399
- ).images[0]
400
- if network:
401
- network.unapply_to()
402
- network = None
403
- if profile is not None:
404
- gr_user_history.save_image(
405
- label=prompt,
406
- image=image,
407
- profile=profile,
408
- metadata=metadata,
409
- )
410
- if image and IS_COLAB:
411
- current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
412
- output_directory = "./outputs"
413
- os.makedirs(output_directory, exist_ok=True)
414
- filename = f"image_{current_time}.png"
415
- filepath = os.path.join(output_directory, filename)
416
-
417
- # Convert metadata to a string and save as a text chunk in the PNG
418
- metadata_str = json.dumps(metadata)
419
- info = PngImagePlugin.PngInfo()
420
- info.add_text("metadata", metadata_str)
421
- image.save(filepath, "PNG", pnginfo=info)
422
- print(f"Image saved as {filepath} with metadata")
423
 
424
- return image, metadata
 
 
 
425
 
 
426
  except Exception as e:
427
- print(f"An error occurred: {e}")
428
  raise
429
  finally:
430
- if network:
431
- network.unapply_to()
432
- network = None
433
- if use_lora:
434
- del lora_sd, text_encoders
435
  if use_upscaler:
436
  del upscaler_pipe
437
  pipe.scheduler = backup_scheduler
438
- free_memory()
439
-
440
-
441
- examples = [
442
- "1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
443
- "1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
444
- "1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
445
- "1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
446
- "A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
447
- ]
448
-
449
- quality_prompt_list = [
450
- {
451
- "name": "(None)",
452
- "prompt": "{prompt}",
453
- "negative_prompt": "nsfw, lowres, ",
454
- },
455
- {
456
- "name": "Standard",
457
- "prompt": "{prompt}, masterpiece, best quality",
458
- "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
459
- },
460
- {
461
- "name": "Light",
462
- "prompt": "{prompt}, (masterpiece), best quality, perfect face",
463
- "negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
464
- },
465
- {
466
- "name": "Heavy",
467
- "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
468
- "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
469
- },
470
- ]
471
-
472
- sampler_list = [
473
- "DPM++ 2M Karras",
474
- "DPM++ SDE Karras",
475
- "DPM++ 2M SDE Karras",
476
- "Euler",
477
- "Euler a",
478
- "DDIM",
479
- ]
480
-
481
- aspect_ratios = [
482
- "1024 x 1024",
483
- "1152 x 896",
484
- "896 x 1152",
485
- "1216 x 832",
486
- "832 x 1216",
487
- "1344 x 768",
488
- "768 x 1344",
489
- "1536 x 640",
490
- "640 x 1536",
491
- "Custom",
492
- ]
493
-
494
- style_list = [
495
- {
496
- "name": "(None)",
497
- "prompt": "{prompt}",
498
- "negative_prompt": "",
499
- },
500
- {
501
- "name": "Cinematic",
502
- "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
503
- "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
504
- },
505
- {
506
- "name": "Photographic",
507
- "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
508
- "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
509
- },
510
- {
511
- "name": "Anime",
512
- "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
513
- "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
514
- },
515
- {
516
- "name": "Manga",
517
- "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
518
- "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
519
- },
520
- {
521
- "name": "Digital Art",
522
- "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
523
- "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
524
- },
525
- {
526
- "name": "Pixel art",
527
- "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
528
- "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
529
- },
530
- {
531
- "name": "Fantasy art",
532
- "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
533
- "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
534
- },
535
- {
536
- "name": "Neonpunk",
537
- "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
538
- "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
539
- },
540
- {
541
- "name": "3D Model",
542
- "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
543
- "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
544
- },
545
- ]
546
-
547
- thumbnail_cache = {}
548
 
549
- with open("lora.toml", "r") as file:
550
- data = toml.load(file)
551
 
552
- sdxl_loras = []
553
- saved_names = []
554
- for item in data["data"]:
555
- model_path = hf_hub_download(item["repo"], item["weights"], token=HF_TOKEN)
556
- saved_names.append(model_path) # Store the path in saved_names
557
-
558
- if model_path not in thumbnail_cache:
559
- thumbnail_image = load_and_convert_thumbnail(model_path)
560
- thumbnail_cache[model_path] = thumbnail_image
561
- else:
562
- thumbnail_image = thumbnail_cache[model_path]
563
-
564
- sdxl_loras.append(
565
- {
566
- "image": thumbnail_image, # Storing the PIL image object
567
- "title": item["title"],
568
- "repo": item["repo"],
569
- "weights": item["weights"],
570
- "multiplier": item.get("multiplier", "1.0"),
571
- }
572
- )
573
 
574
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
575
  quality_prompt = {
576
- k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list
577
  }
578
 
579
- # saved_names = [
580
- # hf_hub_download(item["repo"], item["weights"], token=HF_TOKEN)
581
- # for item in sdxl_loras
582
- # ]
583
-
584
- wildcard_files = load_wildcard_files("wildcard")
585
 
586
- with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
587
  title = gr.HTML(
588
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
589
  elem_id="title",
@@ -592,187 +207,131 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
592
  f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
593
  elem_id="subtitle",
594
  )
595
- gr.Markdown(
596
- f"""Prompting is a bit different in this iteration, we train the model like this:
597
- ```
598
- 1girl/1boy, character name, from what series, everything else in any order.
599
- ```
600
- Prompting Tips
601
- ```
602
- 1. Quality Tags: `masterpiece, best quality, high quality, normal quality, worst quality, low quality`
603
- 2. Year Tags: `oldest, early, mid, late, newest`
604
- 3. Rating tags: `rating: general, rating: sensitive, rating: questionable, rating: explicit, nsfw`
605
- 4. Escape character: `character name \(series\)`
606
- 5. Recommended settings: `Euler a, cfg 5-7, 25-28 steps`
607
- 6. It's recommended to use the exact danbooru tags for more accurate result
608
- 7. To use character wildcard, add this syntax to the prompt `__character__`.
609
- ```
610
- """,
611
- elem_id="subtitle",
612
- )
613
  gr.DuplicateButton(
614
  value="Duplicate Space for private use",
615
  elem_id="duplicate-button",
616
  visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
617
  )
618
- selected_state = gr.State()
619
- with gr.Row():
620
- with gr.Column(scale=2):
621
- with gr.Tab("Txt2img"):
622
- with gr.Group():
623
- prompt = gr.Text(
624
- label="Prompt",
625
- max_lines=5,
626
- placeholder="Enter your prompt",
627
- )
628
- negative_prompt = gr.Text(
629
- label="Negative Prompt",
630
- max_lines=5,
631
- placeholder="Enter a negative prompt",
632
- )
633
- with gr.Accordion(label="Quality Tags", open=True):
634
- add_quality_tags = gr.Checkbox(label="Add Quality Tags", value=True)
635
- quality_selector = gr.Dropdown(
636
- label="Quality Tags Presets",
637
- interactive=True,
638
- choices=list(quality_prompt.keys()),
639
- value="Standard",
640
- )
641
- with gr.Row():
642
- use_lora = gr.Checkbox(label="Use LoRA", value=False)
643
- with gr.Group(visible=False) as lora_group:
644
- selector_info = gr.Text(
645
- label="Selected LoRA",
646
- max_lines=1,
647
- value="No LoRA selected.",
648
- )
649
- lora_selection = gr.Gallery(
650
- value=[(item["image"], item["title"]) for item in sdxl_loras],
651
- label="Animagine XL 2.0 LoRA",
652
- show_label=False,
653
- columns=2,
654
- show_share_button=False,
655
- )
656
- lora_weight = gr.Slider(
657
- label="Multiplier",
658
- minimum=-2,
659
- maximum=2,
660
- step=0.05,
661
- value=1,
662
- )
663
- with gr.Tab("Advanced Settings"):
664
- with gr.Group():
665
- style_selector = gr.Radio(
666
- label="Style Preset",
667
- container=True,
668
- interactive=True,
669
- choices=list(styles.keys()),
670
- value="(None)",
671
- )
672
- with gr.Group():
673
- aspect_ratio_selector = gr.Radio(
674
- label="Aspect Ratio",
675
- choices=aspect_ratios,
676
- value="896 x 1152",
677
- container=True,
678
- )
679
- with gr.Group():
680
- use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
681
- with gr.Row() as upscaler_row:
682
- upscaler_strength = gr.Slider(
683
- label="Strength",
684
- minimum=0,
685
- maximum=1,
686
- step=0.05,
687
- value=0.55,
688
- visible=False,
689
- )
690
- upscale_by = gr.Slider(
691
- label="Upscale by",
692
- minimum=1,
693
- maximum=1.5,
694
- step=0.1,
695
- value=1.5,
696
- visible=False,
697
- )
698
- with gr.Group(visible=False) as custom_resolution:
699
- with gr.Row():
700
- custom_width = gr.Slider(
701
- label="Width",
702
- minimum=MIN_IMAGE_SIZE,
703
- maximum=MAX_IMAGE_SIZE,
704
- step=8,
705
- value=1024,
706
- )
707
- custom_height = gr.Slider(
708
- label="Height",
709
- minimum=MIN_IMAGE_SIZE,
710
- maximum=MAX_IMAGE_SIZE,
711
- step=8,
712
- value=1024,
713
- )
714
- with gr.Group():
715
- sampler = gr.Dropdown(
716
- label="Sampler",
717
- choices=sampler_list,
718
- interactive=True,
719
- value="Euler a",
720
- )
721
- with gr.Group():
722
- seed = gr.Slider(
723
- label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0
724
- )
725
-
726
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
727
- with gr.Group():
728
- with gr.Row():
729
- guidance_scale = gr.Slider(
730
- label="Guidance scale",
731
- minimum=1,
732
- maximum=12,
733
- step=0.1,
734
- value=7.0,
735
- )
736
- num_inference_steps = gr.Slider(
737
- label="Number of inference steps",
738
- minimum=1,
739
- maximum=50,
740
- step=1,
741
- value=28,
742
- )
743
-
744
- with gr.Tab("Past Generation"):
745
- gr_user_history.render()
746
- with gr.Column(scale=3):
747
- with gr.Blocks():
748
- run_button = gr.Button("Generate", variant="primary")
749
- result = gr.Image(label="Result", show_label=False)
750
- with gr.Accordion(label="Generation Parameters", open=False):
751
- gr_metadata = gr.JSON(label="Metadata", show_label=False)
752
- gr.Examples(
753
- examples=examples,
754
- inputs=prompt,
755
- outputs=[result, gr_metadata],
756
- fn=generate,
757
- cache_examples=CACHE_EXAMPLES,
758
  )
759
 
760
- lora_selection.select(
761
- update_selection,
762
- outputs=[
763
- selector_info,
764
- selected_state,
765
- lora_weight,
766
- ],
767
- queue=False,
768
- show_progress=False,
769
- )
770
- use_lora.change(
771
- fn=lambda x: gr.update(visible=x),
772
- inputs=use_lora,
773
- outputs=lora_group,
774
- queue=False,
775
- api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  )
777
  use_upscaler.change(
778
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
@@ -797,9 +356,6 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
797
  custom_height,
798
  guidance_scale,
799
  num_inference_steps,
800
- use_lora,
801
- lora_weight,
802
- selected_state,
803
  sampler,
804
  aspect_ratio_selector,
805
  style_selector,
@@ -807,11 +363,11 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
807
  use_upscaler,
808
  upscaler_strength,
809
  upscale_by,
810
- add_quality_tags
811
  ]
812
 
813
  prompt.submit(
814
- fn=randomize_seed_fn,
815
  inputs=[seed, randomize_seed],
816
  outputs=seed,
817
  queue=False,
@@ -823,7 +379,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
823
  api_name="run",
824
  )
825
  negative_prompt.submit(
826
- fn=randomize_seed_fn,
827
  inputs=[seed, randomize_seed],
828
  outputs=seed,
829
  queue=False,
@@ -835,7 +391,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
835
  api_name=False,
836
  )
837
  run_button.click(
838
- fn=randomize_seed_fn,
839
  inputs=[seed, randomize_seed],
840
  outputs=seed,
841
  queue=False,
@@ -846,4 +402,4 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
846
  outputs=[result, gr_metadata],
847
  api_name=False,
848
  )
849
- demo.queue(max_size=30).launch(debug=IS_COLAB, share=IS_COLAB)
 
 
 
 
 
1
  import os
 
2
  import gc
 
3
  import gradio as gr
4
  import numpy as np
 
5
  import torch
6
  import json
7
+ import spaces
8
+ import config
9
+ import utils
10
+ import logging
11
+ from PIL import Image, PngImagePlugin
12
  from datetime import datetime
 
 
 
 
 
 
13
  from diffusers.models import AutoencoderKL
14
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
18
 
19
  DESCRIPTION = "Animagine XL 3.0"
20
  if not torch.cuda.is_available():
21
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
22
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
 
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
25
  MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
26
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
 
31
+ MODEL = os.getenv(
32
+ "MODEL",
33
+ "https://huggingface.co/cagliostrolab/animagine-xl-3.0/blob/main/animagine-xl-3.0.safetensors",
34
+ )
35
 
36
  torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
+
42
+ def load_pipeline(model_name):
43
  vae = AutoencoderKL.from_pretrained(
44
  "madebyollin/sdxl-vae-fp16-fix",
45
  torch_dtype=torch.float16,
46
  )
47
+ pipeline = (
48
+ StableDiffusionXLPipeline.from_single_file
49
+ if MODEL.endswith(".safetensors")
50
+ else StableDiffusionXLPipeline.from_pretrained
51
+ )
52
+
53
  pipe = pipeline(
54
+ model_name,
55
  vae=vae,
56
  torch_dtype=torch.float16,
57
  custom_pipeline="lpw_stable_diffusion_xl",
58
  use_safetensors=True,
59
+ add_watermarker=False,
60
  use_auth_token=HF_TOKEN,
61
  variant="fp16",
62
  )
63
 
64
+ pipe.to(device)
65
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ @spaces.GPU
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
 
74
  custom_height: int = 1024,
75
  guidance_scale: float = 7.0,
76
  num_inference_steps: int = 28,
 
 
 
77
  sampler: str = "Euler a",
78
  aspect_ratio_selector: str = "896 x 1152",
79
  style_selector: str = "(None)",
80
  quality_selector: str = "Standard",
81
  use_upscaler: bool = False,
82
+ upscaler_strength: float = 0.55,
83
  upscale_by: float = 1.5,
84
  add_quality_tags: bool = True,
 
85
  progress=gr.Progress(track_tqdm=True),
86
+ ) -> Image:
87
+ generator = utils.seed_everything(seed)
88
 
89
+ width, height = utils.aspect_ratio_handler(
 
 
 
90
  aspect_ratio_selector,
91
  custom_width,
92
  custom_height,
93
  )
94
 
95
+ prompt = utils.add_wildcard(prompt, wildcard_files)
96
 
97
+ prompt, negative_prompt = utils.preprocess_prompt(
 
98
  quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
99
  )
100
+ prompt, negative_prompt = utils.preprocess_prompt(
101
  styles, style_selector, prompt, negative_prompt
102
  )
103
 
104
+ width, height = utils.preprocess_image_dimensions(width, height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  backup_scheduler = pipe.scheduler
107
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
108
 
109
  if use_upscaler:
110
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
 
111
  metadata = {
112
  "prompt": prompt,
113
  "negative_prompt": negative_prompt,
 
121
  "quality_tags": quality_selector,
122
  }
123
 
 
 
 
 
 
124
  if use_upscaler:
125
  new_width = int(width * upscale_by)
126
  new_height = int(height * upscale_by)
 
132
  }
133
  else:
134
  metadata["use_upscaler"] = None
135
+ logger.info(json.dumps(metadata, indent=4))
 
136
 
137
  try:
138
  if use_upscaler:
 
146
  generator=generator,
147
  output_type="latent",
148
  ).images
149
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
150
+ images = upscaler_pipe(
151
  prompt=prompt,
152
  negative_prompt=negative_prompt,
153
  image=upscaled_latents,
 
156
  strength=upscaler_strength,
157
  generator=generator,
158
  output_type="pil",
159
+ ).images
160
  else:
161
+ images = pipe(
162
  prompt=prompt,
163
  negative_prompt=negative_prompt,
164
  width=width,
 
167
  num_inference_steps=num_inference_steps,
168
  generator=generator,
169
  output_type="pil",
170
+ ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ if images and IS_COLAB:
173
+ for image in images:
174
+ filepath = utils.save_image(image, metadata, OUTPUT_DIR)
175
+ logger.info(f"Image saved as {filepath} with metadata")
176
 
177
+ return images, metadata
178
  except Exception as e:
179
+ logger.exception(f"An error occurred: {e}")
180
  raise
181
  finally:
 
 
 
 
 
182
  if use_upscaler:
183
  del upscaler_pipe
184
  pipe.scheduler = backup_scheduler
185
+ utils.free_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
 
187
 
188
+ if torch.cuda.is_available():
189
+ pipe = load_pipeline(MODEL)
190
+ logger.info("Loaded on Device!")
191
+ else:
192
+ pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
195
  quality_prompt = {
196
+ k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.quality_prompt_list
197
  }
198
 
199
+ wildcard_files = utils.load_wildcard_files("wildcard")
 
 
 
 
 
200
 
201
+ with gr.Blocks(css="style.css") as demo:
202
  title = gr.HTML(
203
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
204
  elem_id="title",
 
207
  f"""Gradio demo for [cagliostrolab/animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0)""",
208
  elem_id="subtitle",
209
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  gr.DuplicateButton(
211
  value="Duplicate Space for private use",
212
  elem_id="duplicate-button",
213
  visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
214
  )
215
+ with gr.Group():
216
+ with gr.Row():
217
+ prompt = gr.Text(
218
+ label="Prompt",
219
+ show_label=False,
220
+ max_lines=5,
221
+ placeholder="Enter your prompt",
222
+ container=False,
223
+ )
224
+ run_button = gr.Button(
225
+ "Generate",
226
+ variant="primary",
227
+ scale=0
228
+ )
229
+ result = gr.Gallery(
230
+ label="Result",
231
+ columns=1,
232
+ preview=True,
233
+ show_label=False
234
+ )
235
+ with gr.Accordion(label="Advanced Settings", open=False):
236
+ negative_prompt = gr.Text(
237
+ label="Negative Prompt",
238
+ max_lines=5,
239
+ placeholder="Enter a negative prompt",
240
+ )
241
+ with gr.Row():
242
+ add_quality_tags = gr.Checkbox(
243
+ label="Add Quality Tags",
244
+ value=True
245
+ )
246
+ quality_selector = gr.Dropdown(
247
+ label="Quality Tags Presets",
248
+ interactive=True,
249
+ choices=list(quality_prompt.keys()),
250
+ value="Standard",
251
+ )
252
+ style_selector = gr.Radio(
253
+ label="Style Preset",
254
+ container=True,
255
+ interactive=True,
256
+ choices=list(styles.keys()),
257
+ value="(None)",
258
+ )
259
+ aspect_ratio_selector = gr.Radio(
260
+ label="Aspect Ratio",
261
+ choices=config.aspect_ratios,
262
+ value="896 x 1152",
263
+ container=True,
264
+ )
265
+ with gr.Group(visible=False) as custom_resolution:
266
+ with gr.Row():
267
+ custom_width = gr.Slider(
268
+ label="Width",
269
+ minimum=MIN_IMAGE_SIZE,
270
+ maximum=MAX_IMAGE_SIZE,
271
+ step=8,
272
+ value=1024,
273
+ )
274
+ custom_height = gr.Slider(
275
+ label="Height",
276
+ minimum=MIN_IMAGE_SIZE,
277
+ maximum=MAX_IMAGE_SIZE,
278
+ step=8,
279
+ value=1024,
280
+ )
281
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
282
+ with gr.Row() as upscaler_row:
283
+ upscaler_strength = gr.Slider(
284
+ label="Strength",
285
+ minimum=0,
286
+ maximum=1,
287
+ step=0.05,
288
+ value=0.55,
289
+ visible=False,
290
+ )
291
+ upscale_by = gr.Slider(
292
+ label="Upscale by",
293
+ minimum=1,
294
+ maximum=1.5,
295
+ step=0.1,
296
+ value=1.5,
297
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
299
 
300
+ sampler = gr.Dropdown(
301
+ label="Sampler",
302
+ choices=config.sampler_list,
303
+ interactive=True,
304
+ value="Euler a",
305
+ )
306
+ with gr.Row():
307
+ seed = gr.Slider(
308
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
309
+ )
310
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
311
+ with gr.Group():
312
+ with gr.Row():
313
+ guidance_scale = gr.Slider(
314
+ label="Guidance scale",
315
+ minimum=1,
316
+ maximum=12,
317
+ step=0.1,
318
+ value=7.0,
319
+ )
320
+ num_inference_steps = gr.Slider(
321
+ label="Number of inference steps",
322
+ minimum=1,
323
+ maximum=50,
324
+ step=1,
325
+ value=28,
326
+ )
327
+ with gr.Accordion(label="Generation Parameters", open=False):
328
+ gr_metadata = gr.JSON(label="Metadata", show_label=False)
329
+ gr.Examples(
330
+ examples=config.examples,
331
+ inputs=prompt,
332
+ outputs=[result, gr_metadata],
333
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
334
+ cache_examples=CACHE_EXAMPLES,
335
  )
336
  use_upscaler.change(
337
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
 
356
  custom_height,
357
  guidance_scale,
358
  num_inference_steps,
 
 
 
359
  sampler,
360
  aspect_ratio_selector,
361
  style_selector,
 
363
  use_upscaler,
364
  upscaler_strength,
365
  upscale_by,
366
+ add_quality_tags,
367
  ]
368
 
369
  prompt.submit(
370
+ fn=utils.randomize_seed_fn,
371
  inputs=[seed, randomize_seed],
372
  outputs=seed,
373
  queue=False,
 
379
  api_name="run",
380
  )
381
  negative_prompt.submit(
382
+ fn=utils.randomize_seed_fn,
383
  inputs=[seed, randomize_seed],
384
  outputs=seed,
385
  queue=False,
 
391
  api_name=False,
392
  )
393
  run_button.click(
394
+ fn=utils.randomize_seed_fn,
395
  inputs=[seed, randomize_seed],
396
  outputs=seed,
397
  queue=False,
 
402
  outputs=[result, gr_metadata],
403
  api_name=False,
404
  )
405
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
config.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ examples = [
2
+ "1girl, arima kana, oshi no ko, solo, idol, idol clothes, one eye closed, red shirt, black skirt, black headwear, gloves, stage light, singing, open mouth, crowd, smile, pointing at viewer",
3
+ "1girl, c.c., code geass, white shirt, long sleeves, turtleneck, sitting, looking at viewer, eating, pizza, plate, fork, knife, table, chair, table, restaurant, cinematic angle, cinematic lighting",
4
+ "1girl, sakurauchi riko, \(love live\), queen hat, noble coat, red coat, noble shirt, sitting, crossed legs, gentle smile, parted lips, throne, cinematic angle",
5
+ "1girl, amiya \(arknights\), arknights, dirty face, outstretched hand, close-up, cinematic angle, foreshortening, dark, dark background",
6
+ "A boy and a girl, Emiya Shirou and Artoria Pendragon from fate series, having their breakfast in the dining room. Emiya Shirou wears white t-shirt and jacket. Artoria Pendragon wears white dress with blue neck ribbon. Rice, soup, and minced meats are served on the table. They look at each other while smiling happily",
7
+ ]
8
+
9
+ quality_prompt_list = [
10
+ {
11
+ "name": "(None)",
12
+ "prompt": "{prompt}",
13
+ "negative_prompt": "nsfw, lowres, ",
14
+ },
15
+ {
16
+ "name": "Standard",
17
+ "prompt": "{prompt}, masterpiece, best quality",
18
+ "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, ",
19
+ },
20
+ {
21
+ "name": "Light",
22
+ "prompt": "{prompt}, (masterpiece), best quality, perfect face",
23
+ "negative_prompt": "nsfw, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, ",
24
+ },
25
+ {
26
+ "name": "Heavy",
27
+ "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), illustration, disheveled hair, perfect composition, moist skin, intricate details, earrings",
28
+ "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, ",
29
+ },
30
+ ]
31
+
32
+ sampler_list = [
33
+ "DPM++ 2M Karras",
34
+ "DPM++ SDE Karras",
35
+ "DPM++ 2M SDE Karras",
36
+ "Euler",
37
+ "Euler a",
38
+ "DDIM",
39
+ ]
40
+
41
+ aspect_ratios = [
42
+ "1024 x 1024",
43
+ "1152 x 896",
44
+ "896 x 1152",
45
+ "1216 x 832",
46
+ "832 x 1216",
47
+ "1344 x 768",
48
+ "768 x 1344",
49
+ "1536 x 640",
50
+ "640 x 1536",
51
+ "Custom",
52
+ ]
53
+
54
+ style_list = [
55
+ {
56
+ "name": "(None)",
57
+ "prompt": "{prompt}",
58
+ "negative_prompt": "",
59
+ },
60
+ {
61
+ "name": "Cinematic",
62
+ "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
63
+ "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
64
+ },
65
+ {
66
+ "name": "Photographic",
67
+ "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
68
+ "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
69
+ },
70
+ {
71
+ "name": "Anime",
72
+ "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
73
+ "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
74
+ },
75
+ {
76
+ "name": "Manga",
77
+ "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
78
+ "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
79
+ },
80
+ {
81
+ "name": "Digital Art",
82
+ "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
83
+ "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
84
+ },
85
+ {
86
+ "name": "Pixel art",
87
+ "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
88
+ "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
89
+ },
90
+ {
91
+ "name": "Fantasy art",
92
+ "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
93
+ "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
94
+ },
95
+ {
96
+ "name": "Neonpunk",
97
+ "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
98
+ "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
99
+ },
100
+ {
101
+ "name": "3D Model",
102
+ "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
103
+ "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
104
+ },
105
+ ]
lora.toml DELETED
@@ -1,28 +0,0 @@
1
- [[data]]
2
- title = "Style Enhancer XL"
3
- repo = "Linaqruf/style-enhancer-xl-lora"
4
- weights = "style-enhancer-xl.safetensors"
5
- multiplier = 0.6
6
- [[data]]
7
- title = "Anime Detailer XL"
8
- repo = "Linaqruf/anime-detailer-xl-lora"
9
- weights = "anime-detailer-xl.safetensors"
10
- multiplier = 2.0
11
-
12
- [[data]]
13
- title = "Sketch Style XL"
14
- repo = "Linaqruf/sketch-style-xl-lora"
15
- weights = "sketch-style-xl.safetensors"
16
- multiplier = 0.6
17
-
18
- [[data]]
19
- title = "Pastel Style XL 2.0"
20
- repo = "Linaqruf/pastel-style-xl-lora"
21
- weights = "pastel-style-xl-v2.safetensors"
22
- multiplier = 0.6
23
-
24
- [[data]]
25
- title = "Anime Nouveau XL"
26
- repo = "Linaqruf/anime-nouveau-xl-lora"
27
- weights = "anime-nouveau-xl.safetensors"
28
- multiplier = 0.6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lora_diffusers.py DELETED
@@ -1,478 +0,0 @@
1
- """
2
- LoRA module for Diffusers
3
- ==========================
4
-
5
- This file works independently and is designed to operate with Diffusers.
6
-
7
- Credits
8
- -------
9
- - Modified from: https://github.com/vladmandic/automatic/blob/master/modules/lora_diffusers.py
10
- - Originally from: https://github.com/kohya-ss/sd-scripts/blob/sdxl/networks/lora_diffusers.py
11
- """
12
-
13
- import bisect
14
- import math
15
- import random
16
- from typing import Any, Dict, List, Mapping, Optional, Union
17
- from diffusers import UNet2DConditionModel
18
- import numpy as np
19
- from tqdm import tqdm
20
- from transformers import CLIPTextModel
21
- import torch
22
-
23
-
24
- def make_unet_conversion_map() -> Dict[str, str]:
25
- unet_conversion_map_layer = []
26
-
27
- for i in range(3): # num_blocks is 3 in sdxl
28
- # loop over downblocks/upblocks
29
- for j in range(2):
30
- # loop over resnets/attentions for downblocks
31
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
32
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
33
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
34
-
35
- if i < 3:
36
- # no attention layers in down_blocks.3
37
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
38
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
39
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
40
-
41
- for j in range(3):
42
- # loop over resnets/attentions for upblocks
43
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
44
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
45
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
46
-
47
- # if i > 0: commentout for sdxl
48
- # no attention layers in up_blocks.0
49
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
50
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
51
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
52
-
53
- if i < 3:
54
- # no downsample in down_blocks.3
55
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
56
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
57
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
58
-
59
- # no upsample in up_blocks.3
60
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
61
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
62
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
63
-
64
- hf_mid_atn_prefix = "mid_block.attentions.0."
65
- sd_mid_atn_prefix = "middle_block.1."
66
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
67
-
68
- for j in range(2):
69
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
70
- sd_mid_res_prefix = f"middle_block.{2*j}."
71
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
72
-
73
- unet_conversion_map_resnet = [
74
- # (stable-diffusion, HF Diffusers)
75
- ("in_layers.0.", "norm1."),
76
- ("in_layers.2.", "conv1."),
77
- ("out_layers.0.", "norm2."),
78
- ("out_layers.3.", "conv2."),
79
- ("emb_layers.1.", "time_emb_proj."),
80
- ("skip_connection.", "conv_shortcut."),
81
- ]
82
-
83
- unet_conversion_map = []
84
- for sd, hf in unet_conversion_map_layer:
85
- if "resnets" in hf:
86
- for sd_res, hf_res in unet_conversion_map_resnet:
87
- unet_conversion_map.append((sd + sd_res, hf + hf_res))
88
- else:
89
- unet_conversion_map.append((sd, hf))
90
-
91
- for j in range(2):
92
- hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
93
- sd_time_embed_prefix = f"time_embed.{j*2}."
94
- unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
95
-
96
- for j in range(2):
97
- hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
98
- sd_label_embed_prefix = f"label_emb.0.{j*2}."
99
- unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
100
-
101
- unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
102
- unet_conversion_map.append(("out.0.", "conv_norm_out."))
103
- unet_conversion_map.append(("out.2.", "conv_out."))
104
-
105
- sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
106
- return sd_hf_conversion_map
107
-
108
-
109
- UNET_CONVERSION_MAP = make_unet_conversion_map()
110
-
111
-
112
- class LoRAModule(torch.nn.Module):
113
- """
114
- replaces forward method of the original Linear, instead of replacing the original Linear module.
115
- """
116
-
117
- def __init__(
118
- self,
119
- lora_name,
120
- org_module: torch.nn.Module,
121
- multiplier=1.0,
122
- lora_dim=4,
123
- alpha=1,
124
- ):
125
- """if alpha == 0 or None, alpha is rank (no scaling)."""
126
- super().__init__()
127
- self.lora_name = lora_name
128
-
129
- if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
130
- in_dim = org_module.in_channels
131
- out_dim = org_module.out_channels
132
- else:
133
- in_dim = org_module.in_features
134
- out_dim = org_module.out_features
135
-
136
- self.lora_dim = lora_dim
137
-
138
- if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
139
- kernel_size = org_module.kernel_size
140
- stride = org_module.stride
141
- padding = org_module.padding
142
- self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
143
- self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
144
- else:
145
- self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
146
- self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
147
-
148
- if type(alpha) == torch.Tensor:
149
- alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
150
- alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
151
- self.scale = alpha / self.lora_dim
152
- self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
153
-
154
- # same as microsoft's
155
- torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
156
- torch.nn.init.zeros_(self.lora_up.weight)
157
-
158
- self.multiplier = multiplier
159
- self.org_module = [org_module]
160
- self.enabled = True
161
- self.network: LoRANetwork = None
162
- self.org_forward = None
163
-
164
- # override org_module's forward method
165
- def apply_to(self, multiplier=None):
166
- if multiplier is not None:
167
- self.multiplier = multiplier
168
- if self.org_forward is None:
169
- self.org_forward = self.org_module[0].forward
170
- self.org_module[0].forward = self.forward
171
-
172
- # restore org_module's forward method
173
- def unapply_to(self):
174
- if self.org_forward is not None:
175
- self.org_module[0].forward = self.org_forward
176
-
177
- # forward with lora
178
- # scale is used LoRACompatibleConv, but we ignore it because we have multiplier
179
- def forward(self, x, scale=1.0):
180
- if not self.enabled:
181
- return self.org_forward(x)
182
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
183
-
184
- def set_network(self, network):
185
- self.network = network
186
-
187
- # merge lora weight to org weight
188
- def merge_to(self, multiplier=1.0):
189
- # get lora weight
190
- lora_weight = self.get_weight(multiplier)
191
-
192
- # get org weight
193
- org_sd = self.org_module[0].state_dict()
194
- org_weight = org_sd["weight"]
195
- weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
196
-
197
- # set weight to org_module
198
- org_sd["weight"] = weight
199
- self.org_module[0].load_state_dict(org_sd)
200
-
201
- # restore org weight from lora weight
202
- def restore_from(self, multiplier=1.0):
203
- # get lora weight
204
- lora_weight = self.get_weight(multiplier)
205
-
206
- # get org weight
207
- org_sd = self.org_module[0].state_dict()
208
- org_weight = org_sd["weight"]
209
- weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
210
-
211
- # set weight to org_module
212
- org_sd["weight"] = weight
213
- self.org_module[0].load_state_dict(org_sd)
214
-
215
- # return lora weight
216
- def get_weight(self, multiplier=None):
217
- if multiplier is None:
218
- multiplier = self.multiplier
219
-
220
- # get up/down weight from module
221
- up_weight = self.lora_up.weight.to(torch.float)
222
- down_weight = self.lora_down.weight.to(torch.float)
223
-
224
- # pre-calculated weight
225
- if len(down_weight.size()) == 2:
226
- # linear
227
- weight = self.multiplier * (up_weight @ down_weight) * self.scale
228
- elif down_weight.size()[2:4] == (1, 1):
229
- # conv2d 1x1
230
- weight = (
231
- self.multiplier
232
- * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
233
- * self.scale
234
- )
235
- else:
236
- # conv2d 3x3
237
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
238
- weight = self.multiplier * conved * self.scale
239
-
240
- return weight
241
-
242
-
243
- # Create network from weights for inference, weights are not loaded here
244
- def create_network_from_weights(
245
- text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
246
- ):
247
- # get dim/alpha mapping
248
- modules_dim = {}
249
- modules_alpha = {}
250
- for key, value in weights_sd.items():
251
- if "." not in key:
252
- continue
253
-
254
- lora_name = key.split(".")[0]
255
- if "alpha" in key:
256
- modules_alpha[lora_name] = value
257
- elif "lora_down" in key:
258
- dim = value.size()[0]
259
- modules_dim[lora_name] = dim
260
- # print(lora_name, value.size(), dim)
261
-
262
- # support old LoRA without alpha
263
- for key in modules_dim.keys():
264
- if key not in modules_alpha:
265
- modules_alpha[key] = modules_dim[key]
266
-
267
- return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
268
-
269
-
270
- def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
271
- text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
272
- unet = pipe.unet
273
-
274
- lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
275
- lora_network.load_state_dict(weights_sd)
276
- lora_network.merge_to(multiplier=multiplier)
277
-
278
-
279
- # block weightや学習に対応しない簡易版 / simple version without block weight and training
280
- class LoRANetwork(torch.nn.Module):
281
- UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
282
- UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
283
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
284
- LORA_PREFIX_UNET = "lora_unet"
285
- LORA_PREFIX_TEXT_ENCODER = "lora_te"
286
-
287
- # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
288
- LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
289
- LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
290
-
291
- def __init__(
292
- self,
293
- text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
294
- unet: UNet2DConditionModel,
295
- multiplier: float = 1.0,
296
- modules_dim: Optional[Dict[str, int]] = None,
297
- modules_alpha: Optional[Dict[str, int]] = None,
298
- varbose: Optional[bool] = False,
299
- ) -> None:
300
- super().__init__()
301
- self.multiplier = multiplier
302
-
303
- print(f"create LoRA network from weights")
304
-
305
- # convert SDXL Stability AI's U-Net modules to Diffusers
306
- converted = self.convert_unet_modules(modules_dim, modules_alpha)
307
- if converted:
308
- print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
309
-
310
- # create module instances
311
- def create_modules(
312
- is_unet: bool,
313
- text_encoder_idx: Optional[int], # None, 1, 2
314
- root_module: torch.nn.Module,
315
- target_replace_modules: List[torch.nn.Module],
316
- ) -> List[LoRAModule]:
317
- prefix = (
318
- self.LORA_PREFIX_UNET
319
- if is_unet
320
- else (
321
- self.LORA_PREFIX_TEXT_ENCODER
322
- if text_encoder_idx is None
323
- else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
324
- )
325
- )
326
- loras = []
327
- skipped = []
328
- for name, module in root_module.named_modules():
329
- if module.__class__.__name__ in target_replace_modules:
330
- for child_name, child_module in module.named_modules():
331
- is_linear = (
332
- child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
333
- )
334
- is_conv2d = (
335
- child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
336
- )
337
-
338
- if is_linear or is_conv2d:
339
- lora_name = prefix + "." + name + "." + child_name
340
- lora_name = lora_name.replace(".", "_")
341
-
342
- if lora_name not in modules_dim:
343
- # print(f"skipped {lora_name} (not found in modules_dim)")
344
- skipped.append(lora_name)
345
- continue
346
-
347
- dim = modules_dim[lora_name]
348
- alpha = modules_alpha[lora_name]
349
- lora = LoRAModule(
350
- lora_name,
351
- child_module,
352
- self.multiplier,
353
- dim,
354
- alpha,
355
- )
356
- loras.append(lora)
357
- return loras, skipped
358
-
359
- text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
360
-
361
- # create LoRA for text encoder
362
- # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
363
- self.text_encoder_loras: List[LoRAModule] = []
364
- skipped_te = []
365
- for i, text_encoder in enumerate(text_encoders):
366
- if len(text_encoders) > 1:
367
- index = i + 1
368
- else:
369
- index = None
370
-
371
- text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
372
- self.text_encoder_loras.extend(text_encoder_loras)
373
- skipped_te += skipped
374
- print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
375
- if len(skipped_te) > 0:
376
- print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
377
-
378
- # extend U-Net target modules to include Conv2d 3x3
379
- target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
380
-
381
- self.unet_loras: List[LoRAModule]
382
- self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
383
- print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
384
- if len(skipped_un) > 0:
385
- print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
386
-
387
- # assertion
388
- names = set()
389
- for lora in self.text_encoder_loras + self.unet_loras:
390
- names.add(lora.lora_name)
391
- for lora_name in modules_dim.keys():
392
- assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
393
-
394
- # make to work load_state_dict
395
- for lora in self.text_encoder_loras + self.unet_loras:
396
- self.add_module(lora.lora_name, lora)
397
-
398
- # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
399
- def convert_unet_modules(self, modules_dim, modules_alpha):
400
- converted_count = 0
401
- not_converted_count = 0
402
-
403
- map_keys = list(UNET_CONVERSION_MAP.keys())
404
- map_keys.sort()
405
-
406
- for key in list(modules_dim.keys()):
407
- if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
408
- search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
409
- position = bisect.bisect_right(map_keys, search_key)
410
- map_key = map_keys[position - 1]
411
- if search_key.startswith(map_key):
412
- new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
413
- modules_dim[new_key] = modules_dim[key]
414
- modules_alpha[new_key] = modules_alpha[key]
415
- del modules_dim[key]
416
- del modules_alpha[key]
417
- converted_count += 1
418
- else:
419
- not_converted_count += 1
420
- assert (
421
- converted_count == 0 or not_converted_count == 0
422
- ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
423
- return converted_count
424
-
425
- def set_multiplier(self, multiplier):
426
- self.multiplier = multiplier
427
- for lora in self.text_encoder_loras + self.unet_loras:
428
- lora.multiplier = self.multiplier
429
-
430
- def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
431
- if apply_text_encoder:
432
- print("enable LoRA for text encoder")
433
- for lora in self.text_encoder_loras:
434
- lora.apply_to(multiplier)
435
- if apply_unet:
436
- print("enable LoRA for U-Net")
437
- for lora in self.unet_loras:
438
- lora.apply_to(multiplier)
439
-
440
- def unapply_to(self):
441
- for lora in self.text_encoder_loras + self.unet_loras:
442
- lora.unapply_to()
443
-
444
- def merge_to(self, multiplier=1.0):
445
- print("merge LoRA weights to original weights")
446
- for lora in tqdm(self.text_encoder_loras + self.unet_loras):
447
- lora.merge_to(multiplier)
448
- print(f"weights are merged")
449
-
450
- def restore_from(self, multiplier=1.0):
451
- print("restore LoRA weights from original weights")
452
- for lora in tqdm(self.text_encoder_loras + self.unet_loras):
453
- lora.restore_from(multiplier)
454
- print(f"weights are restored")
455
-
456
- def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
457
- # convert SDXL Stability AI's state dict to Diffusers' based state dict
458
- map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
459
- map_keys.sort()
460
- for key in list(state_dict.keys()):
461
- if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
462
- search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
463
- position = bisect.bisect_right(map_keys, search_key)
464
- map_key = map_keys[position - 1]
465
- if search_key.startswith(map_key):
466
- new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
467
- state_dict[new_key] = state_dict[key]
468
- del state_dict[key]
469
-
470
- # in case of V2, some weights have different shape, so we need to convert them
471
- # because V2 LoRA is based on U-Net created by use_linear_projection=False
472
- my_state_dict = self.state_dict()
473
- for key in state_dict.keys():
474
- if state_dict[key].size() != my_state_dict[key].size():
475
- # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
476
- state_dict[key] = state_dict[key].view(my_state_dict[key].size())
477
-
478
- return super().load_state_dict(state_dict, strict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,11 +1,10 @@
1
- accelerate==0.24.1
2
- diffusers==0.23.0
3
- gradio==4.2.0
4
  invisible-watermark==0.2.0
5
- Pillow==10.1.0
 
6
  torch==2.0.1
7
- transformers==4.35.0
8
- toml==0.10.2
9
  omegaconf==2.3.0
10
  timm==0.9.10
11
- git+https://huggingface.co/spaces/Wauplin/gradio-user-history
 
1
+ accelerate==0.27.2
2
+ diffusers==0.26.3
3
+ gradio==4.20.0
4
  invisible-watermark==0.2.0
5
+ Pillow==10.2.0
6
+ spaces==0.24.0
7
  torch==2.0.1
8
+ transformers==4.38.1
 
9
  omegaconf==2.3.0
10
  timm==0.9.10
 
style.css CHANGED
@@ -1,11 +1,6 @@
1
  h1 {
2
  text-align: center;
3
- font-size: 10vw; /* relative to the viewport width */
4
- }
5
-
6
- h2 {
7
- text-align: center;
8
- font-size: 10vw; /* relative to the viewport width */
9
  }
10
 
11
  #duplicate-button {
@@ -15,24 +10,12 @@ h2 {
15
  border-radius: 100vh;
16
  }
17
 
18
- #component-0 {
19
- max-width: 80%; /* relative to the parent element's width */
20
  margin: auto;
21
  padding-top: 1.5rem;
22
  }
23
 
24
- /* You can also use media queries to adjust your style for different screen sizes */
25
- @media (max-width: 600px) {
26
- #component-0 {
27
- max-width: 90%;
28
- padding-top: 1rem;
29
- }
30
- }
31
-
32
- #gallery .grid-wrap{
33
- min-height: 25%;
34
- }
35
-
36
  #title-container {
37
  display: flex;
38
  justify-content: center;
@@ -43,18 +26,9 @@ h2 {
43
  #title {
44
  font-size: 3em;
45
  text-align: center;
46
- color: #333;
47
- font-family: 'Helvetica Neue', sans-serif;
48
- text-transform: uppercase;
49
  background: transparent;
50
  }
51
 
52
- #title span {
53
- background: -webkit-linear-gradient(45deg, #4EACEF, #28b485);
54
- -webkit-background-clip: text;
55
- -webkit-text-fill-color: transparent;
56
- }
57
-
58
  #subtitle {
59
  text-align: center;
60
- }
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
 
 
 
 
 
4
  }
5
 
6
  #duplicate-button {
 
10
  border-radius: 100vh;
11
  }
12
 
13
+ .gradio-container {
14
+ max-width: 730px !important;
15
  margin: auto;
16
  padding-top: 1.5rem;
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  #title-container {
20
  display: flex;
21
  justify-content: center;
 
26
  #title {
27
  font-size: 3em;
28
  text-align: center;
 
 
 
29
  background: transparent;
30
  }
31
 
 
 
 
 
 
 
32
  #subtitle {
33
  text-align: center;
34
+ }
utils.py CHANGED
@@ -1,7 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def is_google_colab():
2
  try:
3
  import google.colab
4
-
5
  return True
6
  except:
7
  return False
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import json
6
+ import torch
7
+ from PIL import Image, PngImagePlugin
8
+ from datetime import datetime
9
+ from dataclasses import dataclass
10
+ from typing import Callable, Dict, Optional, Tuple
11
+ from diffusers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ DPMSolverSinglestepScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ EulerDiscreteScheduler,
17
+ )
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+
21
+
22
+ @dataclass
23
+ class StyleConfig:
24
+ prompt: str
25
+ negative_prompt: str
26
+
27
+
28
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
29
+ if randomize_seed:
30
+ seed = random.randint(0, MAX_SEED)
31
+ return seed
32
+
33
+
34
+ def seed_everything(seed: int) -> torch.Generator:
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ np.random.seed(seed)
38
+ generator = torch.Generator()
39
+ generator.manual_seed(seed)
40
+ return generator
41
+
42
+
43
+ def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
44
+ if aspect_ratio == "Custom":
45
+ return None
46
+ width, height = aspect_ratio.split(" x ")
47
+ return int(width), int(height)
48
+
49
+
50
+ def aspect_ratio_handler(
51
+ aspect_ratio: str, custom_width: int, custom_height: int
52
+ ) -> Tuple[int, int]:
53
+ if aspect_ratio == "Custom":
54
+ return custom_width, custom_height
55
+ else:
56
+ width, height = parse_aspect_ratio(aspect_ratio)
57
+ return width, height
58
+
59
+
60
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
61
+ scheduler_factory_map = {
62
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
63
+ scheduler_config, use_karras_sigmas=True
64
+ ),
65
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
66
+ scheduler_config, use_karras_sigmas=True
67
+ ),
68
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
69
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
70
+ ),
71
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
72
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
73
+ scheduler_config
74
+ ),
75
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
76
+ }
77
+ return scheduler_factory_map.get(name, lambda: None)()
78
+
79
+
80
+ def free_memory() -> None:
81
+ torch.cuda.empty_cache()
82
+ gc.collect()
83
+
84
+
85
+ def preprocess_prompt(
86
+ style_dict,
87
+ style_name: str,
88
+ positive: str,
89
+ negative: str = "",
90
+ add_style: bool = True,
91
+ ) -> Tuple[str, str]:
92
+ p, n = style_dict.get(style_name, style_dict["(None)"])
93
+
94
+ if add_style and positive.strip():
95
+ formatted_positive = p.format(prompt=positive)
96
+ else:
97
+ formatted_positive = positive
98
+
99
+ combined_negative = n
100
+ if negative.strip():
101
+ if combined_negative:
102
+ combined_negative += ", " + negative
103
+ else:
104
+ combined_negative = negative
105
+
106
+ return formatted_positive, combined_negative
107
+
108
+
109
+ def common_upscale(
110
+ samples: torch.Tensor,
111
+ width: int,
112
+ height: int,
113
+ upscale_method: str,
114
+ ) -> torch.Tensor:
115
+ return torch.nn.functional.interpolate(
116
+ samples, size=(height, width), mode=upscale_method
117
+ )
118
+
119
+
120
+ def upscale(
121
+ samples: torch.Tensor, upscale_method: str, scale_by: float
122
+ ) -> torch.Tensor:
123
+ width = round(samples.shape[3] * scale_by)
124
+ height = round(samples.shape[2] * scale_by)
125
+ return common_upscale(samples, width, height, upscale_method)
126
+
127
+
128
+ def load_wildcard_files(wildcard_dir: str) -> Dict[str, str]:
129
+ wildcard_files = {}
130
+ for file in os.listdir(wildcard_dir):
131
+ if file.endswith(".txt"):
132
+ key = f"__{file.split('.')[0]}__" # Create a key like __character__
133
+ wildcard_files[key] = os.path.join(wildcard_dir, file)
134
+ return wildcard_files
135
+
136
+
137
+ def get_random_line_from_file(file_path: str) -> str:
138
+ with open(file_path, "r") as file:
139
+ lines = file.readlines()
140
+ if not lines:
141
+ return ""
142
+ return random.choice(lines).strip()
143
+
144
+
145
+ def add_wildcard(prompt: str, wildcard_files: Dict[str, str]) -> str:
146
+ for key, file_path in wildcard_files.items():
147
+ if key in prompt:
148
+ wildcard_line = get_random_line_from_file(file_path)
149
+ prompt = prompt.replace(key, wildcard_line)
150
+ return prompt
151
+
152
+
153
+ def preprocess_image_dimensions(width, height):
154
+ if width % 8 != 0:
155
+ width = width - (width % 8)
156
+ if height % 8 != 0:
157
+ height = height - (height % 8)
158
+ return width, height
159
+
160
+
161
+ def save_image(image, metadata, output_dir):
162
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
163
+ os.makedirs(output_dir, exist_ok=True)
164
+ filename = f"image_{current_time}.png"
165
+ filepath = os.path.join(output_dir, filename)
166
+
167
+ metadata_str = json.dumps(metadata)
168
+ info = PngImagePlugin.PngInfo()
169
+ info.add_text("metadata", metadata_str)
170
+ image.save(filepath, "PNG", pnginfo=info)
171
+ return filepath
172
+
173
+
174
  def is_google_colab():
175
  try:
176
  import google.colab
 
177
  return True
178
  except:
179
  return False