from argparse import ArgumentParser from datetime import datetime from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline from diffusers.utils import load_image from os import makedirs, path from pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline import torch from time import time from utils import * from utils.media import preprocess from utils.sdxl import * import yaml @torch.no_grad() def inference( pipe, refiner, device, structure_image, appearance_image, prompt, structure_prompt, appearance_prompt, positive_prompt, negative_prompt, guidance_scale, structure_guidance_scale, appearance_guidance_scale, num_inference_steps, eta, seed, width, height, structure_schedule, appearance_schedule, ): seed_everything(seed) # Process images. # Moved from CtrlXStableDiffusionXLPipeline.__call__. if structure_image is not None and isinstance(args.structure_image, str): structure_image = load_image(args.structure_image) structure_image = preprocess(structure_image, pipe.image_processor, height=height, width=width, resize_mode="crop") if appearance_image is not None: appearance_image = load_image(appearance_image) appearance_image = preprocess(appearance_image, pipe.image_processor, height=height, width=width, resize_mode="crop") # Scheduler. pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps control_config = get_control_config(structure_schedule, appearance_schedule) print(f"\nUsing the following control config:\n{control_config}\n") config = yaml.safe_load(control_config) register_control( model=pipe, timesteps=timesteps, control_schedule=config["control_schedule"], control_target=config["control_target"], ) # Pipe settings. pipe.safety_checker = None pipe.requires_safety_checker = False self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps) pipe.set_progress_bar_config(desc="Ctrl-X inference") # Inference. result, structure, appearance = pipe( prompt=prompt, structure_prompt=structure_prompt, appearance_prompt=appearance_prompt, structure_image=structure_image, appearance_image=appearance_image, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, positive_prompt=positive_prompt, height=height, width=width, guidance_scale=guidance_scale, structure_guidance_scale=structure_guidance_scale, appearance_guidance_scale=appearance_guidance_scale, eta=eta, output_type="pil", return_dict=False, control_schedule=config["control_schedule"], self_recurrence_schedule=self_recurrence_schedule, ) result_refiner = [None] del pipe.refiner_args return result[0], result_refiner[0], structure[0], appearance[0] @torch.no_grad() def main(args): torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id_or_path = "OzzyGT/SSD-1B" # refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0" device = "cuda" if torch.cuda.is_available() else "cpu" variant = "fp16" if device == "cuda" else "fp32" scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") if args.model is None: pipe = CtrlXStableDiffusionXLPipeline.from_pretrained( model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True, ) else: print(f"Using weights {args.model} for SDXL base model.") pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype) if args.model_offload or args.sequential_offload: try: import accelerate # Checking if accelerate is installed for Model/CPU offloading except: raise ModuleNotFoundError("`accelerate` must be installed for Model/CPU offloading.") if args.sequential_offload: pipe.enable_sequential_cpu_offload() elif args.model_offload: pipe.enable_model_cpu_offload() else: pipe = pipe.to(device) model_load_print = "Base model " if not args.disable_refiner: model_load_print += "+ refiner " if args.sequential_offload: model_load_print += "loaded with sequential CPU offloading." elif args.model_offload: model_load_print += "loaded with model CPU offloading." else: model_load_print += "loaded." print(f"{model_load_print} Running on device: {device}.") t = time() result, result_refiner, structure, appearance = inference( pipe=pipe, refiner=None, device=device, structure_image=args.structure_image, appearance_image=args.appearance_image, prompt=args.prompt, structure_prompt=args.structure_prompt, appearance_prompt=args.appearance_prompt, positive_prompt=args.positive_prompt, negative_prompt=args.negative_prompt, guidance_scale=args.guidance_scale, structure_guidance_scale=args.structure_guidance_scale, appearance_guidance_scale=args.appearance_guidance_scale, num_inference_steps=args.num_inference_steps, eta=args.eta, seed=args.seed, width=args.width, height=args.height, structure_schedule=args.structure_schedule, appearance_schedule=args.appearance_schedule, ) makedirs(args.output_folder, exist_ok=True) prefix = "ctrlx__" + datetime.now().strftime("%Y%m%d_%H%M%S") structure.save(path.join(args.output_folder, f"{prefix}__structure.jpg"), quality=JPEG_QUALITY) appearance.save(path.join(args.output_folder, f"{prefix}__appearance.jpg"), quality=JPEG_QUALITY) result.save(path.join(args.output_folder, f"{prefix}__result.jpg"), quality=JPEG_QUALITY) if result_refiner is not None: result_refiner.save(path.join(args.output_folder, f"{prefix}__result_refiner.jpg"), quality=JPEG_QUALITY) if args.benchmark: inference_time = time() - t peak_memory_usage = torch.cuda.max_memory_reserved() print(f"Inference time: {inference_time:.2f}s") print(f"Peak memory usage: {peak_memory_usage / pow(1024, 3):.2f}GiB") print("Done.") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--structure_image", "-si", type=str, default=None) parser.add_argument("--appearance_image", "-ai", type=str, default=None) parser.add_argument("--prompt", "-p", type=str, required=True) parser.add_argument("--structure_prompt", "-sp", type=str, default="") parser.add_argument("--appearance_prompt", "-ap", type=str, default="") parser.add_argument("--positive_prompt", "-pp", type=str, default="high quality") parser.add_argument("--negative_prompt", "-np", type=str, default="ugly, blurry, dark, low res, unrealistic") parser.add_argument("--guidance_scale", "-g", type=float, default=5.0) parser.add_argument("--structure_guidance_scale", "-sg", type=float, default=5.0) parser.add_argument("--appearance_guidance_scale", "-ag", type=float, default=5.0) parser.add_argument("--num_inference_steps", "-n", type=int, default=50) parser.add_argument("--eta", "-e", type=float, default=1.0) parser.add_argument("--seed", "-s", type=int, default=90095) parser.add_argument("--width", "-W", type=int, default=1024) parser.add_argument("--height", "-H", type=int, default=1024) parser.add_argument("--structure_schedule", "-ss", type=float, default=0.6) parser.add_argument("--appearance_schedule", "-as", type=float, default=0.6) parser.add_argument("--output_folder", "-o", type=str, default="./results") parser.add_argument( "-mo", "--model_offload", action="store_true", help="Model CPU offload, lowers memory usage with slight runtime increase. `accelerate` must be installed.", ) parser.add_argument( "-so", "--sequential_offload", action="store_true", help=( "Sequential layer CPU offload, significantly lowers memory usage with massive runtime increase." "`accelerate` must be installed. If both model_offload and sequential_offload are set, then use the latter." ), ) parser.add_argument("-r", "--disable_refiner", action="store_true") parser.add_argument("-m", "--model", type=str, default=None, help="Optionally, load model safetensors.") parser.add_argument("-b", "--benchmark", action="store_true", help="Show inference time and max memory usage.") args = parser.parse_args() main(args)