DDT / src /diffusion /base /scheduling.py
wangshuai6
init space
9e426da
raw
history blame contribute delete
870 Bytes
import torch
from torch import Tensor
class BaseScheduler:
def alpha(self, t) -> Tensor:
...
def sigma(self, t) -> Tensor:
...
def dalpha(self, t) -> Tensor:
...
def dsigma(self, t) -> Tensor:
...
def dalpha_over_alpha(self, t) -> Tensor:
return self.dalpha(t) / self.alpha(t)
def dsigma_mul_sigma(self, t) -> Tensor:
return self.dsigma(t)*self.sigma(t)
def drift_coefficient(self, t):
alpha, sigma = self.alpha(t), self.sigma(t)
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
return dalpha/(alpha + 1e-6)
def diffuse_coefficient(self, t):
alpha, sigma = self.alpha(t), self.sigma(t)
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
def w(self, t):
return self.sigma(t)