pcuenq HF staff commited on
Commit
f3566c2
·
1 Parent(s): 05ad577

Initial demo by Shuangfei Zhai

Browse files

Reference: https://github.com/apple/ml-mdm/blob/ecbbc341bc863b014682d3501bbece5c3a8b5e8b/ml_mdm/clis/generate_sample.py

Files changed (1) hide show
  1. app.py +548 -0
app.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For licensing see accompanying LICENSE file.
2
+ # Copyright (C) 2024 Apple Inc. All rights reserved.
3
+ import logging
4
+ import os
5
+ import shlex
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ import gradio as gr
11
+ import simple_parsing
12
+ import yaml
13
+ from einops import rearrange, repeat
14
+
15
+ import numpy as np
16
+ import torch
17
+ from torchvision.utils import make_grid
18
+
19
+ from ml_mdm import helpers, reader
20
+ from ml_mdm.config import get_arguments, get_model, get_pipeline
21
+ from ml_mdm.language_models import factory
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Note that it is called add_arguments, not add_argument.
26
+ logging.basicConfig(
27
+ level=getattr(logging, "INFO", None),
28
+ format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
29
+ datefmt="%H:%M:%S",
30
+ )
31
+
32
+
33
+ def dividable(n):
34
+ for i in range(int(np.sqrt(n)), 0, -1):
35
+ if n % i == 0:
36
+ break
37
+ return i, n // i
38
+
39
+
40
+ def generate_lm_outputs(device, sample, tokenizer, language_model, args):
41
+ with torch.no_grad():
42
+ lm_outputs, lm_mask = language_model(sample, tokenizer)
43
+ sample["lm_outputs"] = lm_outputs
44
+ sample["lm_mask"] = lm_mask
45
+ return sample
46
+
47
+
48
+ def setup_models(args, device):
49
+ input_channels = 3
50
+
51
+ # load the language model
52
+ tokenizer, language_model = factory.create_lm(args, device=device)
53
+ language_model_dim = language_model.embed_dim
54
+ args.unet_config.conditioning_feature_dim = language_model_dim
55
+ denoising_model = get_model(args.model)(
56
+ input_channels, input_channels, args.unet_config
57
+ ).to(device)
58
+ diffusion_model = get_pipeline(args.model)(
59
+ denoising_model, args.diffusion_config
60
+ ).to(device)
61
+ # denoising_model.print_size(args.sample_image_size)
62
+ return tokenizer, language_model, diffusion_model
63
+
64
+
65
+ def plot_logsnr(logsnrs, total_steps):
66
+ import matplotlib.pyplot as plt
67
+
68
+ x = 1 - np.arange(len(logsnrs)) / (total_steps - 1)
69
+ plt.plot(x, np.asarray(logsnrs))
70
+ plt.xlabel("timesteps")
71
+ plt.ylabel("LogSNR")
72
+ plt.grid(True)
73
+ plt.xlim(0, 1)
74
+ plt.ylim(-20, 10)
75
+ plt.gca().invert_xaxis()
76
+
77
+ # Convert the plot to a numpy array
78
+ fig = plt.gcf()
79
+ fig.canvas.draw()
80
+ image = np.array(fig.canvas.renderer._renderer)
81
+ plt.close()
82
+ return image
83
+
84
+
85
+ @dataclass
86
+ class GLOBAL_DATA:
87
+ reader_config: Optional[reader.ReaderConfig] = None
88
+ tokenizer = None
89
+ args = None
90
+ language_model = None
91
+ diffusion_model = None
92
+ override_args = ""
93
+ ckpt_name = ""
94
+ config_file = ""
95
+
96
+
97
+ global_config = GLOBAL_DATA()
98
+
99
+
100
+ def stop_run():
101
+ return (
102
+ gr.update(value="Run", variant="primary", visible=True),
103
+ gr.update(visible=False),
104
+ )
105
+
106
+
107
+ def get_model_type(config_file):
108
+ with open(config_file, "r") as f:
109
+ d = yaml.safe_load(f)
110
+ return d.get("model", d.get("vision_model", "unet"))
111
+
112
+
113
+ def generate(
114
+ config_file="cc12m_64x64.yaml",
115
+ ckpt_name="vis_model_64x64.pth",
116
+ prompt="a chair",
117
+ input_template="",
118
+ negative_prompt="",
119
+ negative_template="",
120
+ batch_size=20,
121
+ guidance_scale=7.5,
122
+ threshold_function="clip",
123
+ num_inference_steps=250,
124
+ eta=0,
125
+ save_diffusion_path=False,
126
+ show_diffusion_path=False,
127
+ show_xt=False,
128
+ reader_config="",
129
+ seed=10,
130
+ comment="",
131
+ override_args="",
132
+ output_inner=False,
133
+ ):
134
+ np.random.seed(seed)
135
+ torch.random.manual_seed(seed)
136
+
137
+ if len(input_template) > 0:
138
+ prompt = input_template.format(prompt=prompt)
139
+ if len(negative_template) > 0:
140
+ negative_prompt = negative_prompt + negative_template
141
+ print(f"Postive: {prompt} / Negative: {negative_prompt}")
142
+
143
+ if not os.path.exists(ckpt_name):
144
+ logging.info(f"Did not generate because {ckpt_name} does not exist")
145
+ return None, None, f"{ckpt_name} does not exist", None, None
146
+
147
+ if (
148
+ global_config.config_file != config_file
149
+ or global_config.ckpt_name != ckpt_name
150
+ or global_config.override_args != override_args
151
+ ):
152
+ # Identify model type
153
+ model_type = get_model_type(f"configs/models/{config_file}")
154
+ # reload the arguments
155
+ args = get_arguments(
156
+ shlex.split(override_args + f" --model {model_type}"),
157
+ mode="demo",
158
+ additional_config_paths=[f"configs/models/{config_file}"],
159
+ )
160
+ helpers.print_args(args)
161
+
162
+ # setup model when the parent task changed.
163
+ tokenizer, language_model, diffusion_model = setup_models(args, device)
164
+ vision_model_file = ckpt_name
165
+ try:
166
+ other_items = diffusion_model.model.load(vision_model_file)
167
+ except Exception as e:
168
+ logging.error(f"failed to load {vision_model_file}", exc_info=e)
169
+ return None, None, "Loading Model Error", None, None
170
+
171
+ # setup global configs
172
+ global_config.batch_num = -1 # reset batch num
173
+ global_config.args = args
174
+ global_config.override_args = override_args
175
+ global_config.tokenizer = tokenizer
176
+ global_config.language_model = language_model
177
+ global_config.diffusion_model = diffusion_model
178
+ global_config.reader_config = args.reader_config
179
+ global_config.config_file = config_file
180
+ global_config.ckpt_name = ckpt_name
181
+
182
+ else:
183
+ args = global_config.args
184
+ tokenizer = global_config.tokenizer
185
+ language_model = global_config.language_model
186
+ diffusion_model = global_config.diffusion_model
187
+
188
+ tokenizer = global_config.tokenizer
189
+
190
+ sample = {}
191
+ sample["text"] = [negative_prompt, prompt] if guidance_scale != 1 else [prompt]
192
+ sample["tokens"] = np.asarray(
193
+ reader.process_text(sample["text"], tokenizer, args.reader_config)
194
+ )
195
+ sample = generate_lm_outputs(device, sample, tokenizer, language_model, args)
196
+ assert args.sample_image_size != -1
197
+
198
+ # set up thresholding
199
+ from samplers import ThresholdType
200
+
201
+ diffusion_model.sampler._config.threshold_function = {
202
+ "clip": ThresholdType.CLIP,
203
+ "dynamic (Imagen)": ThresholdType.DYNAMIC,
204
+ "dynamic (DeepFloyd)": ThresholdType.DYNAMIC_IF,
205
+ "none": ThresholdType.NONE,
206
+ }[threshold_function]
207
+
208
+ output_comments = f"{comment}\n"
209
+
210
+ bsz = batch_size
211
+ with torch.no_grad():
212
+ if bsz > 1:
213
+ sample["lm_outputs"] = repeat(
214
+ sample["lm_outputs"], "b n d -> (b r) n d", r=bsz
215
+ )
216
+ sample["lm_mask"] = repeat(sample["lm_mask"], "b n -> (b r) n", r=bsz)
217
+
218
+ num_samples = bsz
219
+ original, outputs, logsnrs = [], [], []
220
+ logging.info(f"Starting to sample from the model")
221
+ start_time = time.time()
222
+ for step, result in enumerate(
223
+ diffusion_model.sample(
224
+ num_samples,
225
+ sample,
226
+ args.sample_image_size,
227
+ device,
228
+ return_sequence=False,
229
+ num_inference_steps=num_inference_steps,
230
+ ddim_eta=eta,
231
+ guidance_scale=guidance_scale,
232
+ resample_steps=True,
233
+ disable_bar=False,
234
+ yield_output=True,
235
+ yield_full=True,
236
+ output_inner=output_inner,
237
+ )
238
+ ):
239
+ x0, x_t, extra = result
240
+ if step < num_inference_steps:
241
+ g = extra[0][0, 0, 0, 0].cpu()
242
+ logsnrs += [torch.log(g / (1 - g))]
243
+ output = x0 if not show_xt else x_t
244
+ output = torch.clamp(output * 0.5 + 0.5, min=0, max=1).cpu()
245
+ original += [
246
+ output if not output_inner else output[..., -args.sample_image_size :]
247
+ ]
248
+
249
+ output = (
250
+ make_grid(output, nrow=dividable(bsz)[0]).permute(1, 2, 0).numpy() * 255
251
+ ).astype(np.uint8)
252
+ outputs += [output]
253
+
254
+ output_video_path = None
255
+ if step == num_inference_steps and save_diffusion_path:
256
+ import imageio
257
+
258
+ writer = imageio.get_writer("temp_output.mp4", fps=32)
259
+ for output in outputs:
260
+ writer.append_data(output)
261
+ writer.close()
262
+ output_video_path = "temp_output.mp4"
263
+ if any(diffusion_model.model.vision_model.is_temporal):
264
+ data = rearrange(
265
+ original[-1],
266
+ "(a b) c (n h) (m w) -> (n m) (a h) (b w) c",
267
+ a=dividable(bsz)[0],
268
+ n=4,
269
+ m=4,
270
+ )
271
+ data = (data.numpy() * 255).astype(np.uint8)
272
+ writer = imageio.get_writer("temp_output.mp4", fps=4)
273
+ for d in data:
274
+ writer.append_data(d)
275
+ writer.close()
276
+
277
+ if show_diffusion_path or (step == num_inference_steps):
278
+ yield output, plot_logsnr(
279
+ logsnrs, num_inference_steps
280
+ ), output_comments + f"Step ({step} / {num_inference_steps}) Time ({time.time() - start_time:.4}s)", output_video_path, gr.update(
281
+ value="Run",
282
+ variant="primary",
283
+ visible=(step == num_inference_steps),
284
+ ), gr.update(
285
+ value="Stop", variant="stop", visible=(step != num_inference_steps)
286
+ )
287
+
288
+
289
+ def main(args):
290
+ # get the language model outputs
291
+ example_texts = open("data/prompts_demo.tsv").readlines()
292
+
293
+ css = """
294
+ #config-accordion, #logs-accordion {color: black !important;}
295
+ .dark #config-accordion, .dark #logs-accordion {color: white !important;}
296
+ .stop {background: darkred !important;}
297
+ """
298
+
299
+ with gr.Blocks(
300
+ title="Demo of Text-to-Image Diffusion",
301
+ theme="EveryPizza/Cartoony-Gradio-Theme",
302
+ css=css,
303
+ ) as demo:
304
+ with gr.Row(equal_height=True):
305
+ header = """
306
+ # MLR Text-to-Image Diffusion Model Web Demo
307
+
308
+ ### Usage
309
+ - Select examples below or manually input model and prompts
310
+ - Change more advanced settings such as inference steps.
311
+ """
312
+ gr.Markdown(header)
313
+
314
+ with gr.Row(equal_height=False):
315
+ pid = gr.State()
316
+ with gr.Column(scale=2):
317
+ with gr.Row(equal_height=False):
318
+ with gr.Column(scale=1):
319
+ config_file = gr.Dropdown(
320
+ [
321
+ "cc12m_64x64.yaml",
322
+ "cc12m_256x256.yaml",
323
+ "cc12m_1024x1024.yaml",
324
+ ],
325
+ value="cc12m_64x64.yaml",
326
+ label="Select the config file",
327
+ )
328
+ with gr.Column(scale=1):
329
+ ckpt_name = gr.Dropdown(
330
+ [
331
+ "vis_model_64x64.pth",
332
+ "vis_model_256x256.pth",
333
+ "vis_model_1024x1024.pth",
334
+ ],
335
+ value="vis_model_64x64.pth",
336
+ label="Load checkpoint",
337
+ )
338
+ with gr.Row(equal_height=False):
339
+ with gr.Column(scale=1):
340
+ save_diffusion_path = gr.Checkbox(
341
+ value=True, label="Show diffusion path as a video"
342
+ )
343
+ show_diffusion_path = gr.Checkbox(
344
+ value=False, label="Show diffusion progress"
345
+ )
346
+ with gr.Column(scale=1):
347
+ show_xt = gr.Checkbox(value=False, label="Show predicted x_t")
348
+ output_inner = gr.Checkbox(
349
+ value=False,
350
+ label="Output inner UNet (High-res models Only)",
351
+ )
352
+
353
+ with gr.Column(scale=2):
354
+ prompt_input = gr.Textbox(label="Input prompt")
355
+ with gr.Row(equal_height=False):
356
+ with gr.Column(scale=1):
357
+ guidance_scale = gr.Slider(
358
+ value=7.5,
359
+ minimum=0.0,
360
+ maximum=50,
361
+ step=0.1,
362
+ label="Guidance scale",
363
+ )
364
+ with gr.Column(scale=1):
365
+ batch_size = gr.Slider(
366
+ value=16, minimum=1, maximum=128, step=1, label="Batch size"
367
+ )
368
+
369
+ with gr.Row(equal_height=False):
370
+ comment = gr.Textbox(value="", label="Comments to the model (optional)")
371
+
372
+ with gr.Row(equal_height=False):
373
+ with gr.Column(scale=2):
374
+ output_image = gr.Image(value=None, label="Output image")
375
+ with gr.Column(scale=2):
376
+ output_video = gr.Video(value=None, label="Diffusion Path")
377
+
378
+ with gr.Row(equal_height=False):
379
+ with gr.Column(scale=2):
380
+ with gr.Accordion(
381
+ "Advanced settings", open=False, elem_id="config-accordion"
382
+ ):
383
+ input_template = gr.Dropdown(
384
+ [
385
+ "",
386
+ "breathtaking {prompt}. award-winning, professional, highly detailed",
387
+ "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
388
+ "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
389
+ "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
390
+ "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
391
+ "cinematic film still {prompt}. shallow depth of field, vignette, highly detailed, high budget Hollywood movie, bokeh, cinemascope, moody",
392
+ "analog film photo {prompt}. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage",
393
+ "vaporwave synthwave style {prompt}. cyberpunk, neon, vibes, stunningly beautiful, crisp, detailed, sleek, ultramodern, high contrast, cinematic composition",
394
+ "isometric style {prompt}. vibrant, beautiful, crisp, detailed, ultra detailed, intricate",
395
+ "low-poly style {prompt}. ambient occlusion, low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
396
+ "claymation style {prompt}. sculpture, clay art, centered composition, play-doh",
397
+ "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
398
+ "origami style {prompt}. paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
399
+ "pixel-art {prompt}. low-res, blocky, pixel art style, 16-bit graphics",
400
+ ],
401
+ value="",
402
+ label="Positive Template (by default, not use)",
403
+ )
404
+ with gr.Row(equal_height=False):
405
+ with gr.Column(scale=1):
406
+ negative_prompt_input = gr.Textbox(
407
+ value="", label="Negative prompt"
408
+ )
409
+ with gr.Column(scale=1):
410
+ negative_template = gr.Dropdown(
411
+ [
412
+ "",
413
+ "anime, cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry",
414
+ "photo, deformed, black and white, realism, disfigured, low contrast",
415
+ "photo, photorealistic, realism, ugly",
416
+ "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
417
+ "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
418
+ "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
419
+ "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
420
+ "illustration, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
421
+ "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy, realistic, photographic",
422
+ "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo",
423
+ "ugly, deformed, noisy, low poly, blurry, painting",
424
+ ],
425
+ value="",
426
+ label="Negative Template (by default, not use)",
427
+ )
428
+
429
+ with gr.Row(equal_height=False):
430
+ with gr.Column(scale=1):
431
+ threshold_function = gr.Dropdown(
432
+ [
433
+ "clip",
434
+ "dynamic (Imagen)",
435
+ "dynamic (DeepFloyd)",
436
+ "none",
437
+ ],
438
+ value="dynamic (DeepFloyd)",
439
+ label="Thresholding",
440
+ )
441
+ with gr.Column(scale=1):
442
+ reader_config = gr.Dropdown(
443
+ ["configs/datasets/reader_config.yaml"],
444
+ value="configs/datasets/reader_config.yaml",
445
+ label="Reader Config",
446
+ )
447
+ with gr.Row(equal_height=False):
448
+ with gr.Column(scale=1):
449
+ num_inference_steps = gr.Slider(
450
+ value=50,
451
+ minimum=1,
452
+ maximum=2000,
453
+ step=1,
454
+ label="# of steps",
455
+ )
456
+ with gr.Column(scale=1):
457
+ eta = gr.Slider(
458
+ value=0,
459
+ minimum=0,
460
+ maximum=1,
461
+ step=0.05,
462
+ label="DDIM eta",
463
+ )
464
+ seed = gr.Slider(
465
+ value=137,
466
+ minimum=0,
467
+ maximum=2147483647,
468
+ step=1,
469
+ label="Random seed",
470
+ )
471
+ override_args = gr.Textbox(
472
+ value="--reader_config.max_token_length 128 --reader_config.max_caption_length 512",
473
+ label="Override model arguments (optional)",
474
+ )
475
+
476
+ run_btn = gr.Button(value="Run", variant="primary")
477
+ stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
478
+
479
+ with gr.Column(scale=2):
480
+ with gr.Accordion(
481
+ "Addditional outputs", open=False, elem_id="output-accordion"
482
+ ):
483
+ with gr.Row(equal_height=True):
484
+ output_text = gr.Textbox(value=None, label="System output")
485
+ with gr.Row(equal_height=True):
486
+ logsnr_fig = gr.Image(value=None, label="Noise schedule")
487
+
488
+ run_event = run_btn.click(
489
+ fn=generate,
490
+ inputs=[
491
+ config_file,
492
+ ckpt_name,
493
+ prompt_input,
494
+ input_template,
495
+ negative_prompt_input,
496
+ negative_template,
497
+ batch_size,
498
+ guidance_scale,
499
+ threshold_function,
500
+ num_inference_steps,
501
+ eta,
502
+ save_diffusion_path,
503
+ show_diffusion_path,
504
+ show_xt,
505
+ reader_config,
506
+ seed,
507
+ comment,
508
+ override_args,
509
+ output_inner,
510
+ ],
511
+ outputs=[
512
+ output_image,
513
+ logsnr_fig,
514
+ output_text,
515
+ output_video,
516
+ run_btn,
517
+ stop_btn,
518
+ ],
519
+ )
520
+
521
+ stop_btn.click(
522
+ fn=stop_run,
523
+ outputs=[run_btn, stop_btn],
524
+ cancels=[run_event],
525
+ queue=False,
526
+ )
527
+ example0 = gr.Examples(
528
+ [
529
+ ["cc12m_64x64.yaml", "vis_model_64x64.pth", 64, 50, 0],
530
+ ["cc12m_256x256.yaml", "vis_model_256x256.pth", 16, 100, 0],
531
+ ["cc12m_1024x1024.yaml", "vis_model_1024x1024.pth", 4, 250, 1],
532
+ ],
533
+ inputs=[config_file, ckpt_name, batch_size, num_inference_steps, eta],
534
+ )
535
+ example1 = gr.Examples(
536
+ examples=[[t.strip()] for t in example_texts],
537
+ inputs=[prompt_input],
538
+ )
539
+
540
+ launch_args = {"server_port": int(args.port), "server_name": "0.0.0.0"}
541
+ demo.queue(default_concurrency_limit=1).launch(**launch_args)
542
+
543
+
544
+ if __name__ == "__main__":
545
+ parser = simple_parsing.ArgumentParser(description="pre-loading demo")
546
+ parser.add_argument("--port", type=int, default=19231)
547
+ args = parser.parse_known_args()[0]
548
+ main(args)