import torch

from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *

from typing import Callable


def shift_respace_fn(t, shift=3.0):
    return t / (t + (1 - t) * shift)

def ode_step_fn(x, v, dt, s, w):
    return x + v * dt


import logging
logger = logging.getLogger(__name__)

class EulerSampler(BaseSampler):
    def __init__(
            self,
            w_scheduler: BaseScheduler = None,
            timeshift=1.0,
            guidance_interval_min: float = 0.0,
            guidance_interval_max: float = 1.0,
            state_refresh_rate=1,
            step_fn: Callable = ode_step_fn,
            last_step=None,
            last_step_fn: Callable = ode_step_fn,
            *args,
            **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.step_fn = step_fn
        self.last_step = last_step
        self.last_step_fn = last_step_fn
        self.w_scheduler = w_scheduler
        self.timeshift = timeshift
        self.state_refresh_rate = state_refresh_rate
        self.guidance_interval_min = guidance_interval_min
        self.guidance_interval_max = guidance_interval_max

        if self.last_step is None or self.num_steps == 1:
            self.last_step = 1.0 / self.num_steps

        timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
        timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
        self.timesteps = shift_respace_fn(timesteps, self.timeshift)

        assert self.last_step > 0.0
        assert self.scheduler is not None
        assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
        if self.w_scheduler is not None:
            if self.step_fn == ode_step_fn:
                logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")

        # init recompute
        self.recompute_timesteps = list(range(self.num_steps))

    def sharing_dp(self, net, noise, condition, uncondition):
        _, C, H, W = noise.shape
        B = 8
        template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
        template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
        template_uncondition = torch.full((B, ), 1000, device=condition.device)
        _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
        states = torch.stack(state_list)
        N, B, L, C = states.shape
        states = states.view(N, B*L, C )
        states = states.permute(1, 0, 2)
        states = torch.nn.functional.normalize(states, dim=-1)
        with torch.autocast(device_type="cuda", dtype=torch.float64):
            sim = torch.bmm(states, states.transpose(1, 2))
        sim = torch.mean(sim, dim=0).cpu()
        error_map = (1-sim).tolist()

        # init cum-error
        for i in range(1, self.num_steps):
            for j in range(0, i):
                error_map[i][j] = error_map[i-1][j] + error_map[i][j]

        # init dp and force 0 start
        C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
        P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
        for i in range(1, self.num_steps+1):
            C[1][i] = error_map[i - 1][0]
            P[1][i] = 0

        # dp state
        for step in range(2, self.num_recompute_timesteps+1):
            for i in range(step, self.num_steps+1):
                min_value = 99999
                min_index = -1
                for j in range(step-1, i):
                    value = C[step-1][j] + error_map[i-1][j]
                    if value < min_value:
                        min_value = value
                        min_index = j
                C[step][i] = min_value
                P[step][i] = min_index

        # trace back
        timesteps = [self.num_steps,]
        for i in range(self.num_recompute_timesteps, 0, -1):
            idx = timesteps[-1]
            timesteps.append(P[i][idx])
        timesteps.reverse()

        print("recompute timesteps solved by DP: ", timesteps)
        return timesteps[:-1][:self.num_recompute_timesteps]

    def _impl_sampling(self, net, noise, condition, uncondition):
        """
        sampling process of Euler sampler
        -
        """
        batch_size = noise.shape[0]
        steps = self.timesteps.to(noise.device)
        cfg_condition = torch.cat([uncondition, condition], dim=0)
        x = noise
        state = None
        pooled_state_list = []
        for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
            dt = t_next - t_cur
            t_cur = t_cur.repeat(batch_size)
            cfg_x = torch.cat([x, x], dim=0)
            cfg_t = t_cur.repeat(2)
            if i in self.recompute_timesteps:
                state = None
            out, state = net(cfg_x, cfg_t, cfg_condition, state)
            if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
                out = self.guidance_fn(out, self.guidance)
            else:
                out = self.guidance_fn(out, 1.0)
            v = out
            if i < self.num_steps -1 :
                x = self.step_fn(x, v, dt, s=0.0, w=0.0)
            else:
                x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
            pooled_state_list.append(state)
        return x, pooled_state_list

    def __call__(self, net, noise, condition, uncondition):
        self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
        if len(self.recompute_timesteps) != self.num_recompute_timesteps:
            self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
        denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
        return denoised