import random

import numpy as np
import torch

from videosys.utils.logging import logger

PAB_MANAGER = None


class PABConfig:
    def __init__(
        self,
        steps: int,
        cross_broadcast: bool,
        cross_threshold: list,
        cross_gap: int,
        spatial_broadcast: bool,
        spatial_threshold: list,
        spatial_gap: int,
        temporal_broadcast: bool,
        temporal_threshold: list,
        temporal_gap: int,
        diffusion_skip: bool,
        diffusion_timestep_respacing: list,
        diffusion_skip_timestep: list,
        mlp_skip: bool,
        mlp_spatial_skip_config: dict,
        mlp_temporal_skip_config: dict,
        full_broadcast: bool = False,
        full_threshold: list = None,
        full_gap: int = 1,
    ):
        self.steps = steps

        self.cross_broadcast = cross_broadcast
        self.cross_threshold = cross_threshold
        self.cross_gap = cross_gap

        self.spatial_broadcast = spatial_broadcast
        self.spatial_threshold = spatial_threshold
        self.spatial_gap = spatial_gap

        self.temporal_broadcast = temporal_broadcast
        self.temporal_threshold = temporal_threshold
        self.temporal_gap = temporal_gap

        self.diffusion_skip = diffusion_skip
        self.diffusion_timestep_respacing = diffusion_timestep_respacing
        self.diffusion_skip_timestep = diffusion_skip_timestep

        self.mlp_skip = mlp_skip
        self.mlp_spatial_skip_config = mlp_spatial_skip_config
        self.mlp_temporal_skip_config = mlp_temporal_skip_config

        self.temporal_mlp_outputs = {}
        self.spatial_mlp_outputs = {}
        
        self.full_broadcast = full_broadcast
        self.full_threshold = full_threshold
        self.full_gap = full_gap


class PABManager:
    def __init__(self, config: PABConfig):
        self.config: PABConfig = config

        init_prompt = f"Init PABManager. steps: {config.steps}."
        init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
        init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
        init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
        init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
        logger.info(init_prompt)

    def if_broadcast_cross(self, timestep: int, count: int):
        if (
            self.config.cross_broadcast
            and (timestep is not None)
            and (count % self.config.cross_gap != 0)
            and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
        ):
            flag = True
        else:
            flag = False
        count = (count + 1) % self.config.steps
        return flag, count

    def if_broadcast_temporal(self, timestep: int, count: int):
        if (
            self.config.temporal_broadcast
            and (timestep is not None)
            and (count % self.config.temporal_gap != 0)
            and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
        ):
            flag = True
        else:
            flag = False
        count = (count + 1) % self.config.steps
        return flag, count

    def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
        if (
            self.config.spatial_broadcast
            and (timestep is not None)
            and (count % self.config.spatial_gap != 0)
            and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
        ):
            flag = True
        else:
            flag = False
        count = (count + 1) % self.config.steps
        return flag, count

    def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
        if (
            self.config.full_broadcast
            and (timestep is not None)
            and (count % self.config.full_gap != 0)
            and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
        ):
            flag = True
        else:
            flag = False
        count = (count + 1) % self.config.steps
        return flag, count

    @staticmethod
    def _is_t_in_skip_config(all_timesteps, timestep, config):
        is_t_in_skip_config = False
        for key in config:
            if key not in all_timesteps:
                continue
            index = all_timesteps.index(key)
            skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
            if timestep in skip_range:
                is_t_in_skip_config = True
                skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
                break
        return is_t_in_skip_config, skip_range

    def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
        if not self.config.mlp_skip:
            return False, None, False, None

        if is_temporal:
            cur_config = self.config.mlp_temporal_skip_config
        else:
            cur_config = self.config.mlp_spatial_skip_config

        is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
        next_flag = False
        if (
            self.config.mlp_skip
            and (timestep is not None)
            and (timestep in cur_config)
            and (block_idx in cur_config[timestep]["block"])
        ):
            flag = False
            next_flag = True
            count = count + 1
        elif (
            self.config.mlp_skip
            and (timestep is not None)
            and (is_t_in_skip_config)
            and (block_idx in cur_config[skip_range[0]]["block"])
        ):
            flag = True
            count = 0
        else:
            flag = False

        return flag, count, next_flag, skip_range

    def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
        if is_temporal:
            self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
        else:
            self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output

    def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
        skip_start_t = skip_range[0]
        if is_temporal:
            skip_output = (
                self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
                if self.config.temporal_mlp_outputs is not None
                else None
            )
        else:
            skip_output = (
                self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
                if self.config.spatial_mlp_outputs is not None
                else None
            )

        if skip_output is not None:
            if timestep == skip_range[-1]:
                # TODO: save memory
                if is_temporal:
                    del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
                else:
                    del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
        else:
            raise ValueError(
                f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
            )

        return skip_output

    def get_spatial_mlp_outputs(self):
        return self.config.spatial_mlp_outputs

    def get_temporal_mlp_outputs(self):
        return self.config.temporal_mlp_outputs


