The_Broccolator / app.py
plaidam's picture
Update app.py
c0abc9d verified
#########################
# app.py for ZeroGPU #
#########################
import os
import sys
import random
import torch
import gradio as gr
import spaces # for ZeroGPU usage
from typing import Sequence, Mapping, Any, Union
import os
os.environ["CUDA_HOME"] = "/usr/local/cuda"
import subprocess
def install_catvton_dependencies():
import shutil
import pathlib
# ----------------------------------------------------------------
# 1) Install detectron2 from GitHub
# ----------------------------------------------------------------
print("[CatVTON Setup] Installing detectron2 from GitHub...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/[email protected]"
])
# ----------------------------------------------------------------
# 2) Remove detectron2 lines from CatVTON’s local requirements
# Make sure the path matches your actual folder name:
# 'The_Broccolator/custom_nodes/Comfyui-CatVTON'
# ----------------------------------------------------------------
catvton_repo_path = "custom_nodes/Comfyui-CatVTON"
catvton_req_path = f"{catvton_repo_path}/requirements.txt"
catvton_req_modified = f"{catvton_repo_path}/requirements_modified.txt"
print("[CatVTON Setup] Removing detectron2 lines from CatVTON’s requirements...")
with open(catvton_req_path, "r") as fin, open(catvton_req_modified, "w") as fout:
for line in fin:
if "detectron2" not in line.lower():
fout.write(line)
# ----------------------------------------------------------------
# 3) Install the rest of CatVTON’s requirements
# ----------------------------------------------------------------
print("[CatVTON Setup] Installing CatVTON requirements (minus detectron2)...")
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-r", catvton_req_modified
])
# ----------------------------------------------------------------
# 4) Install DensePose
# ----------------------------------------------------------------
print("[CatVTON Setup] Installing DensePose...")
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/facebookresearch/[email protected]#subdirectory=projects/DensePose"
])
print("[CatVTON Setup] All CatVTON dependencies installed!\n")
# ----------------------------------------------------
# Call the install function on startup
# ----------------------------------------------------
install_catvton_dependencies()
# 1) Load your token from environment (make sure HF_TOKEN is set)
token = os.environ["HF_TOKEN"]
from huggingface_hub import hf_hub_download
import shutil
import pathlib
# Create the directories we need under 'The_Broccolator'
pathlib.Path("The_Broccolator/models/vae").mkdir(parents=True, exist_ok=True)
pathlib.Path("The_Broccolator/models/clip").mkdir(parents=True, exist_ok=True)
pathlib.Path("The_Broccolator/models/clip_vision").mkdir(parents=True, exist_ok=True)
pathlib.Path("The_Broccolator/models/unet").mkdir(parents=True, exist_ok=True)
pathlib.Path("The_Broccolator/models/loras").mkdir(parents=True, exist_ok=True)
pathlib.Path("The_Broccolator/models/style_models").mkdir(parents=True, exist_ok=True)
# Download each gated model into the correct local folder
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="ae.safetensors",
local_dir="The_Broccolator/models/vae",
use_auth_token=token
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir="The_Broccolator/models/clip",
use_auth_token=token
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="clip_l.safetensors",
local_dir="The_Broccolator/models/clip",
use_auth_token=token
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Fill-dev",
filename="flux1-fill-dev.safetensors",
local_dir="The_Broccolator/models/unet",
use_auth_token=token
)
download_path = hf_hub_download(
repo_id="zhengchong/CatVTON",
filename="flux-lora/pytorch_lora_weights.safetensors",
local_dir="The_Broccolator/models/loras",
use_auth_token=token
)
os.rename(download_path, os.path.join("The_Broccolator/models/loras", "catvton-flux-lora.safetensors"))
download_path = hf_hub_download(
repo_id="alimama-creative/FLUX.1-Turbo-Alpha",
filename="diffusion_pytorch_model.safetensors",
local_dir="The_Broccolator/models/loras",
use_auth_token=token
)
os.rename(download_path, os.path.join("The_Broccolator/models/loras", "alimama-flux-turbo-alpha.safetensors"))
hf_hub_download(
repo_id="Comfy-Org/sigclip_vision_384",
filename="sigclip_vision_patch14_384.safetensors",
local_dir="The_Broccolator/models/clip_vision",
use_auth_token=token
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Redux-dev",
filename="flux1-redux-dev.safetensors",
local_dir="The_Broccolator/models/style_models",
use_auth_token=token
)
#############################
# ComfyUI (or Broccolator) Support Functions
#############################
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
if path is None:
path = os.getcwd()
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = find_path("The_Broccolator")
if comfyui_path is not None and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
try:
from main import load_extra_path_config
except ImportError:
print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths is not None:
load_extra_path_config(extra_model_paths)
else:
print("Could not find the extra_model_paths config file.")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
def import_custom_nodes() -> None:
import asyncio
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
from nodes import NODE_CLASS_MAPPINGS
#############################################
# MAIN PIPELINE with ZeroGPU Decorator
#############################################
@spaces.GPU(duration=90)
def generate_images(user_image_path):
import_custom_nodes()
with torch.inference_mode():
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loadimage_116 = loadimage.load_image(image="assets_black_tshirt.png")
loadautomasker = NODE_CLASS_MAPPINGS["LoadAutoMasker"]()
loadautomasker_120 = loadautomasker.load(catvton_path="zhengchong/CatVTON")
loadcatvtonpipeline = NODE_CLASS_MAPPINGS["LoadCatVTONPipeline"]()
loadcatvtonpipeline_123 = loadcatvtonpipeline.load(
sd15_inpaint_path="runwayml/stable-diffusion-inpainting",
catvton_path="zhengchong/CatVTON",
mixed_precision="bf16",
)
loadimage_264 = loadimage.load_image(image=user_image_path)
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
randomnoise_273 = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]()
downloadandloadflorence2model_274 = downloadandloadflorence2model.loadmodel(
model="gokaygokay/Florence-2-Flux-Large", precision="fp16", attention="sdpa"
)
automasker = NODE_CLASS_MAPPINGS["AutoMasker"]()
automasker_119 = automasker.generate(
cloth_type="overall",
pipe=get_value_at_index(loadautomasker_120, 0),
target_image=get_value_at_index(loadimage_264, 0),
)
catvton = NODE_CLASS_MAPPINGS["CatVTON"]()
catvton_121 = catvton.generate(
seed=random.randint(1, 2**64),
steps=50,
cfg=2.5,
pipe=get_value_at_index(loadcatvtonpipeline_123, 0),
target_image=get_value_at_index(loadimage_264, 0),
refer_image=get_value_at_index(loadimage_116, 0),
mask_image=get_value_at_index(automasker_119, 0),
)
florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]()
florence2run_275 = florence2run.encode(
text_input="Haircut",
task="caption_to_phrase_grounding",
fill_mask=True,
keep_model_loaded=False,
max_new_tokens=1024,
num_beams=3,
do_sample=True,
output_mask_select="",
seed=random.randint(1, 2**64),
image=get_value_at_index(catvton_121, 0),
florence2_model=get_value_at_index(downloadandloadflorence2model_274, 0),
)
downloadandloadsam2model = NODE_CLASS_MAPPINGS["DownloadAndLoadSAM2Model"]()
downloadandloadsam2model_277 = downloadandloadsam2model.loadmodel(
model="sam2.1_hiera_large.safetensors",
segmentor="single_image",
device="cuda",
precision="fp16",
)
dualcliploadergguf = NODE_CLASS_MAPPINGS["DualCLIPLoaderGGUF"]()
dualcliploadergguf_284 = dualcliploadergguf.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="clip_l.safetensors",
type="flux",
)
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
cliptextencode_279 = cliptextencode.encode(
text="Br0k0L8, Broccoli haircut with voluminous, textured curls on top resembling broccoli florets, contrasted by closely shaved tapered sides",
clip=get_value_at_index(dualcliploadergguf_284, 0),
)
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
clipvisionloader_281 = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
loadimage_289 = loadimage.load_image(image="assets_broc_ref.jpg")
clipvisionencode_282 = clipvisionencode.encode(
crop="center",
clip_vision=get_value_at_index(clipvisionloader_281, 0),
image=get_value_at_index(loadimage_289, 0),
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_285 = vaeloader.load_vae(vae_name="ae.safetensors")
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
stylemodelloader_292 = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
stylemodelapply = NODE_CLASS_MAPPINGS["StyleModelApply"]()
stylemodelapply_280 = stylemodelapply.apply_stylemodel(
strength=1,
strength_type="multiply",
conditioning=get_value_at_index(cliptextencode_279, 0),
style_model=get_value_at_index(stylemodelloader_292, 0),
clip_vision_output=get_value_at_index(clipvisionencode_282, 0),
)
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
fluxguidance_288 = fluxguidance.append(
guidance=30,
conditioning=get_value_at_index(stylemodelapply_280, 0)
)
conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]()
conditioningzeroout_287 = conditioningzeroout.zero_out(
conditioning=get_value_at_index(fluxguidance_288, 0)
)
florence2tocoordinates = NODE_CLASS_MAPPINGS["Florence2toCoordinates"]()
florence2tocoordinates_276 = florence2tocoordinates.segment(
index="", batch=False, data=get_value_at_index(florence2run_275, 3)
)
sam2segmentation = NODE_CLASS_MAPPINGS["Sam2Segmentation"]()
sam2segmentation_278 = sam2segmentation.segment(
keep_model_loaded=False,
individual_objects=False,
sam2_model=get_value_at_index(downloadandloadsam2model_277, 0),
image=get_value_at_index(florence2run_275, 0),
bboxes=get_value_at_index(florence2tocoordinates_276, 1),
)
growmask = NODE_CLASS_MAPPINGS["GrowMask"]()
growmask_299 = growmask.expand_mask(
expand=35,
tapered_corners=True,
mask=get_value_at_index(sam2segmentation_278, 0),
)
layermask_segformerb2clothesultra = NODE_CLASS_MAPPINGS["LayerMask: SegformerB2ClothesUltra"]()
layermask_segformerb2clothesultra_300 = layermask_segformerb2clothesultra.segformer_ultra(
face=True,
hair=False,
hat=False,
sunglass=False,
left_arm=False,
right_arm=False,
left_leg=False,
right_leg=False,
upper_clothes=True,
skirt=False,
pants=False,
dress=False,
belt=False,
shoe=False,
bag=False,
scarf=True,
detail_method="VITMatte",
detail_erode=12,
detail_dilate=6,
black_point=0.15,
white_point=0.99,
process_detail=True,
device="cuda",
max_megapixels=2,
image=get_value_at_index(catvton_121, 0),
)
masks_subtract = NODE_CLASS_MAPPINGS["Masks Subtract"]()
masks_subtract_296 = masks_subtract.subtract_masks(
masks_a=get_value_at_index(growmask_299, 0),
masks_b=get_value_at_index(layermask_segformerb2clothesultra_300, 1),
)
inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]()
inpaintmodelconditioning_286 = inpaintmodelconditioning.encode(
noise_mask=True,
positive=get_value_at_index(fluxguidance_288, 0),
negative=get_value_at_index(conditioningzeroout_287, 0),
vae=get_value_at_index(vaeloader_285, 0),
pixels=get_value_at_index(catvton_121, 0),
mask=get_value_at_index(masks_subtract_296, 0),
)
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
unetloader_291 = unetloader.load_unet(
unet_name="flux1-fill-dev.safetensors",
weight_dtype="default"
)
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
loraloadermodelonly_290 = loraloadermodelonly.load_lora_model_only(
lora_name="alimama-flux-turbo-alpha.safetensors",
strength_model=1,
model=get_value_at_index(unetloader_291, 0),
)
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
# We'll do a single pass
for q in range(1):
ksampler_283 = ksampler.sample(
seed=random.randint(1, 2**64),
steps=10,
cfg=1,
sampler_name="dpmpp_2m",
scheduler="sgm_uniform",
denoise=1,
model=get_value_at_index(loraloadermodelonly_290, 0),
positive=get_value_at_index(inpaintmodelconditioning_286, 0),
negative=get_value_at_index(inpaintmodelconditioning_286, 1),
latent_image=get_value_at_index(inpaintmodelconditioning_286, 2),
)
vaedecode_294 = vaedecode.decode(
samples=get_value_at_index(ksampler_283, 0),
vae=get_value_at_index(vaeloader_285, 0),
)
saveimage_295 = saveimage.save_images(
filename_prefix="The_Broccolator_",
images=get_value_at_index(vaedecode_294, 0),
)
# final output
return f"output/{saveimage_295['ui']['images'][0]['filename']}"
###################################
# A simple Gradio interface
###################################
with gr.Blocks() as demo:
gr.Markdown("## The Broccolator 🥦\nUpload an image for `loadimage_264` and see final output.")
with gr.Row():
with gr.Column():
user_input_image = gr.Image(type="filepath", label="Input Image")
btn_generate = gr.Button("Generate")
with gr.Column():
final_image = gr.Image(label="Final output (saveimage_295)")
btn_generate.click(
fn=generate_images,
inputs=user_input_image,
outputs=final_image
)
demo.launch(debug=True)