Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
from src.diffusion.base.scheduling import * | |
class DDPMScheduler(BaseScheduler): | |
def __init__( | |
self, | |
beta_min=0.0001, | |
beta_max=0.02, | |
num_steps=1000, | |
): | |
super().__init__() | |
self.beta_min = beta_min | |
self.beta_max = beta_max | |
self.num_steps = num_steps | |
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") | |
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) | |
self.sigmas_table = 1-self.alphas_table | |
def beta(self, t) -> Tensor: | |
t = t.to(torch.long) | |
return self.betas_table[t].view(-1, 1, 1, 1) | |
def alpha(self, t) -> Tensor: | |
t = t.to(torch.long) | |
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 | |
def sigma(self, t) -> Tensor: | |
t = t.to(torch.long) | |
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 | |
def dsigma(self, t) -> Tensor: | |
raise NotImplementedError("wrong usage") | |
def dalpha_over_alpha(self, t) ->Tensor: | |
raise NotImplementedError("wrong usage") | |
def dsigma_mul_sigma(self, t) ->Tensor: | |
raise NotImplementedError("wrong usage") | |
def dalpha(self, t) -> Tensor: | |
raise NotImplementedError("wrong usage") | |
def drift_coefficient(self, t): | |
raise NotImplementedError("wrong usage") | |
def diffuse_coefficient(self, t): | |
raise NotImplementedError("wrong usage") | |
def w(self, t): | |
raise NotImplementedError("wrong usage") | |
class VPScheduler(BaseScheduler): | |
def __init__( | |
self, | |
beta_min=0.1, | |
beta_max=20, | |
): | |
super().__init__() | |
self.beta_min = beta_min | |
self.beta_d = beta_max - beta_min | |
def beta(self, t) -> Tensor: | |
t = torch.clamp(t, min=1e-3, max=1) | |
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) | |
def sigma(self, t) -> Tensor: | |
t = torch.clamp(t, min=1e-3, max=1) | |
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t | |
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) | |
def dsigma(self, t) -> Tensor: | |
raise NotImplementedError("wrong usage") | |
def dalpha_over_alpha(self, t) ->Tensor: | |
raise NotImplementedError("wrong usage") | |
def dsigma_mul_sigma(self, t) ->Tensor: | |
raise NotImplementedError("wrong usage") | |
def dalpha(self, t) -> Tensor: | |
raise NotImplementedError("wrong usage") | |
def alpha(self, t) -> Tensor: | |
t = torch.clamp(t, min=1e-3, max=1) | |
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t | |
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) | |
def drift_coefficient(self, t): | |
raise NotImplementedError("wrong usage") | |
def diffuse_coefficient(self, t): | |
raise NotImplementedError("wrong usage") | |
def w(self, t): | |
return self.diffuse_coefficient(t) | |