import spaces import gradio as gr import os import math from preprocess.humanparsing.run_parsing import Parsing from preprocess.dwpose import DWposeDetector from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor import torch import torch.nn as nn from src.pose_guider import PoseGuider from PIL import Image from src.utils_mask import get_mask_location import numpy as np from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton import cv2 import random from huggingface_hub import snapshot_download example_path = os.path.join(os.path.dirname(__file__), 'examples') access_token = os.getenv("HUGGING_FACE_HUB_TOKEN") fitdit_repo = "BoyuanJiang/FitDiT" repo_path = snapshot_download(repo_id=fitdit_repo) class FitDiTGenerator: def __init__(self, model_root, device="cuda", with_fp16=False): weight_dtype = torch.float16 if with_fp16 else torch.bfloat16 transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(model_root, "transformer_garm"), torch_dtype=weight_dtype) transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(model_root, "transformer_vton"), torch_dtype=weight_dtype) pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512)) pose_guider.load_state_dict(torch.load(os.path.join(model_root, "pose_guider", "diffusion_pytorch_model.bin"))) image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype) image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype) pose_guider.to(device=device, dtype=weight_dtype) image_encoder_large.to(device=device) image_encoder_bigG.to(device=device) self.pipeline = StableDiffusion3TryOnPipeline.from_pretrained(model_root, torch_dtype=weight_dtype, transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG) self.pipeline.to(device) self.dwprocessor = DWposeDetector(model_root=model_root, device=device) self.parsing_model = Parsing(model_root=model_root, device=device) @spaces.GPU def generate_mask(self, vton_img, category, offset_top, offset_bottom, offset_left, offset_right): with torch.inference_mode(): vton_img = Image.open(vton_img) vton_img_det = resize_image(vton_img) pose_image, keypoints, _, candidate = self.dwprocessor(np.array(vton_img_det)[:,:,::-1]) candidate[candidate<0]=0 candidate = candidate[0] candidate[:, 0]*=vton_img_det.width candidate[:, 1]*=vton_img_det.height pose_image = pose_image[:,:,::-1] #rgb pose_image = Image.fromarray(pose_image) model_parse, _ = self.parsing_model(vton_img_det) mask, mask_gray = get_mask_location(category, model_parse, \ candidate, model_parse.width, model_parse.height, \ offset_top, offset_bottom, offset_left, offset_right) mask = mask.resize(vton_img.size) mask_gray = mask_gray.resize(vton_img.size) mask = mask.convert("L") mask_gray = mask_gray.convert("L") masked_vton_img = Image.composite(mask_gray, vton_img, mask) im = {} im['background'] = np.array(vton_img.convert("RGBA")) im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)] im['composite'] = np.array(masked_vton_img.convert("RGBA")) return im, pose_image @spaces.GPU def process(self, vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution): assert resolution in ["768x1024", "1152x1536", "1536x2048"] new_width, new_height = resolution.split("x") new_width = int(new_width) new_height = int(new_height) with torch.inference_mode(): garm_img = Image.open(garm_img) vton_img = Image.open(vton_img) model_image_size = vton_img.size garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height) vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height) mask = pre_mask["layers"][0][:,:,3] mask = Image.fromarray(mask) mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0)) mask = mask.convert("L") pose_image = Image.fromarray(pose_image) pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0)) if seed==-1: seed = random.randint(0, 2147483647) res = self.pipeline( height=new_height, width=new_width, guidance_scale=image_scale, num_inference_steps=n_steps, generator=torch.Generator("cpu").manual_seed(seed), cloth_image=garm_img, model_image=vton_img, mask=mask, pose_image=pose_image, num_images_per_prompt=num_images_per_prompt ).images for idx in range(len(res)): res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1]) return res def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS): old_width, old_height = im.size ratio_w = new_width / old_width ratio_h = new_height / old_height if ratio_w < ratio_h: new_size = (new_width, round(old_height * ratio_w)) else: new_size = (round(old_width * ratio_h), new_height) im_resized = im.resize(new_size, mode) pad_w = math.ceil((new_width - im_resized.width) / 2) pad_h = math.ceil((new_height - im_resized.height) / 2) new_im = Image.new('RGB', (new_width, new_height), pad_color) new_im.paste(im_resized, (pad_w, pad_h)) return new_im, pad_w, pad_h def unpad_and_resize(padded_im, pad_w, pad_h, original_width, original_height): width, height = padded_im.size left = pad_w top = pad_h right = width - pad_w bottom = height - pad_h cropped_im = padded_im.crop((left, top, right, bottom)) resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS) return resized_im def resize_image(img, target_size=768): width, height = img.size if width < height: scale = target_size / width else: scale = target_size / height new_width = int(round(width * scale)) new_height = int(round(height * scale)) resized_img = img.resize((new_width, new_height), Image.LANCZOS) return resized_img HEADER = """