basso4 commited on
Commit
f8c0fcf
Β·
verified Β·
1 Parent(s): 3f9659e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -257
app.py CHANGED
@@ -1,257 +1,257 @@
1
- import spaces
2
- import gradio as gr
3
- import apply_net
4
-
5
- import os
6
- import sys
7
- import cv2
8
-
9
- sys.path.append('./')
10
- import numpy as np
11
- import argparse
12
-
13
- import torch
14
- import torchvision
15
- import pytorch_lightning
16
- from torch import autocast
17
- from torchvision import transforms
18
- from pytorch_lightning import seed_everything
19
-
20
- from einops import rearrange
21
- from functools import partial
22
- from omegaconf import OmegaConf
23
- from PIL import Image
24
- from typing import List
25
- import matplotlib.pyplot as plt
26
- from torchvision.transforms.functional import to_pil_image
27
- from utils_mask import get_mask_location
28
- from preprocess.humanparsing.run_parsing import Parsing
29
- from preprocess.openpose.run_openpose import OpenPose
30
- from ldm.util import instantiate_from_config, get_obj_from_str
31
- from ldm.models.diffusion.ddim import DDIMSampler
32
- from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
33
-
34
-
35
-
36
- if __name__ == "__main__":
37
-
38
- parser = argparse.ArgumentParser(description="Script for demo model")
39
- parser.add_argument("-b", "--base", type=str, default=r"configs/test_vitonhd.yaml")
40
- parser.add_argument("-c", "--ckpt", type=str, default=r"ckpt/hitonhd.ckpt")
41
- parser.add_argument("-s", "--seed", type=str, default=42)
42
- parser.add_argument("-d", "--ddim", type=str, default=16)
43
- opt = parser.parse_args()
44
-
45
- seed_everything(opt.seed)
46
- config = OmegaConf.load(f"{opt.base}")
47
-
48
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49
-
50
- model = instantiate_from_config(config.model)
51
- model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/basso4/hitonhd/resolve/main/hitonhd.ckpt")["state_dict"], strict=False)
52
- model.cuda()
53
- model.eval()
54
- model = model.to(device)
55
- sampler = DDIMSampler(model)
56
-
57
- # model = instantiate_from_config(config.model)
58
- # model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
59
- # model.cuda()
60
- # model.eval()
61
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
62
- # model = model.to(device)
63
- # sampler = DDIMSampler(model)
64
-
65
- precision_scope = autocast
66
-
67
-
68
- @spaces.GPU
69
- def start_tryon(dict_human,garm_img):
70
- #load human image
71
- human_img = dict_human['background'].convert("RGB").resize((768,1024))
72
-
73
- #mask
74
- tensor_transfrom = transforms.Compose(
75
- [
76
- transforms.ToTensor(),
77
- transforms.Normalize([0.5], [0.5]),
78
- ]
79
- )
80
-
81
- parsing_model = Parsing(0)
82
- openose_model = OpenPose(0)
83
- openose_model.preprocessor.body_estimation.model.to(device)
84
-
85
- keypoints = openose_model(human_img.resize((384,512)))
86
- model_parse, _ = parsing_model(human_img.resize((384,512)))
87
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
88
- mask_cv = mask
89
- mask = mask.resize((768, 1024))
90
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
91
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
92
-
93
- #densepose
94
- human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
95
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
96
- args = apply_net.create_argument_parser().parse_args(('show',
97
- './configs/configs_densepose/densepose_rcnn_R_50_FPN_s1x.yaml',
98
- 'https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl',
99
- 'dp_segm', '-v',
100
- '--opts',
101
- 'MODEL.DEVICE',
102
- 'cuda'))
103
- # verbosity = getattr(args, "verbosity", None)
104
- pose_img = args.func(args,human_img_arg)
105
- pose_img = pose_img[:,:,::-1]
106
- pose_img = Image.fromarray(pose_img).resize((768,1024))
107
-
108
- #preprocessing image
109
- human_img = human_img.convert("RGB").resize((512, 512))
110
- human_img = torchvision.transforms.ToTensor()(human_img)
111
-
112
- garm_img = garm_img.convert("RGB").resize((224, 224))
113
- garm_img = torchvision.transforms.ToTensor()(garm_img)
114
-
115
- mask = mask.convert("L").resize((512,512))
116
- mask = torchvision.transforms.ToTensor()(mask)
117
- mask = 1-mask
118
-
119
- pose_img = pose_img.convert("RGB").resize((512, 512))
120
- pose_img = torchvision.transforms.ToTensor()(pose_img)
121
-
122
- #Normalize
123
- human_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(human_img)
124
- garm_img = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
125
- (0.26862954, 0.26130258, 0.27577711))(garm_img)
126
- pose_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(pose_img)
127
-
128
- #create inpaint & hint
129
- inpaint = human_img * mask
130
- hint = torchvision.transforms.Resize((512, 512))(garm_img)
131
- hint = torch.cat((hint, pose_img), dim=0)
132
-
133
- # {"human_img": human_img, # [3, 512, 512]
134
- # "inpaint_image": inpaint, # [3, 512, 512]
135
- # "inpaint_mask": mask, # [1, 512, 512]
136
- # "garm_img": garm_img, # [3, 224, 224]
137
- # "hint": hint, # [6, 512, 512]
138
- # }
139
-
140
-
141
- with torch.no_grad():
142
- with precision_scope("cuda"):
143
- #loading data
144
- inpaint = inpaint.unsqueeze(0).to(torch.float16).to(device)
145
- reference = garm_img.unsqueeze(0).to(torch.float16).to(device)
146
- mask = mask.unsqueeze(0).to(torch.float16).to(device)
147
- hint = hint.unsqueeze(0).to(torch.float16).to(device)
148
- truth = human_img.unsqueeze(0).to(torch.float16).to(device)
149
-
150
- #data preprocessing
151
- encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
152
- z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
153
- mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
154
- test_model_kwargs = {}
155
- test_model_kwargs['inpaint_image'] = z_inpaint
156
- test_model_kwargs['inpaint_mask'] = mask_resize
157
- shape = (model.channels, model.image_size, model.image_size)
158
-
159
- #predict
160
- samples, _ = sampler.sample(S=opt.ddim,
161
- batch_size=1,
162
- shape=shape,
163
- pose=hint,
164
- conditioning=reference,
165
- verbose=False,
166
- eta=0,
167
- test_model_kwargs=test_model_kwargs)
168
- samples = 1. / model.scale_factor * samples
169
- x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
170
-
171
- x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
172
- x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
173
- x_checked_image=x_samples_ddim
174
- x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
175
-
176
- x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
177
-
178
- #apply seamlessClone technique here
179
- #img_base
180
- dict_human = dict_human.convert("RGB").resize((384, 512))
181
- dict_human = np.array(dict_human)
182
- dict_human = cv2.cvtColor(dict_human, cv2.COLOR_RGB2BGR)
183
-
184
- #img_output
185
- img_cv = rearrange(x_checked_image_torch[0], 'c h w -> h w c').cpu().numpy()
186
- img_cv = (img_cv * 255).astype(np.uint8)
187
- img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2BGR)
188
-
189
- #mask
190
- mask_cv = mask_cv.convert("L").resize((384,512))
191
- mask_cv = np.array(mask_cv)
192
- mask_cv = 255-mask_cv
193
-
194
- img_C = cv2.seamlessClone(dict_human, img_cv, mask_cv, (192,256), cv2.NORMAL_CLONE)
195
-
196
-
197
- return img_C, mask_gray
198
-
199
-
200
- example_path = os.path.join(os.path.dirname(__file__), 'example')
201
-
202
- garm_list = os.listdir(os.path.join(example_path,"cloth"))
203
- garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
204
-
205
- human_list = os.listdir(os.path.join(example_path,"human"))
206
- human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
207
-
208
- human_ex_list = []
209
- for ex_human in human_list_path:
210
- ex_dict= {}
211
- ex_dict['background'] = ex_human
212
- ex_dict['layers'] = None
213
- ex_dict['composite'] = None
214
- human_ex_list.append(ex_dict)
215
-
216
- ##default human
217
-
218
-
219
- image_blocks = gr.Blocks().queue()
220
- with image_blocks as demo:
221
- gr.Markdown("## FPT_VTON πŸ‘•πŸ‘”πŸ‘š")
222
- gr.Markdown("Virtual Try-on with your image and garment image")
223
- with gr.Row():
224
- with gr.Column():
225
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Picture or use Examples below', interactive=True)
226
-
227
- example = gr.Examples(
228
- inputs=imgs,
229
- examples_per_page=10,
230
- examples=human_list_path
231
- )
232
-
233
- with gr.Column():
234
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
235
-
236
- example = gr.Examples(
237
- inputs=garm_img,
238
- examples_per_page=8,
239
- examples=garm_list_path
240
- )
241
-
242
- with gr.Column():
243
- image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=True)
244
- try_button = gr.Button(value="Try-on")
245
-
246
- # with gr.Column():
247
- # image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
248
-
249
- with gr.Column():
250
- masked_img = gr.Image(label="Masked image output", elem_id="masked_img", show_download_button=True)
251
-
252
-
253
- try_button.click(fn=start_tryon, inputs=[imgs,garm_img], outputs=[image_out_c,masked_img], api_name='tryon')
254
-
255
-
256
-
257
- image_blocks.launch()
 
