import os import shutil import numpy as np import torchvision.transforms as transforms import cv2 from omegaconf import OmegaConf from torch.utils.data import DataLoader import torch from importlib import import_module from .cldm.model import create_model from .cldm.plms_hacked import PLMSSampler from .utils.utils import * from .utils.file_util import * vition_path = node_path("ComfyUI_Seg_VITON") cache_dir = os.path.join(vition_path,"cache") model_load_path = os.path.join( vition_path,"checkpoints/VITONHD.ckpt") yaml_path = os.path.join(vition_path,"configs/VITON512_COMFYUI.yaml") def tensor2img_seg(x): ''' x : [BS x c x H x W] or [c x H x W] ''' if x.ndim == 3: x = x.unsqueeze(0) BS, C, H, W = x.shape x = x.permute(0,2,3,1).reshape(-1, W, C).detach().cpu().numpy() x = np.clip(x, -1, 1) x = (x+1)/2 x = np.uint8(x*255.0) if x.shape[-1] == 1: x = np.concatenate([x,x,x], axis=-1) return x def imread(p, h, w, is_mask=False, in_inverse_mask=False, img=None): if img is None: img = cv2.imread(p) if not is_mask: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (w,h)) img = (img.astype(np.float32) / 127.5) - 1.0 # [-1, 1] else: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.resize(img, (w,h)) img = (img >= 128).astype(np.float32) # 0 or 1 img = img[:,:,None] if in_inverse_mask: img = 1-img return img class stabel_vition: def __init__(self): self.model = None self.sampler = None @classmethod def INPUT_TYPES(cls): return {"required": { "agn":("IMAGE", {"default": "","multiline": False}), "agn_mask":("MASK", {"default": "","multiline": False}), "cloth":("IMAGE", {"default": "","multiline": False}), "image":("IMAGE", {"default": "","multiline": False}), "image_densepose":("IMAGE", {"default": "","multiline": False}), "img_H": ("INT", {"default": 512, "min": 268, "max": 2048}), "img_W": ("INT", {"default": 384, "min": 268, "max": 2048}), "denoise_steps": ("INT", {"default": 20, "min": 5, "max": 200}), "batch_size": ("INT", {"default": 16, "min": 0, "max": 32, "step": 16}), "eta": ("INT", {"default": 0, "min": 0, "max": 200}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "cache": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}), "repaint": ("BOOLEAN", {"default": False, "label_on": "enabled", "label_off": "disabled"}), } } RETURN_TYPES = ("IMAGE","BOOLEAN") RETURN_NAMES = ("image","open") OUTPUT_NODE = True FUNCTION = "sample" CATEGORY = "CXH" def sample(self,agn,agn_mask,cloth,image,image_densepose,img_H,img_W,denoise_steps,batch_size,eta,seed,cache,repaint): seed = str(seed) img_fn = seed+"_img.jpg" cloth_fn = seed+"_cloth.jpg" #创建缓存文件夹 +缓存本地(待优化直接tensor转cv2) mkdir(cache_dir) agnostic_v3_2_dir = os.path.join(cache_dir,seed,"agnostic_v3_2") mkdir(agnostic_v3_2_dir) agnostic_v3_2_img_path = os.path.join(agnostic_v3_2_dir,img_fn) save_tensor_image(agn,agnostic_v3_2_img_path) agnostic_mask_dir = os.path.join(cache_dir,seed,"agnostic_mask") mkdir(agnostic_mask_dir) agnostic_mask_img_path = os.path.join(agnostic_mask_dir,img_fn) save_tensor_image(agn_mask,agnostic_mask_img_path) cloth_dir = os.path.join(cache_dir,seed,"cloth") mkdir(cloth_dir) cloth_img_path = os.path.join(cloth_dir,img_fn) save_tensor_image(cloth,cloth_img_path) image_dir = os.path.join(cache_dir,seed,"image") mkdir(image_dir) image_img_path = os.path.join(image_dir,img_fn) save_tensor_image(image,image_img_path) image_densepose_dir = os.path.join(cache_dir,seed,"image_densepose") mkdir(image_densepose_dir) image_densepose_img_path = os.path.join(image_densepose_dir,img_fn) save_tensor_image(image_densepose,image_densepose_img_path) agn = imread(agnostic_v3_2_img_path, img_H, img_W) agn_mask = imread(agnostic_mask_img_path, img_H, img_W, is_mask=True, in_inverse_mask=True) cloth = imread(cloth_img_path, img_H, img_W) image = imread(image_img_path, img_H, img_W) image_densepose = imread(image_densepose_img_path, img_H, img_W) config = OmegaConf.load(yaml_path) config.model.params.img_H = img_H config.model.params.img_W = img_W params = config.model.params if self.model == None: self.model = create_model(config_path=None, config=config) self.model.load_state_dict(torch.load(model_load_path, map_location="cpu")) self.model = self.model.cuda() self.model.eval() if self.sampler == None: self.sampler = PLMSSampler(self.model) dataset = getattr(import_module("comyui_dataset"), config.dataset_name)( img_fn, cloth_fn, agn, agn_mask, cloth, image, image_densepose, ) dataloader = DataLoader(dataset, num_workers=4, shuffle=False, batch_size=batch_size, pin_memory=True) shape = (4, img_H//8, img_W//8) x_sample_list =[] for batch_idx, batch in enumerate(dataloader): print(f"{batch_idx}/{len(dataloader)}") z, c = self.model.get_input(batch, params.first_stage_key) bs = z.shape[0] c_crossattn = c["c_crossattn"][0][:bs] if c_crossattn.ndim == 4: c_crossattn = self.model.get_learned_conditioning(c_crossattn) c["c_crossattn"] = [c_crossattn] uc_cross = self.model.get_unconditional_conditioning(bs) uc_full = {"c_concat": c["c_concat"], "c_crossattn": [uc_cross]} uc_full["first_stage_cond"] = c["first_stage_cond"] for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.cuda() self.sampler.model.batch = batch ts = torch.full((1,), 999, device=z.device, dtype=torch.long) start_code = self.model.q_sample(z, ts) samples, _, _ = self.sampler.sample( denoise_steps, bs, shape, c, x_T=start_code, verbose=False, eta=eta, unconditional_conditioning=uc_full, ) x_samples = self.model.decode_first_stage(samples) for sample_idx, (x_sample, fn, cloth_fn) in enumerate(zip(x_samples, batch['img_fn'], batch["cloth_fn"])): x_sample_img = tensor2img_seg(x_sample) x_sample_list.append(x_sample_img) if repaint: repaint_agn_img = np.uint8((batch["image"][sample_idx].cpu().numpy()+1)/2 * 255) # [0,255] repaint_agn_mask_img = batch["agn_mask"][sample_idx].cpu().numpy() # 0 or 1 x_sample_img = repaint_agn_img * repaint_agn_mask_img + x_sample_img * (1-repaint_agn_mask_img) x_sample_img = np.uint8(x_sample_img) to_path = os.path.join(cache_dir,seed,"result_"+str(sample_idx)+".jpg") cv2.imwrite(to_path, x_sample_img[:,:,::-1]) if not cache: shutil.rmtree(os.path.join(cache_dir,seed)) return pil2tensor(x_sample_list[0]),True