import torch import argparse from diffusers.utils import load_image, check_min_version from controlnet_flux import FluxControlNetModel from transformer_flux import FluxTransformer2DModel from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline def main(image, mask, prompt): check_min_version("0.30.2") # Enable memory optimizations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True # Set environment variable for memory allocation import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # Build pipeline components controlnet = FluxControlNetModel.from_pretrained( "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16, ).to("cuda") transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16, ).to("cuda") pipe = FluxControlNetInpaintingPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=controlnet, transformer=transformer, torch_dtype=torch.bfloat16, ).to("cuda") # Enable memory efficient attention pipe.enable_attention_slicing(1) # Load and process images size = (384, 384) # or even (256, 256) image = image.convert("RGB").resize(size) mask = mask.convert("RGB").resize(size) # Set generator generator = torch.Generator(device="cuda").manual_seed(24) # Run inference with memory optimizations with torch.cuda.amp.autocast(): # Enable automatic mixed precision result = pipe( prompt=prompt, height=size[1], width=size[0], control_image=image, control_mask=mask, num_inference_steps=28, generator=generator, controlnet_conditioning_scale=0.9, guidance_scale=3.5, negative_prompt="", true_guidance_scale=1.0, ).images[0] # Clear cache after generation torch.cuda.empty_cache() print("Successfully inpaint image") return result if __name__ == "__main__": parser = argparse.ArgumentParser( description="Inpaint an image using FluxControlNetInpaintingPipeline." ) parser.add_argument( "--image_path", type=str, required=True, help="Path to the input image." ) parser.add_argument( "--mask_path", type=str, required=True, help="Path to the mask image." ) parser.add_argument( "--prompt", type=str, required=True, help="Prompt for the inpainting process." ) args = parser.parse_args() result = main(args.image_path, args.mask_path, args.prompt) result.save("output.png")