# Author: Haohe Liu
# Email: haoheliu@gmail.com
# Date: 11 Feb 2023

import os
import json

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
from scipy.io import wavfile
from matplotlib import pyplot as plt


matplotlib.use("Agg")

import hashlib
import os

import requests
from tqdm import tqdm

URL_MAP = {
    "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
    "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
    "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
}

CKPT_MAP = {
    "vggishish_lpaps": "vggishish16.pt",
    "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
    "melception": "melception-21-05-10T09-28-40.pt",
}

MD5_MAP = {
    "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
    "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
    "melception": "a71a41041e945b457c7d3d814bbcf72d",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_json(fname):
    with open(fname, "r") as f:
        data = json.load(f)
        return data


def read_json(dataset_json_file):
    with open(dataset_json_file, "r") as fp:
        data_json = json.load(fp)
    return data_json["data"]


def copy_test_subset_data(metadata, testset_copy_target_path):
    # metadata = read_json(testset_metadata)
    os.makedirs(testset_copy_target_path, exist_ok=True)
    if len(os.listdir(testset_copy_target_path)) == len(metadata):
        return
    else:
        # delete files in folder testset_copy_target_path
        for file in os.listdir(testset_copy_target_path):
            try:
                os.remove(os.path.join(testset_copy_target_path, file))
            except Exception as e:
                print(e)

    print("Copying test subset data to {}".format(testset_copy_target_path))
    for each in tqdm(metadata):
        cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
        os.system(cmd)


def listdir_nohidden(path):
    for f in os.listdir(path):
        if not f.startswith("."):
            yield f


def get_restore_step(path):
    checkpoints = os.listdir(path)
    if os.path.exists(os.path.join(path, "final.ckpt")):
        return "final.ckpt", 0
    elif not os.path.exists(os.path.join(path, "last.ckpt")):
        steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
        return checkpoints[np.argmax(steps)], np.max(steps)
    else:
        steps = []
        for x in checkpoints:
            if "last" in x:
                if "-v" not in x:
                    fname = "last.ckpt"
                else:
                    this_version = int(x.split(".ckpt")[0].split("-v")[1])
                    steps.append(this_version)
                    if len(steps) == 0 or this_version > np.max(steps):
                        fname = "last-v%s.ckpt" % this_version
        return fname, 0


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
    assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5
    return path


class KeyNotFoundError(Exception):
    def __init__(self, cause, keys=None, visited=None):
        self.cause = cause
        self.keys = keys
        self.visited = visited
        messages = list()
        if keys is not None:
            messages.append("Key not found: {}".format(keys))
        if visited is not None:
            messages.append("Visited: {}".format(visited))
        messages.append("Cause:\n{}".format(cause))
        message = "\n".join(messages)
        super().__init__(message)


def retrieve(
    list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
    """Given a nested list or dict return the desired value at key expanding
    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
    is done in-place.

    Parameters
    ----------
        list_or_dict : list or dict
            Possibly nested list or dictionary.
        key : str
            key/to/value, path like string describing all keys necessary to
            consider to get to the desired value. List indices can also be
            passed here.
        splitval : str
            String that defines the delimiter between keys of the
            different depth levels in `key`.
        default : obj
            Value returned if :attr:`key` is not found.
        expand : bool
            Whether to expand callable nodes on the path or not.

    Returns
    -------
        The desired value or if :attr:`default` is not ``None`` and the
        :attr:`key` is not found returns ``default``.

    Raises
    ------
        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
        ``None``.
    """

    keys = key.split(splitval)

    success = True
    try:
        visited = []
        parent = None
        last_key = None
        for key in keys:
            if callable(list_or_dict):
                if not expand:
                    raise KeyNotFoundError(
                        ValueError(
                            "Trying to get past callable node with expand=False."
                        ),
                        keys=keys,
                        visited=visited,
                    )
                list_or_dict = list_or_dict()
                parent[last_key] = list_or_dict

            last_key = key
            parent = list_or_dict

            try:
                if isinstance(list_or_dict, dict):
                    list_or_dict = list_or_dict[key]
                else:
                    list_or_dict = list_or_dict[int(key)]
            except (KeyError, IndexError, ValueError) as e:
                raise KeyNotFoundError(e, keys=keys, visited=visited)

            visited += [key]
        # final expansion of retrieved value
        if expand and callable(list_or_dict):
            list_or_dict = list_or_dict()
            parent[last_key] = list_or_dict
    except KeyNotFoundError as e:
        if default is None:
            raise e
        else:
            list_or_dict = default
            success = False

    if not pass_success:
        return list_or_dict
    else:
        return list_or_dict, success


def to_device(data, device):
    if len(data) == 12:
        (
            ids,
            raw_texts,
            speakers,
            texts,
            src_lens,
            max_src_len,
            mels,
            mel_lens,
            max_mel_len,
            pitches,
            energies,
            durations,
        ) = data

        speakers = torch.from_numpy(speakers).long().to(device)
        texts = torch.from_numpy(texts).long().to(device)
        src_lens = torch.from_numpy(src_lens).to(device)
        mels = torch.from_numpy(mels).float().to(device)
        mel_lens = torch.from_numpy(mel_lens).to(device)
        pitches = torch.from_numpy(pitches).float().to(device)
        energies = torch.from_numpy(energies).to(device)
        durations = torch.from_numpy(durations).long().to(device)

        return (
            ids,
            raw_texts,
            speakers,
            texts,
            src_lens,
            max_src_len,
            mels,
            mel_lens,
            max_mel_len,
            pitches,
            energies,
            durations,
        )

    if len(data) == 6:
        (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data

        speakers = torch.from_numpy(speakers).long().to(device)
        texts = torch.from_numpy(texts).long().to(device)
        src_lens = torch.from_numpy(src_lens).to(device)

        return (ids, raw_texts, speakers, texts, src_lens, max_src_len)


def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
    # if losses is not None:
    #     logger.add_scalar("Loss/total_loss", losses[0], step)
    #     logger.add_scalar("Loss/mel_loss", losses[1], step)
    #     logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
    #     logger.add_scalar("Loss/pitch_loss", losses[3], step)
    #     logger.add_scalar("Loss/energy_loss", losses[4], step)
    #     logger.add_scalar("Loss/duration_loss", losses[5], step)
    #     if(len(losses) > 6):
    #         logger.add_scalar("Loss/disc_loss", losses[6], step)
    #         logger.add_scalar("Loss/fmap_loss", losses[7], step)
    #         logger.add_scalar("Loss/r_loss", losses[8], step)
    #         logger.add_scalar("Loss/g_loss", losses[9], step)
    #         logger.add_scalar("Loss/gen_loss", losses[10], step)
    #         logger.add_scalar("Loss/diff_loss", losses[11], step)

    if fig is not None:
        logger.add_figure(tag, fig)

    if audio is not None:
        audio = audio / (max(abs(audio)) * 1.1)
        logger.add_audio(
            tag,
            audio,
            sample_rate=sampling_rate,
        )


def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()

    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask


def expand(values, durations):
    out = list()
    for value, d in zip(values, durations):
        out += [value] * max(0, int(d))
    return np.array(out)


def synth_one_sample_val(
    targets, predictions, vocoder, model_config, preprocess_config
):
    index = np.random.choice(list(np.arange(targets[6].size(0))))

    basename = targets[0][index]
    src_len = predictions[8][index].item()
    mel_len = predictions[9][index].item()
    mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)

    mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
    postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
    duration = targets[11][index, :src_len].detach().cpu().numpy()

    if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
        pitch = predictions[2][index, :src_len].detach().cpu().numpy()
        pitch = expand(pitch, duration)
    else:
        pitch = predictions[2][index, :mel_len].detach().cpu().numpy()

    if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
        energy = predictions[3][index, :src_len].detach().cpu().numpy()
        energy = expand(energy, duration)
    else:
        energy = predictions[3][index, :mel_len].detach().cpu().numpy()

    with open(
        os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
    ) as f:
        stats = json.load(f)
        stats = stats["pitch"] + stats["energy"][:2]

    # from datetime import datetime
    # now = datetime.now()
    # current_time = now.strftime("%D:%H:%M:%S")
    # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
    # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
    # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())

    fig = plot_mel(
        [
            (mel_prediction.cpu().numpy(), pitch, energy),
            (postnet_mel_prediction.cpu().numpy(), pitch, energy),
            (mel_target.cpu().numpy(), pitch, energy),
        ],
        stats,
        [
            "Raw mel spectrogram prediction",
            "Postnet mel prediction",
            "Ground-Truth Spectrogram",
        ],
    )

    if vocoder is not None:
        from .model import vocoder_infer

        wav_reconstruction = vocoder_infer(
            mel_target.unsqueeze(0),
            vocoder,
            model_config,
            preprocess_config,
        )[0]
        wav_prediction = vocoder_infer(
            postnet_mel_prediction.unsqueeze(0),
            vocoder,
            model_config,
            preprocess_config,
        )[0]
    else:
        wav_reconstruction = wav_prediction = None

    return fig, wav_reconstruction, wav_prediction, basename


def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
    if vocoder is not None:
        from .model import vocoder_infer

        wav_reconstruction = vocoder_infer(
            mel_input.permute(0, 2, 1),
            vocoder,
        )
        wav_prediction = vocoder_infer(
            mel_prediction.permute(0, 2, 1),
            vocoder,
        )
    else:
        wav_reconstruction = wav_prediction = None

    return wav_reconstruction, wav_prediction


def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
    # (diff_output, diff_loss, latent_loss) = diffusion

    basenames = targets[0]

    for i in range(len(predictions[1])):
        basename = basenames[i]
        src_len = predictions[8][i].item()
        mel_len = predictions[9][i].item()
        mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
        # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
        # duration = predictions[5][i, :src_len].detach().cpu().numpy()
        if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
            pitch = predictions[2][i, :src_len].detach().cpu().numpy()
            # pitch = expand(pitch, duration)
        else:
            pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
        if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
            energy = predictions[3][i, :src_len].detach().cpu().numpy()
            # energy = expand(energy, duration)
        else:
            energy = predictions[3][i, :mel_len].detach().cpu().numpy()
        # import ipdb; ipdb.set_trace()
        with open(
            os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
        ) as f:
            stats = json.load(f)
            stats = stats["pitch"] + stats["energy"][:2]

        fig = plot_mel(
            [
                (mel_prediction.cpu().numpy(), pitch, energy),
            ],
            stats,
            ["Synthetized Spectrogram by PostNet"],
        )
        # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
        plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
        plt.close()

    from .model import vocoder_infer

    mel_predictions = predictions[1].transpose(1, 2)
    lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
    wav_predictions = vocoder_infer(
        mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
    )

    sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
    for wav, basename in zip(wav_predictions, basenames):
        wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)


def plot_mel(data, titles=None):
    fig, axes = plt.subplots(len(data), 1, squeeze=False)
    if titles is None:
        titles = [None for i in range(len(data))]

    for i in range(len(data)):
        mel = data[i]
        axes[i][0].imshow(mel, origin="lower", aspect="auto")
        axes[i][0].set_aspect(2.5, adjustable="box")
        axes[i][0].set_ylim(0, mel.shape[0])
        axes[i][0].set_title(titles[i], fontsize="medium")
        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
        axes[i][0].set_anchor("W")

    return fig


def pad_1D(inputs, PAD=0):
    def pad_data(x, length, PAD):
        x_padded = np.pad(
            x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
        )
        return x_padded

    max_len = max((len(x) for x in inputs))
    padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])

    return padded


def pad_2D(inputs, maxlen=None):
    def pad(x, max_len):
        PAD = 0
        if np.shape(x)[0] > max_len:
            raise ValueError("not max_len")

        s = np.shape(x)[1]
        x_padded = np.pad(
            x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
        )
        return x_padded[:, :s]

    if maxlen:
        output = np.stack([pad(x, maxlen) for x in inputs])
    else:
        max_len = max(np.shape(x)[0] for x in inputs)
        output = np.stack([pad(x, max_len) for x in inputs])

    return output


def pad(input_ele, mel_max_length=None):
    if mel_max_length:
        max_len = mel_max_length
    else:
        max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])

    out_list = list()
    for i, batch in enumerate(input_ele):
        if len(batch.shape) == 1:
            one_batch_padded = F.pad(
                batch, (0, max_len - batch.size(0)), "constant", 0.0
            )
        elif len(batch.shape) == 2:
            one_batch_padded = F.pad(
                batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
            )
        out_list.append(one_batch_padded)
    out_padded = torch.stack(out_list)
    return out_padded