1
+ import spaces
2
+ import gradio as gr
3
+ import apply_net
4
+
5
+ import os
6
+ import sys
7
+ import cv2
8
+
9
+ sys.path.append('./')
10
+ import numpy as np
11
+ import argparse
12
+
13
+ import torch
14
+ import torchvision
15
+ import pytorch_lightning
16
+ from torch import autocast
17
+ from torchvision import transforms
18
+ from pytorch_lightning import seed_everything
19
+
20
+ from einops import rearrange
21
+ from functools import partial
22
+ from omegaconf import OmegaConf
23
+ from PIL import Image
24
+ from typing import List
25
+ import matplotlib.pyplot as plt
26
+ from torchvision.transforms.functional import to_pil_image
27
+ from utils_mask import get_mask_location
28
+ from preprocess.humanparsing.run_parsing import Parsing
29
+ from preprocess.openpose.run_openpose import OpenPose
30
+ from ldm.util import instantiate_from_config, get_obj_from_str
31
+ from ldm.models.diffusion.ddim import DDIMSampler
32
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
33
+
34
+
35
+
36
+ if __name__ == "__main__":
37
+
38
+ parser = argparse.ArgumentParser(description="Script for demo model")
39
+ parser.add_argument("-b", "--base", type=str, default=r"configs/test_vitonhd.yaml")
40
+ parser.add_argument("-c", "--ckpt", type=str, default=r"ckpt/hitonhd.ckpt")
41
+ parser.add_argument("-s", "--seed", type=str, default=42)
42
+ parser.add_argument("-d", "--ddim", type=str, default=16)
43
+ opt = parser.parse_args()
44
+
45
+ seed_everything(opt.seed)
46
+ config = OmegaConf.load(f"{opt.base}")
47
+
48
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49
+
50
+ model = instantiate_from_config(config.model)
51
+ model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/basso4/hitonhd/resolve/main/hitonhd.ckpt")["state_dict"], strict=False)
52
+ model.cuda()
53
+ model.eval()
54
+ model = model.to(device)
55
+ sampler = DDIMSampler(model)
56
+
57
+ # model = instantiate_from_config(config.model)
58
+ # model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
59
+ # model.cuda()
60
+ # model.eval()
61
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
62
+ # model = model.to(device)
63
+ # sampler = DDIMSampler(model)
64
+
65
+ precision_scope = autocast
66
+
67
+
68
+ @spaces.GPU
69
+ def start_tryon(dict,garm_img):
70
+ #load human image
71
+ human_img = dict['background'].convert("RGB").resize((768,1024))
72
+
73
+ #mask
74
+ tensor_transfrom = transforms.Compose(
75
+ [
76
+ transforms.ToTensor(),
77
+ transforms.Normalize([0.5], [0.5]),
78
+ ]
79
+ )
80
+
81
+ parsing_model = Parsing(0)
82
+ openose_model = OpenPose(0)
83
+ openose_model.preprocessor.body_estimation.model.to(device)
84
+
85
+ keypoints = openose_model(human_img.resize((384,512)))
86
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
87
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
88
+ mask_cv = mask
89
+ mask = mask.resize((768, 1024))
90
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
91
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
92
+
93
+ #densepose
94
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
95
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
96
+ args = apply_net.create_argument_parser().parse_args(('show',
97
+ './configs/configs_densepose/densepose_rcnn_R_50_FPN_s1x.yaml',
98
+ 'https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl',
99
+ 'dp_segm', '-v',
100
+ '--opts',
101
+ 'MODEL.DEVICE',
102
+ 'cuda'))
103
+ # verbosity = getattr(args, "verbosity", None)
104
+ pose_img = args.func(args,human_img_arg)
105
+ pose_img = pose_img[:,:,::-1]
106
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
107
+
108
+ #preprocessing image
109
+ human_img = human_img.convert("RGB").resize((512, 512))
110
+ human_img = torchvision.transforms.ToTensor()(human_img)
111
+
112
+ garm_img = garm_img.convert("RGB").resize((224, 224))
113
+ garm_img = torchvision.transforms.ToTensor()(garm_img)
114
+
115
+ mask = mask.convert("L").resize((512,512))
116
+ mask = torchvision.transforms.ToTensor()(mask)
117
+ mask = 1-mask
118
+
119
+ pose_img = pose_img.convert("RGB").resize((512, 512))
120
+ pose_img = torchvision.transforms.ToTensor()(pose_img)
121
+
122
+ #Normalize
123
+ human_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(human_img)
124
+ garm_img = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
125
+ (0.26862954, 0.26130258, 0.27577711))(garm_img)
126
+ pose_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(pose_img)
127
+
128
+ #create inpaint & hint
129
+ inpaint = human_img * mask
130
+ hint = torchvision.transforms.Resize((512, 512))(garm_img)
131
+ hint = torch.cat((hint, pose_img), dim=0)
132
+
133
+ # {"human_img": human_img, # [3, 512, 512]
134
+ # "inpaint_image": inpaint, # [3, 512, 512]
135
+ # "inpaint_mask": mask, # [1, 512, 512]
136
+ # "garm_img": garm_img, # [3, 224, 224]
137
+ # "hint": hint, # [6, 512, 512]
138
+ # }
139
+
140
+
141
+ with torch.no_grad():
142
+ with precision_scope("cuda"):
143
+ #loading data
144
+ inpaint = inpaint.unsqueeze(0).to(torch.float16).to(device)
145
+ reference = garm_img.unsqueeze(0).to(torch.float16).to(device)
146
+ mask = mask.unsqueeze(0).to(torch.float16).to(device)
147
+ hint = hint.unsqueeze(0).to(torch.float16).to(device)
148
+ truth = human_img.unsqueeze(0).to(torch.float16).to(device)
149
+
150
+ #data preprocessing
151
+ encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
152
+ z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
153
+ mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
154
+ test_model_kwargs = {}
155
+ test_model_kwargs['inpaint_image'] = z_inpaint
156
+ test_model_kwargs['inpaint_mask'] = mask_resize
157
+ shape = (model.channels, model.image_size, model.image_size)
158
+
159
+ #predict
160
+ samples, _ = sampler.sample(S=opt.ddim,
161
+ batch_size=1,
162
+ shape=shape,
163
+ pose=hint,
164
+ conditioning=reference,
165
+ verbose=False,
166
+ eta=0,
167
+ test_model_kwargs=test_model_kwargs)
168
+ samples = 1. / model.scale_factor * samples
169
+ x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
170
+
171
+ x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
172
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
173
+ x_checked_image=x_samples_ddim
174
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
175
+
176
+ x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
177
+
178
+ #apply seamlessClone technique here
179
+ #img_base
180
+ dict = dict['background'].convert("RGB").resize((384, 512))
181
+ dict = np.array(dict)
182
+ dict = cv2.cvtColor(dict, cv2.COLOR_RGB2BGR)
183
+
184
+ #img_output
185
+ img_cv = rearrange(x_checked_image_torch[0], 'c h w -> h w c').cpu().numpy()
186
+ img_cv = (img_cv * 255).astype(np.uint8)
187
+ img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2BGR)
188
+
189
+ #mask
190
+ mask_cv = mask_cv.convert("L").resize((384,512))
191
+ mask_cv = np.array(mask_cv)
192
+ mask_cv = 255-mask_cv
193
+
194
+ img_C = cv2.seamlessClone(dict, img_cv, mask_cv, (192,256), cv2.NORMAL_CLONE)
195
+
196
+
197
+ return img_C, mask_gray
198
+
199
+
200
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
201
+
202
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
203
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
204
+
205
+ human_list = os.listdir(os.path.join(example_path,"human"))
206
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
207
+
208
+ human_ex_list = []
209
+ for ex_human in human_list_path:
210
+ ex_dict= {}
211
+ ex_dict['background'] = ex_human
212
+ ex_dict['layers'] = None
213
+ ex_dict['composite'] = None
214
+ human_ex_list.append(ex_dict)
215
+
216
+ ##default human
217
+
218
+
219
+ image_blocks = gr.Blocks().queue()
220
+ with image_blocks as demo:
221
+ gr.Markdown("## FPT_VTON πŸ‘•πŸ‘”πŸ‘š")
222
+ gr.Markdown("Virtual Try-on with your image and garment image")
223
+ with gr.Row():
224
+ with gr.Column():
225
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Picture or use Examples below', interactive=True)
226
+
227
+ example = gr.Examples(
228
+ inputs=imgs,
229
+ examples_per_page=10,
230
+ examples=human_list_path
231
+ )
232
+
233
+ with gr.Column():
234
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
235
+
236
+ example = gr.Examples(
237
+ inputs=garm_img,
238
+ examples_per_page=8,
239
+ examples=garm_list_path
240
+ )
241
+
242
+ with gr.Column():
243
+ image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=True)
244
+ try_button = gr.Button(value="Try-on")
245
+
246
+ # with gr.Column():
247
+ # image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
248
+
249
+ with gr.Column():
250
+ masked_img = gr.Image(label="Masked image output", elem_id="masked_img", show_download_button=True)
251
+
252
+
253
+ try_button.click(fn=start_tryon, inputs=[imgs,garm_img], outputs=[image_out_c,masked_img], api_name='tryon')
254
+
255
+
256
+
257
+ image_blocks.launch()