def set_pab_manager(config: PABConfig):
    global PAB_MANAGER
    PAB_MANAGER = PABManager(config)


def enable_pab():
    if PAB_MANAGER is None:
        return False
    return (
        PAB_MANAGER.config.cross_broadcast
        or PAB_MANAGER.config.spatial_broadcast
        or PAB_MANAGER.config.temporal_broadcast
    )


def update_steps(steps: int):
    if PAB_MANAGER is not None:
        PAB_MANAGER.config.steps = steps


def if_broadcast_cross(timestep: int, count: int):
    if not enable_pab():
        return False, count
    return PAB_MANAGER.if_broadcast_cross(timestep, count)


def if_broadcast_temporal(timestep: int, count: int):
    if not enable_pab():
        return False, count
    return PAB_MANAGER.if_broadcast_temporal(timestep, count)


def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
    if not enable_pab():
        return False, count
    return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)

def if_broadcast_full(timestep: int, count: int, block_idx: int):
    if not enable_pab():
        return False, count
    return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)


def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
    if not enable_pab():
        return False, count
    return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)


def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
    return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)


def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
    return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)


def get_diffusion_skip():
    return enable_pab() and PAB_MANAGER.config.diffusion_skip


def get_diffusion_timestep_respacing():
    return PAB_MANAGER.config.diffusion_timestep_respacing


def get_diffusion_skip_timestep():
    return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep


def space_timesteps(time_steps, time_bins):
    num_bins = len(time_bins)
    bin_size = time_steps // num_bins

    result = []

    for i, bin_count in enumerate(time_bins):
        start = i * bin_size
        end = start + bin_size

        bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
        result.extend(bin_steps)

    result_tensor = torch.tensor(result, dtype=torch.int32)
    sorted_tensor = torch.sort(result_tensor, descending=True).values

    return sorted_tensor


def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
    if isinstance(timesteps, list):
        # If timesteps is a list, we assume each element is a tensor
        timesteps_np = [t.cpu().numpy() for t in timesteps]
        device = timesteps[0].device
    else:
        # If timesteps is a tensor
        timesteps_np = timesteps.cpu().numpy()
        device = timesteps.device

    num_bins = len(diffusion_skip_timestep)

    if isinstance(timesteps_np, list):
        bin_size = len(timesteps_np) // num_bins
        new_timesteps = []

        for i in range(num_bins):
            bin_start = i * bin_size
            bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
            bin_timesteps = timesteps_np[bin_start:bin_end]

            if diffusion_skip_timestep[i] == 0:
                # If the bin is marked with 0, keep all timesteps
                new_timesteps.extend(bin_timesteps)
            elif diffusion_skip_timestep[i] == 1:
                # If the bin is marked with 1, omit the last timestep in the bin
                new_timesteps.extend(bin_timesteps[1:])

        new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
    else:
        bin_size = len(timesteps_np) // num_bins
        new_timesteps = []

        for i in range(num_bins):
            bin_start = i * bin_size
            bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
            bin_timesteps = timesteps_np[bin_start:bin_end]

            if diffusion_skip_timestep[i] == 0:
                # If the bin is marked with 0, keep all timesteps
                new_timesteps.extend(bin_timesteps)
            elif diffusion_skip_timestep[i] == 1:
                # If the bin is marked with 1, omit the last timestep in the bin
                new_timesteps.extend(bin_timesteps[1:])
            elif diffusion_skip_timestep[i] != 0:
                # If the bin is marked with a non-zero value, randomly omit n timesteps
                if len(bin_timesteps) > diffusion_skip_timestep[i]:
                    indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
                    timesteps_to_keep = [
                        timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
                    ]
                else:
                    timesteps_to_keep = bin_timesteps  # 如果bin_timesteps的长度小于等于n,则不删除任何元素
                new_timesteps.extend(timesteps_to_keep)

        new_timesteps_tensor = torch.tensor(new_timesteps, device=device)

    if isinstance(timesteps, list):
        return new_timesteps_tensor
    else:
        return new_timesteps_tensor