Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Add AuraSR GAN
Browse files- README.md +8 -4
- app.py +11 -15
- cli.py +4 -0
- generate.py +47 -20
- requirements.txt +1 -0
README.md
CHANGED
@@ -43,7 +43,14 @@ preload_from_hub:
|
|
43 |
|
44 |
# diffusion
|
45 |
|
46 |
-
Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
## Usage
|
49 |
|
@@ -65,8 +72,5 @@ python cli.py 'an astronaut riding a horse on mars'
|
|
65 |
|
66 |
## TODO
|
67 |
|
68 |
-
- [ ] Hires fix
|
69 |
-
- [ ] Support LoRA
|
70 |
- [ ] Metadata embed and display
|
71 |
- [ ] Image-to-image
|
72 |
-
- [ ] Latent preview
|
|
|
43 |
|
44 |
# diffusion
|
45 |
|
46 |
+
Gradio app for Stable Diffusion 1.5 including:
|
47 |
+
* curated models and TI embeddings
|
48 |
+
* multiple samplers with Karras schedule
|
49 |
+
* Compel prompting
|
50 |
+
* 100+ styles from sdxl_prompt_styler
|
51 |
+
* AuraSR GAN
|
52 |
+
* DeepCache and ToMe
|
53 |
+
* optional TAESD
|
54 |
|
55 |
## Usage
|
56 |
|
|
|
72 |
|
73 |
## TODO
|
74 |
|
|
|
|
|
75 |
- [ ] Metadata embed and display
|
76 |
- [ ] Image-to-image
|
|
app.py
CHANGED
@@ -100,7 +100,6 @@ with gr.Blocks(
|
|
100 |
style = gr.Dropdown(
|
101 |
value=cfg.STYLE,
|
102 |
label="Style",
|
103 |
-
scale=1,
|
104 |
choices=["None"] + [f"{style['name']}" for style in styles],
|
105 |
)
|
106 |
scheduler = gr.Dropdown(
|
@@ -109,7 +108,6 @@ with gr.Blocks(
|
|
109 |
label="Scheduler",
|
110 |
filterable=False,
|
111 |
min_width=200,
|
112 |
-
scale=1,
|
113 |
choices=cfg.SCHEDULERS,
|
114 |
)
|
115 |
|
@@ -119,7 +117,6 @@ with gr.Blocks(
|
|
119 |
label="Guidance Scale",
|
120 |
minimum=1.0,
|
121 |
maximum=15.0,
|
122 |
-
scale=1,
|
123 |
step=0.1,
|
124 |
)
|
125 |
inference_steps = gr.Slider(
|
@@ -127,7 +124,6 @@ with gr.Blocks(
|
|
127 |
label="Inference Steps",
|
128 |
minimum=1,
|
129 |
maximum=50,
|
130 |
-
scale=1,
|
131 |
step=1,
|
132 |
)
|
133 |
seed = gr.Number(
|
@@ -135,32 +131,28 @@ with gr.Blocks(
|
|
135 |
label="Seed",
|
136 |
minimum=-1,
|
137 |
maximum=(2**64) - 1,
|
138 |
-
scale=1,
|
139 |
)
|
140 |
|
141 |
with gr.Row():
|
142 |
width = gr.Slider(
|
143 |
value=cfg.WIDTH,
|
144 |
label="Width",
|
145 |
-
minimum=
|
146 |
-
maximum=
|
147 |
step=32,
|
148 |
-
scale=1,
|
149 |
)
|
150 |
height = gr.Slider(
|
151 |
value=cfg.HEIGHT,
|
152 |
label="Height",
|
153 |
-
minimum=
|
154 |
-
maximum=
|
155 |
step=32,
|
156 |
-
scale=1,
|
157 |
)
|
158 |
num_images = gr.Dropdown(
|
159 |
choices=list(range(1, 5)),
|
160 |
value=cfg.NUM_IMAGES,
|
161 |
filterable=False,
|
162 |
label="Images",
|
163 |
-
scale=1,
|
164 |
)
|
165 |
|
166 |
with gr.Row():
|
@@ -174,6 +166,12 @@ with gr.Blocks(
|
|
174 |
elem_classes=["checkbox"],
|
175 |
label="Autoincrement",
|
176 |
value=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
scale=3,
|
178 |
)
|
179 |
|
@@ -206,19 +204,16 @@ with gr.Blocks(
|
|
206 |
elem_classes=["checkbox"],
|
207 |
label="Tiny VAE",
|
208 |
value=False,
|
209 |
-
scale=1,
|
210 |
)
|
211 |
use_clip_skip = gr.Checkbox(
|
212 |
elem_classes=["checkbox"],
|
213 |
label="Clip skip",
|
214 |
value=False,
|
215 |
-
scale=1,
|
216 |
)
|
217 |
truncate_prompts = gr.Checkbox(
|
218 |
elem_classes=["checkbox"],
|
219 |
label="Truncate prompts",
|
220 |
value=False,
|
221 |
-
scale=1,
|
222 |
)
|
223 |
|
224 |
with gr.TabItem("ℹ️ Usage"):
|
@@ -304,6 +299,7 @@ with gr.Blocks(
|
|
304 |
increment_seed,
|
305 |
deepcache_interval,
|
306 |
tome_ratio,
|
|
|
307 |
],
|
308 |
)
|
309 |
|
|
|
100 |
style = gr.Dropdown(
|
101 |
value=cfg.STYLE,
|
102 |
label="Style",
|
|
|
103 |
choices=["None"] + [f"{style['name']}" for style in styles],
|
104 |
)
|
105 |
scheduler = gr.Dropdown(
|
|
|
108 |
label="Scheduler",
|
109 |
filterable=False,
|
110 |
min_width=200,
|
|
|
111 |
choices=cfg.SCHEDULERS,
|
112 |
)
|
113 |
|
|
|
117 |
label="Guidance Scale",
|
118 |
minimum=1.0,
|
119 |
maximum=15.0,
|
|
|
120 |
step=0.1,
|
121 |
)
|
122 |
inference_steps = gr.Slider(
|
|
|
124 |
label="Inference Steps",
|
125 |
minimum=1,
|
126 |
maximum=50,
|
|
|
127 |
step=1,
|
128 |
)
|
129 |
seed = gr.Number(
|
|
|
131 |
label="Seed",
|
132 |
minimum=-1,
|
133 |
maximum=(2**64) - 1,
|
|
|
134 |
)
|
135 |
|
136 |
with gr.Row():
|
137 |
width = gr.Slider(
|
138 |
value=cfg.WIDTH,
|
139 |
label="Width",
|
140 |
+
minimum=320,
|
141 |
+
maximum=768,
|
142 |
step=32,
|
|
|
143 |
)
|
144 |
height = gr.Slider(
|
145 |
value=cfg.HEIGHT,
|
146 |
label="Height",
|
147 |
+
minimum=320,
|
148 |
+
maximum=768,
|
149 |
step=32,
|
|
|
150 |
)
|
151 |
num_images = gr.Dropdown(
|
152 |
choices=list(range(1, 5)),
|
153 |
value=cfg.NUM_IMAGES,
|
154 |
filterable=False,
|
155 |
label="Images",
|
|
|
156 |
)
|
157 |
|
158 |
with gr.Row():
|
|
|
166 |
elem_classes=["checkbox"],
|
167 |
label="Autoincrement",
|
168 |
value=True,
|
169 |
+
scale=1,
|
170 |
+
)
|
171 |
+
upscale_4x = gr.Checkbox(
|
172 |
+
elem_classes=["checkbox"],
|
173 |
+
label="Upscale 4x",
|
174 |
+
value=False,
|
175 |
scale=3,
|
176 |
)
|
177 |
|
|
|
204 |
elem_classes=["checkbox"],
|
205 |
label="Tiny VAE",
|
206 |
value=False,
|
|
|
207 |
)
|
208 |
use_clip_skip = gr.Checkbox(
|
209 |
elem_classes=["checkbox"],
|
210 |
label="Clip skip",
|
211 |
value=False,
|
|
|
212 |
)
|
213 |
truncate_prompts = gr.Checkbox(
|
214 |
elem_classes=["checkbox"],
|
215 |
label="Truncate prompts",
|
216 |
value=False,
|
|
|
217 |
)
|
218 |
|
219 |
with gr.TabItem("ℹ️ Usage"):
|
|
|
299 |
increment_seed,
|
300 |
deepcache_interval,
|
301 |
tome_ratio,
|
302 |
+
upscale_4x,
|
303 |
],
|
304 |
)
|
305 |
|
cli.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import argparse
|
2 |
|
3 |
import config as cfg
|
@@ -31,6 +33,7 @@ def main():
|
|
31 |
parser.add_argument("--clip-skip", action="store_true")
|
32 |
parser.add_argument("--truncate", action="store_true")
|
33 |
parser.add_argument("--karras", action="store_true")
|
|
|
34 |
parser.add_argument("--no-increment", action="store_false")
|
35 |
# fmt: on
|
36 |
|
@@ -54,6 +57,7 @@ def main():
|
|
54 |
args.no_increment,
|
55 |
args.deepcache,
|
56 |
args.tome,
|
|
|
57 |
)
|
58 |
save_images(images, args.filename)
|
59 |
|
|
|
1 |
+
# CLI
|
2 |
+
# usage: python cli.py 'colorful calico cat artstation'
|
3 |
import argparse
|
4 |
|
5 |
import config as cfg
|
|
|
33 |
parser.add_argument("--clip-skip", action="store_true")
|
34 |
parser.add_argument("--truncate", action="store_true")
|
35 |
parser.add_argument("--karras", action="store_true")
|
36 |
+
parser.add_argument("--upscale", action="store_true")
|
37 |
parser.add_argument("--no-increment", action="store_false")
|
38 |
# fmt: on
|
39 |
|
|
|
57 |
args.no_increment,
|
58 |
args.deepcache,
|
59 |
args.tome,
|
60 |
+
args.upscale,
|
61 |
)
|
62 |
save_images(images, args.filename)
|
63 |
|
generate.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import json
|
|
|
2 |
import re
|
3 |
import time
|
4 |
from contextlib import contextmanager
|
5 |
from datetime import datetime
|
6 |
from itertools import product
|
7 |
-
from os import environ
|
8 |
-
from types import MethodType
|
9 |
from typing import Callable
|
10 |
|
11 |
import spaces
|
12 |
import tomesd
|
13 |
import torch
|
|
|
14 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
15 |
from compel.prompt_parser import PromptParser
|
16 |
from DeepCache import DeepCacheSDHelper
|
@@ -27,13 +27,13 @@ from diffusers import (
|
|
27 |
from diffusers.models import AutoencoderKL, AutoencoderTiny
|
28 |
from torch._dynamo import OptimizedModule
|
29 |
|
30 |
-
|
31 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
32 |
__import__("transformers").logging.set_verbosity_error()
|
33 |
|
34 |
ZERO_GPU = (
|
35 |
-
environ.get("SPACES_ZERO_GPU", "").lower() == "true"
|
36 |
-
or environ.get("SPACES_ZERO_GPU", "") == "1"
|
37 |
)
|
38 |
|
39 |
EMBEDDINGS = {
|
@@ -58,6 +58,7 @@ class Loader:
|
|
58 |
cls._instance = super(Loader, cls).__new__(cls)
|
59 |
cls._instance.cpu = torch.device("cpu")
|
60 |
cls._instance.gpu = torch.device("cuda")
|
|
|
61 |
cls._instance.pipe = None
|
62 |
return cls._instance
|
63 |
|
@@ -105,7 +106,7 @@ class Loader:
|
|
105 |
)
|
106 |
return self.pipe.vae
|
107 |
|
108 |
-
def load(self, model, scheduler, karras, taesd, deepcache_interval, dtype=None):
|
109 |
model_lower = model.lower()
|
110 |
|
111 |
schedulers = {
|
@@ -127,7 +128,7 @@ class Loader:
|
|
127 |
"steps_offset": 1,
|
128 |
}
|
129 |
|
130 |
-
if scheduler
|
131 |
del scheduler_kwargs["use_karras_sigmas"]
|
132 |
|
133 |
pipe_kwargs = {
|
@@ -159,7 +160,7 @@ class Loader:
|
|
159 |
|
160 |
self._load_vae(model_lower, taesd, dtype)
|
161 |
self._load_deepcache(interval=deepcache_interval)
|
162 |
-
return self.pipe
|
163 |
else:
|
164 |
print(f"Unloading {model_name.lower()}...")
|
165 |
self.pipe = None
|
@@ -181,7 +182,17 @@ class Loader:
|
|
181 |
)
|
182 |
self._load_vae(model_lower, taesd, dtype)
|
183 |
self._load_deepcache(interval=deepcache_interval)
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
|
187 |
# applies tome to the pipeline
|
@@ -227,8 +238,7 @@ def apply_style(prompt, style_name, negative=False):
|
|
227 |
return prompt
|
228 |
|
229 |
|
230 |
-
|
231 |
-
@spaces.GPU(duration=44)
|
232 |
def generate(
|
233 |
positive_prompt,
|
234 |
negative_prompt="",
|
@@ -248,6 +258,7 @@ def generate(
|
|
248 |
increment_seed=True,
|
249 |
deepcache_interval=1,
|
250 |
tome_ratio=0,
|
|
|
251 |
log: Callable[[str], None] = None,
|
252 |
Error=Exception,
|
253 |
):
|
@@ -258,9 +269,11 @@ def generate(
|
|
258 |
if seed is None or seed < 0:
|
259 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
260 |
|
|
|
|
|
261 |
TORCH_DTYPE = (
|
262 |
torch.bfloat16
|
263 |
-
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
264 |
else torch.float16
|
265 |
)
|
266 |
|
@@ -273,7 +286,15 @@ def generate(
|
|
273 |
with torch.inference_mode():
|
274 |
start = time.perf_counter()
|
275 |
loader = Loader()
|
276 |
-
pipe = loader.load(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
# prompt embeds
|
279 |
compel = Compel(
|
@@ -283,7 +304,7 @@ def generate(
|
|
283 |
truncate_long_prompts=truncate_prompts,
|
284 |
text_encoder=pipe.text_encoder,
|
285 |
tokenizer=pipe.tokenizer,
|
286 |
-
device=
|
287 |
)
|
288 |
|
289 |
images = []
|
@@ -297,7 +318,7 @@ def generate(
|
|
297 |
|
298 |
for i in range(num_images):
|
299 |
# seeded generator for each iteration
|
300 |
-
generator = torch.Generator(device=
|
301 |
|
302 |
try:
|
303 |
all_positive_prompts = parse_prompt(positive_prompt)
|
@@ -312,7 +333,7 @@ def generate(
|
|
312 |
raise Error("ParsingException: Invalid prompt")
|
313 |
|
314 |
with token_merging(pipe, tome_ratio=tome_ratio):
|
315 |
-
|
316 |
num_inference_steps=inference_steps,
|
317 |
negative_prompt_embeds=neg_embeds,
|
318 |
guidance_scale=guidance_scale,
|
@@ -320,8 +341,14 @@ def generate(
|
|
320 |
generator=generator,
|
321 |
height=height,
|
322 |
width=width,
|
323 |
-
)
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
if increment_seed:
|
327 |
current_seed += 1
|
@@ -329,9 +356,9 @@ def generate(
|
|
329 |
if ZERO_GPU:
|
330 |
# spaces always start fresh
|
331 |
loader.pipe = None
|
|
|
332 |
|
333 |
-
|
334 |
-
diff = end - start
|
335 |
if log:
|
336 |
log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
|
337 |
return images
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
import re
|
4 |
import time
|
5 |
from contextlib import contextmanager
|
6 |
from datetime import datetime
|
7 |
from itertools import product
|
|
|
|
|
8 |
from typing import Callable
|
9 |
|
10 |
import spaces
|
11 |
import tomesd
|
12 |
import torch
|
13 |
+
from aura_sr import AuraSR
|
14 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
15 |
from compel.prompt_parser import PromptParser
|
16 |
from DeepCache import DeepCacheSDHelper
|
|
|
27 |
from diffusers.models import AutoencoderKL, AutoencoderTiny
|
28 |
from torch._dynamo import OptimizedModule
|
29 |
|
30 |
+
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
31 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
32 |
__import__("transformers").logging.set_verbosity_error()
|
33 |
|
34 |
ZERO_GPU = (
|
35 |
+
os.environ.get("SPACES_ZERO_GPU", "").lower() == "true"
|
36 |
+
or os.environ.get("SPACES_ZERO_GPU", "") == "1"
|
37 |
)
|
38 |
|
39 |
EMBEDDINGS = {
|
|
|
58 |
cls._instance = super(Loader, cls).__new__(cls)
|
59 |
cls._instance.cpu = torch.device("cpu")
|
60 |
cls._instance.gpu = torch.device("cuda")
|
61 |
+
cls._instance.gan = None
|
62 |
cls._instance.pipe = None
|
63 |
return cls._instance
|
64 |
|
|
|
106 |
)
|
107 |
return self.pipe.vae
|
108 |
|
109 |
+
def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype=None):
|
110 |
model_lower = model.lower()
|
111 |
|
112 |
schedulers = {
|
|
|
128 |
"steps_offset": 1,
|
129 |
}
|
130 |
|
131 |
+
if scheduler in ["Euler a", "PNDM"]:
|
132 |
del scheduler_kwargs["use_karras_sigmas"]
|
133 |
|
134 |
pipe_kwargs = {
|
|
|
160 |
|
161 |
self._load_vae(model_lower, taesd, dtype)
|
162 |
self._load_deepcache(interval=deepcache_interval)
|
163 |
+
return self.pipe, self.gan
|
164 |
else:
|
165 |
print(f"Unloading {model_name.lower()}...")
|
166 |
self.pipe = None
|
|
|
182 |
)
|
183 |
self._load_vae(model_lower, taesd, dtype)
|
184 |
self._load_deepcache(interval=deepcache_interval)
|
185 |
+
|
186 |
+
if upscale and self.gan is None:
|
187 |
+
print("Loading fal/AuraSR-v2...")
|
188 |
+
self.gan = AuraSR.from_pretrained("fal/AuraSR-v2")
|
189 |
+
|
190 |
+
if not upscale and self.gan is not None:
|
191 |
+
print("Unloading fal/AuraSR-v2...")
|
192 |
+
self.gan = None
|
193 |
+
torch.cuda.empty_cache
|
194 |
+
|
195 |
+
return self.pipe, self.gan
|
196 |
|
197 |
|
198 |
# applies tome to the pipeline
|
|
|
238 |
return prompt
|
239 |
|
240 |
|
241 |
+
@spaces.GPU(duration=40)
|
|
|
242 |
def generate(
|
243 |
positive_prompt,
|
244 |
negative_prompt="",
|
|
|
258 |
increment_seed=True,
|
259 |
deepcache_interval=1,
|
260 |
tome_ratio=0,
|
261 |
+
upscale=False,
|
262 |
log: Callable[[str], None] = None,
|
263 |
Error=Exception,
|
264 |
):
|
|
|
269 |
if seed is None or seed < 0:
|
270 |
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
|
271 |
|
272 |
+
GPU = torch.device("cuda")
|
273 |
+
|
274 |
TORCH_DTYPE = (
|
275 |
torch.bfloat16
|
276 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(including_emulation=False)
|
277 |
else torch.float16
|
278 |
)
|
279 |
|
|
|
286 |
with torch.inference_mode():
|
287 |
start = time.perf_counter()
|
288 |
loader = Loader()
|
289 |
+
pipe, gan = loader.load(
|
290 |
+
model,
|
291 |
+
scheduler,
|
292 |
+
karras,
|
293 |
+
taesd,
|
294 |
+
deepcache_interval,
|
295 |
+
upscale,
|
296 |
+
TORCH_DTYPE,
|
297 |
+
)
|
298 |
|
299 |
# prompt embeds
|
300 |
compel = Compel(
|
|
|
304 |
truncate_long_prompts=truncate_prompts,
|
305 |
text_encoder=pipe.text_encoder,
|
306 |
tokenizer=pipe.tokenizer,
|
307 |
+
device=GPU,
|
308 |
)
|
309 |
|
310 |
images = []
|
|
|
318 |
|
319 |
for i in range(num_images):
|
320 |
# seeded generator for each iteration
|
321 |
+
generator = torch.Generator(device=GPU).manual_seed(current_seed)
|
322 |
|
323 |
try:
|
324 |
all_positive_prompts = parse_prompt(positive_prompt)
|
|
|
333 |
raise Error("ParsingException: Invalid prompt")
|
334 |
|
335 |
with token_merging(pipe, tome_ratio=tome_ratio):
|
336 |
+
image = pipe(
|
337 |
num_inference_steps=inference_steps,
|
338 |
negative_prompt_embeds=neg_embeds,
|
339 |
guidance_scale=guidance_scale,
|
|
|
341 |
generator=generator,
|
342 |
height=height,
|
343 |
width=width,
|
344 |
+
).images[0]
|
345 |
+
|
346 |
+
if upscale:
|
347 |
+
print("Upscaling image...")
|
348 |
+
batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
|
349 |
+
image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
|
350 |
+
|
351 |
+
images.append((image, str(current_seed)))
|
352 |
|
353 |
if increment_seed:
|
354 |
current_seed += 1
|
|
|
356 |
if ZERO_GPU:
|
357 |
# spaces always start fresh
|
358 |
loader.pipe = None
|
359 |
+
loader.gan = None
|
360 |
|
361 |
+
diff = time.perf_counter() - start
|
|
|
362 |
if log:
|
363 |
log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
|
364 |
return images
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
accelerate
|
|
|
2 |
compel
|
3 |
deepcache==0.1.1
|
4 |
diffusers
|
|
|
1 |
accelerate
|
2 |
+
aura-sr==0.0.4
|
3 |
compel
|
4 |
deepcache==0.1.1
|
5 |
diffusers
|