adamelliotfields commited on
Commit
cb5daed
·
verified ·
1 Parent(s): c62ffd9

Add AuraSR GAN

Browse files
Files changed (5) hide show
  1. README.md +8 -4
  2. app.py +11 -15
  3. cli.py +4 -0
  4. generate.py +47 -20
  5. requirements.txt +1 -0
README.md CHANGED
@@ -43,7 +43,14 @@ preload_from_hub:
43
 
44
  # diffusion
45
 
46
- Gradio-based UI for Stable Diffusion pipelines.
 
 
 
 
 
 
 
47
 
48
  ## Usage
49
 
@@ -65,8 +72,5 @@ python cli.py 'an astronaut riding a horse on mars'
65
 
66
  ## TODO
67
 
68
- - [ ] Hires fix
69
- - [ ] Support LoRA
70
  - [ ] Metadata embed and display
71
  - [ ] Image-to-image
72
- - [ ] Latent preview
 
43
 
44
  # diffusion
45
 
46
+ Gradio app for Stable Diffusion 1.5 including:
47
+ * curated models and TI embeddings
48
+ * multiple samplers with Karras schedule
49
+ * Compel prompting
50
+ * 100+ styles from sdxl_prompt_styler
51
+ * AuraSR GAN
52
+ * DeepCache and ToMe
53
+ * optional TAESD
54
 
55
  ## Usage
56
 
 
72
 
73
  ## TODO
74
 
 
 
75
  - [ ] Metadata embed and display
76
  - [ ] Image-to-image
 
app.py CHANGED
@@ -100,7 +100,6 @@ with gr.Blocks(
100
  style = gr.Dropdown(
101
  value=cfg.STYLE,
102
  label="Style",
103
- scale=1,
104
  choices=["None"] + [f"{style['name']}" for style in styles],
105
  )
106
  scheduler = gr.Dropdown(
@@ -109,7 +108,6 @@ with gr.Blocks(
109
  label="Scheduler",
110
  filterable=False,
111
  min_width=200,
112
- scale=1,
113
  choices=cfg.SCHEDULERS,
114
  )
115
 
@@ -119,7 +117,6 @@ with gr.Blocks(
119
  label="Guidance Scale",
120
  minimum=1.0,
121
  maximum=15.0,
122
- scale=1,
123
  step=0.1,
124
  )
125
  inference_steps = gr.Slider(
@@ -127,7 +124,6 @@ with gr.Blocks(
127
  label="Inference Steps",
128
  minimum=1,
129
  maximum=50,
130
- scale=1,
131
  step=1,
132
  )
133
  seed = gr.Number(
@@ -135,32 +131,28 @@ with gr.Blocks(
135
  label="Seed",
136
  minimum=-1,
137
  maximum=(2**64) - 1,
138
- scale=1,
139
  )
140
 
141
  with gr.Row():
142
  width = gr.Slider(
143
  value=cfg.WIDTH,
144
  label="Width",
145
- minimum=256,
146
- maximum=1024,
147
  step=32,
148
- scale=1,
149
  )
150
  height = gr.Slider(
151
  value=cfg.HEIGHT,
152
  label="Height",
153
- minimum=256,
154
- maximum=1024,
155
  step=32,
156
- scale=1,
157
  )
158
  num_images = gr.Dropdown(
159
  choices=list(range(1, 5)),
160
  value=cfg.NUM_IMAGES,
161
  filterable=False,
162
  label="Images",
163
- scale=1,
164
  )
165
 
166
  with gr.Row():
@@ -174,6 +166,12 @@ with gr.Blocks(
174
  elem_classes=["checkbox"],
175
  label="Autoincrement",
176
  value=True,
 
 
 
 
 
 
177
  scale=3,
178
  )
179
 
@@ -206,19 +204,16 @@ with gr.Blocks(
206
  elem_classes=["checkbox"],
207
  label="Tiny VAE",
208
  value=False,
209
- scale=1,
210
  )
211
  use_clip_skip = gr.Checkbox(
212
  elem_classes=["checkbox"],
213
  label="Clip skip",
214
  value=False,
215
- scale=1,
216
  )
217
  truncate_prompts = gr.Checkbox(
218
  elem_classes=["checkbox"],
219
  label="Truncate prompts",
220
  value=False,
221
- scale=1,
222
  )
223
 
224
  with gr.TabItem("ℹ️ Usage"):
@@ -304,6 +299,7 @@ with gr.Blocks(
304
  increment_seed,
305
  deepcache_interval,
306
  tome_ratio,
 
307
  ],
308
  )
309
 
 
100
  style = gr.Dropdown(
101
  value=cfg.STYLE,
102
  label="Style",
 
103
  choices=["None"] + [f"{style['name']}" for style in styles],
104
  )
105
  scheduler = gr.Dropdown(
 
108
  label="Scheduler",
109
  filterable=False,
110
  min_width=200,
 
111
  choices=cfg.SCHEDULERS,
112
  )
113
 
 
117
  label="Guidance Scale",
118
  minimum=1.0,
119
  maximum=15.0,
 
120
  step=0.1,
121
  )
122
  inference_steps = gr.Slider(
 
124
  label="Inference Steps",
125
  minimum=1,
126
  maximum=50,
 
127
  step=1,
128
  )
129
  seed = gr.Number(
 
131
  label="Seed",
132
  minimum=-1,
133
  maximum=(2**64) - 1,
 
134
  )
135
 
136
  with gr.Row():
137
  width = gr.Slider(
138
  value=cfg.WIDTH,
139
  label="Width",
140
+ minimum=320,
141
+ maximum=768,
142
  step=32,
 
143
  )
144
  height = gr.Slider(
145
  value=cfg.HEIGHT,
146
  label="Height",
147
+ minimum=320,
148
+ maximum=768,
149
  step=32,
 
150
  )
151
  num_images = gr.Dropdown(
152
  choices=list(range(1, 5)),
153
  value=cfg.NUM_IMAGES,
154
  filterable=False,
155
  label="Images",
 
156
  )
157
 
158
  with gr.Row():
 
166
  elem_classes=["checkbox"],
167
  label="Autoincrement",
168
  value=True,
169
+ scale=1,
170
+ )
171
+ upscale_4x = gr.Checkbox(
172
+ elem_classes=["checkbox"],
173
+ label="Upscale 4x",
174
+ value=False,
175
  scale=3,
176
  )
177
 
 
204
  elem_classes=["checkbox"],
205
  label="Tiny VAE",
206
  value=False,
 
207
  )
