Spaces:
Paused
Paused
import torch | |
def linear(n_timestep = 1000, start = 1e-4, end = 2e-2): | |
return Schedule(torch.linspace(start ** 0.5, end ** 0.5, n_timestep, dtype = torch.float64) ** 2) | |
class Schedule: | |
def __init__(self, betas): | |
self.betas = betas | |
self._alphas = 1 - betas | |
self.alphas = torch.cumprod(self._alphas, 0) | |
self.one_minus_alphas = 1 - self.alphas | |
self.sqrt_alphas = torch.sqrt(self.alphas) | |
self.sqrt_one_minus_alphas = torch.sqrt(1 - self.alphas) | |
self.sqrt_noise_signal_ratio = self.sqrt_one_minus_alphas / self.sqrt_alphas | |
self.noise_signal_ratio = (1 - self.alphas) / self.alphas | |
def range(self, dt): | |
return range(len(self.betas)-1, 0, -dt) | |
def sigma(self, t, dt): | |
return torch.sqrt((1 - self._alphas[t - dt]) / (1 - self._alphas[t]) * (1 - self._alphas[t] / self._alphas[t - dt])) | |