Spaces:
Running
on
Zero
Running
on
Zero
######################### | |
# app.py for ZeroGPU # | |
######################### | |
import os | |
import sys | |
import random | |
import torch | |
import gradio as gr | |
import spaces # for ZeroGPU usage | |
from typing import Sequence, Mapping, Any, Union | |
import os | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
import subprocess | |
def install_catvton_dependencies(): | |
import shutil | |
import pathlib | |
# ---------------------------------------------------------------- | |
# 1) Install detectron2 from GitHub | |
# ---------------------------------------------------------------- | |
print("[CatVTON Setup] Installing detectron2 from GitHub...") | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/[email protected]" | |
]) | |
# ---------------------------------------------------------------- | |
# 2) Remove detectron2 lines from CatVTON’s local requirements | |
# Make sure the path matches your actual folder name: | |
# 'The_Broccolator/custom_nodes/Comfyui-CatVTON' | |
# ---------------------------------------------------------------- | |
catvton_repo_path = "custom_nodes/Comfyui-CatVTON" | |
catvton_req_path = f"{catvton_repo_path}/requirements.txt" | |
catvton_req_modified = f"{catvton_repo_path}/requirements_modified.txt" | |
print("[CatVTON Setup] Removing detectron2 lines from CatVTON’s requirements...") | |
with open(catvton_req_path, "r") as fin, open(catvton_req_modified, "w") as fout: | |
for line in fin: | |
if "detectron2" not in line.lower(): | |
fout.write(line) | |
# ---------------------------------------------------------------- | |
# 3) Install the rest of CatVTON’s requirements | |
# ---------------------------------------------------------------- | |
print("[CatVTON Setup] Installing CatVTON requirements (minus detectron2)...") | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", "-r", catvton_req_modified | |
]) | |
# ---------------------------------------------------------------- | |
# 4) Install DensePose | |
# ---------------------------------------------------------------- | |
print("[CatVTON Setup] Installing DensePose...") | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", | |
"git+https://github.com/facebookresearch/[email protected]#subdirectory=projects/DensePose" | |
]) | |
print("[CatVTON Setup] All CatVTON dependencies installed!\n") | |
# ---------------------------------------------------- | |
# Call the install function on startup | |
# ---------------------------------------------------- | |
install_catvton_dependencies() | |
# 1) Load your token from environment (make sure HF_TOKEN is set) | |
token = os.environ["HF_TOKEN"] | |
from huggingface_hub import hf_hub_download | |
import shutil | |
import pathlib | |
# Create the directories we need under 'The_Broccolator' | |
pathlib.Path("The_Broccolator/models/vae").mkdir(parents=True, exist_ok=True) | |
pathlib.Path("The_Broccolator/models/clip").mkdir(parents=True, exist_ok=True) | |
pathlib.Path("The_Broccolator/models/clip_vision").mkdir(parents=True, exist_ok=True) | |
pathlib.Path("The_Broccolator/models/unet").mkdir(parents=True, exist_ok=True) | |
pathlib.Path("The_Broccolator/models/loras").mkdir(parents=True, exist_ok=True) | |
pathlib.Path("The_Broccolator/models/style_models").mkdir(parents=True, exist_ok=True) | |
# Download each gated model into the correct local folder | |
hf_hub_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
filename="ae.safetensors", | |
local_dir="The_Broccolator/models/vae", | |
use_auth_token=token | |
) | |
hf_hub_download( | |
repo_id="comfyanonymous/flux_text_encoders", | |
filename="t5xxl_fp16.safetensors", | |
local_dir="The_Broccolator/models/clip", | |
use_auth_token=token | |
) | |
hf_hub_download( | |
repo_id="comfyanonymous/flux_text_encoders", | |
filename="clip_l.safetensors", | |
local_dir="The_Broccolator/models/clip", | |
use_auth_token=token | |
) | |
hf_hub_download( | |
repo_id="black-forest-labs/FLUX.1-Fill-dev", | |
filename="flux1-fill-dev.safetensors", | |
local_dir="The_Broccolator/models/unet", | |
use_auth_token=token | |
) | |
download_path = hf_hub_download( | |
repo_id="zhengchong/CatVTON", | |
filename="flux-lora/pytorch_lora_weights.safetensors", | |
local_dir="The_Broccolator/models/loras", | |
use_auth_token=token | |
) | |
os.rename(download_path, os.path.join("The_Broccolator/models/loras", "catvton-flux-lora.safetensors")) | |
download_path = hf_hub_download( | |
repo_id="alimama-creative/FLUX.1-Turbo-Alpha", | |
filename="diffusion_pytorch_model.safetensors", | |
local_dir="The_Broccolator/models/loras", | |
use_auth_token=token | |
) | |
os.rename(download_path, os.path.join("The_Broccolator/models/loras", "alimama-flux-turbo-alpha.safetensors")) | |
hf_hub_download( | |
repo_id="Comfy-Org/sigclip_vision_384", | |
filename="sigclip_vision_patch14_384.safetensors", | |
local_dir="The_Broccolator/models/clip_vision", | |
use_auth_token=token | |
) | |
hf_hub_download( | |
repo_id="black-forest-labs/FLUX.1-Redux-dev", | |
filename="flux1-redux-dev.safetensors", | |
local_dir="The_Broccolator/models/style_models", | |
use_auth_token=token | |
) | |
############################# | |
# ComfyUI (or Broccolator) Support Functions | |
############################# | |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
try: | |
return obj[index] | |
except KeyError: | |
return obj["result"][index] | |
def find_path(name: str, path: str = None) -> str: | |
if path is None: | |
path = os.getcwd() | |
if name in os.listdir(path): | |
path_name = os.path.join(path, name) | |
print(f"{name} found: {path_name}") | |
return path_name | |
parent_directory = os.path.dirname(path) | |
if parent_directory == path: | |
return None | |
return find_path(name, parent_directory) | |
def add_comfyui_directory_to_sys_path() -> None: | |
comfyui_path = find_path("The_Broccolator") | |
if comfyui_path is not None and os.path.isdir(comfyui_path): | |
sys.path.append(comfyui_path) | |
print(f"'{comfyui_path}' added to sys.path") | |
def add_extra_model_paths() -> None: | |
try: | |
from main import load_extra_path_config | |
except ImportError: | |
print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.") | |
from utils.extra_config import load_extra_path_config | |
extra_model_paths = find_path("extra_model_paths.yaml") | |
if extra_model_paths is not None: | |
load_extra_path_config(extra_model_paths) | |
else: | |
print("Could not find the extra_model_paths config file.") | |
add_comfyui_directory_to_sys_path() | |
add_extra_model_paths() | |
def import_custom_nodes() -> None: | |
import asyncio | |
import execution | |
from nodes import init_extra_nodes | |
import server | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
server_instance = server.PromptServer(loop) | |
execution.PromptQueue(server_instance) | |
init_extra_nodes() | |
from nodes import NODE_CLASS_MAPPINGS | |
############################################# | |
# MAIN PIPELINE with ZeroGPU Decorator | |
############################################# | |
def generate_images(user_image_path): | |
import_custom_nodes() | |
with torch.inference_mode(): | |
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() | |
loadimage_116 = loadimage.load_image(image="assets_black_tshirt.png") | |
loadautomasker = NODE_CLASS_MAPPINGS["LoadAutoMasker"]() | |
loadautomasker_120 = loadautomasker.load(catvton_path="zhengchong/CatVTON") | |
loadcatvtonpipeline = NODE_CLASS_MAPPINGS["LoadCatVTONPipeline"]() | |
loadcatvtonpipeline_123 = loadcatvtonpipeline.load( | |
sd15_inpaint_path="runwayml/stable-diffusion-inpainting", | |
catvton_path="zhengchong/CatVTON", | |
mixed_precision="bf16", | |
) | |
loadimage_264 = loadimage.load_image(image=user_image_path) | |
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() | |
randomnoise_273 = randomnoise.get_noise(noise_seed=random.randint(1, 2**64)) | |
downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]() | |
downloadandloadflorence2model_274 = downloadandloadflorence2model.loadmodel( | |
model="gokaygokay/Florence-2-Flux-Large", precision="fp16", attention="sdpa" | |
) | |
automasker = NODE_CLASS_MAPPINGS["AutoMasker"]() | |
automasker_119 = automasker.generate( | |
cloth_type="overall", | |
pipe=get_value_at_index(loadautomasker_120, 0), | |
target_image=get_value_at_index(loadimage_264, 0), | |
) | |
catvton = NODE_CLASS_MAPPINGS["CatVTON"]() | |
catvton_121 = catvton.generate( | |
seed=random.randint(1, 2**64), | |
steps=50, | |
cfg=2.5, | |
pipe=get_value_at_index(loadcatvtonpipeline_123, 0), | |
target_image=get_value_at_index(loadimage_264, 0), | |
refer_image=get_value_at_index(loadimage_116, 0), | |
mask_image=get_value_at_index(automasker_119, 0), | |
) | |
florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]() | |
florence2run_275 = florence2run.encode( | |
text_input="Haircut", | |
task="caption_to_phrase_grounding", | |
fill_mask=True, | |
keep_model_loaded=False, | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=True, | |
output_mask_select="", | |
seed=random.randint(1, 2**64), | |
image=get_value_at_index(catvton_121, 0), | |
florence2_model=get_value_at_index(downloadandloadflorence2model_274, 0), | |
) | |
downloadandloadsam2model = NODE_CLASS_MAPPINGS["DownloadAndLoadSAM2Model"]() | |
downloadandloadsam2model_277 = downloadandloadsam2model.loadmodel( | |
model="sam2.1_hiera_large.safetensors", | |
segmentor="single_image", | |
device="cuda", | |
precision="fp16", | |
) | |
dualcliploadergguf = NODE_CLASS_MAPPINGS["DualCLIPLoaderGGUF"]() | |
dualcliploadergguf_284 = dualcliploadergguf.load_clip( | |
clip_name1="t5xxl_fp16.safetensors", | |
clip_name2="clip_l.safetensors", | |
type="flux", | |
) | |
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
cliptextencode_279 = cliptextencode.encode( | |
text="Br0k0L8, Broccoli haircut with voluminous, textured curls on top resembling broccoli florets, contrasted by closely shaved tapered sides", | |
clip=get_value_at_index(dualcliploadergguf_284, 0), | |
) | |
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() | |
clipvisionloader_281 = clipvisionloader.load_clip( | |
clip_name="sigclip_vision_patch14_384.safetensors" | |
) | |
loadimage_289 = loadimage.load_image(image="assets_broc_ref.jpg") | |
clipvisionencode_282 = clipvisionencode.encode( | |
crop="center", | |
clip_vision=get_value_at_index(clipvisionloader_281, 0), | |
image=get_value_at_index(loadimage_289, 0), | |
) | |
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
vaeloader_285 = vaeloader.load_vae(vae_name="ae.safetensors") | |
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]() | |
stylemodelloader_292 = stylemodelloader.load_style_model( | |
style_model_name="flux1-redux-dev.safetensors" | |
) | |
stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]() | |
stylemodelapply_280 = stylemodelapply.apply_stylemodel( | |
strength=1, | |
strength_type="multiply", | |
conditioning=get_value_at_index(cliptextencode_279, 0), | |
style_model=get_value_at_index(stylemodelloader_292, 0), | |
clip_vision_output=get_value_at_index(clipvisionencode_282, 0), | |
) | |
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() | |
fluxguidance_288 = fluxguidance.append( | |
guidance=30, | |
conditioning=get_value_at_index(stylemodelapply_280, 0) | |
) | |
conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]() | |
conditioningzeroout_287 = conditioningzeroout.zero_out( | |
conditioning=get_value_at_index(fluxguidance_288, 0) | |
) | |
florence2tocoordinates = NODE_CLASS_MAPPINGS["Florence2toCoordinates"]() | |
florence2tocoordinates_276 = florence2tocoordinates.segment( | |
index="", batch=False, data=get_value_at_index(florence2run_275, 3) | |
) | |
sam2segmentation = NODE_CLASS_MAPPINGS["Sam2Segmentation"]() | |
sam2segmentation_278 = sam2segmentation.segment( | |
keep_model_loaded=False, | |
individual_objects=False, | |
sam2_model=get_value_at_index(downloadandloadsam2model_277, 0), | |
image=get_value_at_index(florence2run_275, 0), | |
bboxes=get_value_at_index(florence2tocoordinates_276, 1), | |
) | |
growmask = NODE_CLASS_MAPPINGS["GrowMask"]() | |
growmask_299 = growmask.expand_mask( | |
expand=35, | |
tapered_corners=True, | |
mask=get_value_at_index(sam2segmentation_278, 0), | |
) | |
layermask_segformerb2clothesultra = NODE_CLASS_MAPPINGS["LayerMask: SegformerB2ClothesUltra"]() | |
layermask_segformerb2clothesultra_300 = layermask_segformerb2clothesultra.segformer_ultra( | |
face=True, | |
hair=False, | |
hat=False, | |
sunglass=False, | |
left_arm=False, | |
right_arm=False, | |
left_leg=False, | |
right_leg=False, | |
upper_clothes=True, | |
skirt=False, | |
pants=False, | |
dress=False, | |
belt=False, | |
shoe=False, | |
bag=False, | |
scarf=True, | |
detail_method="VITMatte", | |
detail_erode=12, | |
detail_dilate=6, | |
black_point=0.15, | |
white_point=0.99, | |
process_detail=True, | |
device="cuda", | |
max_megapixels=2, | |
image=get_value_at_index(catvton_121, 0), | |
) | |
masks_subtract = NODE_CLASS_MAPPINGS["Masks Subtract"]() | |
masks_subtract_296 = masks_subtract.subtract_masks( | |
masks_a=get_value_at_index(growmask_299, 0), | |
masks_b=get_value_at_index(layermask_segformerb2clothesultra_300, 1), | |
) | |
inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]() | |
inpaintmodelconditioning_286 = inpaintmodelconditioning.encode( | |
noise_mask=True, | |
positive=get_value_at_index(fluxguidance_288, 0), | |
negative=get_value_at_index(conditioningzeroout_287, 0), | |
vae=get_value_at_index(vaeloader_285, 0), | |
pixels=get_value_at_index(catvton_121, 0), | |
mask=get_value_at_index(masks_subtract_296, 0), | |
) | |
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
unetloader_291 = unetloader.load_unet( | |
unet_name="flux1-fill-dev.safetensors", | |
weight_dtype="default" | |
) | |
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() | |
loraloadermodelonly_290 = loraloadermodelonly.load_lora_model_only( | |
lora_name="alimama-flux-turbo-alpha.safetensors", | |
strength_model=1, | |
model=get_value_at_index(unetloader_291, 0), | |
) | |
ksampler = NODE_CLASS_MAPPINGS["KSampler"]() | |
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]() | |
saveimage = NODE_CLASS_MAPPINGS["SaveImage"]() | |
# We'll do a single pass | |
for q in range(1): | |
ksampler_283 = ksampler.sample( | |
seed=random.randint(1, 2**64), | |
steps=10, | |
cfg=1, | |
sampler_name="dpmpp_2m", | |
scheduler="sgm_uniform", | |
denoise=1, | |
model=get_value_at_index(loraloadermodelonly_290, 0), | |
positive=get_value_at_index(inpaintmodelconditioning_286, 0), | |
negative=get_value_at_index(inpaintmodelconditioning_286, 1), | |
latent_image=get_value_at_index(inpaintmodelconditioning_286, 2), | |
) | |
vaedecode_294 = vaedecode.decode( | |
samples=get_value_at_index(ksampler_283, 0), | |
vae=get_value_at_index(vaeloader_285, 0), | |
) | |
saveimage_295 = saveimage.save_images( | |
filename_prefix="The_Broccolator_", | |
images=get_value_at_index(vaedecode_294, 0), | |
) | |
# final output | |
return f"output/{saveimage_295['ui']['images'][0]['filename']}" | |
################################### | |
# A simple Gradio interface | |
################################### | |
with gr.Blocks() as demo: | |
gr.Markdown("## The Broccolator 🥦\nUpload an image for `loadimage_264` and see final output.") | |
with gr.Row(): | |
with gr.Column(): | |
user_input_image = gr.Image(type="filepath", label="Input Image") | |
btn_generate = gr.Button("Generate") | |
with gr.Column(): | |
final_image = gr.Image(label="Final output (saveimage_295)") | |
btn_generate.click( | |
fn=generate_images, | |
inputs=user_input_image, | |
outputs=final_image | |
) | |
demo.launch(debug=True) |