nvn04 commited on
Commit
699e2a3
·
verified ·
1 Parent(s): e65a343

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +579 -778
app.py CHANGED
@@ -1,778 +1,579 @@
1
- import argparse
2
- import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
- from datetime import datetime
6
-
7
- import gradio as gr
8
- import spaces
9
- import numpy as np
10
- import torch
11
- from diffusers.image_processor import VaeImageProcessor
12
- from huggingface_hub import snapshot_download
13
- from PIL import Image
14
- torch.jit.script = lambda f: f
15
- from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
- from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
- from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
-
20
-
21
- def parse_args():
22
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
- parser.add_argument(
24
- "--base_model_path",
25
- type=str,
26
- default="booksforcharlie/stable-diffusion-inpainting",
27
- help=(
28
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
- ),
30
- )
31
- parser.add_argument(
32
- "--p2p_base_model_path",
33
- type=str,
34
- default="timbrooks/instruct-pix2pix",
35
- help=(
36
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
- ),
38
- )
39
- parser.add_argument(
40
- "--resume_path",
41
- type=str,
42
- default="zhengchong/CatVTON",
43
- help=(
44
- "The Path to the checkpoint of trained tryon model."
45
- ),
46
- )
47
- parser.add_argument(
48
- "--output_dir",
49
- type=str,
50
- default="resource/demo/output",
51
- help="The output directory where the model predictions will be written.",
52
- )
53
-
54
- parser.add_argument(
55
- "--width",
56
- type=int,
57
- default=768,
58
- help=(
59
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
- " resolution"
61
- ),
62
- )
63
- parser.add_argument(
64
- "--height",
65
- type=int,
66
- default=1024,
67
- help=(
68
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
- " resolution"
70
- ),
71
- )
72
- parser.add_argument(
73
- "--repaint",
74
- action="store_true",
75
- help="Whether to repaint the result image with the original background."
76
- )
77
- parser.add_argument(
78
- "--allow_tf32",
79
- action="store_true",
80
- default=True,
81
- help=(
82
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
- ),
85
- )
86
- parser.add_argument(
87
- "--mixed_precision",
88
- type=str,
89
- default="bf16",
90
- choices=["no", "fp16", "bf16"],
91
- help=(
92
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
- ),
96
- )
97
-
98
- args = parser.parse_args()
99
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
- if env_local_rank != -1 and env_local_rank != args.local_rank:
101
- args.local_rank = env_local_rank
102
-
103
- return args
104
-
105
- def image_grid(imgs, rows, cols):
106
- assert len(imgs) == rows * cols
107
-
108
- w, h = imgs[0].size
109
- grid = Image.new("RGB", size=(cols * w, rows * h))
110
-
111
- for i, img in enumerate(imgs):
112
- grid.paste(img, box=(i % cols * w, i // cols * h))
113
- return grid
114
-
115
-
116
- args = parse_args()
117
-
118
- # Mask-based CatVTON
119
- catvton_repo = "zhengchong/CatVTON"
120
- repo_path = snapshot_download(repo_id=catvton_repo)
121
- # Pipeline
122
- pipeline = CatVTONPipeline(
123
- base_ckpt=args.base_model_path,
124
- attn_ckpt=repo_path,
125
- attn_ckpt_version="mix",
126
- weight_dtype=init_weight_dtype(args.mixed_precision),
127
- use_tf32=args.allow_tf32,
128
- device='cuda'
129
- )
130
- # AutoMasker
131
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
- automasker = AutoMasker(
133
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
- schp_ckpt=os.path.join(repo_path, "SCHP"),
135
- device='cuda',
136
- )
137
-
138
-
139
- # Flux-based CatVTON
140
- access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
- flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
- pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
- pipeline_flux.load_lora_weights(
144
- os.path.join(repo_path, "flux-lora"),
145
- weight_name='pytorch_lora_weights.safetensors'
146
- )
147
- pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
-
149
-
150
- # Mask-free CatVTON
151
- catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
152
- repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
153
- pipeline_p2p = CatVTONPix2PixPipeline(
154
- base_ckpt=args.p2p_base_model_path,
155
- attn_ckpt=repo_path_mf,
156
- attn_ckpt_version="mix-48k-1024",
157
- weight_dtype=init_weight_dtype(args.mixed_precision),
158
- use_tf32=args.allow_tf32,
159
- device='cuda'
160
- )
161
-
162
-
163
- @spaces.GPU(duration=120)
164
- def submit_function(
165
- person_image,
166
- cloth_image,
167
- cloth_type,
168
- num_inference_steps,
169
- guidance_scale,
170
- seed,
171
- show_type
172
- ):
173
- person_image, mask = person_image["background"], person_image["layers"][0]
174
- mask = Image.open(mask).convert("L")
175
- if len(np.unique(np.array(mask))) == 1:
176
- mask = None
177
- else:
178
- mask = np.array(mask)
179
- mask[mask > 0] = 255
180
- mask = Image.fromarray(mask)
181
-
182
- tmp_folder = args.output_dir
183
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
184
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
185
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
186
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
187
-
188
- generator = None
189
- if seed != -1:
190
- generator = torch.Generator(device='cuda').manual_seed(seed)
191
-
192
- person_image = Image.open(person_image).convert("RGB")
193
- cloth_image = Image.open(cloth_image).convert("RGB")
194
- person_image = resize_and_crop(person_image, (args.width, args.height))
195
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
196
-
197
- # Process mask
198
- if mask is not None:
199
- mask = resize_and_crop(mask, (args.width, args.height))
200
- else:
201
- mask = automasker(
202
- person_image,
203
- cloth_type
204
- )['mask']
205
- mask = mask_processor.blur(mask, blur_factor=9)
206
-
207
- # Inference
208
- # try:
209
- result_image = pipeline(
210
- image=person_image,
211
- condition_image=cloth_image,
212
- mask=mask,
213
- num_inference_steps=num_inference_steps,
214
- guidance_scale=guidance_scale,
215
- generator=generator
216
- )[0]
217
- # except Exception as e:
218
- # raise gr.Error(
219
- # "An error occurred. Please try again later: {}".format(e)
220
- # )
221
-
222
- # Post-process
223
- masked_person = vis_mask(person_image, mask)
224
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
225
- save_result_image.save(result_save_path)
226
- if show_type == "result only":
227
- return result_image
228
- else:
229
- width, height = person_image.size
230
- if show_type == "input & result":
231
- condition_width = width // 2
232
- conditions = image_grid([person_image, cloth_image], 2, 1)
233
- else:
234
- condition_width = width // 3
235
- conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
236
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
238
- new_result_image.paste(conditions, (0, 0))
239
- new_result_image.paste(result_image, (condition_width + 5, 0))
240
- return new_result_image
241
-
242
- @spaces.GPU(duration=120)
243
- def submit_function_p2p(
244
- person_image,
245
- cloth_image,
246
- num_inference_steps,
247
- guidance_scale,
248
- seed):
249
- person_image= person_image["background"]
250
-
251
- tmp_folder = args.output_dir
252
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
253
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
254
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
255
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
256
-
257
- generator = None
258
- if seed != -1:
259
- generator = torch.Generator(device='cuda').manual_seed(seed)
260
-
261
- person_image = Image.open(person_image).convert("RGB")
262
- cloth_image = Image.open(cloth_image).convert("RGB")
263
- person_image = resize_and_crop(person_image, (args.width, args.height))
264
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
-
266
- # Inference
267
- try:
268
- result_image = pipeline_p2p(
269
- image=person_image,
270
- condition_image=cloth_image,
271
- num_inference_steps=num_inference_steps,
272
- guidance_scale=guidance_scale,
273
- generator=generator
274
- )[0]
275
- except Exception as e:
276
- raise gr.Error(
277
- "An error occurred. Please try again later: {}".format(e)
278
- )
279
-
280
- # Post-process
281
- save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
282
- save_result_image.save(result_save_path)
283
- return result_image
284
-
285
- @spaces.GPU(duration=120)
286
- def submit_function_flux(
287
- person_image,
288
- cloth_image,
289
- cloth_type,
290
- num_inference_steps,
291
- guidance_scale,
292
- seed,
293
- show_type
294
- ):
295
-
296
- # Process image editor input
297
- person_image, mask = person_image["background"], person_image["layers"][0]
298
- mask = Image.open(mask).convert("L")
299
- if len(np.unique(np.array(mask))) == 1:
300
- mask = None
301
- else:
302
- mask = np.array(mask)
303
- mask[mask > 0] = 255
304
- mask = Image.fromarray(mask)
305
-
306
- # Set random seed
307
- generator = None
308
- if seed != -1:
309
- generator = torch.Generator(device='cuda').manual_seed(seed)
310
-
311
- # Process input images
312
- person_image = Image.open(person_image).convert("RGB")
313
- cloth_image = Image.open(cloth_image).convert("RGB")
314
-
315
- # Adjust image sizes
316
- person_image = resize_and_crop(person_image, (args.width, args.height))
317
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
318
-
319
- # Process mask
320
- if mask is not None:
321
- mask = resize_and_crop(mask, (args.width, args.height))
322
- else:
323
- mask = automasker(
324
- person_image,
325
- cloth_type
326
- )['mask']
327
- mask = mask_processor.blur(mask, blur_factor=9)
328
-
329
- # Inference
330
- result_image = pipeline_flux(
331
- image=person_image,
332
- condition_image=cloth_image,
333
- mask_image=mask,
334
- width=args.width,
335
- height=args.height,
336
- num_inference_steps=num_inference_steps,
337
- guidance_scale=guidance_scale,
338
- generator=generator
339
- ).images[0]
340
-
341
- # Post-processing
342
- masked_person = vis_mask(person_image, mask)
343
-
344
- # Return result based on show type
345
- if show_type == "result only":
346
- return result_image
347
- else:
348
- width, height = person_image.size
349
- if show_type == "input & result":
350
- condition_width = width // 2
351
- conditions = image_grid([person_image, cloth_image], 2, 1)
352
- else:
353
- condition_width = width // 3
354
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
355
-
356
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
357
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
358
- new_result_image.paste(conditions, (0, 0))
359
- new_result_image.paste(result_image, (condition_width + 5, 0))
360
- return new_result_image
361
-
362
-
363
- def person_example_fn(image_path):
364
- return image_path
365
-
366
-
367
- HEADER = """
368
- <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
369
- <div style="display: flex; justify-content: center; align-items: center;">
370
- <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
371
- <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
372
- </a>
373
- <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
374
- <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
375
- </a>
376
- <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
377
- <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
378
- </a>
379
- <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
380
- <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
381
- </a>
382
- <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
383
- <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
384
- </a>
385
- <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
386
- <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
387
- </a>
388
- <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
389
- <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
390
- </a>
391
- </div>
392
- <br>
393
- · This demo and our weights are only for Non-commercial Use. <br>
394
- · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
395
- · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
396
- """
397
-
398
- def app_gradio():
399
- with gr.Blocks(title="CatVTON") as demo:
400
- gr.Markdown(HEADER)
401
- with gr.Tab("Mask-based & SD1.5"):
402
- with gr.Row():
403
- with gr.Column(scale=1, min_width=350):
404
- with gr.Row():
405
- image_path = gr.Image(
406
- type="filepath",
407
- interactive=True,
408
- visible=False,
409
- )
410
- person_image = gr.ImageEditor(
411
- interactive=True, label="Person Image", type="filepath"
412
- )
413
-
414
- with gr.Row():
415
- with gr.Column(scale=1, min_width=230):
416
- cloth_image = gr.Image(
417
- interactive=True, label="Condition Image", type="filepath"
418
- )
419
- with gr.Column(scale=1, min_width=120):
420
- gr.Markdown(
421
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
422
- )
423
- cloth_type = gr.Radio(
424
- label="Try-On Cloth Type",
425
- choices=["upper", "lower", "overall"],
426
- value="upper",
427
- )
428
-
429
-
430
- submit = gr.Button("Submit")
431
- gr.Markdown(
432
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
433
- )
434
-
435
- gr.Markdown(
436
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
437
- )
438
- with gr.Accordion("Advanced Options", open=False):
439
- num_inference_steps = gr.Slider(
440
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
441
- )
442
- # Guidence Scale
443
- guidance_scale = gr.Slider(
444
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
445
- )
446
- # Random Seed
447
- seed = gr.Slider(
448
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
449
- )
450
- show_type = gr.Radio(
451
- label="Show Type",
452
- choices=["result only", "input & result", "input & mask & result"],
453
- value="input & mask & result",
454
- )
455
-
456
- with gr.Column(scale=2, min_width=500):
457
- result_image = gr.Image(interactive=False, label="Result")
458
- with gr.Row():
459
- # Photo Examples
460
- root_path = "resource/demo/example"
461
- with gr.Column():
462
- men_exm = gr.Examples(
463
- examples=[
464
- os.path.join(root_path, "person", "men", _)
465
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
466
- ],
467
- examples_per_page=4,
468
- inputs=image_path,
469
- label="Person Examples ①",
470
- )
471
- women_exm = gr.Examples(
472
- examples=[
473
- os.path.join(root_path, "person", "women", _)
474
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
475
- ],
476
- examples_per_page=4,
477
- inputs=image_path,
478
- label="Person Examples ②",
479
- )
480
- gr.Markdown(
481
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
482
- )
483
- with gr.Column():
484
- condition_upper_exm = gr.Examples(
485
- examples=[
486
- os.path.join(root_path, "condition", "upper", _)
487
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
488
- ],
489
- examples_per_page=4,
490
- inputs=cloth_image,
491
- label="Condition Upper Examples",
492
- )
493
- condition_overall_exm = gr.Examples(
494
- examples=[
495
- os.path.join(root_path, "condition", "overall", _)
496
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
497
- ],
498
- examples_per_page=4,
499
- inputs=cloth_image,
500
- label="Condition Overall Examples",
501
- )
502
- condition_person_exm = gr.Examples(
503
- examples=[
504
- os.path.join(root_path, "condition", "person", _)
505
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
506
- ],
507
- examples_per_page=4,
508
- inputs=cloth_image,
509
- label="Condition Reference Person Examples",
510
- )
511
- gr.Markdown(
512
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
513
- )
514
-
515
- image_path.change(
516
- person_example_fn, inputs=image_path, outputs=person_image
517
- )
518
-
519
- submit.click(
520
- submit_function,
521
- [
522
- person_image,
523
- cloth_image,
524
- cloth_type,
525
- num_inference_steps,
526
- guidance_scale,
527
- seed,
528
- show_type,
529
- ],
530
- result_image,
531
- )
532
-
533
- with gr.Tab("Mask-based & Flux.1 Fill Dev"):
534
- with gr.Row():
535
- with gr.Column(scale=1, min_width=350):
536
- with gr.Row():
537
- image_path_flux = gr.Image(
538
- type="filepath",
539
- interactive=True,
540
- visible=False,
541
- )
542
- person_image_flux = gr.ImageEditor(
543
- interactive=True, label="Person Image", type="filepath"
544
- )
545
-
546
- with gr.Row():
547
- with gr.Column(scale=1, min_width=230):
548
- cloth_image_flux = gr.Image(
549
- interactive=True, label="Condition Image", type="filepath"
550
- )
551
- with gr.Column(scale=1, min_width=120):
552
- gr.Markdown(
553
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
554
- )
555
- cloth_type = gr.Radio(
556
- label="Try-On Cloth Type",
557
- choices=["upper", "lower", "overall"],
558
- value="upper",
559
- )
560
-
561
- submit_flux = gr.Button("Submit")
562
- gr.Markdown(
563
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
564
- )
565
-
566
- with gr.Accordion("Advanced Options", open=False):
567
- num_inference_steps_flux = gr.Slider(
568
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
569
- )
570
- # Guidence Scale
571
- guidance_scale_flux = gr.Slider(
572
- label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
573
- )
574
- # Random Seed
575
- seed_flux = gr.Slider(
576
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
577
- )
578
- show_type = gr.Radio(
579
- label="Show Type",
580
- choices=["result only", "input & result", "input & mask & result"],
581
- value="input & mask & result",
582
- )
583
-
584
- with gr.Column(scale=2, min_width=500):
585
- result_image_flux = gr.Image(interactive=False, label="Result")
586
- with gr.Row():
587
- # Photo Examples
588
- root_path = "resource/demo/example"
589
- with gr.Column():
590
- gr.Examples(
591
- examples=[
592
- os.path.join(root_path, "person", "men", _)
593
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
594
- ],
595
- examples_per_page=4,
596
- inputs=image_path_flux,
597
- label="Person Examples ①",
598
- )
599
- gr.Examples(
600
- examples=[
601
- os.path.join(root_path, "person", "women", _)
602
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
603
- ],
604
- examples_per_page=4,
605
- inputs=image_path_flux,
606
- label="Person Examples ②",
607
- )
608
- gr.Markdown(
609
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
610
- )
611
- with gr.Column():
612
- gr.Examples(
613
- examples=[
614
- os.path.join(root_path, "condition", "upper", _)
615
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
616
- ],
617
- examples_per_page=4,
618
- inputs=cloth_image_flux,
619
- label="Condition Upper Examples",
620
- )
621
- gr.Examples(
622
- examples=[
623
- os.path.join(root_path, "condition", "overall", _)
624
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
625
- ],
626
- examples_per_page=4,
627
- inputs=cloth_image_flux,
628
- label="Condition Overall Examples",
629
- )
630
- condition_person_exm = gr.Examples(
631
- examples=[
632
- os.path.join(root_path, "condition", "person", _)
633
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
634
- ],
635
- examples_per_page=4,
636
- inputs=cloth_image_flux,
637
- label="Condition Reference Person Examples",
638
- )
639
- gr.Markdown(
640
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
641
- )
642
-
643
-
644
- image_path_flux.change(
645
- person_example_fn, inputs=image_path_flux, outputs=person_image_flux
646
- )
647
-
648
- submit_flux.click(
649
- submit_function_flux,
650
- [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
651
- result_image_flux,
652
- )
653
-
654
-
655
- with gr.Tab("Mask-free & SD1.5"):
656
- with gr.Row():
657
- with gr.Column(scale=1, min_width=350):
658
- with gr.Row():
659
- image_path_p2p = gr.Image(
660
- type="filepath",
661
- interactive=True,
662
- visible=False,
663
- )
664
- person_image_p2p = gr.ImageEditor(
665
- interactive=True, label="Person Image", type="filepath"
666
- )
667
-
668
- with gr.Row():
669
- with gr.Column(scale=1, min_width=230):
670
- cloth_image_p2p = gr.Image(
671
- interactive=True, label="Condition Image", type="filepath"
672
- )
673
-
674
- submit_p2p = gr.Button("Submit")
675
- gr.Markdown(
676
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
677
- )
678
-
679
- gr.Markdown(
680
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
681
- )
682
- with gr.Accordion("Advanced Options", open=False):
683
- num_inference_steps_p2p = gr.Slider(
684
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
685
- )
686
- # Guidence Scale
687
- guidance_scale_p2p = gr.Slider(
688
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
689
- )
690
- # Random Seed
691
- seed_p2p = gr.Slider(
692
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
693
- )
694
- # show_type = gr.Radio(
695
- # label="Show Type",
696
- # choices=["result only", "input & result", "input & mask & result"],
697
- # value="input & mask & result",
698
- # )
699
-
700
- with gr.Column(scale=2, min_width=500):
701
- result_image_p2p = gr.Image(interactive=False, label="Result")
702
- with gr.Row():
703
- # Photo Examples
704
- root_path = "resource/demo/example"
705
- with gr.Column():
706
- gr.Examples(
707
- examples=[
708
- os.path.join(root_path, "person", "men", _)
709
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
710
- ],
711
- examples_per_page=4,
712
- inputs=image_path_p2p,
713
- label="Person Examples ①",
714
- )
715
- gr.Examples(
716
- examples=[
717
- os.path.join(root_path, "person", "women", _)
718
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
719
- ],
720
- examples_per_page=4,
721
- inputs=image_path_p2p,
722
- label="Person Examples ②",
723
- )
724
- gr.Markdown(
725
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
726
- )
727
- with gr.Column():
728
- gr.Examples(
729
- examples=[
730
- os.path.join(root_path, "condition", "upper", _)
731
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
732
- ],
733
- examples_per_page=4,
734
- inputs=cloth_image_p2p,
735
- label="Condition Upper Examples",
736
- )
737
- gr.Examples(
738
- examples=[
739
- os.path.join(root_path, "condition", "overall", _)
740
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
741
- ],
742
- examples_per_page=4,
743
- inputs=cloth_image_p2p,
744
- label="Condition Overall Examples",
745
- )
746
- condition_person_exm = gr.Examples(
747
- examples=[
748
- os.path.join(root_path, "condition", "person", _)
749
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
750
- ],
751
- examples_per_page=4,
752
- inputs=cloth_image_p2p,
753
- label="Condition Reference Person Examples",
754
- )
755
- gr.Markdown(
756
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
757
- )
758
-
759
- image_path_p2p.change(
760
- person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
761
- )
762
-
763
- submit_p2p.click(
764
- submit_function_p2p,
765
- [
766
- person_image_p2p,
767
- cloth_image_p2p,
768
- num_inference_steps_p2p,
769
- guidance_scale_p2p,
770
- seed_p2p],
771
- result_image_p2p,
772
- )
773
-
774
- demo.queue().launch(share=True, show_error=True)
775
-
776
-
777
- if __name__ == "__main__":
778
- app_gradio()
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
+ parser.add_argument(
24
+ "--base_model_path",
25
+ type=str,
26
+ default="booksforcharlie/stable-diffusion-inpainting",
27
+ help=(
28
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
+ ),
30
+ )
31
+ parser.add_argument(
32
+ "--p2p_base_model_path",
33
+ type=str,
34
+ default="timbrooks/instruct-pix2pix",
35
+ help=(
36
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
+ ),
38
+ )
39
+ parser.add_argument(
40
+ "--resume_path",
41
+ type=str,
42
+ default="zhengchong/CatVTON",
43
+ help=(
44
+ "The Path to the checkpoint of trained tryon model."
45
+ ),
46
+ )
47
+ parser.add_argument(
48
+ "--output_dir",
49
+ type=str,
50
+ default="resource/demo/output",
51
+ help="The output directory where the model predictions will be written.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--width",
56
+ type=int,
57
+ default=768,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--height",
65
+ type=int,
66
+ default=1024,
67
+ help=(
68
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
+ " resolution"
70
+ ),
71
+ )
72
+ parser.add_argument(
73
+ "--repaint",
74
+ action="store_true",
75
+ help="Whether to repaint the result image with the original background."
76
+ )
77
+ parser.add_argument(
78
+ "--allow_tf32",
79
+ action="store_true",
80
+ default=True,
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--mixed_precision",
88
+ type=str,
89
+ default="bf16",
90
+ choices=["no", "fp16", "bf16"],
91
+ help=(
92
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
+ ),
96
+ )
97
+
98
+ args = parser.parse_args()
99
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
101
+ args.local_rank = env_local_rank
102
+
103
+ return args
104
+
105
+ def image_grid(imgs, rows, cols):
106
+ assert len(imgs) == rows * cols
107
+
108
+ w, h = imgs[0].size
109
+ grid = Image.new("RGB", size=(cols * w, rows * h))
110
+
111
+ for i, img in enumerate(imgs):
112
+ grid.paste(img, box=(i % cols * w, i // cols * h))
113
+ return grid
114
+
115
+
116
+ args = parse_args()
117
+
118
+ # Mask-based CatVTON
119
+ catvton_repo = "zhengchong/CatVTON"
120
+ repo_path = snapshot_download(repo_id=catvton_repo)
121
+ # Pipeline
122
+ pipeline = CatVTONPipeline(
123
+ base_ckpt=args.base_model_path,
124
+ attn_ckpt=repo_path,
125
+ attn_ckpt_version="mix",
126
+ weight_dtype=init_weight_dtype(args.mixed_precision),
127
+ use_tf32=args.allow_tf32,
128
+ device='cuda'
129
+ )
130
+ # AutoMasker
131
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
+ automasker = AutoMasker(
133
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
135
+ device='cuda',
136
+ )
137
+
138
+
139
+ # Flux-based CatVTON
140
+ access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
+ flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
+ pipeline_flux.load_lora_weights(
144
+ os.path.join(repo_path, "flux-lora"),
145
+ weight_name='pytorch_lora_weights.safetensors'
146
+ )
147
+ pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
+
149
+
150
+
151
+
152
+ @spaces.GPU(duration=120)
153
+ def submit_function(
154
+ person_image,
155
+ cloth_image,
156
+ cloth_type,
157
+ num_inference_steps,
158
+ guidance_scale,
159
+ seed,
160
+ show_type
161
+ ):
162
+ person_image, mask = person_image["background"], person_image["layers"][0]
163
+ mask = Image.open(mask).convert("L")
164
+ if len(np.unique(np.array(mask))) == 1:
165
+ mask = None
166
+ else:
167
+ mask = np.array(mask)
168
+ mask[mask > 0] = 255
169
+ mask = Image.fromarray(mask)
170
+
171
+ tmp_folder = args.output_dir
172
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
173
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
174
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
175
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
176
+
177
+ generator = None
178
+ if seed != -1:
179
+ generator = torch.Generator(device='cuda').manual_seed(seed)
180
+
181
+ person_image = Image.open(person_image).convert("RGB")
182
+ cloth_image = Image.open(cloth_image).convert("RGB")
183
+ person_image = resize_and_crop(person_image, (args.width, args.height))
184
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
185
+
186
+ # Process mask
187
+ if mask is not None:
188
+ mask = resize_and_crop(mask, (args.width, args.height))
189
+ else:
190
+ mask = automasker(
191
+ person_image,
192
+ cloth_type
193
+ )['mask']
194
+ mask = mask_processor.blur(mask, blur_factor=9)
195
+
196
+ # Inference
197
+ # try:
198
+ result_image = pipeline(
199
+ image=person_image,
200
+ condition_image=cloth_image,
201
+ mask=mask,
202
+ num_inference_steps=num_inference_steps,
203
+ guidance_scale=guidance_scale,
204
+ generator=generator
205
+ )[0]
206
+ # except Exception as e:
207
+ # raise gr.Error(
208
+ # "An error occurred. Please try again later: {}".format(e)
209
+ # )
210
+
211
+ # Post-process
212
+ masked_person = vis_mask(person_image, mask)
213
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
214
+ save_result_image.save(result_save_path)
215
+ if show_type == "result only":
216
+ return result_image
217
+ else:
218
+ width, height = person_image.size
219
+ if show_type == "input & result":
220
+ condition_width = width // 2
221
+ conditions = image_grid([person_image, cloth_image], 2, 1)
222
+ else:
223
+ condition_width = width // 3
224
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
225
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
226
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
227
+ new_result_image.paste(conditions, (0, 0))
228
+ new_result_image.paste(result_image, (condition_width + 5, 0))
229
+ return new_result_image
230
+
231
+
232
+ @spaces.GPU(duration=120)
233
+ def submit_function_flux(
234
+ person_image,
235
+ cloth_image,
236
+ cloth_type,
237
+ num_inference_steps,
238
+ guidance_scale,
239
+ seed,
240
+ show_type
241
+ ):
242
+
243
+ # Process image editor input
244
+ person_image, mask = person_image["background"], person_image["layers"][0]
245
+ mask = Image.open(mask).convert("L")
246
+ if len(np.unique(np.array(mask))) == 1:
247
+ mask = None
248
+ else:
249
+ mask = np.array(mask)
250
+ mask[mask > 0] = 255
251
+ mask = Image.fromarray(mask)
252
+
253
+ # Set random seed
254
+ generator = None
255
+ if seed != -1:
256
+ generator = torch.Generator(device='cuda').manual_seed(seed)
257
+
258
+ # Process input images
259
+ person_image = Image.open(person_image).convert("RGB")
260
+ cloth_image = Image.open(cloth_image).convert("RGB")
261
+
262
+ # Adjust image sizes
263
+ person_image = resize_and_crop(person_image, (args.width, args.height))
264
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
+
266
+ # Process mask
267
+ if mask is not None:
268
+ mask = resize_and_crop(mask, (args.width, args.height))
269
+ else:
270
+ mask = automasker(
271
+ person_image,
272
+ cloth_type
273
+ )['mask']
274
+ mask = mask_processor.blur(mask, blur_factor=9)
275
+
276
+ # Inference
277
+ result_image = pipeline_flux(
278
+ image=person_image,
279
+ condition_image=cloth_image,
280
+ mask_image=mask,
281
+ width=args.width,
282
+ height=args.height,
283
+ num_inference_steps=num_inference_steps,
284
+ guidance_scale=guidance_scale,
285
+ generator=generator
286
+ ).images[0]
287
+
288
+ # Post-processing
289
+ masked_person = vis_mask(person_image, mask)
290
+
291
+ # Return result based on show type
292
+ if show_type == "result only":
293
+ return result_image
294
+ else:
295
+ width, height = person_image.size
296
+ if show_type == "input & result":
297
+ condition_width = width // 2
298
+ conditions = image_grid([person_image, cloth_image], 2, 1)
299
+ else:
300
+ condition_width = width // 3
301
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
302
+
303
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
304
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
305
+ new_result_image.paste(conditions, (0, 0))
306
+ new_result_image.paste(result_image, (condition_width + 5, 0))
307
+ return new_result_image
308
+
309
+
310
+ def person_example_fn(image_path):
311
+ return image_path
312
+
313
+
314
+ HEADER = ""
315
+
316
+ def app_gradio():
317
+ with gr.Blocks(title="CatVTON") as demo:
318
+ gr.Markdown(HEADER)
319
+ with gr.Tab("Mask-based & SD1.5"):
320
+ with gr.Row():
321
+ with gr.Column(scale=1, min_width=350):
322
+ with gr.Row():
323
+ image_path = gr.Image(
324
+ type="filepath",
325
+ interactive=True,
326
+ visible=False,
327
+ )
328
+ person_image = gr.ImageEditor(
329
+ interactive=True, label="Person Image", type="filepath"
330
+ )
331
+
332
+ with gr.Row():
333
+ with gr.Column(scale=1, min_width=230):
334
+ cloth_image = gr.Image(
335
+ interactive=True, label="Condition Image", type="filepath"
336
+ )
337
+ with gr.Column(scale=1, min_width=120):
338
+ gr.Markdown(
339
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
340
+ )
341
+ cloth_type = gr.Radio(
342
+ label="Try-On Cloth Type",
343
+ choices=["upper", "lower", "overall"],
344
+ value="upper",
345
+ )
346
+
347
+
348
+ submit = gr.Button("Submit")
349
+ gr.Markdown(
350
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
351
+ )
352
+
353
+ gr.Markdown(
354
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
355
+ )
356
+ with gr.Accordion("Advanced Options", open=False):
357
+ num_inference_steps = gr.Slider(
358
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
359
+ )
360
+ # Guidence Scale
361
+ guidance_scale = gr.Slider(
362
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
363
+ )
364
+ # Random Seed
365
+ seed = gr.Slider(
366
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
367
+ )
368
+ show_type = gr.Radio(
369
+ label="Show Type",
370
+ choices=["result only", "input & result", "input & mask & result"],
371
+ value="input & mask & result",
372
+ )
373
+
374
+ with gr.Column(scale=2, min_width=500):
375
+ result_image = gr.Image(interactive=False, label="Result")
376
+ with gr.Row():
377
+ # Photo Examples
378
+ root_path = "resource/demo/example"
379
+ with gr.Column():
380
+ men_exm = gr.Examples(
381
+ examples=[
382
+ os.path.join(root_path, "person", "men", _)
383
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
384
+ ],
385
+ examples_per_page=4,
386
+ inputs=image_path,
387
+ label="Person Examples ①",
388
+ )
389
+ women_exm = gr.Examples(
390
+ examples=[
391
+ os.path.join(root_path, "person", "women", _)
392
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
393
+ ],
394
+ examples_per_page=4,
395
+ inputs=image_path,
396
+ label="Person Examples ②",
397
+ )
398
+ gr.Markdown(
399
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
400
+ )
401
+ with gr.Column():
402
+ condition_upper_exm = gr.Examples(
403
+ examples=[
404
+ os.path.join(root_path, "condition", "upper", _)
405
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
406
+ ],
407
+ examples_per_page=4,
408
+ inputs=cloth_image,
409
+ label="Condition Upper Examples",
410
+ )
411
+ condition_overall_exm = gr.Examples(
412
+ examples=[
413
+ os.path.join(root_path, "condition", "overall", _)
414
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
415
+ ],
416
+ examples_per_page=4,
417
+ inputs=cloth_image,
418
+ label="Condition Overall Examples",
419
+ )
420
+ condition_person_exm = gr.Examples(
421
+ examples=[
422
+ os.path.join(root_path, "condition", "person", _)
423
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
424
+ ],
425
+ examples_per_page=4,
426
+ inputs=cloth_image,
427
+ label="Condition Reference Person Examples",
428
+ )
429
+ gr.Markdown(
430
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
431
+ )
432
+
433
+ image_path.change(
434
+ person_example_fn, inputs=image_path, outputs=person_image
435
+ )
436
+
437
+ submit.click(
438
+ submit_function,
439
+ [
440
+ person_image,
441
+ cloth_image,
442
+ cloth_type,
443
+ num_inference_steps,
444
+ guidance_scale,
445
+ seed,
446
+ show_type,
447
+ ],
448
+ result_image,
449
+ )
450
+
451
+ with gr.Tab("Mask-based & Flux.1 Fill Dev"):
452
+ with gr.Row():
453
+ with gr.Column(scale=1, min_width=350):
454
+ with gr.Row():
455
+ image_path_flux = gr.Image(
456
+ type="filepath",
457
+ interactive=True,
458
+ visible=False,
459
+ )
460
+ person_image_flux = gr.ImageEditor(
461
+ interactive=True, label="Person Image", type="filepath"
462
+ )
463
+
464
+ with gr.Row():
465
+ with gr.Column(scale=1, min_width=230):
466
+ cloth_image_flux = gr.Image(
467
+ interactive=True, label="Condition Image", type="filepath"
468
+ )
469
+ with gr.Column(scale=1, min_width=120):
470
+ gr.Markdown(
471
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
472
+ )
473
+ cloth_type = gr.Radio(
474
+ label="Try-On Cloth Type",
475
+ choices=["upper", "lower", "overall"],
476
+ value="upper",
477
+ )
478
+
479
+ submit_flux = gr.Button("Submit")
480
+ gr.Markdown(
481
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
482
+ )
483
+
484
+ with gr.Accordion("Advanced Options", open=False):
485
+ num_inference_steps_flux = gr.Slider(
486
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
487
+ )
488
+ # Guidence Scale
489
+ guidance_scale_flux = gr.Slider(
490
+ label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
491
+ )
492
+ # Random Seed
493
+ seed_flux = gr.Slider(
494
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
495
+ )
496
+ show_type = gr.Radio(
497
+ label="Show Type",
498
+ choices=["result only", "input & result", "input & mask & result"],
499
+ value="input & mask & result",
500
+ )
501
+
502
+ with gr.Column(scale=2, min_width=500):
503
+ result_image_flux = gr.Image(interactive=False, label="Result")
504
+ with gr.Row():
505
+ # Photo Examples
506
+ root_path = "resource/demo/example"
507
+ with gr.Column():
508
+ gr.Examples(
509
+ examples=[
510
+ os.path.join(root_path, "person", "men", _)
511
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
512
+ ],
513
+ examples_per_page=4,
514
+ inputs=image_path_flux,
515
+ label="Person Examples ①",
516
+ )
517
+ gr.Examples(
518
+ examples=[
519
+ os.path.join(root_path, "person", "women", _)
520
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
521
+ ],
522
+ examples_per_page=4,
523
+ inputs=image_path_flux,
524
+ label="Person Examples ②",
525
+ )
526
+ gr.Markdown(
527
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
528
+ )
529
+ with gr.Column():
530
+ gr.Examples(
531
+ examples=[
532
+ os.path.join(root_path, "condition", "upper", _)
533
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
534
+ ],
535
+ examples_per_page=4,
536
+ inputs=cloth_image_flux,
537
+ label="Condition Upper Examples",
538
+ )
539
+ gr.Examples(
540
+ examples=[
541
+ os.path.join(root_path, "condition", "overall", _)
542
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
543
+ ],
544
+ examples_per_page=4,
545
+ inputs=cloth_image_flux,
546
+ label="Condition Overall Examples",
547
+ )
548
+ condition_person_exm = gr.Examples(
549
+ examples=[
550
+ os.path.join(root_path, "condition", "person", _)
551
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
552
+ ],
553
+ examples_per_page=4,
554
+ inputs=cloth_image_flux,
555
+ label="Condition Reference Person Examples",
556
+ )
557
+ gr.Markdown(
558
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
559
+ )
560
+
561
+
562
+ image_path_flux.change(
563
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
564
+ )
565
+
566
+ submit_flux.click(
567
+ submit_function_flux,
568
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
569
+ result_image_flux,
570
+ )
571
+
572
+
573
+
574
+
575
+ demo.queue().launch(share=True, show_error=True)
576
+
577
+
578
+ if __name__ == "__main__":
579
+ app_gradio()