|
import os |
|
import sys |
|
import time |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
|
|
from utils_stableviton import get_mask_location |
|
|
|
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute() |
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
from preprocess.detectron2.projects.DensePose.apply_net_gradio import DensePose4Gradio |
|
from preprocess.humanparsing.run_parsing import Parsing |
|
from preprocess.openpose.run_openpose import OpenPose |
|
|
|
os.environ['GRADIO_TEMP_DIR'] = './tmp' |
|
|
|
|
|
openpose_model_hd = OpenPose(0) |
|
parsing_model_hd = Parsing(0) |
|
densepose_model_hd = DensePose4Gradio( |
|
cfg='preprocess/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml', |
|
model='https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl', |
|
) |
|
stable_viton_model_hd = ... |
|
|
|
category_dict = ['upperbody', 'lowerbody', 'dress'] |
|
category_dict_utils = ['upper_body', 'lower_body', 'dresses'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_hd(vton_img, garm_img, n_samples, n_steps, guidance_scale, seed): |
|
model_type = 'hd' |
|
category = 0 |
|
|
|
with torch.no_grad(): |
|
openpose_model_hd.preprocessor.body_estimation.model.to('cuda') |
|
|
|
stt = time.time() |
|
print('load images... ', end='') |
|
garm_img = Image.open(garm_img).resize((768, 1024)) |
|
vton_img = Image.open(vton_img).resize((768, 1024)) |
|
print('%.2fs' % (time.time() - stt)) |
|
|
|
stt = time.time() |
|
print('get agnostic map... ', end='') |
|
keypoints = openpose_model_hd(vton_img.resize((384, 512))) |
|
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512))) |
|
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints) |
|
mask = mask.resize((768, 1024), Image.NEAREST) |
|
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST) |
|
masked_vton_img = Image.composite(mask_gray, vton_img, mask) |
|
print('%.2fs' % (time.time() - stt)) |
|
|
|
stt = time.time() |
|
print('get densepose... ', end='') |
|
vton_img = vton_img.resize((768, 1024)) |
|
densepose = densepose_model_hd.execute(vton_img) |
|
print('%.2fs' % (time.time() - stt)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_path = os.path.join(os.path.dirname(__file__), 'examples') |
|
model_hd = os.path.join(example_path, 'model/model_1.png') |
|
garment_hd = os.path.join(example_path, 'garment/00055_00.jpg') |
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
gr.HTML( |
|
""" |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<div> |
|
<h1>StableVITON Demo πππ</h1> |
|
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
|
<a href='https://arxiv.org/abs/2312.01725'> |
|
<img src="https://img.shields.io/badge/arXiv-2312.01725-red"> |
|
</a> |
|
|
|
<a href='https://rlawjdghek.github.io/StableVITON/'> |
|
<img src='https://img.shields.io/badge/page-github.io-blue.svg'> |
|
</a> |
|
|
|
<a href='https://github.com/rlawjdghek/StableVITON'> |
|
<img src='https://img.shields.io/github/stars/rlawjdghek/StableVITON'> |
|
</a> |
|
|
|
<a href='https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode'> |
|
<img src='https://img.shields.io/badge/license-CC_BY--NC--SA_4.0-lightgrey'> |
|
</a> |
|
</div> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
with gr.Row(): |
|
gr.Markdown("## Experience virtual try-on with your own images!") |
|
with gr.Row(): |
|
with gr.Column(): |
|
vton_img = gr.Image(label="Model", type="filepath", height=384, value=model_hd) |
|
example = gr.Examples( |
|
inputs=vton_img, |
|
examples_per_page=14, |
|
examples=[ |
|
os.path.join(example_path, 'model/model_1.png'), |
|
os.path.join(example_path, 'model/model_2.png'), |
|
os.path.join(example_path, 'model/model_3.png'), |
|
]) |
|
with gr.Column(): |
|
garm_img = gr.Image(label="Garment", type="filepath", height=384, value=garment_hd) |
|
example = gr.Examples( |
|
inputs=garm_img, |
|
examples_per_page=14, |
|
examples=[ |
|
os.path.join(example_path, 'garment/00055_00.jpg'), |
|
os.path.join(example_path, 'garment/00126_00.jpg'), |
|
os.path.join(example_path, 'garment/00151_00.jpg'), |
|
]) |
|
with gr.Column(): |
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1) |
|
with gr.Column(): |
|
run_button = gr.Button(value="Run") |
|
|
|
n_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) |
|
n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1) |
|
guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1) |
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) |
|
|
|
ips = [vton_img, garm_img, n_samples, n_steps, guidance_scale, seed] |
|
run_button.click(fn=process_hd, inputs=ips, outputs=[result_gallery]) |
|
|
|
demo.launch() |
|
|