File size: 12,344 Bytes
5c4e8c1
ab11d6f
 
 
 
 
 
 
 
 
5c4e8c1
4d6f2bc
ab11d6f
4d6f2bc
d179c4c
 
 
 
80a3408
d179c4c
aafe7f2
effc0a0
 
 
 
 
 
 
 
 
 
ad19934
b7c0c19
97256ff
ad19934
b7c0c19
 
 
 
 
effc0a0
5d1fc69
 
 
 
 
 
 
 
effc0a0
 
 
 
 
 
 
 
 
 
 
60849d7
80a3408
 
 
 
 
 
 
39a6792
4d6f2bc
48c31e7
710fb68
 
4d6f2bc
 
48c31e7
4d6f2bc
 
 
 
 
48c31e7
4d6f2bc
ab11d6f
 
4d6f2bc
b7fd57e
4d6f2bc
 
48c31e7
 
4d6f2bc
 
48c31e7
b7fd57e
b70fffe
effc0a0
b70fffe
98afd85
 
 
 
 
 
 
 
 
 
ab11d6f
98afd85
 
 
 
 
 
 
b70fffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97256ff
effc0a0
 
 
 
 
 
ca2f5d2
effc0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab11d6f
effc0a0
 
 
 
 
 
 
 
 
 
 
 
 
98afd85
ab11d6f
98afd85
 
7a7cda5
 
 
 
 
ab11d6f
7a7cda5
 
98afd85
 
 
 
ab11d6f
98afd85
 
 
 
 
 
 
 
7a7cda5
98afd85
 
 
 
 
ab11d6f
7a7cda5
 
 
 
 
 
98afd85
ab11d6f
98afd85
7a7cda5
98afd85
 
48c31e7
effc0a0
 
b7c0c19
effc0a0
ab11d6f
 
 
ad19934
effc0a0
5d1fc69
97256ff
effc0a0
5d1fc69
4d6f2bc
effc0a0
 
 
 
 
 
61ad3d2
60849d7
effc0a0
5d1fc69
 
 
effc0a0
5d1fc69
effc0a0
5c4e8c1
 
effc0a0
dffd0bb
ab11d6f
 
5d1fc69
4d6f2bc
 
ab11d6f
4d6f2bc
ab11d6f
 
 
4d6f2bc
 
 
ab11d6f
48c31e7
 
4d6f2bc
 
60849d7
9edebae
 
4d6f2bc
48c31e7
ab11d6f
4d6f2bc
 
 
5c4e8c1
 
 
 
 
 
ab11d6f
 
 
 
 
 
 
 
 
 
507ffa3
5c4e8c1
b70fffe
5c4e8c1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import argparse
import os
from importlib.util import find_spec

# Improved GPU handling and progress bars
os.environ["ZEROGPU_V2"] = "1"

# Use Rust-based downloader
if find_spec("hf_transfer"):
    os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import gradio as gr
from huggingface_hub._snapshot_download import snapshot_download

from lib import (
    Config,
    generate,
    read_file,
    read_json,
)

# Update refresh button hover text
seed_js = """
(seed) => {
    const button = document.getElementById("refresh");
    button.style.setProperty("--seed", `"${seed}"`);
    return seed;
}
"""

# The CSS `content` attribute expects a string so we need to wrap the number in quotes
refresh_seed_js = """
() => {
    const n = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER);
    const button = document.getElementById("refresh");
    button.style.setProperty("--seed", `"${n}"`);
    return n;
}
"""

# Update width and height on aspect ratio change
aspect_ratio_js = """
(ar, w, h) => {
    if (!ar) return [w, h];
    const [width, height] = ar.split(",");
    return [parseInt(width), parseInt(height)];
}
"""

# Show "Custom" aspect ratio when manually changing width or height, or one of the predefined ones
custom_aspect_ratio_js = """
(w, h) => {
    if (w === 384 && h === 672) return "384,672";
    if (w === 448 && h === 576) return "448,576";
    if (w === 512 && h === 512) return "512,512"; 
    if (w === 576 && h === 448) return "576,448";
    if (w === 672 && h === 384) return "672,384";
    return null;
}
"""

random_prompt_js = f"""
(prompt) => {{
    const prompts = {read_json("data/prompts.json")};
    const filtered = prompts.filter(p => p !== prompt);
    return filtered[Math.floor(Math.random() * filtered.length)];
}}
"""

