import typing

import julius
import numpy as np
import torch

from . import util


class DSPMixin:
    _original_batch_size = None
    _original_num_channels = None
    _padded_signal_length = None

    def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
        self._original_batch_size = self.batch_size
        self._original_num_channels = self.num_channels

        window_length = int(window_duration * self.sample_rate)
        hop_length = int(hop_duration * self.sample_rate)

        if window_length % hop_length != 0:
            factor = window_length // hop_length
            window_length = factor * hop_length

        self.zero_pad(hop_length, hop_length)
        self._padded_signal_length = self.signal_length

        return window_length, hop_length

    def windows(
        self, window_duration: float, hop_duration: float, preprocess: bool = True
    ):
        """Generator which yields windows of specified duration from signal with a specified
        hop length.

        Parameters
        ----------
        window_duration : float
            Duration of every window in seconds.
        hop_duration : float
            Hop between windows in seconds.
        preprocess : bool, optional
            Whether to preprocess the signal, so that the first sample is in
            the middle of the first window, by default True

        Yields
        ------
        AudioSignal
            Each window is returned as an AudioSignal.
        """
        if preprocess:
            window_length, hop_length = self._preprocess_signal_for_windowing(
                window_duration, hop_duration
            )

        self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)

        for b in range(self.batch_size):
            i = 0
            start_idx = i * hop_length
            while True:
                start_idx = i * hop_length
                i += 1
                end_idx = start_idx + window_length
                if end_idx > self.signal_length:
                    break
                yield self[b, ..., start_idx:end_idx]

    def collect_windows(
        self, window_duration: float, hop_duration: float, preprocess: bool = True
    ):
        """Reshapes signal into windows of specified duration from signal with a specified
        hop length. Window are placed along the batch dimension. Use with
        :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
        original signal.

        Parameters
        ----------
        window_duration : float
            Duration of every window in seconds.
        hop_duration : float
            Hop between windows in seconds.
        preprocess : bool, optional
            Whether to preprocess the signal, so that the first sample is in
            the middle of the first window, by default True

        Returns
        -------
        AudioSignal
            AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
        """
        if preprocess:
            window_length, hop_length = self._preprocess_signal_for_windowing(
                window_duration, hop_duration
            )

        # self.audio_data: (nb, nch, nt).
        unfolded = torch.nn.functional.unfold(
            self.audio_data.reshape(-1, 1, 1, self.signal_length),
            kernel_size=(1, window_length),
            stride=(1, hop_length),
        )
        # unfolded: (nb * nch, window_length, num_windows).
        # -> (nb * nch * num_windows, 1, window_length)
        unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
        self.audio_data = unfolded
        return self

    def overlap_and_add(self, hop_duration: float):
        """Function which takes a list of windows and overlap adds them into a
        signal the same length as ``audio_signal``.

        Parameters
        ----------
        hop_duration : float
            How much to shift for each window
            (overlap is window_duration - hop_duration) in seconds.

        Returns
        -------
        AudioSignal
            overlap-and-added signal.
        """
        hop_length = int(hop_duration * self.sample_rate)
        window_length = self.signal_length

        nb, nch = self._original_batch_size, self._original_num_channels

        unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
        folded = torch.nn.functional.fold(
            unfolded,
            output_size=(1, self._padded_signal_length),
            kernel_size=(1, window_length),
            stride=(1, hop_length),
        )

        norm = torch.ones_like(unfolded, device=unfolded.device)
        norm = torch.nn.functional.fold(
            norm,
            output_size=(1, self._padded_signal_length),
            kernel_size=(1, window_length),
            stride=(1, hop_length),
        )

        folded = folded / norm

        folded = folded.reshape(nb, nch, -1)
        self.audio_data = folded
        self.trim(hop_length, hop_length)
        return self

    def low_pass(
        self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
    ):
        """Low-passes the signal in-place. Each item in the batch
        can have a different low-pass cutoff, if the input
        to this signal is an array or tensor. If a float, all
        items are given the same low-pass filter.

        Parameters
        ----------
        cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
            Cutoff in Hz of low-pass filter.
        zeros : int, optional
            Number of taps to use in low-pass filter, by default 51

        Returns
        -------
        AudioSignal
            Low-passed AudioSignal.
        """
        cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
        cutoffs = cutoffs / self.sample_rate
        filtered = torch.empty_like(self.audio_data)

        for i, cutoff in enumerate(cutoffs):
            lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
            filtered[i] = lp_filter(self.audio_data[i])

        self.audio_data = filtered
        self.stft_data = None
        return self

    def high_pass(
        self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
    ):
        """High-passes the signal in-place. Each item in the batch
        can have a different high-pass cutoff, if the input
        to this signal is an array or tensor. If a float, all
        items are given the same high-pass filter.

        Parameters
        ----------
        cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
            Cutoff in Hz of high-pass filter.
        zeros : int, optional
            Number of taps to use in high-pass filter, by default 51

        Returns
        -------
        AudioSignal
            High-passed AudioSignal.
        """
        cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
        cutoffs = cutoffs / self.sample_rate
        filtered = torch.empty_like(self.audio_data)

        for i, cutoff in enumerate(cutoffs):
            hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
            filtered[i] = hp_filter(self.audio_data[i])

        self.audio_data = filtered
        self.stft_data = None
        return self

    def mask_frequencies(
        self,
        fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
        fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
        val: float = 0.0,
    ):
        """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
        with the value specified by ``val``. Useful for implementing SpecAug.
        The min and max can be different for every item in the batch.

        Parameters
        ----------
        fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
            Lower end of band to mask out.
        fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
            Upper end of band to mask out.
        val : float, optional
            Value to fill in, by default 0.0

        Returns
        -------
        AudioSignal
            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
            masked audio data.
        """
        # SpecAug
        mag, phase = self.magnitude, self.phase
        fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
        fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
        assert torch.all(fmin_hz < fmax_hz)

        # build mask
        nbins = mag.shape[-2]
        bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
        bins_hz = bins_hz[None, None, :, None].repeat(
            self.batch_size, 1, 1, mag.shape[-1]
        )
        mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
        mask = mask.to(self.device)

        mag = mag.masked_fill(mask, val)
        phase = phase.masked_fill(mask, val)
        self.stft_data = mag * torch.exp(1j * phase)
        return self

    def mask_timesteps(
        self,
        tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
        tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
        val: float = 0.0,
    ):
        """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
        with the value specified by ``val``. Useful for implementing SpecAug.
        The min and max can be different for every item in the batch.

        Parameters
        ----------
        tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
            Lower end of timesteps to mask out.
        tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
            Upper end of timesteps to mask out.
        val : float, optional
            Value to fill in, by default 0.0

        Returns
        -------
        AudioSignal
            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
            masked audio data.
        """
        # SpecAug
        mag, phase = self.magnitude, self.phase
        tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
        tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)

        assert torch.all(tmin_s < tmax_s)

        # build mask
        nt = mag.shape[-1]
        bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
        bins_t = bins_t[None, None, None, :].repeat(
            self.batch_size, 1, mag.shape[-2], 1
        )
        mask = (tmin_s <= bins_t) & (bins_t < tmax_s)

        mag = mag.masked_fill(mask, val)
        phase = phase.masked_fill(mask, val)
        self.stft_data = mag * torch.exp(1j * phase)
        return self

    def mask_low_magnitudes(
        self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
    ):
        """Mask away magnitudes below a specified threshold, which
        can be different for every item in the batch.

        Parameters
        ----------
        db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
            Decibel value for which things below it will be masked away.
        val : float, optional
            Value to fill in for masked portions, by default 0.0

        Returns
        -------
        AudioSignal
            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
            masked audio data.
        """
        mag = self.magnitude
        log_mag = self.log_magnitude()

        db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
        mask = log_mag < db_cutoff
        mag = mag.masked_fill(mask, val)

        self.magnitude = mag
        return self

    def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
        """Shifts the phase by a constant value.

        Parameters
        ----------
        shift : typing.Union[torch.Tensor, np.ndarray, float]
            What to shift the phase by.

        Returns
        -------
        AudioSignal
            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
            masked audio data.
        """
        shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
        self.phase = self.phase + shift
        return self

    def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
        """Corrupts the phase randomly by some scaled value.

        Parameters
        ----------
        scale : typing.Union[torch.Tensor, np.ndarray, float]
            Standard deviation of noise to add to the phase.

        Returns
        -------
        AudioSignal
            Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
            masked audio data.
        """
        scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
        self.phase = self.phase + scale * torch.randn_like(self.phase)
        return self

    def preemphasis(self, coef: float = 0.85):
        """Applies pre-emphasis to audio signal.

        Parameters
        ----------
        coef : float, optional
            How much pre-emphasis to apply, lower values do less. 0 does nothing.
            by default 0.85

        Returns
        -------
        AudioSignal
            Pre-emphasized signal.
        """
        kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
        x = self.audio_data.reshape(-1, 1, self.signal_length)
        x = torch.nn.functional.conv1d(x, kernel, padding=1)
        self.audio_data = x.reshape(*self.audio_data.shape)
        return self