|
import os |
|
import shutil |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
from omegaconf import OmegaConf |
|
from torch.utils.data import DataLoader |
|
import torch |
|
from importlib import import_module |
|
from .cldm.model import create_model |
|
from .cldm.plms_hacked import PLMSSampler |
|
from .utils.utils import * |
|
from .utils.file_util import * |
|
|
|
vition_path = node_path("ComfyUI_Seg_VITON") |
|
cache_dir = os.path.join(vition_path,"cache") |
|
|
|
model_load_path = os.path.join( vition_path,"checkpoints/VITONHD.ckpt") |
|
yaml_path = os.path.join(vition_path,"configs/VITON512_COMFYUI.yaml") |
|
|
|
def tensor2img_seg(x): |
|
''' |
|
x : [BS x c x H x W] or [c x H x W] |
|
''' |
|
if x.ndim == 3: |
|
x = x.unsqueeze(0) |
|
BS, C, H, W = x.shape |
|
x = x.permute(0,2,3,1).reshape(-1, W, C).detach().cpu().numpy() |
|
x = np.clip(x, -1, 1) |
|
x = (x+1)/2 |
|
x = np.uint8(x*255.0) |
|
if x.shape[-1] == 1: |
|
x = np.concatenate([x,x,x], axis=-1) |
|
return x |
|
|
|
def imread(p, h, w, is_mask=False, in_inverse_mask=False, img=None): |
|
if img is None: |
|
img = cv2.imread(p) |
|
if not is_mask: |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = cv2.resize(img, (w,h)) |
|
img = (img.astype(np.float32) / 127.5) - 1.0 |
|
else: |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
img = cv2.resize(img, (w,h)) |
|
img = (img >= 128).astype(np.float32) |
|
img = img[:,:,None] |
|
if in_inverse_mask: |
|
img = 1-img |
|
return img |
|
|
|
|
|
|
|
class stabel_vition: |
|
def __init__(self): |
|
self.model = None |
|
self.sampler = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return {"required": |
|
{ |
|
"agn":("IMAGE", {"default": "","multiline": False}), |
|
"agn_mask":("MASK", {"default": "","multiline": False}), |
|
"cloth":("IMAGE", {"default": "","multiline": False}), |
|
"image":("IMAGE", {"default": "","multiline": False}), |
|
"image_densepose":("IMAGE", {"default": "","multiline": False}), |
|
"img_H": ("INT", {"default": 512, "min": 268, "max": 2048}), |
|
"img_W": ("INT", {"default": 384, "min": 268, "max": 2048}), |
|
"denoise_steps": ("INT", {"default": 20, "min": 5, "max": 200}), |
|
"batch_size": ("INT", {"default": 16, "min": 0, "max": 32, "step": 16}), |
|
"eta": ("INT", {"default": 0, "min": 0, "max": 200}), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"cache": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}), |
|
"repaint": ("BOOLEAN", {"default": False, "label_on": "enabled", "label_off": "disabled"}), |
|
|
|
} |
|
} |
|
RETURN_TYPES = ("IMAGE","BOOLEAN") |
|
RETURN_NAMES = ("image","open") |
|
OUTPUT_NODE = True |
|
FUNCTION = "sample" |
|
CATEGORY = "CXH" |
|
def sample(self,agn,agn_mask,cloth,image,image_densepose,img_H,img_W,denoise_steps,batch_size,eta,seed,cache,repaint): |
|
seed = str(seed) |
|
img_fn = seed+"_img.jpg" |
|
cloth_fn = seed+"_cloth.jpg" |
|
|
|
mkdir(cache_dir) |
|
agnostic_v3_2_dir = os.path.join(cache_dir,seed,"agnostic_v3_2") |
|
mkdir(agnostic_v3_2_dir) |
|
agnostic_v3_2_img_path = os.path.join(agnostic_v3_2_dir,img_fn) |
|
save_tensor_image(agn,agnostic_v3_2_img_path) |
|
|
|
agnostic_mask_dir = os.path.join(cache_dir,seed,"agnostic_mask") |
|
mkdir(agnostic_mask_dir) |
|
agnostic_mask_img_path = os.path.join(agnostic_mask_dir,img_fn) |
|
save_tensor_image(agn_mask,agnostic_mask_img_path) |
|
|
|
cloth_dir = os.path.join(cache_dir,seed,"cloth") |
|
mkdir(cloth_dir) |
|
cloth_img_path = os.path.join(cloth_dir,img_fn) |
|
save_tensor_image(cloth,cloth_img_path) |
|
|
|
image_dir = os.path.join(cache_dir,seed,"image") |
|
mkdir(image_dir) |
|
image_img_path = os.path.join(image_dir,img_fn) |
|
save_tensor_image(image,image_img_path) |
|
|
|
image_densepose_dir = os.path.join(cache_dir,seed,"image_densepose") |
|
mkdir(image_densepose_dir) |
|
image_densepose_img_path = os.path.join(image_densepose_dir,img_fn) |
|
save_tensor_image(image_densepose,image_densepose_img_path) |
|
|
|
agn = imread(agnostic_v3_2_img_path, img_H, img_W) |
|
agn_mask = imread(agnostic_mask_img_path, img_H, img_W, is_mask=True, in_inverse_mask=True) |
|
cloth = imread(cloth_img_path, img_H, img_W) |
|
image = imread(image_img_path, img_H, img_W) |
|
image_densepose = imread(image_densepose_img_path, img_H, img_W) |
|
|
|
|
|
config = OmegaConf.load(yaml_path) |
|
config.model.params.img_H = img_H |
|
config.model.params.img_W = img_W |
|
params = config.model.params |
|
|
|
if self.model == None: |
|
self.model = create_model(config_path=None, config=config) |
|
self.model.load_state_dict(torch.load(model_load_path, map_location="cpu")) |
|
self.model = self.model.cuda() |
|
self.model.eval() |
|
|
|
if self.sampler == None: |
|
self.sampler = PLMSSampler(self.model) |
|
|
|
dataset = getattr(import_module("comyui_dataset"), config.dataset_name)( |
|
img_fn, |
|
cloth_fn, |
|
agn, |
|
agn_mask, |
|
cloth, |
|
image, |
|
image_densepose, |
|
) |
|
dataloader = DataLoader(dataset, num_workers=4, shuffle=False, batch_size=batch_size, pin_memory=True) |
|
|
|
shape = (4, img_H//8, img_W//8) |
|
x_sample_list =[] |
|
|
|
for batch_idx, batch in enumerate(dataloader): |
|
print(f"{batch_idx}/{len(dataloader)}") |
|
z, c = self.model.get_input(batch, params.first_stage_key) |
|
bs = z.shape[0] |
|
c_crossattn = c["c_crossattn"][0][:bs] |
|
if c_crossattn.ndim == 4: |
|
c_crossattn = self.model.get_learned_conditioning(c_crossattn) |
|
c["c_crossattn"] = [c_crossattn] |
|
uc_cross = self.model.get_unconditional_conditioning(bs) |
|
uc_full = {"c_concat": c["c_concat"], "c_crossattn": [uc_cross]} |
|
uc_full["first_stage_cond"] = c["first_stage_cond"] |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.cuda() |
|
self.sampler.model.batch = batch |
|
|
|
ts = torch.full((1,), 999, device=z.device, dtype=torch.long) |
|
start_code = self.model.q_sample(z, ts) |
|
|
|
samples, _, _ = self.sampler.sample( |
|
denoise_steps, |
|
bs, |
|
shape, |
|
c, |
|
x_T=start_code, |
|
verbose=False, |
|
eta=eta, |
|
unconditional_conditioning=uc_full, |
|
) |
|
|
|
x_samples = self.model.decode_first_stage(samples) |
|
for sample_idx, (x_sample, fn, cloth_fn) in enumerate(zip(x_samples, batch['img_fn'], batch["cloth_fn"])): |
|
x_sample_img = tensor2img_seg(x_sample) |
|
x_sample_list.append(x_sample_img) |
|
if repaint: |
|
repaint_agn_img = np.uint8((batch["image"][sample_idx].cpu().numpy()+1)/2 * 255) |
|
repaint_agn_mask_img = batch["agn_mask"][sample_idx].cpu().numpy() |
|
x_sample_img = repaint_agn_img * repaint_agn_mask_img + x_sample_img * (1-repaint_agn_mask_img) |
|
x_sample_img = np.uint8(x_sample_img) |
|
to_path = os.path.join(cache_dir,seed,"result_"+str(sample_idx)+".jpg") |
|
cv2.imwrite(to_path, x_sample_img[:,:,::-1]) |
|
|
|
if not cache: |
|
shutil.rmtree(os.path.join(cache_dir,seed)) |
|
|
|
return pil2tensor(x_sample_list[0]),True |
|
|
|
|