with gr.Blocks(
    head=read_file("./partials/head.html"),
    css="./app.css",
    js="./app.js",
    theme=gr.themes.Default(
        # colors
        neutral_hue=gr.themes.colors.gray,
        primary_hue=gr.themes.colors.orange,
        secondary_hue=gr.themes.colors.blue,
        # sizing
        text_size=gr.themes.sizes.text_md,
        radius_size=gr.themes.sizes.radius_sm,
        spacing_size=gr.themes.sizes.spacing_md,
        # fonts
        font=[gr.themes.GoogleFont("Inter"), "sans-serif"],
        font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), "monospace"],
    ).set(
        layout_gap="8px",
        block_shadow="0 0 #0000",
        block_shadow_dark="0 0 #0000",
        block_background_fill=gr.themes.colors.gray.c50,
        block_background_fill_dark=gr.themes.colors.gray.c900,
    ),
) as demo:
    gr.HTML(read_file("./partials/intro.html"))

    with gr.Tabs():
        with gr.TabItem("🏠 Home"):
            with gr.Column():
                output_images = gr.Gallery(
                    elem_classes=["gallery"],
                    show_share_button=False,
                    object_fit="cover",
                    interactive=False,
                    show_label=False,
                    label="Output",
                    format="png",
                    columns=2,
                )
                positive_prompt = gr.Textbox(
                    placeholder="What do you want to see?",
                    autoscroll=False,
                    show_label=False,
                    label="Prompt",
                    max_lines=3,
                    lines=3,
                )
                with gr.Row():
                    generate_btn = gr.Button("Generate", variant="primary")
                    random_btn = gr.Button(
                        elem_classes=["icon-button", "popover"],
                        variant="secondary",
                        elem_id="random",
                        min_width=0,
                        value="🎲",
                    )
                    refresh_btn = gr.Button(
                        elem_classes=["icon-button", "popover"],
                        variant="secondary",
                        elem_id="refresh",
                        min_width=0,
                        value="🔄",
                    )
                    clear_btn = gr.ClearButton(
                        elem_classes=["icon-button", "popover"],
                        components=[output_images],
                        variant="secondary",
                        elem_id="clear",
                        min_width=0,
                        value="🗑️",
                    )

        with gr.TabItem("⚙️ Settings", elem_id="settings"):
            # Prompt settings
            gr.HTML("<h3>Prompt</h3>")
            with gr.Row():
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    value="nsfw, <fast_negative>",
                    lines=1,
                )

            # Model settings
            gr.HTML("<h3>Model</h3>")
            with gr.Row():
                model = gr.Dropdown(
                    choices=Config.MODELS,
                    value=Config.MODEL,
                    filterable=False,
                    label="Checkpoint",
                    min_width=240,
                )
                scheduler = gr.Dropdown(
                    choices=Config.SCHEDULERS.keys(),
                    value=Config.SCHEDULER,
                    elem_id="scheduler",
                    label="Scheduler",
                    filterable=False,
                )

            # Generation settings
            gr.HTML("<h3>Generation</h3>")
            with gr.Row():
                guidance_scale = gr.Slider(
                    value=Config.GUIDANCE_SCALE,
                    label="Guidance Scale",
                    minimum=1.0,
                    maximum=15.0,
                    step=0.1,
                )
                inference_steps = gr.Slider(
                    value=Config.INFERENCE_STEPS,
                    label="Inference Steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                )
                deepcache_interval = gr.Slider(
                    value=Config.DEEPCACHE_INTERVAL,
                    label="DeepCache",
                    minimum=1,
                    maximum=4,
                    step=1,
                )
            with gr.Row():
                width = gr.Slider(
                    value=Config.WIDTH,
                    label="Width",
                    minimum=256,
                    maximum=768,
                    step=32,
                )
                height = gr.Slider(
                    value=Config.HEIGHT,
                    label="Height",
                    minimum=256,
                    maximum=768,
                    step=32,
                )
                aspect_ratio = gr.Dropdown(
                    value=f"{Config.WIDTH},{Config.HEIGHT}",
                    label="Aspect Ratio",
                    filterable=False,
                    choices=[
                        ("Custom", None),
                        ("4:7 (384x672)", "384,672"),
                        ("7:9 (448x576)", "448,576"),
                        ("1:1 (512x512)", "512,512"),
                        ("9:7 (576x448)", "576,448"),
                        ("7:4 (672x384)", "672,384"),
                    ],
                )
            with gr.Row():
                num_images = gr.Dropdown(
                    choices=list(range(1, 5)),
                    value=Config.NUM_IMAGES,
                    filterable=False,
                    label="Images",
                )
                scale = gr.Dropdown(
                    choices=[(f"{s}x", s) for s in Config.SCALES],
                    filterable=False,
                    value=Config.SCALE,
                    label="Scale",
                )
                seed = gr.Number(
                    value=-1,
                    label="Seed",
                    minimum=-1,
                    maximum=(2**64) - 1,
                )
            with gr.Row():
                use_karras = gr.Checkbox(
                    elem_classes=["checkbox"],
                    label="Karras σ",
                    value=True,
                )

            # Image-to-Image settings
            gr.HTML("<h3>Image-to-Image</h3>")
            with gr.Row():
                image_input = gr.Image(
                    show_share_button=False,
                    label="Initial Image",
                    min_width=640,
                    format="png",
                    type="pil",
                )
            with gr.Row():
                controlnet_input = gr.Image(
                    show_share_button=False,
                    label="Control Image",
                    min_width=320,
                    format="png",
                    type="pil",
                )
                ip_adapter_input = gr.Image(
                    show_share_button=False,
                    label="IP-Adapter Image",
                    min_width=320,
                    format="png",
                    type="pil",
                )
            with gr.Row():
                denoising_strength = gr.Slider(
                    label="Initial Image Strength",
                    value=Config.DENOISING_STRENGTH,
                    minimum=0.0,
                    maximum=1.0,
                    step=0.1,
                )
                controlnet_annotator = gr.Dropdown(
                    label="ControlNet Annotator",
                    # TODO: annotators should be in config with names
                    choices=[("Canny", "canny")],
                    value=Config.ANNOTATOR,
                    filterable=False,
                )
            with gr.Row():
                use_ip_adapter_face = gr.Checkbox(
                    label="Use IP-Adapter Face",
                    elem_classes=["checkbox"],
                    value=False,
                )

        with gr.TabItem("ℹ️ Info"):
            gr.Markdown(read_file("DOCS.md"))

    # Random prompt on click
    random_btn.click(
        None, inputs=[positive_prompt], outputs=[positive_prompt], js=random_prompt_js
    )

    # Update seed on click
    refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)

    # Update seed button hover text
    seed.change(None, inputs=[seed], outputs=[], js=seed_js)

    # Update width and height on aspect ratio change
    aspect_ratio.input(
        None,
        inputs=[aspect_ratio, width, height],
        outputs=[width, height],
        js=aspect_ratio_js,
    )

    # Show "Custom" aspect ratio when manually changing width or height
    gr.on(
        triggers=[width.input, height.input],
        fn=None,
        inputs=[width, height],
        outputs=[aspect_ratio],
        js=custom_aspect_ratio_js,
    )

    # Generate images
    gr.on(
        triggers=[generate_btn.click, positive_prompt.submit],
        fn=generate,
        api_name="generate",
        outputs=[output_images],
        inputs=[
            positive_prompt,
            negative_prompt,
            image_input,
            controlnet_input,
            ip_adapter_input,
            seed,
            model,
            scheduler,
            controlnet_annotator,
            width,
            height,
            guidance_scale,
            inference_steps,
            denoising_strength,
            deepcache_interval,
            scale,
            num_images,
            use_karras,
            use_ip_adapter_face,
        ],
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    parser.add_argument("-s", "--server", type=str, metavar="STR", default="0.0.0.0")
    parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
    args = parser.parse_args()

    token = os.environ.get("HF_TOKEN", None)
    for repo_id, allow_patterns in Config.HF_REPOS.items():
        snapshot_download(
            repo_id,
            repo_type="model",
            revision="main",
            token=token,
            allow_patterns=allow_patterns,
            ignore_patterns=None,
        )

    # https://www.gradio.app/docs/gradio/interface#interface-queue
    demo.queue(default_concurrency_limit=1).launch(
        server_name=args.server,
        server_port=args.port,
    )