File size: 4,280 Bytes
d661b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from pathlib import Path
import json
from math import sqrt
import numpy as np
import torch
from abc import ABCMeta, abstractmethod


class ScoreAdapter(metaclass=ABCMeta):

    @abstractmethod
    def denoise(self, xs, σ, **kwargs):
        pass

    def score(self, xs, σ, **kwargs):
        Ds = self.denoise(xs, σ, **kwargs)
        grad_log_p_t = (Ds - xs) / (σ ** 2)
        return grad_log_p_t

    @abstractmethod
    def data_shape(self):
        return (3, 256, 256)  # for example

    def samps_centered(self):
        # if centered, samples expected to be in range [-1, 1], else [0, 1]
        return True

    @property
    @abstractmethod
    def σ_max(self):
        pass

    @property
    @abstractmethod
    def σ_min(self):
        pass

    def cond_info(self, batch_size):
        return {}

    @abstractmethod
    def unet_is_cond(self):
        return False

    @abstractmethod
    def use_cls_guidance(self):
        return False  # most models do not use cls guidance

    def classifier_grad(self, xs, σ, ys):
        raise NotImplementedError()

    @abstractmethod
    def snap_t_to_nearest_tick(self, t):
        # need to confirm for each model; continuous time model doesn't need this
        return t, None

    @property
    def device(self):
        return self._device

    def checkpoint_root(self):
        """the path at which the pretrained checkpoints are stored"""
        with Path(__file__).resolve().with_name("env.json").open("r") as f:
            root = json.load(f)
        return root


def karras_t_schedule(ρ=7, N=10, σ_max=80, σ_min=0.002):
    ts = []
    for i in range(N):

        t = (
            σ_max ** (1 / ρ) + (i / (N - 1)) * (σ_min ** (1 / ρ) - σ_max ** (1 / ρ))
        ) ** ρ
        ts.append(t)
    return ts


def power_schedule(σ_max, σ_min, num_stages):
    σs = np.exp(np.linspace(np.log(σ_max), np.log(σ_min), num_stages))
    return σs


class Karras():

    @classmethod
    @torch.no_grad()
    def inference(
        cls, model, batch_size, num_t, *,
        σ_max=80, cls_scaling=1,
        init_xs=None, heun=True,
        langevin=False,
        S_churn=80, S_min=0.05, S_max=50, S_noise=1.003,
    ):
        σ_max = min(σ_max, model.σ_max)
        σ_min = model.σ_min
        ts = karras_t_schedule(ρ=7, N=num_t, σ_max=σ_max, σ_min=σ_min)
        assert len(ts) == num_t
        ts = [model.snap_t_to_nearest_tick(t)[0] for t in ts]
        ts.append(0)  # 0 is the destination
        σ_max = ts[0]

        cond_inputs = model.cond_info(batch_size)

        def compute_step(xs, σ):
            grad_log_p_t = model.score(
                xs, σ, **(cond_inputs if model.unet_is_cond() else {})
            )
            if model.use_cls_guidance():
                grad_cls = model.classifier_grad(xs, σ, cond_inputs["y"])
                grad_cls = grad_cls * cls_scaling
                grad_log_p_t += grad_cls
            d_i = -1 * σ * grad_log_p_t
            return d_i

        if init_xs is not None:
            xs = init_xs.to(model.device)
        else:
            xs = σ_max * torch.randn(
                batch_size, *model.data_shape(), device=model.device
            )

        yield xs

        for i in range(num_t):
            t_i = ts[i]

            if langevin and (S_min < t_i and t_i < S_max):
                xs, t_i = cls.noise_backward_in_time(
                    model, xs, t_i, S_noise, S_churn / num_t
                )

            Δt = ts[i+1] - t_i

            d_1 = compute_step(xs, σ=t_i)
            xs_1 = xs + Δt * d_1

            # Heun's 2nd order method; don't apply on the last step
            if (not heun) or (ts[i+1] == 0):
                xs = xs_1
            else:
                d_2 = compute_step(xs_1, σ=ts[i+1])
                xs = xs + Δt * (d_1 + d_2) / 2

            yield xs

    @staticmethod
    def noise_backward_in_time(model, xs, t_i, S_noise, S_churn_i):
        n = S_noise * torch.randn_like(xs)
        γ_i = min(sqrt(2)-1, S_churn_i)
        t_i_hat = t_i * (1 + γ_i)
        t_i_hat = model.snap_t_to_nearest_tick(t_i_hat)[0]
        xs = xs + n * sqrt(t_i_hat ** 2 - t_i ** 2)
        return xs, t_i_hat


def test():
    pass


if __name__ == "__main__":
    test()