from pathlib import Path
import json
from math import sqrt
import numpy as np
import torch
from abc import ABCMeta, abstractmethod


class ScoreAdapter(metaclass=ABCMeta):

    @abstractmethod
    def denoise(self, xs, σ, **kwargs):
        pass

    def score(self, xs, σ, **kwargs):
        Ds = self.denoise(xs, σ, **kwargs)
        grad_log_p_t = (Ds - xs) / (σ ** 2)
        return grad_log_p_t

    @abstractmethod
    def data_shape(self):
        return (3, 256, 256)  # for example

    def samps_centered(self):
        # if centered, samples expected to be in range [-1, 1], else [0, 1]
        return True

    @property
    @abstractmethod
    def σ_max(self):
        pass

    @property
    @abstractmethod
    def σ_min(self):
        pass

    def cond_info(self, batch_size):
        return {}

    @abstractmethod
    def unet_is_cond(self):
        return False

    @abstractmethod
    def use_cls_guidance(self):
        return False  # most models do not use cls guidance

    def classifier_grad(self, xs, σ, ys):
        raise NotImplementedError()

    @abstractmethod
    def snap_t_to_nearest_tick(self, t):
        # need to confirm for each model; continuous time model doesn't need this
        return t, None

    @property
    def device(self):
        return self._device

    def checkpoint_root(self):
        """the path at which the pretrained checkpoints are stored"""
        with Path(__file__).resolve().with_name("env.json").open("r") as f:
            root = json.load(f)
        return root


def karras_t_schedule(ρ=7, N=10, σ_max=80, σ_min=0.002):
    ts = []
    for i in range(N):

        t = (
            σ_max ** (1 / ρ) + (i / (N - 1)) * (σ_min ** (1 / ρ) - σ_max ** (1 / ρ))
        ) ** ρ
        ts.append(t)
    return ts


def power_schedule(σ_max, σ_min, num_stages):
    σs = np.exp(np.linspace(np.log(σ_max), np.log(σ_min), num_stages))
    return σs


class Karras():

    @classmethod
    @torch.no_grad()
    def inference(
        cls, model, batch_size, num_t, *,
        σ_max=80, cls_scaling=1,
        init_xs=None, heun=True,
        langevin=False,
        S_churn=80, S_min=0.05, S_max=50, S_noise=1.003,
    ):
        σ_max = min(σ_max, model.σ_max)
        σ_min = model.σ_min
        ts = karras_t_schedule(ρ=7, N=num_t, σ_max=σ_max, σ_min=σ_min)
        assert len(ts) == num_t
        ts = [model.snap_t_to_nearest_tick(t)[0] for t in ts]
        ts.append(0)  # 0 is the destination
        σ_max = ts[0]

        cond_inputs = model.cond_info(batch_size)

        def compute_step(xs, σ):
            grad_log_p_t = model.score(
                xs, σ, **(cond_inputs if model.unet_is_cond() else {})
            )
            if model.use_cls_guidance():
                grad_cls = model.classifier_grad(xs, σ, cond_inputs["y"])
                grad_cls = grad_cls * cls_scaling
                grad_log_p_t += grad_cls
            d_i = -1 * σ * grad_log_p_t
            return d_i

        if init_xs is not None:
            xs = init_xs.to(model.device)
        else:
            xs = σ_max * torch.randn(
                batch_size, *model.data_shape(), device=model.device
            )

        yield xs

        for i in range(num_t):
            t_i = ts[i]

            if langevin and (S_min < t_i and t_i < S_max):
                xs, t_i = cls.noise_backward_in_time(
                    model, xs, t_i, S_noise, S_churn / num_t
                )

            Δt = ts[i+1] - t_i

            d_1 = compute_step(xs, σ=t_i)
            xs_1 = xs + Δt * d_1

            # Heun's 2nd order method; don't apply on the last step
            if (not heun) or (ts[i+1] == 0):
                xs = xs_1
            else:
                d_2 = compute_step(xs_1, σ=ts[i+1])
                xs = xs + Δt * (d_1 + d_2) / 2

            yield xs

    @staticmethod
    def noise_backward_in_time(model, xs, t_i, S_noise, S_churn_i):
        n = S_noise * torch.randn_like(xs)
        γ_i = min(sqrt(2)-1, S_churn_i)
        t_i_hat = t_i * (1 + γ_i)
        t_i_hat = model.snap_t_to_nearest_tick(t_i_hat)[0]
        xs = xs + n * sqrt(t_i_hat ** 2 - t_i ** 2)
        return xs, t_i_hat


def test():
    pass


if __name__ == "__main__":
    test()