File size: 6,826 Bytes
6dd488f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# everything that can improve v-prediction model
# dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
# written by lvmin at stanford 2024

import torch
import numpy as np

from tqdm import tqdm
from functools import partial
from diffusers_vdm.basics import extract_into_tensor


to_torch = partial(torch.tensor, dtype=torch.float32)


def rescale_zero_terminal_snr(betas):
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = np.cumprod(alphas, axis=0)
    alphas_bar_sqrt = np.sqrt(alphas_cumprod)

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = np.concatenate([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)

    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)

    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg

    return noise_cfg


class SamplerDynamicTSNR(torch.nn.Module):
    @torch.no_grad()
    def __init__(self, unet, terminal_scale=0.7):
        super().__init__()
        self.unet = unet

        self.is_v = True
        self.n_timestep = 1000
        self.guidance_rescale = 0.7

        linear_start = 0.00085
        linear_end = 0.012

        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
        betas = rescale_zero_terminal_snr(betas)
        alphas = 1. - betas

        alphas_cumprod = np.cumprod(alphas, axis=0)

        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))

        # Dynamic TSNR
        turning_step = 400
        scale_arr = np.concatenate([
            np.linspace(1.0, terminal_scale, turning_step),
            np.full(self.n_timestep - turning_step, terminal_scale)
        ])
        self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))

    def predict_eps_from_z_and_v(self, x_t, t, v):
        return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t

    def predict_start_from_z_and_v(self, x_t, t, v):
        return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v

    def q_sample(self, x0, t, noise):
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    def get_v(self, x0, t, noise):
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)

    def dynamic_x0_rescale(self, x0, t):
        return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)

    @torch.no_grad()
    def get_ground_truth(self, x0, noise, t):
        x0 = self.dynamic_x0_rescale(x0, t)
        xt = self.q_sample(x0, t, noise)
        target = self.get_v(x0, t, noise) if self.is_v else noise
        return xt, target

    def get_uniform_trailing_steps(self, steps):
        c = self.n_timestep / steps
        ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
        steps_out = ddim_timesteps - 1
        return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)

    @torch.no_grad()
    def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
        bar = tqdm if progress_tqdm is None else progress_tqdm

        eta = 1.0

        timesteps = self.get_uniform_trailing_steps(steps)
        timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))

        x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)

        alphas = self.alphas_cumprod[timesteps]
        alphas_prev = self.alphas_cumprod[timesteps_prev]
        scale_arr = self.scale_arr[timesteps]
        scale_arr_prev = self.scale_arr[timesteps_prev]

        sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
        sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

        s_in = x.new_ones((x.shape[0]))
        s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
        for i in bar(range(len(timesteps))):
            index = len(timesteps) - 1 - i
            t = timesteps[index].item()

            model_output = self.model_apply(x, t * s_in, **extra_args)

            if self.is_v:
                e_t = self.predict_eps_from_z_and_v(x, t, model_output)
            else:
                e_t = model_output

            a_prev = alphas_prev[index].item() * s_x
            sigma_t = sigmas[index].item() * s_x

            if self.is_v:
                pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
            else:
                a_t = alphas[index].item() * s_x
                sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
                pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

            # dynamic rescale
            scale_t = scale_arr[index].item() * s_x
            prev_scale_t = scale_arr_prev[index].item() * s_x
            rescale = (prev_scale_t / scale_t)
            pred_x0 = pred_x0 * rescale

            dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
            noise = sigma_t * torch.randn_like(x)
            x = a_prev.sqrt() * pred_x0 + dir_xt + noise

        return x

    @torch.no_grad()
    def model_apply(self, x, t, **extra_args):
        x = x.to(device=self.unet.device, dtype=self.unet.dtype)
        cfg_scale = extra_args['cfg_scale']
        p = self.unet(x, t, **extra_args['positive'])
        n = self.unet(x, t, **extra_args['negative'])
        o = n + cfg_scale * (p - n)
        o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
        return o_better