Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,826 Bytes
6dd488f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# everything that can improve v-prediction model
# dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
# written by lvmin at stanford 2024
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from diffusers_vdm.basics import extract_into_tensor
to_torch = partial(torch.tensor, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class SamplerDynamicTSNR(torch.nn.Module):
@torch.no_grad()
def __init__(self, unet, terminal_scale=0.7):
super().__init__()
self.unet = unet
self.is_v = True
self.n_timestep = 1000
self.guidance_rescale = 0.7
linear_start = 0.00085
linear_end = 0.012
betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
betas = rescale_zero_terminal_snr(betas)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))
# Dynamic TSNR
turning_step = 400
scale_arr = np.concatenate([
np.linspace(1.0, terminal_scale, turning_step),
np.full(self.n_timestep - turning_step, terminal_scale)
])
self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))
def predict_eps_from_z_and_v(self, x_t, t, v):
return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t
def predict_start_from_z_and_v(self, x_t, t, v):
return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v
def q_sample(self, x0, t, noise):
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
def get_v(self, x0, t, noise):
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)
def dynamic_x0_rescale(self, x0, t):
return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)
@torch.no_grad()
def get_ground_truth(self, x0, noise, t):
x0 = self.dynamic_x0_rescale(x0, t)
xt = self.q_sample(x0, t, noise)
target = self.get_v(x0, t, noise) if self.is_v else noise
return xt, target
def get_uniform_trailing_steps(self, steps):
c = self.n_timestep / steps
ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
steps_out = ddim_timesteps - 1
return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)
@torch.no_grad()
def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
bar = tqdm if progress_tqdm is None else progress_tqdm
eta = 1.0
timesteps = self.get_uniform_trailing_steps(steps)
timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))
x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)
alphas = self.alphas_cumprod[timesteps]
alphas_prev = self.alphas_cumprod[timesteps_prev]
scale_arr = self.scale_arr[timesteps]
scale_arr_prev = self.scale_arr[timesteps_prev]
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
for i in bar(range(len(timesteps))):
index = len(timesteps) - 1 - i
t = timesteps[index].item()
model_output = self.model_apply(x, t * s_in, **extra_args)
if self.is_v:
e_t = self.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
if self.is_v:
pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
else:
a_t = alphas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# dynamic rescale
scale_t = scale_arr[index].item() * s_x
prev_scale_t = scale_arr_prev[index].item() * s_x
rescale = (prev_scale_t / scale_t)
pred_x0 = pred_x0 * rescale
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x
@torch.no_grad()
def model_apply(self, x, t, **extra_args):
x = x.to(device=self.unet.device, dtype=self.unet.dtype)
cfg_scale = extra_args['cfg_scale']
p = self.unet(x, t, **extra_args['positive'])
n = self.unet(x, t, **extra_args['negative'])
o = n + cfg_scale * (p - n)
o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
return o_better
|