import argparse import torch from PIL import Image import numpy as np from model import FluxModel def parse_args(): parser = argparse.ArgumentParser(description='Flux Image Generation Tool') # Required arguments parser.add_argument('--mode', type=str, required=True, choices=['variation', 'img2img', 'inpaint', 'controlnet', 'controlnet-inpaint'], help='Generation mode') parser.add_argument('--input_image', type=str, required=True, help='Path to the input image') # Optional arguments parser.add_argument('--prompt', type=str, default="", help='Text prompt to guide the generation') parser.add_argument('--reference_image', type=str, default=None, help='Path to the reference image (for img2img/controlnet modes)') parser.add_argument('--mask_image', type=str, default=None, help='Path to the mask image (for inpainting modes)') parser.add_argument('--output_dir', type=str, default='outputs', help='Directory to save generated images') parser.add_argument('--image_count', type=int, default=1, help='Number of images to generate') parser.add_argument('--aspect_ratio', type=str, default='1:1', choices=['1:1', '16:9', '9:16', '2.4:1', '3:4', '4:3'], help='Output image aspect ratio') parser.add_argument('--steps', type=int, default=28, help='Number of inference steps') parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance scale for generation') parser.add_argument('--denoise_strength', type=float, default=0.8, help='Denoising strength for img2img/inpaint') # Attention related arguments parser.add_argument('--center_x', type=float, default=None, help='X coordinate of attention center (0-1)') parser.add_argument('--center_y', type=float, default=None, help='Y coordinate of attention center (0-1)') parser.add_argument('--radius', type=float, default=None, help='Radius of attention circle (0-1)') # ControlNet related arguments parser.add_argument('--line_mode', action='store_true', help='Enable line detection mode for ControlNet') parser.add_argument('--depth_mode', action='store_true', help='Enable depth mode for ControlNet') parser.add_argument('--line_strength', type=float, default=0.4, help='Strength of line guidance') parser.add_argument('--depth_strength', type=float, default=0.2, help='Strength of depth guidance') # Device selection parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='Device to run the model on') parser.add_argument('--turbo', action='store_true', help='Enable turbo mode for faster inference') return parser.parse_args() def load_image(image_path): """Load and return a PIL Image.""" try: return Image.open(image_path).convert('RGB') except Exception as e: raise ValueError(f"Error loading image {image_path}: {str(e)}") def save_images(images, output_dir, prefix="generated"): """Save generated images with sequential numbering.""" import os os.makedirs(output_dir, exist_ok=True) for i, image in enumerate(images): output_path = os.path.join(output_dir, f"{prefix}_{i+1}.png") image.save(output_path) print(f"Saved image to {output_path}") def get_required_features(args): """Determine which model features are required based on the arguments.""" features = [] if args.mode in ['controlnet', 'controlnet-inpaint']: features.append('controlnet') if args.depth_mode: features.append('depth') if args.line_mode: features.append('line') if args.mode in ['inpaint', 'controlnet-inpaint']: features.append('sam') # If you're using SAM for mask generation return features def main(): args = parse_args() # Check CUDA availability if requested if args.device == 'cuda' and not torch.cuda.is_available(): print("CUDA requested but not available. Falling back to CPU.") args.device = 'cpu' # Determine required features based on mode and arguments required_features = get_required_features(args) # Initialize model with only required features print(f"Initializing model on {args.device} with features: {required_features}") model = FluxModel( is_turbo=args.turbo, device=args.device, required_features=required_features ) # Load input images input_image = load_image(args.input_image) reference_image = load_image(args.reference_image) if args.reference_image else None mask_image = load_image(args.mask_image) if args.mask_image else None # Validate inputs based on mode if args.mode in ['inpaint', 'controlnet-inpaint'] and mask_image is None: raise ValueError(f"{args.mode} mode requires a mask image") # Generate images print(f"Generating {args.image_count} images in {args.mode} mode...") generated_images = model.generate( input_image_a=input_image, input_image_b=reference_image, prompt=args.prompt, mask_image=mask_image, mode=args.mode, imageCount=args.image_count, aspect_ratio=args.aspect_ratio, num_inference_steps=args.steps, guidance_scale=args.guidance_scale, denoise_strength=args.denoise_strength, center_x=args.center_x, center_y=args.center_y, radius=args.radius, line_mode=args.line_mode, depth_mode=args.depth_mode, line_strength=args.line_strength, depth_strength=args.depth_strength ) # Save generated images save_images(generated_images, args.output_dir) print("Generation completed successfully!") if __name__ == "__main__": main()