File size: 3,234 Bytes
5c31d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import random
from os import environ

import numpy as np
import torch


JPEG_QUALITY = 100


def seed_everything(seed):
    random.seed(seed)
    environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def exists(x):
    return x is not None


def get(x, default):
    if exists(x):
        return x
    return default


def get_self_recurrence_schedule(schedule, num_inference_steps):
    self_recurrence_schedule = [0] * num_inference_steps
    for schedule_current in reversed(schedule):
        if schedule_current is None or len(schedule_current) == 0:
            continue
        [start, end, repeat] = schedule_current
        start_i = round(num_inference_steps * start)
        end_i = round(num_inference_steps * end)
        for i in range(start_i, end_i):
            self_recurrence_schedule[i] = repeat
    return self_recurrence_schedule


def batch_dict_to_tensor(batch_dict, batch_order):
    batch_tensor = []
    for batch_type in batch_order:
        batch_tensor.append(batch_dict[batch_type])
    batch_tensor = torch.cat(batch_tensor, dim=0)
    return batch_tensor


def batch_tensor_to_dict(batch_tensor, batch_order):
    batch_tensor_chunk = batch_tensor.chunk(len(batch_order))
    batch_dict = {}
    for i, batch_type in enumerate(batch_order):
        batch_dict[batch_type] = batch_tensor_chunk[i]
    return batch_dict


def noise_prev(scheduler, timestep, x_0, noise=None):
    if scheduler.num_inference_steps is None:
        raise ValueError(
            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
        )
        
    if noise is None:
        noise = torch.randn_like(x_0).to(x_0)
        
    # From DDIMScheduler step function (hopefully this works)
    timestep_i = (scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0].item()
    if timestep_i + 1 >= scheduler.timesteps.shape[0]:  # We are at t = 0 (ish)
        return x_0
    prev_timestep = scheduler.timesteps[timestep_i + 1:timestep_i + 2]  # Make sure t is not 0-dim
    
    x_t_prev = scheduler.add_noise(x_0, noise, prev_timestep)
    return x_t_prev


def noise_t2t(scheduler, timestep, timestep_target, x_t, noise=None):
    assert timestep_target >= timestep
    if noise is None:
        noise = torch.randn_like(x_t).to(x_t)
        
    alphas_cumprod = scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype)
    
    timestep = timestep.to(torch.long)
    timestep_target = timestep_target.to(torch.long)
    
    alpha_prod_t = alphas_cumprod[timestep]
    alpha_prod_tt = alphas_cumprod[timestep_target]
    alpha_prod = alpha_prod_tt / alpha_prod_t
    
    sqrt_alpha_prod = (alpha_prod ** 0.5).flatten()
    while len(sqrt_alpha_prod.shape) < len(x_t.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
    
    sqrt_one_minus_alpha_prod = ((1 - alpha_prod) ** 0.5).flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(x_t.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    x_tt = sqrt_alpha_prod * x_t + sqrt_one_minus_alpha_prod * noise
    return x_tt