208
  use_clip_skip = gr.Checkbox(
209
  elem_classes=["checkbox"],
210
  label="Clip skip",
211
  value=False,
 
212
  )
213
  truncate_prompts = gr.Checkbox(
214
  elem_classes=["checkbox"],
215
  label="Truncate prompts",
216
  value=False,
 
217
  )
218
 
219
  with gr.TabItem("ℹ️ Usage"):
 
299
  increment_seed,
300
  deepcache_interval,
301
  tome_ratio,
302
+ upscale_4x,
303
  ],
304
  )
305
 
cli.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import argparse
2
 
3
  import config as cfg
@@ -31,6 +33,7 @@ def main():
31
  parser.add_argument("--clip-skip", action="store_true")
32
  parser.add_argument("--truncate", action="store_true")
33
  parser.add_argument("--karras", action="store_true")
 
34
  parser.add_argument("--no-increment", action="store_false")
35
  # fmt: on
36
 
@@ -54,6 +57,7 @@ def main():
54
  args.no_increment,
55
  args.deepcache,
56
  args.tome,
 
57
  )
58
  save_images(images, args.filename)
59
 
 
1
+ # CLI
2
+ # usage: python cli.py 'colorful calico cat artstation'
3
  import argparse
4
 
5
  import config as cfg
 
33
  parser.add_argument("--clip-skip", action="store_true")
34
  parser.add_argument("--truncate", action="store_true")
35
  parser.add_argument("--karras", action="store_true")
36
+ parser.add_argument("--upscale", action="store_true")
37
  parser.add_argument("--no-increment", action="store_false")
38
  # fmt: on
39
 
 
57
  args.no_increment,
58
  args.deepcache,
59
  args.tome,
60
+ args.upscale,
61
  )
62
  save_images(images, args.filename)
63
 
generate.py CHANGED
@@ -1,16 +1,16 @@
1
  import json
 
2
  import re
3
  import time
4
  from contextlib import contextmanager
5
  from datetime import datetime
6
  from itertools import product
7
- from os import environ
8
- from types import MethodType
9
  from typing import Callable
10
 
11
  import spaces
12
  import tomesd
13
  import torch
 
14
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
15
  from compel.prompt_parser import PromptParser
16
  from DeepCache import DeepCacheSDHelper
