import math
import torch
from src.diffusion.base.scheduling import *


class DDPMScheduler(BaseScheduler):
    def __init__(
            self,
            beta_min=0.0001,
            beta_max=0.02,
            num_steps=1000,
    ):
        super().__init__()
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.num_steps = num_steps

        self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
        self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
        self.sigmas_table = 1-self.alphas_table


    def beta(self, t) -> Tensor:
        t = t.to(torch.long)
        return self.betas_table[t].view(-1, 1, 1, 1)

    def alpha(self, t) -> Tensor:
        t = t.to(torch.long)
        return self.alphas_table[t].view(-1, 1, 1, 1)**0.5

    def sigma(self, t) -> Tensor:
        t = t.to(torch.long)
        return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5

    def dsigma(self, t) -> Tensor:
        raise NotImplementedError("wrong usage")

    def dalpha_over_alpha(self, t) ->Tensor:
        raise NotImplementedError("wrong usage")

    def dsigma_mul_sigma(self, t) ->Tensor:
        raise NotImplementedError("wrong usage")

    def dalpha(self, t) -> Tensor:
        raise NotImplementedError("wrong usage")

    def drift_coefficient(self, t):
        raise NotImplementedError("wrong usage")

    def diffuse_coefficient(self, t):
        raise NotImplementedError("wrong usage")

    def w(self, t):
        raise NotImplementedError("wrong usage")


class VPScheduler(BaseScheduler):
    def __init__(
            self,
            beta_min=0.1,
            beta_max=20,
    ):
        super().__init__()
        self.beta_min = beta_min
        self.beta_d = beta_max - beta_min
    def beta(self, t) -> Tensor:
        t = torch.clamp(t, min=1e-3, max=1)
        return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)

    def sigma(self, t) -> Tensor:
        t = torch.clamp(t, min=1e-3, max=1)
        inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
        return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)

    def dsigma(self, t) -> Tensor:
        raise NotImplementedError("wrong usage")

    def dalpha_over_alpha(self, t) ->Tensor:
        raise NotImplementedError("wrong usage")

    def dsigma_mul_sigma(self, t) ->Tensor:
        raise NotImplementedError("wrong usage")

    def dalpha(self, t) -> Tensor:
        raise NotImplementedError("wrong usage")

    def alpha(self, t) -> Tensor:
        t = torch.clamp(t, min=1e-3, max=1)
        inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
        return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)

    def drift_coefficient(self, t):
        raise NotImplementedError("wrong usage")

    def diffuse_coefficient(self, t):
       raise NotImplementedError("wrong usage")

    def w(self, t):
        return self.diffuse_coefficient(t)