import torch from diffusers import ( DDPMScheduler, StableDiffusionXLImg2ImgPipeline, ) from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents from PIL import Image from inversion_utils import get_ddpm_inversion_scheduler, create_xts from config import get_config, get_num_steps_actual from functools import partial from compel import Compel, ReturnedEmbeddingsType from hidiffusion import apply_hidiffusion, remove_hidiffusion class Object(object): pass args = Object() args.images_paths = None args.images_folder = None args.force_use_cpu = False args.folder_name = 'test_measure_time' args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml' args.save_intermediate_results = False args.batch_size = None args.skip_p_to_p = True args.only_p_to_p = False args.fp16 = False args.prompts_file = 'dataset_measure_time/dataset.json' args.images_in_prompts_file = None args.seed = 986 args.time_measure_n = 1 assert ( args.batch_size is None or args.save_intermediate_results is False ), "save_intermediate_results is not implemented for batch_size > 1" generator = None device = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # BASE_MODEL = "stabilityai/sdxl-turbo" pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ) pipeline = pipeline.to(device) pipeline.scheduler = DDPMScheduler.from_pretrained( BASE_MODEL, subfolder="scheduler", ) apply_hidiffusion(pipeline) config = get_config(args) compel_proc = Compel( tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True] ) def run( input_image:Image, src_prompt:str, tgt_prompt:str, generate_size:int, seed:int, w1:float, w2:float, num_steps:int, start_step:int, guidance_scale:float, ): generator = torch.Generator().manual_seed(seed) config.num_steps_inversion = num_steps config.step_start = start_step num_steps_actual = get_num_steps_actual(config) num_steps_inversion = config.num_steps_inversion denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}") timesteps, num_inference_steps = retrieve_timesteps( pipeline.scheduler, num_steps_inversion, device, None ) timesteps, num_inference_steps = pipeline.get_timesteps( num_inference_steps=num_inference_steps, denoising_start=denoising_start, strength=0, device=device, ) timesteps = timesteps.type(torch.int64) timesteps = [torch.tensor(t) for t in timesteps.tolist()] timesteps_len = len(timesteps) config.step_start = start_step + num_steps_actual - timesteps_len num_steps_actual = timesteps_len config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5] print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}") print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}") pipeline.__call__ = partial( pipeline.__call__, num_inference_steps=num_steps_inversion, guidance_scale=guidance_scale, generator=generator, denoising_start=denoising_start, strength=0, ) x_0_image = input_image x_0 = encode_image(x_0_image, pipeline) x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False) x_ts = [xt.to(dtype=torch.float16) for xt in x_ts] latents = [x_ts[0]] x_ts_c_hat = [None] config.ws1 = [w1] * num_steps_actual config.ws2 = [w2] * num_steps_actual pipeline.scheduler = get_ddpm_inversion_scheduler( pipeline.scheduler, config.step_function, config, timesteps, config.save_timesteps, latents, x_ts, x_ts_c_hat, args.save_intermediate_results, pipeline, x_0, v1s_images := [], v2s_images := [], deltas_images := [], v1_x0s := [], v2_x0s := [], deltas_x0s := [], "res12", image_name="im_name", time_measure_n=args.time_measure_n, ) latent = latents[0].expand(3, -1, -1, -1) prompt = [src_prompt, src_prompt, tgt_prompt] conditioning, pooled = compel_proc(prompt) image = pipeline.__call__( image=latent, prompt_embeds=conditioning, pooled_prompt_embeds=pooled, eta=1, ).images return image[2] def encode_image(image, pipe): image = pipe.image_processor.preprocess(image) originDtype = pipe.dtype image = image.to(device=device, dtype=originDtype) if pipe.vae.config.force_upcast: image = image.float() pipe.vae.to(dtype=torch.float32) if isinstance(generator, list): init_latents = [ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(1) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator) if pipe.vae.config.force_upcast: pipe.vae.to(originDtype) init_latents = init_latents.to(originDtype) init_latents = pipe.vae.config.scaling_factor * init_latents return init_latents.to(dtype=torch.float16) def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( pipe.scheduler.config.num_train_timesteps - (denoising_start * pipe.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start