@@ -27,13 +27,13 @@ from diffusers import (
27
  from diffusers.models import AutoencoderKL, AutoencoderTiny
28
  from torch._dynamo import OptimizedModule
29
 
30
- # some models use the deprecated CLIPFeatureExtractor class (should use CLIPImageProcessor)
31
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
32
  __import__("transformers").logging.set_verbosity_error()
33
 
34
  ZERO_GPU = (
35
- environ.get("SPACES_ZERO_GPU", "").lower() == "true"
36
- or environ.get("SPACES_ZERO_GPU", "") == "1"
37
  )
38
 
39
  EMBEDDINGS = {
@@ -58,6 +58,7 @@ class Loader:
58
  cls._instance = super(Loader, cls).__new__(cls)
59
  cls._instance.cpu = torch.device("cpu")
60
  cls._instance.gpu = torch.device("cuda")
 
61
  cls._instance.pipe = None
62
  return cls._instance
63
 
@@ -105,7 +106,7 @@ class Loader:
105
  )
106
  return self.pipe.vae
107
 
108
- def load(self, model, scheduler, karras, taesd, deepcache_interval, dtype=None):
109
  model_lower = model.lower()
110
 
111
  schedulers = {
@@ -127,7 +128,7 @@ class Loader:
127
  "steps_offset": 1,
128
  }
129
 
130
- if scheduler == "PNDM" or scheduler == "Euler a":
131
  del scheduler_kwargs["use_karras_sigmas"]
132
 
133
  pipe_kwargs = {
@@ -159,7 +160,7 @@ class Loader:
159
 
160
  self._load_vae(model_lower, taesd, dtype)
161
  self._load_deepcache(interval=deepcache_interval)
162
- return self.pipe
163
  else:
164
  print(f"Unloading {model_name.lower()}...")
165
  self.pipe = None
@@ -181,7 +182,17 @@ class Loader:
181
  )
182
  self._load_vae(model_lower, taesd, dtype)
183
  self._load_deepcache(interval=deepcache_interval)
184
- return self.pipe
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  # applies tome to the pipeline
@@ -227,8 +238,7 @@ def apply_style(prompt, style_name, negative=False):
227
  return prompt
228
 
229
 
230
- # 1024x1024 for 50 steps can take ~10s each
231
- @spaces.GPU(duration=44)
232
  def generate(
233
  positive_prompt,
234
  negative_prompt="",
@@ -248,6 +258,7 @@ def generate(
248
  increment_seed=True,
249
  deepcache_interval=1,
250
  tome_ratio=0,
 
251
  log: Callable[[str], None] = None,
252
  Error=Exception,
253
  ):
@@ -258,9 +269,11 @@ def generate(
258
  if seed is None or seed < 0:
259
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
260
 
 
 
261
  TORCH_DTYPE = (
262
  torch.bfloat16
263
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
264
  else torch.float16
265
  )
266
 
@@ -273,7 +286,15 @@ def generate(
273
  with torch.inference_mode():
274
  start = time.perf_counter()
275
  loader = Loader()
276
- pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE)
 
 
 
 
 
 
 
 
277
 
278
  # prompt embeds
279
  compel = Compel(
@@ -283,7 +304,7 @@ def generate(
283
  truncate_long_prompts=truncate_prompts,
284
  text_encoder=pipe.text_encoder,
285
  tokenizer=pipe.tokenizer,
286
- device=pipe.device,
287
  )
288
 
289
  images = []
@@ -297,7 +318,7 @@ def generate(
297
 
298
  for i in range(num_images):
299
  # seeded generator for each iteration
300
- generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
301
 
302
  try:
303
  all_positive_prompts = parse_prompt(positive_prompt)
@@ -312,7 +333,7 @@ def generate(
312
  raise Error("ParsingException: Invalid prompt")
313
 
314
  with token_merging(pipe, tome_ratio=tome_ratio):
315
- result = pipe(
316
  num_inference_steps=inference_steps,
317
  negative_prompt_embeds=neg_embeds,
318
  guidance_scale=guidance_scale,
@@ -320,8 +341,14 @@ def generate(
320
  generator=generator,
321
  height=height,
322
  width=width,
323
- )
324
- images.append((result.images[0], str(current_seed)))
 
 
 
 
 
 
325
 
326
  if increment_seed:
327
  current_seed += 1
@@ -329,9 +356,9 @@ def generate(
329
  if ZERO_GPU:
330
  # spaces always start fresh
331
  loader.pipe = None
 
332
 
333
- end = time.perf_counter()
334
- diff = end - start
335
  if log:
336
  log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
337
  return images
 
1
  import json
2
+ import os
3
  import re
4
  import time
5
  from contextlib import contextmanager
6
  from datetime import datetime
7
  from itertools import product
 
 
8
  from typing import Callable
9
 
10
  import spaces
11
  import tomesd
12
  import torch
13
+ from aura_sr import AuraSR
14
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
15
  from compel.prompt_parser import PromptParser
16
  from DeepCache import DeepCacheSDHelper
 
27
  from diffusers.models import AutoencoderKL, AutoencoderTiny
28
  from torch._dynamo import OptimizedModule
29
 
30
+ __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
31
  __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
32
  __import__("transformers").logging.set_verbosity_error()
33
 
34
  ZERO_GPU = (
35
+ os.environ.get("SPACES_ZERO_GPU", "").lower() == "true"
36
+ or os.environ.get("SPACES_ZERO_GPU", "") == "1"
37
  )
38
 
39
  EMBEDDINGS = {
 
58
  cls._instance = super(Loader, cls).__new__(cls)
59
  cls._instance.cpu = torch.device("cpu")
60
  cls._instance.gpu = torch.device("cuda")
61
+ cls._instance.gan = None
62
  cls._instance.pipe = None
63
  return cls._instance
64
 
 
106
  )
107
  return self.pipe.vae
108
 
109
+ def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype=None):
110
  model_lower = model.lower()
111
 
112
  schedulers = {
 
128
  "steps_offset": 1,
129
  }
130
 
131
+ if scheduler in ["Euler a", "PNDM"]:
132
  del scheduler_kwargs["use_karras_sigmas"]
133
 
134
  pipe_kwargs = {
 
160
 
161
  self._load_vae(model_lower, taesd, dtype)
162
  self._load_deepcache(interval=deepcache_interval)
163
+ return self.pipe, self.gan
164
  else:
165
  print(f"Unloading {model_name.lower()}...")
166
  self.pipe = None
 
182
  )
183
  self._load_vae(model_lower, taesd, dtype)
184
  self._load_deepcache(interval=deepcache_interval)
185
+
186
+ if upscale and self.gan is None:
187
+ print("Loading fal/AuraSR-v2...")
188
+ self.gan = AuraSR.from_pretrained("fal/AuraSR-v2")
189
+
190
+ if not upscale and self.gan is not None:
191
+ print("Unloading fal/AuraSR-v2...")
192
+ self.gan = None
193
+ torch.cuda.empty_cache
194
+
195
+ return self.pipe, self.gan
196
 
197
 
198
  # applies tome to the pipeline
 
238
  return prompt
239
 
240
 
241
+ @spaces.GPU(duration=40)
 
242
  def generate(
243
  positive_prompt,
244
  negative_prompt="",
 
258
  increment_seed=True,
259
  deepcache_interval=1,
260
  tome_ratio=0,
261
+ upscale=False,
262
  log: Callable[[str], None] = None,
263
  Error=Exception,
264
  ):
 
269
  if seed is None or seed < 0:
270
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
271
 
272
+ GPU = torch.device("cuda")
273
+
274
  TORCH_DTYPE = (
275
  torch.bfloat16
276
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported(including_emulation=False)
277
  else torch.float16
278
  )
279
 
 
286
  with torch.inference_mode():
287
  start = time.perf_counter()
288
  loader = Loader()
289
+ pipe, gan = loader.load(
290
+ model,
291
+ scheduler,
292
+ karras,
293
+ taesd,
294
+ deepcache_interval,
295
+ upscale,
296
+ TORCH_DTYPE,
297
+ )
298
 
299
  # prompt embeds
300
  compel = Compel(
 
304
  truncate_long_prompts=truncate_prompts,
305
  text_encoder=pipe.text_encoder,
306
  tokenizer=pipe.tokenizer,
307
+ device=GPU,
308
  )
309
 
310
  images = []
 
318
 
319
  for i in range(num_images):
320
  # seeded generator for each iteration
321
+ generator = torch.Generator(device=GPU).manual_seed(current_seed)
322
 
323
  try:
324
  all_positive_prompts = parse_prompt(positive_prompt)
 
333
  raise Error("ParsingException: Invalid prompt")
334
 
335
  with token_merging(pipe, tome_ratio=tome_ratio):
336
+ image = pipe(
337
  num_inference_steps=inference_steps,
338
  negative_prompt_embeds=neg_embeds,
339
  guidance_scale=guidance_scale,
 
341
  generator=generator,
342
  height=height,
343
  width=width,
344
+ ).images[0]
345
+
346
+ if upscale:
347
+ print("Upscaling image...")
348
+ batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
349
+ image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
350
+
351
+ images.append((image, str(current_seed)))
352
 
353
  if increment_seed:
354
  current_seed += 1
 
356
  if ZERO_GPU:
357
  # spaces always start fresh
358
  loader.pipe = None
359
+ loader.gan = None
360
 
361
+ diff = time.perf_counter() - start
 
362
  if log:
363
  log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
364
  return images
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  accelerate
 
2
  compel
3
  deepcache==0.1.1
4
  diffusers
 
1
  accelerate
2
+ aura-sr==0.0.4
3
  compel
4
  deepcache==0.1.1
5
  diffusers