"""
Alle transforms sind grundsätzlich auf batches bezogen!
Vae transforms sind invertierbar
"""
import pickle
from dataclasses import dataclass
from functools import partial, reduce, wraps

import numpy as np
import torch

# Allgemeine Funktionen -------------------------------------------------------------
# Transformations in Pytorch sind am einfachsten.


def load(p):
    with open(p, "rb") as stream:
        return pickle.load(stream)


def save(obj, p):
    with open(p, "wb") as stream:
        pickle.dump(obj, stream)


def sequential_function(*functions):
    return lambda x: reduce(lambda res, func: func(res), functions, x)


def np_sample(func):
    rtn = sequential_function(
        lambda x: torch.from_numpy(x).float(),
        lambda x: torch.unsqueeze(x, 0),
        func,
        lambda x: x[0].numpy(),
    )
    return rtn


# Inverseabvle
class SequentialInversable(torch.nn.Sequential):
    def __init__(self, *functions):
        super().__init__(*functions)

        self.inv_funcs = [f.inv for f in functions]
        self.inv_funcs.reverse()

    # def forward(self, x):
    #     return sequential_function(*self.functions)(x)

    def inv(self, x):
        return sequential_function(*self.inv_funcs)(x)


class LatentSelector(torch.nn.Module):
    """Verarbeitet Tensoren und numpy arrays"""

    def __init__(self, ldim: int, selectdim: int):
        super().__init__()
        self.ldim = ldim
        self.selectdim = selectdim

    def forward(self, x: torch.Tensor):
        return x[:, : self.selectdim]

    def inv(self, x: torch.Tensor):
        rtn = torch.cat(
            [x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
            dim=1,
        )
        return rtn


class MinMaxScaler(torch.nn.Module):
    #! Bei mehreren Signalen vorsicht mit dem Broadcasting.
    def __init__(
        self,
        _min: torch.Tensor,
        _max: torch.Tensor,
        min_norm: float = 0.0,
        max_norm: float = 1.0,
    ):
        super().__init__()
        self._min = _min
        self._max = _max
        self.min_norm = min_norm
        self.max_norm = max_norm

    def forward(self, ts):
        """None, no_signals"""
        std = (ts - self._min) / (self._max - self._min)
        rtn = std * (self.max_norm - self.min_norm) + self.min_norm
        return rtn

    def inv(self, ts):
        std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
        rtn = std * (self._max - self._min) + self._min
        return rtn

    @classmethod
    def from_array(cls, arr: torch.Tensor):
        _min = torch.min(arr, axis=0).values
        _max = torch.max(arr, axis=0).values

        return cls(_min, _max)


class LatentSorter(torch.nn.Module):
    def __init__(self, kl_dict: dict):
        super().__init__()
        self.kl_dict = kl_dict

    def forward(self, latent):
        """
        unsorted -> sorted
        latent: (None, latent_dim)
        """
        return latent[:, list(self.kl_dict.keys())]

    def inv(self, latent):
        keys = np.array(list(self.kl_dict.keys()))
        return latent[:, torch.from_numpy(keys.argsort())]

    @property
    def names(self):
        rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
        return rtn


def apply_along_axis(function, x, axis: int = 0):
    return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)


# Eingangsshapes bleiben wie sie sind!
class SumField(torch.nn.Module):
    """
    time series: [idx, time_step, signal]
    image: [idx, signal, time_step, time_step]
    """

    def forward(self, ts: torch.Tensor):
        """ts2img"""

        samples = ts.shape[0]
        time = ts.shape[1]
        channels = ts.shape[2]

        ts = torch.swapaxes(ts, 1, 2)  # Zeitachse ans Ende
        ts = torch.reshape(
            ts, (samples * channels, time)
        )  # Zusammenfassen von Channel + idx
        #! TODO: Schleife besser lösen
        rtn = apply_along_axis(self._mtf_forward, ts, 0)
        rtn = torch.reshape(rtn, (samples, channels, time, time))

        return rtn

    def inv(self, img: torch.Tensor):
        """img2ts"""
        rtn = torch.diagonal(img, dim1=2, dim2=3)
        rtn = torch.swapaxes(rtn, 1, 2)  # Channel und Zeitachse tauschen

        return rtn

    @staticmethod
    def _mtf_forward(ts):
        """For one dimensional time series ts"""
        return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2