|
from copy import deepcopy |
|
from dataclasses import dataclass |
|
from diffusers import StableDiffusionXLPipeline |
|
from diffusers.image_processor import PipelineImageInput |
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img\ |
|
import rescale_noise_cfg, retrieve_latents, retrieve_timesteps |
|
from diffusers.utils import BaseOutput |
|
from diffusers.utils.torch_utils import randn_tensor |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
from utils.utils import batch_dict_to_tensor, batch_tensor_to_dict, noise_prev, noise_t2t |
|
from utils.sdxl import register_attr |
|
|
|
|
|
|
|
|
|
BATCH_ORDER = [ |
|
"structure_uncond", "appearance_uncond", "uncond", "structure_cond", "appearance_cond", "cond", |
|
] |
|
|
|
def get_last_control_i(control_schedule, num_inference_steps): |
|
if control_schedule is None: |
|
return num_inference_steps, num_inference_steps |
|
|
|
def max_(l): |
|
if len(l) == 0: |
|
return 0.0 |
|
return max(l) |
|
|
|
structure_max = 0.0 |
|
appearance_max = 0.0 |
|
for block in control_schedule.values(): |
|
if isinstance(block, list): |
|
block = {0: block} |
|
for layer in block.values(): |
|
structure_max = max(structure_max, max_(layer[0] + layer[1])) |
|
appearance_max = max(appearance_max, max_(layer[2])) |
|
|
|
structure_i = round(num_inference_steps * structure_max) |
|
appearance_i = round(num_inference_steps * appearance_max) |
|
|
|
return structure_i, appearance_i |
|
|
|
@dataclass |
|
class CtrlXStableDiffusionXLPipelineOutput(BaseOutput): |
|
images: Union[List[Image.Image], np.ndarray] |
|
structures = Union[List[Image.Image], np.ndarray] |
|
appearances = Union[List[Image.Image], np.ndarray] |
|
|
|
class CtrlXStableDiffusionXLPipeline(StableDiffusionXLPipeline): |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
structure_prompt: Optional[Union[str, List[str]]] = None, |
|
appearance_prompt: Optional[Union[str, List[str]]] = None, |
|
structure_image: Optional[PipelineImageInput] = None, |
|
appearance_image: Optional[PipelineImageInput] = None, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
positive_prompt: Optional[Union[str, List[str]]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
guidance_scale: float = 5.0, |
|
structure_guidance_scale: Optional[float] = None, |
|
appearance_guidance_scale: Optional[float] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.Tensor] = None, |
|
structure_latents: Optional[torch.Tensor] = None, |
|
appearance_latents: Optional[torch.Tensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
structure_prompt_embeds: Optional[torch.Tensor] = None, |
|
appearance_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
structure_pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
appearance_pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, |
|
control_schedule: Optional[Dict] = None, |
|
self_recurrence_schedule: Optional[List[int]] = [], |
|
decode_structure: Optional[bool] = True, |
|
decode_appearance: Optional[bool] = True, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
original_size: Tuple[int, int] = None, |
|
crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
target_size: Tuple[int, int] = None, |
|
clip_skip: Optional[int] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
**kwargs, |
|
): |
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
self._guidance_scale = guidance_scale |
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
original_size = original_size or (height, width) |
|
target_size = target_size or (height, width) |
|
|
|
|
|
batch_size = 1 |
|
if isinstance(prompt, list): |
|
assert len(prompt) == batch_size |
|
if prompt_embeds is not None: |
|
assert prompt_embeds.shape[0] == batch_size |
|
|
|
device = self._execution_device |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get("scale", None) |
|
if cross_attention_kwargs is not None else None |
|
) |
|
|
|
|
|
|
|
|
|
prompts = [ |
|
(prompt, None, None, None, None, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds), |
|
(structure_prompt, structure_prompt_embeds, negative_prompt if structure_image is None else "", None, None, structure_prompt_embeds, None, structure_pooled_prompt_embeds, None), |
|
(appearance_prompt, appearance_prompt_embeds, negative_prompt if appearance_image is None else "", None, None, appearance_prompt_embeds, None, appearance_pooled_prompt_embeds, None) |
|
] |
|
prompt_embeds_list = [] |
|
add_text_embeds_list = [] |
|
for item in prompts: |
|
prompt_text, prompt_embeds_temp, negative_prompt_temp, pooled_prompt_embeds_temp = item[:4] |
|
|
|
if prompt_text is not None and prompt_text != "": |
|
( |
|
prompt_embeds_, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds_, |
|
negative_pooled_prompt_embeds, |
|
) = self.encode_prompt( |
|
prompt=prompt_text, |
|
prompt_2=None, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
do_classifier_free_guidance=True, |
|
negative_prompt=negative_prompt_temp, |
|
negative_prompt_2=None, |
|
prompt_embeds=prompt_embeds_temp, |
|
negative_prompt_embeds=None, |
|
pooled_prompt_embeds=pooled_prompt_embeds_temp, |
|
negative_pooled_prompt_embeds=None, |
|
lora_scale=text_encoder_lora_scale, |
|
clip_skip=clip_skip, |
|
) |
|
prompt_embeds_list.append(torch.cat([negative_prompt_embeds, prompt_embeds_], dim=0).to(device)) |
|
add_text_embeds_list.append(torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_], dim=0).to(device)) |
|
else: |
|
prompt_embeds_list.append(prompt_embeds_list[0]) |
|
add_text_embeds_list.append(add_text_embeds_list[0]) |
|
|
|
|
|
|
|
|
|
if self.text_encoder_2 is None: |
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
|
else: |
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
|
|
|
add_time_ids = self._get_add_time_ids( |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
dtype=self.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
negative_add_time_ids = add_time_ids |
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device) |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
|
|
|
|
latents, _ = self.prepare_latents( |
|
None, batch_size, num_images_per_prompt, num_channels_latents, height, width, |
|
self.dtype, device, generator, latents |
|
) |
|
latents_ = [structure_latents, appearance_latents] |
|
clean_latents_ = [] |
|
for image_index, image_ in enumerate([structure_image, appearance_image]): |
|
if image_ is not None: |
|
|
|
_, clean_latent = self.prepare_latents( |
|
image_, batch_size, num_images_per_prompt, num_channels_latents, height, width, |
|
self.dtype, device, generator, latents_[image_index] |
|
) |
|
clean_latents_.append(clean_latent) |
|
else: |
|
clean_latents_.append(None) |
|
if latents_[image_index] is None: |
|
latents_[image_index] = latents |
|
latents_ = [latents] + latents_ |
|
|
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
|
|
if hasattr(self, 'denoising_end') and self.denoising_end is not None and 0.0 < float(self.denoising_end) < 1.0: |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (self.denoising_end * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
|
timesteps = timesteps[:num_inference_steps] |
|
|
|
|
|
timestep_cond = None |
|
assert self.unet.config.time_cond_proj_dim is None |
|
|
|
|
|
batch_order = deepcopy(BATCH_ORDER) |
|
if structure_image is not None: |
|
batch_order.remove("structure_uncond") |
|
if appearance_image is not None: |
|
batch_order.remove("appearance_uncond") |
|
|
|
baked_latents = self.cfg_loop(batch_order, |
|
prompt_embeds_list, |
|
add_text_embeds_list, |
|
add_time_ids, |
|
latents_, |
|
clean_latents_, |
|
num_inference_steps, |
|
num_warmup_steps, |
|
extra_step_kwargs, |
|
timesteps, |
|
timestep_cond=timestep_cond, |
|
control_schedule=control_schedule, |
|
self_recurrence_schedule=self_recurrence_schedule, |
|
guidance_rescale=guidance_rescale, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
cross_attention_kwargs=cross_attention_kwargs) |
|
latents, structure_latents, appearance_latents = baked_latents |
|
|
|
|
|
self.refiner_args = {"latents": latents.detach(), "prompt": prompt, "negative_prompt": negative_prompt} |
|
|
|
if not output_type == "latent": |
|
|
|
if self.vae.config.force_upcast: |
|
self.upcast_vae() |
|
vae_dtype = next(iter(self.vae.post_quant_conv.parameters())).dtype |
|
latents = latents.to(vae_dtype) |
|
structure_latents = structure_latents.to(vae_dtype) |
|
appearance_latents = appearance_latents.to(vae_dtype) |
|
|
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
if decode_structure: |
|
structure = self.vae.decode(structure_latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
structure = self.image_processor.postprocess(structure, output_type=output_type) |
|
else: |
|
structure = structure_latents |
|
if decode_appearance: |
|
appearance = self.vae.decode(appearance_latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
appearance = self.image_processor.postprocess(appearance, output_type=output_type) |
|
else: |
|
appearance = appearance_latents |
|
|
|
|
|
if self.vae.config.force_upcast: |
|
self.vae.to(dtype=torch.float16) |
|
else: |
|
return CtrlXStableDiffusionXLPipelineOutput( |
|
images=latents, structures=structure_latents, appearances=appearance_latents |
|
) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return image, structure, appearance |
|
|
|
return CtrlXStableDiffusionXLPipelineOutput(images=image, structures=structure, appearances=appearance) |
|
|
|
def cfg_loop(self, |
|
batch_order, |
|
prompt_embeds_list, |
|
add_text_embeds_list, |
|
add_time_ids, |
|
latents_, |
|
clean_latents_, |
|
num_inference_steps, |
|
num_warmup_steps, |
|
extra_step_kwargs, |
|
timesteps, |
|
timestep_cond=None, |
|
control_schedule=None, |
|
self_recurrence_schedule=None, |
|
guidance_rescale=0.0, |
|
callback=None, |
|
callback_steps=None, |
|
callback_on_step_end=None, |
|
callback_on_step_end_tensor_inputs=None, |
|
cross_attention_kwargs=None): |
|
prompt_embeds, structure_prompt_embeds, appearance_prompt_embeds = prompt_embeds_list |
|
add_text_embeds, structure_add_text_embeds, appearance_add_text_embeds = add_text_embeds_list |
|
latents, structure_latents, appearance_latents = latents_ |
|
clean_structure_latents, clean_appearance_latents = clean_latents_ |
|
structure_control_stop_i, appearance_control_stop_i = get_last_control_i(control_schedule, num_inference_steps) |
|
|
|
if self_recurrence_schedule is None: |
|
self_recurrence_schedule = [0] * num_inference_steps |
|
|
|
self._num_timesteps = len(timesteps) |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if hasattr(self, 'interrupt') and self.interrupt: |
|
continue |
|
|
|
if i == structure_control_stop_i: |
|
if "structure_uncond" not in batch_order: |
|
batch_order.remove("structure_cond") |
|
if i == appearance_control_stop_i: |
|
if "appearance_uncond" not in batch_order: |
|
batch_order.remove("appearance_cond") |
|
|
|
register_attr(self, t=t.item(), do_control=True, batch_order=batch_order) |
|
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(latents, t) |
|
structure_latent_model_input = self.scheduler.scale_model_input(structure_latents, t) |
|
appearance_latent_model_input = self.scheduler.scale_model_input(appearance_latents, t) |
|
|
|
pass |
|
all_latent_model_input = { |
|
"structure_uncond": structure_latent_model_input[0:1], |
|
"appearance_uncond": appearance_latent_model_input[0:1], |
|
"uncond": latent_model_input[0:1], |
|
"structure_cond": structure_latent_model_input[0:1], |
|
"appearance_cond": appearance_latent_model_input[0:1], |
|
"cond": latent_model_input[0:1], |
|
} |
|
all_prompt_embeds = { |
|
"structure_uncond": structure_prompt_embeds[0:1], |
|
"appearance_uncond": appearance_prompt_embeds[0:1], |
|
"uncond": prompt_embeds[0:1], |
|
"structure_cond": structure_prompt_embeds[1:2], |
|
"appearance_cond": appearance_prompt_embeds[1:2], |
|
"cond": prompt_embeds[1:2], |
|
} |
|
all_add_text_embeds = { |
|
"structure_uncond": structure_add_text_embeds[0:1], |
|
"appearance_uncond": appearance_add_text_embeds[0:1], |
|
"uncond": add_text_embeds[0:1], |
|
"structure_cond": structure_add_text_embeds[1:2], |
|
"appearance_cond": appearance_add_text_embeds[1:2], |
|
"cond": add_text_embeds[1:2], |
|
} |
|
all_time_ids = { |
|
"structure_uncond": add_time_ids[0:1], |
|
"appearance_uncond": add_time_ids[0:1], |
|
"uncond": add_time_ids[0:1], |
|
"structure_cond": add_time_ids[1:2], |
|
"appearance_cond": add_time_ids[1:2], |
|
"cond": add_time_ids[1:2], |
|
} |
|
|
|
concat_latent_model_input = batch_dict_to_tensor(all_latent_model_input, batch_order) |
|
concat_prompt_embeds = batch_dict_to_tensor(all_prompt_embeds, batch_order) |
|
concat_add_text_embeds = batch_dict_to_tensor(all_add_text_embeds, batch_order) |
|
concat_add_time_ids = batch_dict_to_tensor(all_time_ids, batch_order) |
|
|
|
|
|
added_cond_kwargs = {"text_embeds": concat_add_text_embeds, "time_ids": concat_add_time_ids} |
|
|
|
concat_noise_pred = self.unet( |
|
concat_latent_model_input, |
|
t, |
|
encoder_hidden_states=concat_prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
).sample |
|
all_noise_pred = batch_tensor_to_dict(concat_noise_pred, batch_order) |
|
|
|
|
|
noise_pred = all_noise_pred["uncond"] +\ |
|
self.guidance_scale * (all_noise_pred["cond"] - all_noise_pred["uncond"]) |
|
|
|
structure_noise_pred = all_noise_pred["structure_cond"]\ |
|
if "structure_cond" in batch_order else noise_pred |
|
if "structure_uncond" in all_noise_pred: |
|
structure_noise_pred = all_noise_pred["structure_uncond"] +\ |
|
self.structure_guidance_scale * (structure_noise_pred - all_noise_pred["structure_uncond"]) |
|
|
|
appearance_noise_pred = all_noise_pred["appearance_cond"]\ |
|
if "appearance_cond" in batch_order else noise_pred |
|
if "appearance_uncond" in all_noise_pred: |
|
appearance_noise_pred = all_noise_pred["appearance_uncond"] +\ |
|
self.appearance_guidance_scale * (appearance_noise_pred - all_noise_pred["appearance_uncond"]) |
|
|
|
if guidance_rescale > 0.0: |
|
noise_pred = rescale_noise_cfg( |
|
noise_pred, all_noise_pred["cond"], guidance_rescale=guidance_rescale |
|
) |
|
if "structure_uncond" in all_noise_pred: |
|
structure_noise_pred = rescale_noise_cfg( |
|
structure_noise_pred, all_noise_pred["structure_cond"], |
|
guidance_rescale=guidance_rescale |
|
) |
|
if "appearance_uncond" in all_noise_pred: |
|
appearance_noise_pred = rescale_noise_cfg( |
|
appearance_noise_pred, all_noise_pred["appearance_cond"], |
|
guidance_rescale=guidance_rescale |
|
) |
|
|
|
|
|
concat_noise_pred = torch.cat( |
|
[structure_noise_pred, appearance_noise_pred, noise_pred], dim=0, |
|
) |
|
concat_latents = torch.cat( |
|
[structure_latents, appearance_latents, latents], dim=0, |
|
) |
|
structure_latents, appearance_latents, latents = self.scheduler.step( |
|
concat_noise_pred, t, concat_latents, **extra_step_kwargs, |
|
).prev_sample.chunk(3) |
|
|
|
if clean_structure_latents is not None: |
|
structure_latents = noise_prev(self.scheduler, t, clean_structure_latents) |
|
if clean_appearance_latents is not None: |
|
appearance_latents = noise_prev(self.scheduler, t, clean_appearance_latents) |
|
|
|
|
|
for _ in range(self_recurrence_schedule[i]): |
|
if hasattr(self.scheduler, "_step_index"): |
|
self.scheduler._step_index -= 1 |
|
|
|
t_prev = 0 if i + 1 >= num_inference_steps else timesteps[i + 1] |
|
latents = noise_t2t(self.scheduler, t_prev, t, latents) |
|
latent_model_input = torch.cat([latents] * 2) |
|
|
|
register_attr(self, t=t.item(), do_control=False, batch_order=["uncond", "cond"]) |
|
|
|
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
noise_pred_uncond, noise_pred_ = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
).sample.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_ - noise_pred_uncond) |
|
|
|
if guidance_rescale > 0.0: |
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_, guidance_rescale=guidance_rescale) |
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
assert callback_on_step_end is None |
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
|
|
if clean_structure_latents is not None: |
|
structure_latents = clean_structure_latents |
|
if clean_appearance_latents is not None: |
|
appearance_latents = clean_appearance_latents |
|
|
|
return latents, structure_latents, appearance_latents |
|
|
|
@property |
|
def appearance_guidance_scale(self): |
|
return self._guidance_scale if self._appearance_guidance_scale is None else self._appearance_guidance_scale |
|
|
|
@property |
|
def structure_guidance_scale(self): |
|
return self._guidance_scale if self._structure_guidance_scale is None else self._structure_guidance_scale |
|
|
|
def prepare_latents(self, image, batch_size, num_images_per_prompt, num_channels_latents, height, width, |
|
dtype, device, generator=None, noise=None): |
|
batch_size = batch_size * num_images_per_prompt |
|
|
|
if noise is None: |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor |
|
) |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
noise = noise * self.scheduler.init_noise_sigma |
|
else: |
|
noise = noise.to(device) |
|
|
|
if image is None: |
|
return noise, None |
|
|
|
if not isinstance(image, (torch.Tensor, Image.Image, list)): |
|
raise ValueError( |
|
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" |
|
) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.text_encoder_2.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
|
if image.shape[1] == 4: |
|
init_latents = image |
|
|
|
else: |
|
|
|
if self.vae.config.force_upcast: |
|
image = image.to(torch.float32) |
|
self.vae.to(torch.float32) |
|
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
elif isinstance(generator, list): |
|
init_latents = [ |
|
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) |
|
for i in range(batch_size) |
|
] |
|
init_latents = torch.cat(init_latents, dim=0) |
|
else: |
|
init_latents = retrieve_latents(self.vae.encode(image), generator=generator) |
|
|
|
if self.vae.config.force_upcast: |
|
self.vae.to(dtype) |
|
|
|
init_latents = init_latents.to(dtype) |
|
init_latents = self.vae.config.scaling_factor * init_latents |
|
|
|
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: |
|
|
|
additional_image_per_prompt = batch_size // init_latents.shape[0] |
|
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) |
|
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
|
raise ValueError( |
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
|
) |
|
else: |
|
init_latents = torch.cat([init_latents], dim=0) |
|
|
|
return noise, init_latents |
|
|