SliderSpace / utils /train_util.py
RohitGandikota's picture
adding utils for sliders
4cbd4f2 verified
from typing import Optional, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import UNet2DConditionModel, SchedulerMixin, FluxImg2ImgPipeline
from diffusers.image_processor import VaeImageProcessor
# from model_util import SDXL_TEXT_ENCODER_TYPE
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
TEXT_ENCODER_2_PROJECTION_DIM = 1280
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
def get_random_noise(
batch_size: int, height: int, width: int, generator: torch.Generator = None
) -> torch.Tensor:
return torch.randn(
(
batch_size,
UNET_IN_CHANNELS,
height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
width // VAE_SCALE_FACTOR,
),
generator=generator,
device="cpu",
)
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
latents = latents + noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
return latents
def get_initial_latents(
scheduler: SchedulerMixin,
n_imgs: int,
height: int,
width: int,
n_prompts: int,
generator=None,
) -> torch.Tensor:
noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
n_prompts, 1, 1, 1
)
latents = noise * scheduler.init_noise_sigma
return latents
def text_tokenize(
tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
prompts: list[str],
):
return tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
def text_encode(text_encoder: CLIPTextModel, tokens):
return text_encoder(tokens.to(text_encoder.device))[0]
def encode_prompts(
tokenizer: CLIPTokenizer,
text_encoder: CLIPTokenizer,
prompts: list[str],
):
text_tokens = text_tokenize(tokenizer, prompts)
text_embeddings = text_encode(text_encoder, text_tokens)
return text_embeddings
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
def text_encode_xl(
text_encoder,
tokens: torch.FloatTensor,
num_images_per_prompt: int = 1,
):
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompts_xl(
tokenizers,
text_encoders,
prompts: list[str],
num_images_per_prompt: int = 1,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_tokens_input_ids = text_tokenize(tokenizer, prompts)
text_embeds, pooled_text_embeds = text_encode_xl(
text_encoder, text_tokens_input_ids, num_images_per_prompt
)
text_embeds_list.append(text_embeds)
bs_embed = pooled_text_embeds.shape[0]
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def concat_embeddings(
unconditional: torch.FloatTensor,
conditional: torch.FloatTensor,
n_imgs: int,
):
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
def predict_noise(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
timestep: int, # 現在のタイムステップ
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
guidance_scale=7.5,
) -> torch.FloatTensor:
latent_model_input = latents
if guidance_scale!=0:
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings,
).sample
# perform guidance
if guidance_scale != 1 and guidance_scale!=0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
@torch.no_grad()
def diffusion(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
latents: torch.FloatTensor, # ただのノイズだけのlatents
text_embeddings: torch.FloatTensor,
total_timesteps: int = 1000,
start_timesteps=0,
guidance_scale=1,
composition=False,
**kwargs,
):
# latents_steps = []
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
if not composition:
noise_pred = predict_noise(
unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale
)
if guidance_scale==1:
_, noise_pred = noise_pred.chunk(2)
else:
for idx in range(text_embeddings.shape[0]):
pred = predict_noise(
unet, scheduler, timestep, latents, text_embeddings[idx:idx+1], guidance_scale=1
)
uncond, pred = noise_pred.chunk(2)
if idx == 0:
noise_pred = guidance_scale * pred
else:
noise_pred += guidance_scale * pred
noise_pred += uncond
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
def rescale_noise_cfg(
noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
def predict_noise_xl(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
timestep: int, # 現在のタイムステップ
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
add_text_embeddings: torch.FloatTensor, # pooled なやつ
add_time_ids: torch.FloatTensor,
guidance_scale=7.5,
guidance_rescale=0.7,
) -> torch.FloatTensor:
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = latents
if guidance_scale !=0:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
"text_embeds": add_text_embeddings,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if guidance_scale != 1 and guidance_scale!=0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
# # perform guidance
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# guided_target = noise_pred_uncond + guidance_scale * (
# noise_pred_text - noise_pred_uncond
# )
# # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg(
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
# )
# return guided_target
@torch.no_grad()
def diffusion_xl(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
latents: torch.FloatTensor, # ただのノイズだけのlatents
text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
add_text_embeddings: torch.FloatTensor, # pooled なやつ
add_time_ids: torch.FloatTensor,
guidance_scale: float = 1.0,
total_timesteps: int = 1000,
start_timesteps=0,
composition=False,
):
# latents_steps = []
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
if not composition:
noise_pred = predict_noise_xl(
unet,
scheduler,
timestep,
latents,
text_embeddings,
add_text_embeddings,
add_time_ids,
guidance_scale=guidance_scale,
guidance_rescale=0.7,
)
if guidance_scale==1:
_, noise_pred = noise_pred.chunk(2)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
# for XL
def get_add_time_ids(
height: int,
width: int,
dynamic_crops: bool = False,
dtype: torch.dtype = torch.float32,
):
if dynamic_crops:
# random float scale between 1 and 3
random_scale = torch.rand(1).item() * 2 + 1
original_size = (int(height * random_scale), int(width * random_scale))
# random position
crops_coords_top_left = (
torch.randint(0, original_size[0] - height, (1,)).item(),
torch.randint(0, original_size[1] - width, (1,)).item(),
)
target_size = (height, width)
else:
original_size = (height, width)
crops_coords_top_left = (0, 0)
target_size = (height, width)
# this is expected as 6
add_time_ids = list(original_size + crops_coords_top_left + target_size)
# this is expected as 2816
passed_add_embed_dim = (
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
)
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
raise ValueError(
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def get_optimizer(name: str):
name = name.lower()
if name.startswith("dadapt"):
import dadaptation
if name == "dadaptadam":
return dadaptation.DAdaptAdam
elif name == "dadaptlion":
return dadaptation.DAdaptLion
else:
raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
elif name.endswith("8bit"): # 検証してない
import bitsandbytes as bnb
if name == "adam8bit":
return bnb.optim.Adam8bit
elif name == "lion8bit":
return bnb.optim.Lion8bit
else:
raise ValueError("8bit optimizer must be adam8bit or lion8bit")
else:
if name == "adam":
return torch.optim.Adam
elif name == "adamw":
return torch.optim.AdamW
elif name == "lion":
from lion_pytorch import Lion
return Lion
elif name == "prodigy":
import prodigyopt
return prodigyopt.Prodigy
else:
raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
@torch.no_grad()
def get_noisy_image(
image,
vae,
unet,
scheduler,
timesteps_to = 1000,
generator=None,
**kwargs,
):
# latents_steps = []
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
device = vae.device
image = image_processor.preprocess(image).to(device).to(vae.dtype)
init_latents = vae.encode(image).latents
init_latents = vae.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device)
timestep = scheduler.timesteps[timesteps_to:timesteps_to+1]
# get latents
init_latents = scheduler.add_noise(init_latents, noise, timestep)
return init_latents, noise
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
max_iterations: Optional[int],
lr_min: Optional[float],
**kwargs,
):
if name == "cosine":
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
)
elif name == "cosine_with_restarts":
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
)
elif name == "constant":
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
)
else:
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
max_resolution = bucket_resolution
min_resolution = bucket_resolution // 2
step = 64
min_step = min_resolution // step
max_step = max_resolution // step
height = torch.randint(min_step, max_step, (1,)).item() * step
width = torch.randint(min_step, max_step, (1,)).item() * step
return height, width
def _get_t5_prompt_embeds(
text_encoder,
tokenizer,
prompt,
max_sequence_length=512,
device=None,
dtype=None
):
"""Helper function to get T5 embeddings in Flux format"""
device = device or text_encoder.device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
def _get_clip_prompt_embeds(
text_encoder,
tokenizer,
prompt,
device=None,
):
"""Helper function to get CLIP embeddings in Flux format"""
device = device or text_encoder.device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output for Flux
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
return prompt_embeds
@torch.no_grad()
def get_noisy_image_flux(
image,
vae,
transformer,
scheduler,
timesteps_to=1000,
generator=None,
params = None
):
"""
Gets noisy latents for a given image using Flux pipeline approach.
Args:
image (Union[PIL.Image.Image, torch.Tensor]): Input image
vae (AutoencoderKL): Flux VAE model
transformer (FluxTransformer2DModel): Flux transformer model
scheduler (FlowMatchEulerDiscreteScheduler): Flux noise scheduler
timesteps_to (int, optional): Target timestep. Defaults to 1000.
generator (torch.Generator, optional): Random generator for reproducibility.
Returns:
tuple: (noisy_latents, noise) - Both in packed Flux format
"""
vae_scale_factor = params['vae_scale_factor']
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image = image_processor.preprocess(image, height=params['height'], width=params['width'])
image = image.to(dtype=torch.float32)
# 5. Prepare latent variables
num_channels_latents = transformer.config.in_channels // 4
latents, latent_image_ids = prepare_latents_flux(
image,
timesteps_to.repeat(params['batchsize']),
params['batchsize'],
num_channels_latents,
params['height'],
params['width'],
transformer.dtype,
transformer.device,
generator,
None,
vae_scale_factor,
vae,
scheduler
)
return latents, latent_image_ids
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
"""
Pack latents into Flux's 2x2 patch format
"""
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor):
"""
Unpack latents from Flux's 2x2 patch format back to image space
"""
batch_size, num_patches, channels = latents.shape
# Account for VAE compression and packing
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents_flux(
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
vae_scale_factor=None,
vae=None,
scheduler=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
image_latents = _encode_vae_image(vae=vae, image=image, generator=generator)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = scheduler.scale_noise(image_latents, timestep, noise)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")