|
from argparse import ArgumentParser |
|
from datetime import datetime |
|
from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline |
|
from diffusers.utils import load_image |
|
from os import makedirs, path |
|
from pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline |
|
import torch |
|
from time import time |
|
from utils import * |
|
from utils.media import preprocess |
|
from utils.sdxl import * |
|
import yaml |
|
|
|
|
|
@torch.no_grad() |
|
def inference( |
|
pipe, refiner, device, |
|
structure_image, appearance_image, |
|
prompt, structure_prompt, appearance_prompt, |
|
positive_prompt, negative_prompt, |
|
guidance_scale, structure_guidance_scale, appearance_guidance_scale, |
|
num_inference_steps, eta, seed, |
|
width, height, |
|
structure_schedule, appearance_schedule, |
|
): |
|
seed_everything(seed) |
|
|
|
|
|
|
|
if structure_image is not None and isinstance(args.structure_image, str): |
|
structure_image = load_image(args.structure_image) |
|
structure_image = preprocess(structure_image, pipe.image_processor, |
|
height=height, width=width, resize_mode="crop") |
|
if appearance_image is not None: |
|
appearance_image = load_image(appearance_image) |
|
appearance_image = preprocess(appearance_image, pipe.image_processor, |
|
height=height, width=width, resize_mode="crop") |
|
|
|
|
|
|
|
pipe.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = pipe.scheduler.timesteps |
|
control_config = get_control_config(structure_schedule, appearance_schedule) |
|
print(f"\nUsing the following control config:\n{control_config}\n") |
|
config = yaml.safe_load(control_config) |
|
register_control( |
|
model=pipe, |
|
timesteps=timesteps, |
|
control_schedule=config["control_schedule"], |
|
control_target=config["control_target"], |
|
) |
|
|
|
|
|
pipe.safety_checker = None |
|
pipe.requires_safety_checker = False |
|
self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps) |
|
pipe.set_progress_bar_config(desc="Ctrl-X inference") |
|
|
|
|
|
result, structure, appearance = pipe( |
|
prompt=prompt, |
|
structure_prompt=structure_prompt, |
|
appearance_prompt=appearance_prompt, |
|
structure_image=structure_image, |
|
appearance_image=appearance_image, |
|
num_inference_steps=num_inference_steps, |
|
negative_prompt=negative_prompt, |
|
positive_prompt=positive_prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale, |
|
structure_guidance_scale=structure_guidance_scale, |
|
appearance_guidance_scale=appearance_guidance_scale, |
|
eta=eta, |
|
output_type="pil", |
|
return_dict=False, |
|
control_schedule=config["control_schedule"], |
|
self_recurrence_schedule=self_recurrence_schedule, |
|
) |
|
result_refiner = [None] |
|
|
|
del pipe.refiner_args |
|
|
|
return result[0], result_refiner[0], structure[0], appearance[0] |
|
|
|
|
|
@torch.no_grad() |
|
def main(args): |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model_id_or_path = "OzzyGT/SSD-1B" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
variant = "fp16" if device == "cuda" else "fp32" |
|
|
|
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") |
|
|
|
if args.model is None: |
|
pipe = CtrlXStableDiffusionXLPipeline.from_pretrained( |
|
model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True, |
|
) |
|
else: |
|
print(f"Using weights {args.model} for SDXL base model.") |
|
pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype) |
|
|
|
if args.model_offload or args.sequential_offload: |
|
try: |
|
import accelerate |
|
except: |
|
raise ModuleNotFoundError("`accelerate` must be installed for Model/CPU offloading.") |
|
|
|
if args.sequential_offload: |
|
pipe.enable_sequential_cpu_offload() |
|
elif args.model_offload: |
|
pipe.enable_model_cpu_offload() |
|
else: |
|
pipe = pipe.to(device) |
|
|
|
model_load_print = "Base model " |
|
if not args.disable_refiner: |
|
model_load_print += "+ refiner " |
|
if args.sequential_offload: |
|
model_load_print += "loaded with sequential CPU offloading." |
|
elif args.model_offload: |
|
model_load_print += "loaded with model CPU offloading." |
|
else: |
|
model_load_print += "loaded." |
|
print(f"{model_load_print} Running on device: {device}.") |
|
|
|
t = time() |
|
|
|
result, result_refiner, structure, appearance = inference( |
|
pipe=pipe, |
|
refiner=None, |
|
device=device, |
|
structure_image=args.structure_image, |
|
appearance_image=args.appearance_image, |
|
prompt=args.prompt, |
|
structure_prompt=args.structure_prompt, |
|
appearance_prompt=args.appearance_prompt, |
|
positive_prompt=args.positive_prompt, |
|
negative_prompt=args.negative_prompt, |
|
guidance_scale=args.guidance_scale, |
|
structure_guidance_scale=args.structure_guidance_scale, |
|
appearance_guidance_scale=args.appearance_guidance_scale, |
|
num_inference_steps=args.num_inference_steps, |
|
eta=args.eta, |
|
seed=args.seed, |
|
width=args.width, |
|
height=args.height, |
|
structure_schedule=args.structure_schedule, |
|
appearance_schedule=args.appearance_schedule, |
|
) |
|
|
|
makedirs(args.output_folder, exist_ok=True) |
|
prefix = "ctrlx__" + datetime.now().strftime("%Y%m%d_%H%M%S") |
|
structure.save(path.join(args.output_folder, f"{prefix}__structure.jpg"), quality=JPEG_QUALITY) |
|
appearance.save(path.join(args.output_folder, f"{prefix}__appearance.jpg"), quality=JPEG_QUALITY) |
|
result.save(path.join(args.output_folder, f"{prefix}__result.jpg"), quality=JPEG_QUALITY) |
|
if result_refiner is not None: |
|
result_refiner.save(path.join(args.output_folder, f"{prefix}__result_refiner.jpg"), quality=JPEG_QUALITY) |
|
|
|
if args.benchmark: |
|
inference_time = time() - t |
|
peak_memory_usage = torch.cuda.max_memory_reserved() |
|
print(f"Inference time: {inference_time:.2f}s") |
|
print(f"Peak memory usage: {peak_memory_usage / pow(1024, 3):.2f}GiB") |
|
|
|
print("Done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("--structure_image", "-si", type=str, default=None) |
|
parser.add_argument("--appearance_image", "-ai", type=str, default=None) |
|
|
|
parser.add_argument("--prompt", "-p", type=str, required=True) |
|
parser.add_argument("--structure_prompt", "-sp", type=str, default="") |
|
parser.add_argument("--appearance_prompt", "-ap", type=str, default="") |
|
|
|
parser.add_argument("--positive_prompt", "-pp", type=str, default="high quality") |
|
parser.add_argument("--negative_prompt", "-np", type=str, default="ugly, blurry, dark, low res, unrealistic") |
|
|
|
parser.add_argument("--guidance_scale", "-g", type=float, default=5.0) |
|
parser.add_argument("--structure_guidance_scale", "-sg", type=float, default=5.0) |
|
parser.add_argument("--appearance_guidance_scale", "-ag", type=float, default=5.0) |
|
|
|
parser.add_argument("--num_inference_steps", "-n", type=int, default=50) |
|
parser.add_argument("--eta", "-e", type=float, default=1.0) |
|
parser.add_argument("--seed", "-s", type=int, default=90095) |
|
|
|
parser.add_argument("--width", "-W", type=int, default=1024) |
|
parser.add_argument("--height", "-H", type=int, default=1024) |
|
|
|
parser.add_argument("--structure_schedule", "-ss", type=float, default=0.6) |
|
parser.add_argument("--appearance_schedule", "-as", type=float, default=0.6) |
|
|
|
parser.add_argument("--output_folder", "-o", type=str, default="./results") |
|
|
|
parser.add_argument( |
|
"-mo", "--model_offload", action="store_true", |
|
help="Model CPU offload, lowers memory usage with slight runtime increase. `accelerate` must be installed.", |
|
) |
|
parser.add_argument( |
|
"-so", "--sequential_offload", action="store_true", |
|
help=( |
|
"Sequential layer CPU offload, significantly lowers memory usage with massive runtime increase." |
|
"`accelerate` must be installed. If both model_offload and sequential_offload are set, then use the latter." |
|
), |
|
) |
|
parser.add_argument("-r", "--disable_refiner", action="store_true") |
|
parser.add_argument("-m", "--model", type=str, default=None, help="Optionally, load model safetensors.") |
|
parser.add_argument("-b", "--benchmark", action="store_true", help="Show inference time and max memory usage.") |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|