|
import torch |
|
from typing import Literal, Optional |
|
|
|
from toolkit.basic import value_map |
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO |
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds |
|
from toolkit.stable_diffusion_model import StableDiffusion |
|
from toolkit.train_tools import get_torch_dtype |
|
|
|
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"] |
|
|
|
DIFFERENTIAL_SCALER = 0.2 |
|
|
|
|
|
|
|
|
|
|
|
def get_differential_mask( |
|
conditional_latents: torch.Tensor, |
|
unconditional_latents: torch.Tensor, |
|
threshold: float = 0.2, |
|
gradient: bool = False, |
|
): |
|
|
|
differential_mask = torch.abs(conditional_latents - unconditional_latents) |
|
max_differential = \ |
|
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] |
|
differential_scaler = 1.0 / max_differential |
|
differential_mask = differential_mask * differential_scaler |
|
|
|
if gradient: |
|
|
|
|
|
|
|
|
|
differential_mask = value_map( |
|
differential_mask, |
|
differential_mask.min(), |
|
differential_mask.max(), |
|
0 - threshold, |
|
1 + threshold |
|
) |
|
differential_mask = torch.clamp(differential_mask, 0.0, 1.0) |
|
else: |
|
|
|
|
|
differential_mask = torch.where( |
|
differential_mask < threshold, |
|
torch.zeros_like(differential_mask), |
|
torch.ones_like(differential_mask) |
|
) |
|
return differential_mask |
|
|
|
|
|
def get_targeted_polarity_loss( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: PromptEmbeds, |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
**kwargs |
|
): |
|
dtype = get_torch_dtype(sd.torch_dtype) |
|
device = sd.device_torch |
|
with torch.no_grad(): |
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach() |
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() |
|
|
|
|
|
|
|
differential_scaler = DIFFERENTIAL_SCALER |
|
|
|
unconditional_diff = (unconditional_latents - conditional_latents) |
|
unconditional_diff_noise = unconditional_diff * differential_scaler |
|
conditional_diff = (conditional_latents - unconditional_latents) |
|
conditional_diff_noise = conditional_diff * differential_scaler |
|
conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False) |
|
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) |
|
|
|
baseline_conditional_noisy_latents = sd.add_noise( |
|
conditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
baseline_unconditional_noisy_latents = sd.add_noise( |
|
unconditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
conditional_noise = noise + unconditional_diff_noise |
|
unconditional_noise = noise + conditional_diff_noise |
|
|
|
conditional_noisy_latents = sd.add_noise( |
|
conditional_latents, |
|
conditional_noise, |
|
timesteps |
|
).detach() |
|
|
|
unconditional_noisy_latents = sd.add_noise( |
|
unconditional_latents, |
|
unconditional_noise, |
|
timesteps |
|
).detach() |
|
|
|
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) |
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) |
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd.unet.train() |
|
|
|
|
|
|
|
|
|
|
|
prediction = sd.predict_noise( |
|
latents=cat_latents.to(device, dtype=dtype).detach(), |
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), |
|
timestep=cat_timesteps, |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
) |
|
|
|
|
|
|
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) |
|
|
|
|
|
|
|
pred_loss = torch.nn.functional.mse_loss( |
|
pred_pos.float(), |
|
conditional_noise.float(), |
|
reduction="none" |
|
) |
|
pred_loss = pred_loss.mean([1, 2, 3]) |
|
|
|
pred_neg_loss = torch.nn.functional.mse_loss( |
|
pred_neg.float(), |
|
unconditional_noise.float(), |
|
reduction="none" |
|
) |
|
pred_neg_loss = pred_neg_loss.mean([1, 2, 3]) |
|
|
|
loss = pred_loss + pred_neg_loss |
|
|
|
loss = loss.mean() |
|
loss.backward() |
|
|
|
|
|
loss = loss.detach() |
|
loss.requires_grad_(True) |
|
|
|
return loss |
|
|
|
def get_direct_guidance_loss( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: 'PromptEmbeds', |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
unconditional_embeds: Optional[PromptEmbeds] = None, |
|
mask_multiplier=None, |
|
prior_pred=None, |
|
**kwargs |
|
): |
|
with torch.no_grad(): |
|
|
|
dtype = get_torch_dtype(sd.torch_dtype) |
|
device = sd.device_torch |
|
|
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach() |
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() |
|
|
|
conditional_noisy_latents = sd.add_noise( |
|
conditional_latents, |
|
|
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
unconditional_noisy_latents = sd.add_noise( |
|
unconditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
sd.unet.train() |
|
|
|
|
|
|
|
|
|
if unconditional_embeds is not None: |
|
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach() |
|
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) |
|
|
|
prediction = sd.predict_noise( |
|
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(), |
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(), |
|
unconditional_embeddings=unconditional_embeds, |
|
timestep=torch.cat([timesteps, timesteps]), |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
) |
|
|
|
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) |
|
|
|
guidance_scale = 1.1 |
|
guidance_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_cond - noise_pred_uncond |
|
) |
|
|
|
guidance_loss = torch.nn.functional.mse_loss( |
|
guidance_pred.float(), |
|
noise.detach().float(), |
|
reduction="none" |
|
) |
|
if mask_multiplier is not None: |
|
guidance_loss = guidance_loss * mask_multiplier |
|
|
|
guidance_loss = guidance_loss.mean([1, 2, 3]) |
|
|
|
guidance_loss = guidance_loss.mean() |
|
|
|
|
|
loss = guidance_loss |
|
|
|
loss.backward() |
|
|
|
|
|
loss = loss.detach() |
|
loss.requires_grad_(True) |
|
|
|
return loss |
|
|
|
|
|
|
|
def get_targeted_guidance_loss( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: 'PromptEmbeds', |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
**kwargs |
|
): |
|
with torch.no_grad(): |
|
dtype = get_torch_dtype(sd.torch_dtype) |
|
device = sd.device_torch |
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach() |
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() |
|
|
|
|
|
unconditional_noisy_latents = sd.noise_scheduler.add_noise( |
|
unconditional_latents, |
|
noise, |
|
timesteps |
|
) |
|
conditional_noisy_latents = sd.noise_scheduler.add_noise( |
|
conditional_latents, |
|
noise, |
|
timesteps |
|
) |
|
|
|
|
|
sd.network.is_active = False |
|
sd.unet.eval() |
|
|
|
target_differential = unconditional_latents - conditional_latents |
|
|
|
target_differential_abs = target_differential.abs() |
|
target_differential_abs_min = \ |
|
target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] |
|
target_differential_abs_max = \ |
|
target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] |
|
|
|
min_guidance = 1.0 |
|
max_guidance = 2.0 |
|
|
|
differential_scaler = value_map( |
|
target_differential_abs, |
|
target_differential_abs_min, |
|
target_differential_abs_max, |
|
min_guidance, |
|
max_guidance |
|
).detach() |
|
|
|
|
|
|
|
|
|
target_unconditional = sd.predict_noise( |
|
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), |
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), |
|
timestep=timesteps, |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
).detach() |
|
prior_prediction_loss = torch.nn.functional.mse_loss( |
|
target_unconditional.float(), |
|
noise.float(), |
|
reduction="none" |
|
).detach().clone() |
|
|
|
|
|
sd.unet.train() |
|
sd.network.is_active = True |
|
sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list] |
|
|
|
|
|
|
|
prediction = sd.predict_noise( |
|
latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(), |
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(), |
|
timestep=torch.cat([timesteps, timesteps], dim=0), |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
) |
|
|
|
prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0) |
|
|
|
conditional_loss = torch.nn.functional.mse_loss( |
|
prediction_conditional.float(), |
|
noise.float(), |
|
reduction="none" |
|
) |
|
|
|
unconditional_loss = torch.nn.functional.mse_loss( |
|
prediction_unconditional.float(), |
|
noise.float(), |
|
reduction="none" |
|
) |
|
|
|
positive_loss = torch.abs( |
|
conditional_loss.float() - prior_prediction_loss.float(), |
|
) |
|
|
|
positive_loss = positive_loss * differential_scaler |
|
|
|
positive_loss = positive_loss.mean([1, 2, 3]) |
|
|
|
polar_loss = torch.abs( |
|
conditional_loss.float() - unconditional_loss.float(), |
|
).mean([1, 2, 3]) |
|
|
|
|
|
positive_loss = positive_loss.mean() + polar_loss.mean() |
|
|
|
|
|
positive_loss.backward() |
|
|
|
loss = positive_loss.detach() |
|
|
|
|
|
loss.requires_grad_(True) |
|
|
|
|
|
sd.network.multiplier = network_weight_list |
|
|
|
return loss |
|
|
|
def get_guided_loss_polarity( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: PromptEmbeds, |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
scaler=None, |
|
**kwargs |
|
): |
|
dtype = get_torch_dtype(sd.torch_dtype) |
|
device = sd.device_torch |
|
with torch.no_grad(): |
|
dtype = get_torch_dtype(dtype) |
|
noise = noise.to(device, dtype=dtype).detach() |
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach() |
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() |
|
|
|
target_pos = noise |
|
target_neg = noise |
|
|
|
if sd.is_flow_matching: |
|
|
|
sd.noise_scheduler.set_train_timesteps(1000, device, linear=True) |
|
target_pos = (noise - conditional_latents).detach() |
|
target_neg = (noise - unconditional_latents).detach() |
|
|
|
conditional_noisy_latents = sd.add_noise( |
|
conditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
unconditional_noisy_latents = sd.add_noise( |
|
unconditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) |
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) |
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0) |
|
|
|
negative_network_weights = [weight * -1.0 for weight in network_weight_list] |
|
positive_network_weights = [weight * 1.0 for weight in network_weight_list] |
|
cat_network_weight_list = positive_network_weights + negative_network_weights |
|
|
|
|
|
sd.unet.train() |
|
sd.network.is_active = True |
|
|
|
sd.network.multiplier = cat_network_weight_list |
|
|
|
|
|
prediction = sd.predict_noise( |
|
latents=cat_latents.to(device, dtype=dtype).detach(), |
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), |
|
timestep=cat_timesteps, |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
) |
|
|
|
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0) |
|
|
|
pred_loss = torch.nn.functional.mse_loss( |
|
pred_pos.float(), |
|
target_pos.float(), |
|
reduction="none" |
|
) |
|
|
|
|
|
pred_neg_loss = torch.nn.functional.mse_loss( |
|
pred_neg.float(), |
|
target_neg.float(), |
|
reduction="none" |
|
) |
|
|
|
loss = pred_loss + pred_neg_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = loss.mean([1, 2, 3]) |
|
loss = loss.mean() |
|
if scaler is not None: |
|
scaler.scale(loss).backward() |
|
else: |
|
loss.backward() |
|
|
|
|
|
loss = loss.detach() |
|
loss.requires_grad_(True) |
|
|
|
return loss |
|
|
|
|
|
|
|
def get_guided_tnt( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: PromptEmbeds, |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
prior_pred: torch.Tensor = None, |
|
**kwargs |
|
): |
|
dtype = get_torch_dtype(sd.torch_dtype) |
|
device = sd.device_torch |
|
with torch.no_grad(): |
|
dtype = get_torch_dtype(dtype) |
|
noise = noise.to(device, dtype=dtype).detach() |
|
|
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach() |
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() |
|
|
|
conditional_noisy_latents = sd.add_noise( |
|
conditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
unconditional_noisy_latents = sd.add_noise( |
|
unconditional_latents, |
|
noise, |
|
timesteps |
|
).detach() |
|
|
|
|
|
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) |
|
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) |
|
cat_timesteps = torch.cat([timesteps, timesteps], dim=0) |
|
|
|
|
|
|
|
sd.unet.train() |
|
if sd.network is not None: |
|
cat_network_weight_list = [weight for weight in network_weight_list * 2] |
|
sd.network.multiplier = cat_network_weight_list |
|
sd.network.is_active = True |
|
|
|
|
|
prediction = sd.predict_noise( |
|
latents=cat_latents.to(device, dtype=dtype).detach(), |
|
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), |
|
timestep=cat_timesteps, |
|
guidance_scale=1.0, |
|
**pred_kwargs |
|
) |
|
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0) |
|
|
|
this_loss = torch.nn.functional.mse_loss( |
|
this_prediction.float(), |
|
noise.float(), |
|
reduction="none" |
|
) |
|
|
|
that_loss = torch.nn.functional.mse_loss( |
|
that_prediction.float(), |
|
noise.float(), |
|
reduction="none" |
|
) |
|
|
|
this_loss = this_loss.mean([1, 2, 3]) |
|
|
|
that_loss = -that_loss.mean([1, 2, 3]) |
|
|
|
with torch.no_grad(): |
|
|
|
that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss) |
|
|
|
that_loss = that_loss * that_loss_scaler * 0.01 |
|
|
|
loss = this_loss + that_loss |
|
|
|
loss = loss.mean() |
|
|
|
loss.backward() |
|
|
|
|
|
loss = loss.detach() |
|
loss.requires_grad_(True) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
def get_guidance_loss( |
|
noisy_latents: torch.Tensor, |
|
conditional_embeds: 'PromptEmbeds', |
|
match_adapter_assist: bool, |
|
network_weight_list: list, |
|
timesteps: torch.Tensor, |
|
pred_kwargs: dict, |
|
batch: 'DataLoaderBatchDTO', |
|
noise: torch.Tensor, |
|
sd: 'StableDiffusion', |
|
unconditional_embeds: Optional[PromptEmbeds] = None, |
|
mask_multiplier=None, |
|
prior_pred=None, |
|
scaler=None, |
|
**kwargs |
|
): |
|
|
|
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type |
|
|
|
if guidance_type == "targeted": |
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance" |
|
return get_targeted_guidance_loss( |
|
noisy_latents, |
|
conditional_embeds, |
|
match_adapter_assist, |
|
network_weight_list, |
|
timesteps, |
|
pred_kwargs, |
|
batch, |
|
noise, |
|
sd, |
|
**kwargs |
|
) |
|
elif guidance_type == "polarity": |
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" |
|
return get_guided_loss_polarity( |
|
noisy_latents, |
|
conditional_embeds, |
|
match_adapter_assist, |
|
network_weight_list, |
|
timesteps, |
|
pred_kwargs, |
|
batch, |
|
noise, |
|
sd, |
|
scaler=scaler, |
|
**kwargs |
|
) |
|
elif guidance_type == "tnt": |
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" |
|
return get_guided_tnt( |
|
noisy_latents, |
|
conditional_embeds, |
|
match_adapter_assist, |
|
network_weight_list, |
|
timesteps, |
|
pred_kwargs, |
|
batch, |
|
noise, |
|
sd, |
|
prior_pred=prior_pred, |
|
**kwargs |
|
) |
|
|
|
elif guidance_type == "targeted_polarity": |
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" |
|
return get_targeted_polarity_loss( |
|
noisy_latents, |
|
conditional_embeds, |
|
match_adapter_assist, |
|
network_weight_list, |
|
timesteps, |
|
pred_kwargs, |
|
batch, |
|
noise, |
|
sd, |
|
**kwargs |
|
) |
|
elif guidance_type == "direct": |
|
return get_direct_guidance_loss( |
|
noisy_latents, |
|
conditional_embeds, |
|
match_adapter_assist, |
|
network_weight_list, |
|
timesteps, |
|
pred_kwargs, |
|
batch, |
|
noise, |
|
sd, |
|
unconditional_embeds=unconditional_embeds, |
|
mask_multiplier=mask_multiplier, |
|
prior_pred=prior_pred, |
|
**kwargs |
|
) |
|
else: |
|
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented") |
|
|