######################### # app.py for ZeroGPU # ######################### import os import sys import random import torch import gradio as gr import spaces # for ZeroGPU usage from typing import Sequence, Mapping, Any, Union import os os.environ["CUDA_HOME"] = "/usr/local/cuda" import subprocess def install_catvton_dependencies(): import shutil import pathlib # ---------------------------------------------------------------- # 1) Install detectron2 from GitHub # ---------------------------------------------------------------- print("[CatVTON Setup] Installing detectron2 from GitHub...") subprocess.check_call([ sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/detectron2.git@v0.6" ]) # ---------------------------------------------------------------- # 2) Remove detectron2 lines from CatVTON’s local requirements # Make sure the path matches your actual folder name: # 'The_Broccolator/custom_nodes/Comfyui-CatVTON' # ---------------------------------------------------------------- catvton_repo_path = "custom_nodes/Comfyui-CatVTON" catvton_req_path = f"{catvton_repo_path}/requirements.txt" catvton_req_modified = f"{catvton_repo_path}/requirements_modified.txt" print("[CatVTON Setup] Removing detectron2 lines from CatVTON’s requirements...") with open(catvton_req_path, "r") as fin, open(catvton_req_modified, "w") as fout: for line in fin: if "detectron2" not in line.lower(): fout.write(line) # ---------------------------------------------------------------- # 3) Install the rest of CatVTON’s requirements # ---------------------------------------------------------------- print("[CatVTON Setup] Installing CatVTON requirements (minus detectron2)...") subprocess.check_call([ sys.executable, "-m", "pip", "install", "-r", catvton_req_modified ]) # ---------------------------------------------------------------- # 4) Install DensePose # ---------------------------------------------------------------- print("[CatVTON Setup] Installing DensePose...") subprocess.check_call([ sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/detectron2.git@v0.6#subdirectory=projects/DensePose" ]) print("[CatVTON Setup] All CatVTON dependencies installed!\n") # ---------------------------------------------------- # Call the install function on startup # ---------------------------------------------------- install_catvton_dependencies() # 1) Load your token from environment (make sure HF_TOKEN is set) token = os.environ["HF_TOKEN"] from huggingface_hub import hf_hub_download import shutil import pathlib # Create the directories we need under 'The_Broccolator' pathlib.Path("The_Broccolator/models/vae").mkdir(parents=True, exist_ok=True) pathlib.Path("The_Broccolator/models/clip").mkdir(parents=True, exist_ok=True) pathlib.Path("The_Broccolator/models/clip_vision").mkdir(parents=True, exist_ok=True) pathlib.Path("The_Broccolator/models/unet").mkdir(parents=True, exist_ok=True) pathlib.Path("The_Broccolator/models/loras").mkdir(parents=True, exist_ok=True) pathlib.Path("The_Broccolator/models/style_models").mkdir(parents=True, exist_ok=True) # Download each gated model into the correct local folder hf_hub_download( repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="The_Broccolator/models/vae", use_auth_token=token ) hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="The_Broccolator/models/clip", use_auth_token=token ) hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="The_Broccolator/models/clip", use_auth_token=token ) hf_hub_download( repo_id="black-forest-labs/FLUX.1-Fill-dev", filename="flux1-fill-dev.safetensors", local_dir="The_Broccolator/models/unet", use_auth_token=token ) download_path = hf_hub_download( repo_id="zhengchong/CatVTON", filename="flux-lora/pytorch_lora_weights.safetensors", local_dir="The_Broccolator/models/loras", use_auth_token=token ) os.rename(download_path, os.path.join("The_Broccolator/models/loras", "catvton-flux-lora.safetensors")) download_path = hf_hub_download( repo_id="alimama-creative/FLUX.1-Turbo-Alpha", filename="diffusion_pytorch_model.safetensors", local_dir="The_Broccolator/models/loras", use_auth_token=token ) os.rename(download_path, os.path.join("The_Broccolator/models/loras", "alimama-flux-turbo-alpha.safetensors")) hf_hub_download( repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="The_Broccolator/models/clip_vision", use_auth_token=token ) hf_hub_download( repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="The_Broccolator/models/style_models", use_auth_token=token ) ############################# # ComfyUI (or Broccolator) Support Functions ############################# def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: comfyui_path = find_path("The_Broccolator") if comfyui_path is not None and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: try: from main import load_extra_path_config except ImportError: print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.") from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths is not None: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") add_comfyui_directory_to_sys_path() add_extra_model_paths() def import_custom_nodes() -> None: import asyncio import execution from nodes import init_extra_nodes import server loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() from nodes import NODE_CLASS_MAPPINGS ############################################# # MAIN PIPELINE with ZeroGPU Decorator ############################################# @spaces.GPU(duration=90) def generate_images(user_image_path): import_custom_nodes() with torch.inference_mode(): loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() loadimage_116 = loadimage.load_image(image="assets_black_tshirt.png") loadautomasker = NODE_CLASS_MAPPINGS["LoadAutoMasker"]() loadautomasker_120 = loadautomasker.load(catvton_path="zhengchong/CatVTON") loadcatvtonpipeline = NODE_CLASS_MAPPINGS["LoadCatVTONPipeline"]() loadcatvtonpipeline_123 = loadcatvtonpipeline.load( sd15_inpaint_path="runwayml/stable-diffusion-inpainting", catvton_path="zhengchong/CatVTON", mixed_precision="bf16", ) loadimage_264 = loadimage.load_image(image=user_image_path) randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() randomnoise_273 = randomnoise.get_noise(noise_seed=random.randint(1, 2**64)) downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]() downloadandloadflorence2model_274 = downloadandloadflorence2model.loadmodel( model="gokaygokay/Florence-2-Flux-Large", precision="fp16", attention="sdpa" ) automasker = NODE_CLASS_MAPPINGS["AutoMasker"]() automasker_119 = automasker.generate( cloth_type="overall", pipe=get_value_at_index(loadautomasker_120, 0), target_image=get_value_at_index(loadimage_264, 0), ) catvton = NODE_CLASS_MAPPINGS["CatVTON"]() catvton_121 = catvton.generate( seed=random.randint(1, 2**64), steps=50, cfg=2.5, pipe=get_value_at_index(loadcatvtonpipeline_123, 0), target_image=get_value_at_index(loadimage_264, 0), refer_image=get_value_at_index(loadimage_116, 0), mask_image=get_value_at_index(automasker_119, 0), ) florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]() florence2run_275 = florence2run.encode( text_input="Haircut", task="caption_to_phrase_grounding", fill_mask=True, keep_model_loaded=False, max_new_tokens=1024, num_beams=3, do_sample=True, output_mask_select="", seed=random.randint(1, 2**64), image=get_value_at_index(catvton_121, 0), florence2_model=get_value_at_index(downloadandloadflorence2model_274, 0), ) downloadandloadsam2model = NODE_CLASS_MAPPINGS["DownloadAndLoadSAM2Model"]() downloadandloadsam2model_277 = downloadandloadsam2model.loadmodel( model="sam2.1_hiera_large.safetensors", segmentor="single_image", device="cuda", precision="fp16", ) dualcliploadergguf = NODE_CLASS_MAPPINGS["DualCLIPLoaderGGUF"]() dualcliploadergguf_284 = dualcliploadergguf.load_clip( clip_name1="t5xxl_fp16.safetensors", clip_name2="clip_l.safetensors", type="flux", ) cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() cliptextencode_279 = cliptextencode.encode( text="Br0k0L8, Broccoli haircut with voluminous, textured curls on top resembling broccoli florets, contrasted by closely shaved tapered sides", clip=get_value_at_index(dualcliploadergguf_284, 0), ) clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() clipvisionloader_281 = clipvisionloader.load_clip( clip_name="sigclip_vision_patch14_384.safetensors" ) loadimage_289 = loadimage.load_image(image="assets_broc_ref.jpg") clipvisionencode_282 = clipvisionencode.encode( crop="center", clip_vision=get_value_at_index(clipvisionloader_281, 0), image=get_value_at_index(loadimage_289, 0), ) vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() vaeloader_285 = vaeloader.load_vae(vae_name="ae.safetensors") stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]() stylemodelloader_292 = stylemodelloader.load_style_model( style_model_name="flux1-redux-dev.safetensors" ) stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]() stylemodelapply_280 = stylemodelapply.apply_stylemodel( strength=1, strength_type="multiply", conditioning=get_value_at_index(cliptextencode_279, 0), style_model=get_value_at_index(stylemodelloader_292, 0), clip_vision_output=get_value_at_index(clipvisionencode_282, 0), ) fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() fluxguidance_288 = fluxguidance.append( guidance=30, conditioning=get_value_at_index(stylemodelapply_280, 0) ) conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]() conditioningzeroout_287 = conditioningzeroout.zero_out( conditioning=get_value_at_index(fluxguidance_288, 0) ) florence2tocoordinates = NODE_CLASS_MAPPINGS["Florence2toCoordinates"]() florence2tocoordinates_276 = florence2tocoordinates.segment( index="", batch=False, data=get_value_at_index(florence2run_275, 3) ) sam2segmentation = NODE_CLASS_MAPPINGS["Sam2Segmentation"]() sam2segmentation_278 = sam2segmentation.segment( keep_model_loaded=False, individual_objects=False, sam2_model=get_value_at_index(downloadandloadsam2model_277, 0), image=get_value_at_index(florence2run_275, 0), bboxes=get_value_at_index(florence2tocoordinates_276, 1), ) growmask = NODE_CLASS_MAPPINGS["GrowMask"]() growmask_299 = growmask.expand_mask( expand=35, tapered_corners=True, mask=get_value_at_index(sam2segmentation_278, 0), ) layermask_segformerb2clothesultra = NODE_CLASS_MAPPINGS["LayerMask: SegformerB2ClothesUltra"]() layermask_segformerb2clothesultra_300 = layermask_segformerb2clothesultra.segformer_ultra( face=True, hair=False, hat=False, sunglass=False, left_arm=False, right_arm=False, left_leg=False, right_leg=False, upper_clothes=True, skirt=False, pants=False, dress=False, belt=False, shoe=False, bag=False, scarf=True, detail_method="VITMatte", detail_erode=12, detail_dilate=6, black_point=0.15, white_point=0.99, process_detail=True, device="cuda", max_megapixels=2, image=get_value_at_index(catvton_121, 0), ) masks_subtract = NODE_CLASS_MAPPINGS["Masks Subtract"]() masks_subtract_296 = masks_subtract.subtract_masks( masks_a=get_value_at_index(growmask_299, 0), masks_b=get_value_at_index(layermask_segformerb2clothesultra_300, 1), ) inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]() inpaintmodelconditioning_286 = inpaintmodelconditioning.encode( noise_mask=True, positive=get_value_at_index(fluxguidance_288, 0), negative=get_value_at_index(conditioningzeroout_287, 0), vae=get_value_at_index(vaeloader_285, 0), pixels=get_value_at_index(catvton_121, 0), mask=get_value_at_index(masks_subtract_296, 0), ) unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() unetloader_291 = unetloader.load_unet( unet_name="flux1-fill-dev.safetensors", weight_dtype="default" ) loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() loraloadermodelonly_290 = loraloadermodelonly.load_lora_model_only( lora_name="alimama-flux-turbo-alpha.safetensors", strength_model=1, model=get_value_at_index(unetloader_291, 0), ) ksampler = NODE_CLASS_MAPPINGS["KSampler"]() vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]() saveimage = NODE_CLASS_MAPPINGS["SaveImage"]() # We'll do a single pass for q in range(1): ksampler_283 = ksampler.sample( seed=random.randint(1, 2**64), steps=10, cfg=1, sampler_name="dpmpp_2m", scheduler="sgm_uniform", denoise=1, model=get_value_at_index(loraloadermodelonly_290, 0), positive=get_value_at_index(inpaintmodelconditioning_286, 0), negative=get_value_at_index(inpaintmodelconditioning_286, 1), latent_image=get_value_at_index(inpaintmodelconditioning_286, 2), ) vaedecode_294 = vaedecode.decode( samples=get_value_at_index(ksampler_283, 0), vae=get_value_at_index(vaeloader_285, 0), ) saveimage_295 = saveimage.save_images( filename_prefix="The_Broccolator_", images=get_value_at_index(vaedecode_294, 0), ) # final output return f"output/{saveimage_295['ui']['images'][0]['filename']}" ################################### # A simple Gradio interface ################################### with gr.Blocks() as demo: gr.Markdown("## The Broccolator 🥦\nUpload an image for `loadimage_264` and see final output.") with gr.Row(): with gr.Column(): user_input_image = gr.Image(type="filepath", label="Input Image") btn_generate = gr.Button("Generate") with gr.Column(): final_image = gr.Image(label="Final output (saveimage_295)") btn_generate.click( fn=generate_images, inputs=user_input_image, outputs=final_image ) demo.launch(debug=True)