Balaji23 commited on
Commit
6c71e37
·
verified ·
1 Parent(s): 631a767

update app.py

Browse files
Files changed (1) hide show
  1. app.py +476 -472
app.py CHANGED
@@ -1,472 +1,476 @@
1
- import sys
2
- sys.path.append('./')
3
- from PIL import Image
4
- import cv2
5
- import gradio as gr
6
- from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
7
- from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
8
- from src.unet_hacked_tryon import UNet2DConditionModel
9
- from transformers import (
10
- CLIPImageProcessor,
11
- CLIPVisionModelWithProjection,
12
- CLIPTextModel,
13
- CLIPTextModelWithProjection,
14
- )
15
- from diffusers import DDPMScheduler,AutoencoderKL
16
- from typing import List
17
-
18
- import torch
19
- import os
20
- from transformers import AutoTokenizer
21
- import numpy as np
22
- from utils_mask import get_mask_location
23
- from torchvision import transforms
24
- import apply_net
25
- from preprocess.humanparsing.run_parsing import Parsing
26
- from preprocess.openpose.run_openpose import OpenPose
27
- from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
28
- from torchvision.transforms.functional import to_pil_image
29
-
30
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
31
-
32
- def pil_to_binary_mask(pil_image, threshold=0):
33
- np_image = np.array(pil_image)
34
- grayscale_image = Image.fromarray(np_image).convert("L")
35
- binary_mask = np.array(grayscale_image) > threshold
36
- mask = np.zeros(binary_mask.shape, dtype=np.uint8)
37
- for i in range(binary_mask.shape[0]):
38
- for j in range(binary_mask.shape[1]):
39
- if binary_mask[i,j] == True :
40
- mask[i,j] = 1
41
- mask = (mask*255).astype(np.uint8)
42
- output_mask = Image.fromarray(mask)
43
- return output_mask
44
-
45
-
46
- base_path = 'yisol/IDM-VTON'
47
- example_path = os.path.join(os.path.dirname(__file__), 'example')
48
-
49
- unet = UNet2DConditionModel.from_pretrained(
50
- base_path,
51
- subfolder="unet",
52
- torch_dtype=torch.float16,
53
- )
54
- unet.requires_grad_(False)
55
- tokenizer_one = AutoTokenizer.from_pretrained(
56
- base_path,
57
- subfolder="tokenizer",
58
- revision=None,
59
- use_fast=False,
60
- )
61
- tokenizer_two = AutoTokenizer.from_pretrained(
62
- base_path,
63
- subfolder="tokenizer_2",
64
- revision=None,
65
- use_fast=False,
66
- )
67
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
68
-
69
- text_encoder_one = CLIPTextModel.from_pretrained(
70
- base_path,
71
- subfolder="text_encoder",
72
- torch_dtype=torch.float16,
73
- )
74
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
75
- base_path,
76
- subfolder="text_encoder_2",
77
- torch_dtype=torch.float16,
78
- )
79
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
80
- base_path,
81
- subfolder="image_encoder",
82
- torch_dtype=torch.float16,
83
- )
84
- vae = AutoencoderKL.from_pretrained(base_path,
85
- subfolder="vae",
86
- torch_dtype=torch.float16,
87
- )
88
-
89
- # "stabilityai/stable-diffusion-xl-base-1.0",
90
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
91
- base_path,
92
- subfolder="unet_encoder",
93
- torch_dtype=torch.float16,
94
- )
95
-
96
- parsing_model = Parsing(0)
97
- openpose_model = OpenPose(0)
98
-
99
- UNet_Encoder.requires_grad_(False)
100
- image_encoder.requires_grad_(False)
101
- vae.requires_grad_(False)
102
- unet.requires_grad_(False)
103
- text_encoder_one.requires_grad_(False)
104
- text_encoder_two.requires_grad_(False)
105
- tensor_transfrom = transforms.Compose(
106
- [
107
- transforms.ToTensor(),
108
- transforms.Normalize([0.5], [0.5]),
109
- ]
110
- )
111
-
112
- pipe = TryonPipeline.from_pretrained(
113
- base_path,
114
- unet=unet,
115
- vae=vae,
116
- feature_extractor= CLIPImageProcessor(),
117
- text_encoder = text_encoder_one,
118
- text_encoder_2 = text_encoder_two,
119
- tokenizer = tokenizer_one,
120
- tokenizer_2 = tokenizer_two,
121
- scheduler = noise_scheduler,
122
- image_encoder=image_encoder,
123
- torch_dtype=torch.float16,
124
- )
125
- pipe.unet_encoder = UNet_Encoder
126
-
127
- # Function to visualize parsing
128
- def visualize_parsing(image, mask):
129
- """
130
- Visualize the parsing by applying a color map to the segmentation mask.
131
- """
132
- # Ensure image is in RGB format and convert to numpy array
133
- image_array = np.array(image.convert('RGB'), dtype=np.uint8)
134
-
135
- # Create a color map
136
- num_classes = np.max(mask) + 1
137
- colors = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8)
138
-
139
- # Apply color map to the mask
140
- color_mask = colors[mask.astype(int)]
141
-
142
- # Ensure color_mask is correctly shaped and typed
143
- color_mask = np.array(color_mask, dtype=np.uint8)
144
-
145
- # Combine the original image and the color mask
146
- combined_image = cv2.addWeighted(image_array, 0.5, color_mask, 0.5, 0)
147
-
148
- return Image.fromarray(combined_image)
149
-
150
- def process_densepose(human_img):
151
- """
152
- Processes the human image using DensePose and returns the DensePose image.
153
- Assumes human_img is a dictionary with a 'background' key pointing to the image path.
154
- """
155
- # Load image from path
156
- image_path = human_img['background'] # Assuming 'background' is the correct key
157
- if isinstance(image_path, Image.Image):
158
- image = image_path
159
- else:
160
- image = Image.open(image_path) # Only call Image.open if it's not already an Image object
161
-
162
- # Apply EXIF orientation and resize
163
- human_img_arg = _apply_exif_orientation(image.resize((384, 512)))
164
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
165
-
166
- # Setup DensePose arguments
167
- args = apply_net.create_argument_parser().parse_args(
168
- ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')
169
- )
170
- pose_img = args.func(args, human_img_arg)
171
- pose_img = pose_img[:, :, ::-1] # Convert from BGR to RGB
172
- pose_img = Image.fromarray(pose_img).resize((768, 1024))
173
-
174
- return pose_img, pose_img
175
-
176
- def process_human_parsing(human_img):
177
- """
178
- Processes the human image to perform segmentation using a human parsing model.
179
- """
180
-
181
- image_path = human_img['background'] # Assuming 'background' is the correct key
182
- if isinstance(image_path, Image.Image):
183
- image = image_path
184
- else:
185
- image = Image.open(image_path) # Only call Image.open if it's not already an Image object
186
-
187
- image = image.resize((384, 512))
188
- model_parse, _ = parsing_model(image)
189
- # parsing_image = visualize_parsing(human_img, model_parse) # Visualization function needed
190
- # vis_image = visualize_parsing(image, model_parse)
191
- # state_message = "Human parsing processing completed"
192
- return model_parse
193
-
194
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
195
- """
196
- Preprocesses images and generates outputs using various models.
197
-
198
- Parameters:
199
- - human_img: PIL image of the human.
200
- - garm_img: PIL image of the garment.
201
- - garment_des: Description of the garment.
202
- - is_checked: Boolean flag indicating whether to use auto-generated mask.
203
- - is_checked_crop: Boolean flag indicating whether to use auto-crop & resizing.
204
- - denoise_steps: Number of denoising steps.
205
- - seed: Seed for random generator.
206
- - pose_img: DensePose image generated in the previous step.
207
-
208
- Returns:
209
- - Processed images: Depending on the conditions, it returns human_img_orig, mask_gray, and final output images.
210
- """
211
- openpose_model.preprocessor.body_estimation.model.to(device)
212
- pipe.to(device)
213
- pipe.unet_encoder.to(device)
214
-
215
- garm_img= garm_img.convert("RGB").resize((768,1024))
216
- human_img_orig = dict["background"].convert("RGB")
217
-
218
- if is_checked_crop:
219
- width, height = human_img_orig.size
220
- target_width = int(min(width, height * (3 / 4)))
221
- target_height = int(min(height, width * (4 / 3)))
222
- left = (width - target_width) / 2
223
- top = (height - target_height) / 2
224
- right = (width + target_width) / 2
225
- bottom = (height + target_height) / 2
226
- cropped_img = human_img_orig.crop((left, top, right, bottom))
227
- crop_size = cropped_img.size
228
- human_img = cropped_img.resize((768,1024))
229
- else:
230
- human_img = human_img_orig.resize((768,1024))
231
-
232
-
233
- if is_checked:
234
- keypoints = openpose_model(human_img.resize((384,512)))
235
- print(keypoints)
236
- model_parse, _ = parsing_model(human_img.resize((384,512)))
237
- print(model_parse)
238
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
239
- mask = mask.resize((768,1024))
240
- else:
241
- mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
242
- # mask = transforms.ToTensor()(mask)
243
- # mask = mask.unsqueeze(0)
244
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
245
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
246
-
247
-
248
- human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
249
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
250
-
251
-
252
-
253
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
254
- # verbosity = getattr(args, "verbosity", None)
255
- pose_img = args.func(args,human_img_arg)
256
- pose_img = pose_img[:,:,::-1]
257
- pose_img = Image.fromarray(pose_img).resize((768,1024))
258
-
259
- with torch.no_grad():
260
- # Extract the images
261
- with torch.cuda.amp.autocast():
262
- with torch.no_grad():
263
- prompt = "model is wearing " + garment_des
264
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
265
- with torch.inference_mode():
266
- (
267
- prompt_embeds,
268
- negative_prompt_embeds,
269
- pooled_prompt_embeds,
270
- negative_pooled_prompt_embeds,
271
- ) = pipe.encode_prompt(
272
- prompt,
273
- num_images_per_prompt=1,
274
- do_classifier_free_guidance=True,
275
- negative_prompt=negative_prompt,
276
- )
277
-
278
- prompt = "a photo of " + garment_des
279
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
280
- if not isinstance(prompt, List):
281
- prompt = [prompt] * 1
282
- if not isinstance(negative_prompt, List):
283
- negative_prompt = [negative_prompt] * 1
284
- with torch.inference_mode():
285
- (
286
- prompt_embeds_c,
287
- _,
288
- _,
289
- _,
290
- ) = pipe.encode_prompt(
291
- prompt,
292
- num_images_per_prompt=1,
293
- do_classifier_free_guidance=False,
294
- negative_prompt=negative_prompt,
295
- )
296
-
297
-
298
-
299
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
300
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
301
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
302
- images = pipe(
303
- prompt_embeds=prompt_embeds.to(device,torch.float16),
304
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
305
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
306
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
307
- num_inference_steps=denoise_steps,
308
- generator=generator,
309
- strength = 1.0,
310
- pose_img = pose_img.to(device,torch.float16),
311
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
312
- cloth = garm_tensor.to(device,torch.float16),
313
- mask_image=mask,
314
- image=human_img,
315
- height=1024,
316
- width=768,
317
- ip_adapter_image = garm_img.resize((768,1024)),
318
- guidance_scale=2.0,
319
- )[0]
320
-
321
- if is_checked_crop:
322
- out_img = images[0].resize(crop_size)
323
- human_img_orig.paste(out_img, (int(left), int(top)))
324
- return human_img_orig, mask_gray
325
- else:
326
- # out_img = images[0].resize(crop_size)
327
- return images[0], mask_gray
328
-
329
-
330
-
331
-
332
-
333
- garm_list = os.listdir(os.path.join(example_path,"cloth"))
334
- garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
335
-
336
- human_list = os.listdir(os.path.join(example_path,"human"))
337
- human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
338
-
339
- human_ex_list = []
340
- for ex_human in human_list_path:
341
- ex_dict= {}
342
- ex_dict['background'] = ex_human
343
- ex_dict['layers'] = None
344
- ex_dict['composite'] = None
345
- human_ex_list.append(ex_dict)
346
-
347
- ##default human
348
-
349
-
350
- image_blocks = gr.Blocks().queue()
351
- with image_blocks as demo:
352
- with gr.Row():
353
- with gr.Column():
354
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
355
- with gr.Row():
356
- is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
357
- with gr.Row():
358
- is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
359
-
360
- example = gr.Examples(
361
- inputs=imgs,
362
- examples_per_page=10,
363
- examples=human_ex_list
364
- )
365
-
366
- with gr.Column():
367
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
368
- with gr.Row(elem_id="prompt-container"):
369
- with gr.Row():
370
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
371
- example = gr.Examples(
372
- inputs=garm_img,
373
- examples_per_page=8,
374
- examples=garm_list_path)
375
- with gr.Column():
376
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
377
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
378
-
379
- with gr.Column():
380
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
381
- image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
382
-
383
- with gr.Column():
384
- densepose_img_out = gr.Image(label="Output", elem_id="densepose-img",show_share_button=False)
385
- # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
386
-
387
-
388
-
389
- with gr.Column():
390
- try_button = gr.Button(value="Try-on")
391
- with gr.Accordion(label="Advanced Settings", open=False):
392
- with gr.Row():
393
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
394
- seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
395
-
396
- densepose_state = gr.State(None)
397
-
398
- # Define the steps in sequence
399
- image_blocks = gr.Blocks().queue()
400
- with image_blocks as demo:
401
- with gr.Row():
402
- with gr.Column():
403
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
404
- with gr.Row():
405
- is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
406
- with gr.Row():
407
- is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
408
-
409
- example = gr.Examples(
410
- inputs=imgs,
411
- examples_per_page=10,
412
- examples=human_ex_list
413
- )
414
-
415
- with gr.Column():
416
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
417
- with gr.Row(elem_id="prompt-container"):
418
- with gr.Row():
419
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
420
- example = gr.Examples(
421
- inputs=garm_img,
422
- examples_per_page=8,
423
- examples=garm_list_path)
424
- with gr.Column():
425
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
426
-
427
- with gr.Column():
428
- image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
429
-
430
- with gr.Column():
431
- densepose_img_out = gr.Image(label="Dense-pose", elem_id="densepose-img", show_share_button=False)
432
- # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
433
-
434
- with gr.Column():
435
- human_parse_img_out = gr.Image(label="Human-Parse", elem_id="humanparse-img", show_share_button=False)
436
- # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
437
-
438
- with gr.Column():
439
- try_button = gr.Button(value="Try-on")
440
- get_denspose =gr.Button(value="Get-DensePose")
441
- get_humanparse =gr.Button(value="Get-HumanParse")
442
- with gr.Accordion(label="Advanced Settings", open=False):
443
- with gr.Row():
444
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
445
- seed = gr.Number(label="Seed", minimum=-1, maximum =2147483647, step=1, value=42)
446
-
447
- densepose_state = gr.State(None)
448
-
449
- # Define the steps in sequence
450
- get_denspose.click(
451
- fn=process_densepose,
452
- inputs=[imgs],
453
- outputs=[densepose_img_out, densepose_state],
454
- api_name='process_densepose'
455
- )
456
- get_humanparse.click(
457
- fn=process_human_parsing,
458
- inputs=[imgs],
459
- outputs=[human_parse_img_out],
460
- api_name='process_humanparse'
461
- )
462
- try_button.click(
463
- fn=start_tryon,
464
- inputs=[imgs, garm_img, prompt, is_checked, is_checked_crop, denoise_steps, seed],
465
- outputs=[image_out, masked_img],
466
- api_name='start_tryon'
467
- )
468
-
469
- image_blocks.launch(server_name="0.0.0.0", server_port=3000)
470
-
471
-
472
-
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ from PIL import Image
4
+ try:
5
+ import cv2
6
+ print("OpenCV is installed correctly.")
7
+ except ImportError:
8
+ print("OpenCV is not installed.")
9
+ import gradio as gr
10
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
11
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
12
+ from src.unet_hacked_tryon import UNet2DConditionModel
13
+ from transformers import (
14
+ CLIPImageProcessor,
15
+ CLIPVisionModelWithProjection,
16
+ CLIPTextModel,
17
+ CLIPTextModelWithProjection,
18
+ )
19
+ from diffusers import DDPMScheduler,AutoencoderKL
20
+ from typing import List
21
+
22
+ import torch
23
+ import os
24
+ from transformers import AutoTokenizer
25
+ import numpy as np
26
+ from utils_mask import get_mask_location
27
+ from torchvision import transforms
28
+ import apply_net
29
+ from preprocess.humanparsing.run_parsing import Parsing
30
+ from preprocess.openpose.run_openpose import OpenPose
31
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
32
+ from torchvision.transforms.functional import to_pil_image
33
+
34
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
35
+
36
+ def pil_to_binary_mask(pil_image, threshold=0):
37
+ np_image = np.array(pil_image)
38
+ grayscale_image = Image.fromarray(np_image).convert("L")
39
+ binary_mask = np.array(grayscale_image) > threshold
40
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
41
+ for i in range(binary_mask.shape[0]):
42
+ for j in range(binary_mask.shape[1]):
43
+ if binary_mask[i,j] == True :
44
+ mask[i,j] = 1
45
+ mask = (mask*255).astype(np.uint8)
46
+ output_mask = Image.fromarray(mask)
47
+ return output_mask
48
+
49
+
50
+ base_path = 'yisol/IDM-VTON'
51
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
52
+
53
+ unet = UNet2DConditionModel.from_pretrained(
54
+ base_path,
55
+ subfolder="unet",
56
+ torch_dtype=torch.float16,
57
+ )
58
+ unet.requires_grad_(False)
59
+ tokenizer_one = AutoTokenizer.from_pretrained(
60
+ base_path,
61
+ subfolder="tokenizer",
62
+ revision=None,
63
+ use_fast=False,
64
+ )
65
+ tokenizer_two = AutoTokenizer.from_pretrained(
66
+ base_path,
67
+ subfolder="tokenizer_2",
68
+ revision=None,
69
+ use_fast=False,
70
+ )
71
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
72
+
73
+ text_encoder_one = CLIPTextModel.from_pretrained(
74
+ base_path,
75
+ subfolder="text_encoder",
76
+ torch_dtype=torch.float16,
77
+ )
78
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
79
+ base_path,
80
+ subfolder="text_encoder_2",
81
+ torch_dtype=torch.float16,
82
+ )
83
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
84
+ base_path,
85
+ subfolder="image_encoder",
86
+ torch_dtype=torch.float16,
87
+ )
88
+ vae = AutoencoderKL.from_pretrained(base_path,
89
+ subfolder="vae",
90
+ torch_dtype=torch.float16,
91
+ )
92
+
93
+ # "stabilityai/stable-diffusion-xl-base-1.0",
94
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
95
+ base_path,
96
+ subfolder="unet_encoder",
97
+ torch_dtype=torch.float16,
98
+ )
99
+
100
+ parsing_model = Parsing(0)
101
+ openpose_model = OpenPose(0)
102
+
103
+ UNet_Encoder.requires_grad_(False)
104
+ image_encoder.requires_grad_(False)
105
+ vae.requires_grad_(False)
106
+ unet.requires_grad_(False)
107
+ text_encoder_one.requires_grad_(False)
108
+ text_encoder_two.requires_grad_(False)
109
+ tensor_transfrom = transforms.Compose(
110
+ [
111
+ transforms.ToTensor(),
112
+ transforms.Normalize([0.5], [0.5]),
113
+ ]
114
+ )
115
+
116
+ pipe = TryonPipeline.from_pretrained(
117
+ base_path,
118
+ unet=unet,
119
+ vae=vae,
120
+ feature_extractor= CLIPImageProcessor(),
121
+ text_encoder = text_encoder_one,
122
+ text_encoder_2 = text_encoder_two,
123
+ tokenizer = tokenizer_one,
124
+ tokenizer_2 = tokenizer_two,
125
+ scheduler = noise_scheduler,
126
+ image_encoder=image_encoder,
127
+ torch_dtype=torch.float16,
128
+ )
129
+ pipe.unet_encoder = UNet_Encoder
130
+
131
+ # Function to visualize parsing
132
+ def visualize_parsing(image, mask):
133
+ """
134
+ Visualize the parsing by applying a color map to the segmentation mask.
135
+ """
136
+ # Ensure image is in RGB format and convert to numpy array
137
+ image_array = np.array(image.convert('RGB'), dtype=np.uint8)
138
+
139
+ # Create a color map
140
+ num_classes = np.max(mask) + 1
141
+ colors = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8)
142
+
143
+ # Apply color map to the mask
144
+ color_mask = colors[mask.astype(int)]
145
+
146
+ # Ensure color_mask is correctly shaped and typed
147
+ color_mask = np.array(color_mask, dtype=np.uint8)
148
+
149
+ # Combine the original image and the color mask
150
+ combined_image = cv2.addWeighted(image_array, 0.5, color_mask, 0.5, 0)
151
+
152
+ return Image.fromarray(combined_image)
153
+
154
+ def process_densepose(human_img):
155
+ """
156
+ Processes the human image using DensePose and returns the DensePose image.
157
+ Assumes human_img is a dictionary with a 'background' key pointing to the image path.
158
+ """
159
+ # Load image from path
160
+ image_path = human_img['background'] # Assuming 'background' is the correct key
161
+ if isinstance(image_path, Image.Image):
162
+ image = image_path
163
+ else:
164
+ image = Image.open(image_path) # Only call Image.open if it's not already an Image object
165
+
166
+ # Apply EXIF orientation and resize
167
+ human_img_arg = _apply_exif_orientation(image.resize((384, 512)))
168
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
169
+
170
+ # Setup DensePose arguments
171
+ args = apply_net.create_argument_parser().parse_args(
172
+ ('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')
173
+ )
174
+ pose_img = args.func(args, human_img_arg)
175
+ pose_img = pose_img[:, :, ::-1] # Convert from BGR to RGB
176
+ pose_img = Image.fromarray(pose_img).resize((768, 1024))
177
+
178
+ return pose_img, pose_img
179
+
180
+ def process_human_parsing(human_img):
181
+ """
182
+ Processes the human image to perform segmentation using a human parsing model.
183
+ """
184
+
185
+ image_path = human_img['background'] # Assuming 'background' is the correct key
186
+ if isinstance(image_path, Image.Image):
187
+ image = image_path
188
+ else:
189
+ image = Image.open(image_path) # Only call Image.open if it's not already an Image object
190
+
191
+ image = image.resize((384, 512))
192
+ model_parse, _ = parsing_model(image)
193
+ # parsing_image = visualize_parsing(human_img, model_parse) # Visualization function needed
194
+ # vis_image = visualize_parsing(image, model_parse)
195
+ # state_message = "Human parsing processing completed"
196
+ return model_parse
197
+
198
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
199
+ """
200
+ Preprocesses images and generates outputs using various models.
201
+
202
+ Parameters:
203
+ - human_img: PIL image of the human.
204
+ - garm_img: PIL image of the garment.
205
+ - garment_des: Description of the garment.
206
+ - is_checked: Boolean flag indicating whether to use auto-generated mask.
207
+ - is_checked_crop: Boolean flag indicating whether to use auto-crop & resizing.
208
+ - denoise_steps: Number of denoising steps.
209
+ - seed: Seed for random generator.
210
+ - pose_img: DensePose image generated in the previous step.
211
+
212
+ Returns:
213
+ - Processed images: Depending on the conditions, it returns human_img_orig, mask_gray, and final output images.
214
+ """
215
+ openpose_model.preprocessor.body_estimation.model.to(device)
216
+ pipe.to(device)
217
+ pipe.unet_encoder.to(device)
218
+
219
+ garm_img= garm_img.convert("RGB").resize((768,1024))
220
+ human_img_orig = dict["background"].convert("RGB")
221
+
222
+ if is_checked_crop:
223
+ width, height = human_img_orig.size
224
+ target_width = int(min(width, height * (3 / 4)))
225
+ target_height = int(min(height, width * (4 / 3)))
226
+ left = (width - target_width) / 2
227
+ top = (height - target_height) / 2
228
+ right = (width + target_width) / 2
229
+ bottom = (height + target_height) / 2
230
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
231
+ crop_size = cropped_img.size
232
+ human_img = cropped_img.resize((768,1024))
233
+ else:
234
+ human_img = human_img_orig.resize((768,1024))
235
+
236
+
237
+ if is_checked:
238
+ keypoints = openpose_model(human_img.resize((384,512)))
239
+ print(keypoints)
240
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
241
+ print(model_parse)
242
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
243
+ mask = mask.resize((768,1024))
244
+ else:
245
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
246
+ # mask = transforms.ToTensor()(mask)
247
+ # mask = mask.unsqueeze(0)
248
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
249
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
250
+
251
+
252
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
253
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
254
+
255
+
256
+
257
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
258
+ # verbosity = getattr(args, "verbosity", None)
259
+ pose_img = args.func(args,human_img_arg)
260
+ pose_img = pose_img[:,:,::-1]
261
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
262
+
263
+ with torch.no_grad():
264
+ # Extract the images
265
+ with torch.cuda.amp.autocast():
266
+ with torch.no_grad():
267
+ prompt = "model is wearing " + garment_des
268
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
269
+ with torch.inference_mode():
270
+ (
271
+ prompt_embeds,
272
+ negative_prompt_embeds,
273
+ pooled_prompt_embeds,
274
+ negative_pooled_prompt_embeds,
275
+ ) = pipe.encode_prompt(
276
+ prompt,
277
+ num_images_per_prompt=1,
278
+ do_classifier_free_guidance=True,
279
+ negative_prompt=negative_prompt,
280
+ )
281
+
282
+ prompt = "a photo of " + garment_des
283
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
284
+ if not isinstance(prompt, List):
285
+ prompt = [prompt] * 1
286
+ if not isinstance(negative_prompt, List):
287
+ negative_prompt = [negative_prompt] * 1
288
+ with torch.inference_mode():
289
+ (
290
+ prompt_embeds_c,
291
+ _,
292
+ _,
293
+ _,
294
+ ) = pipe.encode_prompt(
295
+ prompt,
296
+ num_images_per_prompt=1,
297
+ do_classifier_free_guidance=False,
298
+ negative_prompt=negative_prompt,
299
+ )
300
+
301
+
302
+
303
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
304
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
305
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
306
+ images = pipe(
307
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
308
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
309
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
310
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
311
+ num_inference_steps=denoise_steps,
312
+ generator=generator,
313
+ strength = 1.0,
314
+ pose_img = pose_img.to(device,torch.float16),
315
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
316
+ cloth = garm_tensor.to(device,torch.float16),
317
+ mask_image=mask,
318
+ image=human_img,
319
+ height=1024,
320
+ width=768,
321
+ ip_adapter_image = garm_img.resize((768,1024)),
322
+ guidance_scale=2.0,
323
+ )[0]
324
+
325
+ if is_checked_crop:
326
+ out_img = images[0].resize(crop_size)
327
+ human_img_orig.paste(out_img, (int(left), int(top)))
328
+ return human_img_orig, mask_gray
329
+ else:
330
+ # out_img = images[0].resize(crop_size)
331
+ return images[0], mask_gray
332
+
333
+
334
+
335
+
336
+
337
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
338
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
339
+
340
+ human_list = os.listdir(os.path.join(example_path,"human"))
341
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
342
+
343
+ human_ex_list = []
344
+ for ex_human in human_list_path:
345
+ ex_dict= {}
346
+ ex_dict['background'] = ex_human
347
+ ex_dict['layers'] = None
348
+ ex_dict['composite'] = None
349
+ human_ex_list.append(ex_dict)
350
+
351
+ ##default human
352
+
353
+
354
+ image_blocks = gr.Blocks().queue()
355
+ with image_blocks as demo:
356
+ with gr.Row():
357
+ with gr.Column():
358
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
359
+ with gr.Row():
360
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
361
+ with gr.Row():
362
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
363
+
364
+ example = gr.Examples(
365
+ inputs=imgs,
366
+ examples_per_page=10,
367
+ examples=human_ex_list
368
+ )
369
+
370
+ with gr.Column():
371
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
372
+ with gr.Row(elem_id="prompt-container"):
373
+ with gr.Row():
374
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
375
+ example = gr.Examples(
376
+ inputs=garm_img,
377
+ examples_per_page=8,
378
+ examples=garm_list_path)
379
+ with gr.Column():
380
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
381
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
382
+
383
+ with gr.Column():
384
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
385
+ image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
386
+
387
+ with gr.Column():
388
+ densepose_img_out = gr.Image(label="Output", elem_id="densepose-img",show_share_button=False)
389
+ # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
390
+
391
+
392
+
393
+ with gr.Column():
394
+ try_button = gr.Button(value="Try-on")
395
+ with gr.Accordion(label="Advanced Settings", open=False):
396
+ with gr.Row():
397
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
398
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
399
+
400
+ densepose_state = gr.State(None)
401
+
402
+ # Define the steps in sequence
403
+ image_blocks = gr.Blocks().queue()
404
+ with image_blocks as demo:
405
+ with gr.Row():
406
+ with gr.Column():
407
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
408
+ with gr.Row():
409
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
410
+ with gr.Row():
411
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
412
+
413
+ example = gr.Examples(
414
+ inputs=imgs,
415
+ examples_per_page=10,
416
+ examples=human_ex_list
417
+ )
418
+
419
+ with gr.Column():
420
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
421
+ with gr.Row(elem_id="prompt-container"):
422
+ with gr.Row():
423
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
424
+ example = gr.Examples(
425
+ inputs=garm_img,
426
+ examples_per_page=8,
427
+ examples=garm_list_path)
428
+ with gr.Column():
429
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
430
+
431
+ with gr.Column():
432
+ image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
433
+
434
+ with gr.Column():
435
+ densepose_img_out = gr.Image(label="Dense-pose", elem_id="densepose-img", show_share_button=False)
436
+ # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
437
+
438
+ with gr.Column():
439
+ human_parse_img_out = gr.Image(label="Human-Parse", elem_id="humanparse-img", show_share_button=False)
440
+ # densepose_img = gr.Gallery(label="All images", show_label=False, elem_id="all-images", columns=[3], rows=[1], object_fit="contain", height="auto")
441
+
442
+ with gr.Column():
443
+ try_button = gr.Button(value="Try-on")
444
+ get_denspose =gr.Button(value="Get-DensePose")
445
+ get_humanparse =gr.Button(value="Get-HumanParse")
446
+ with gr.Accordion(label="Advanced Settings", open=False):
447
+ with gr.Row():
448
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
449
+ seed = gr.Number(label="Seed", minimum=-1, maximum =2147483647, step=1, value=42)
450
+
451
+ densepose_state = gr.State(None)
452
+
453
+ # Define the steps in sequence
454
+ get_denspose.click(
455
+ fn=process_densepose,
456
+ inputs=[imgs],
457
+ outputs=[densepose_img_out, densepose_state],
458
+ api_name='process_densepose'
459
+ )
460
+ get_humanparse.click(
461
+ fn=process_human_parsing,
462
+ inputs=[imgs],
463
+ outputs=[human_parse_img_out],
464
+ api_name='process_humanparse'
465
+ )
466
+ try_button.click(
467
+ fn=start_tryon,
468
+ inputs=[imgs, garm_img, prompt, is_checked, is_checked_crop, denoise_steps, seed],
469
+ outputs=[image_out, masked_img],
470
+ api_name='start_tryon'
471
+ )
472
+
473
+ image_blocks.launch(server_name="0.0.0.0", server_port=3000)
474
+
475
+
476
+