import os
import types

import librosa
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

import htsat_config
from cog import BasePredictor, Input, Path
from data_processor import MusdbDataset
from models.asp_model import AutoTaggingWarpper, SeparatorModel, ZeroShotASP
from models.htsat import HTSAT_Swin_Transformer
from sed_model import SEDWrapper
from utils import prepprocess_audio

def get_inference_configs():
    config = types.SimpleNamespace()
    config.ckpt_path = "pretrained/zeroshot_asp_full.ckpt"
    config.sed_ckpt_path = "pretrained/htsat_audioset_2048d.ckpt"
    config.wave_output_path = "predict_outputs"
    config.test_key = "query_name"
    config.test_type = "mix"
    config.loss_type = "mae"
    config.infer_type = "mean"
    config.sample_rate = 32000
    config.segment_frames = 200
    config.hop_samples = 320
    config.energy_thres = 0.1
    config.using_whiting = False
    config.latent_dim = 2048
    config.classes_num = 527
    config.overlap_rate = 0.5
    config.num_workers = 1

    return config

def load_models(config):
    sed_model = HTSAT_Swin_Transformer(
        spec_size=htsat_config.htsat_spec_size,
        patch_size=htsat_config.htsat_patch_size,
        in_chans=1,
        num_classes=htsat_config.classes_num,
        window_size=htsat_config.htsat_window_size,
        config=htsat_config,
        depths=htsat_config.htsat_depth,
        embed_dim=htsat_config.htsat_dim,
        patch_stride=htsat_config.htsat_stride,
        num_heads=htsat_config.htsat_num_head,
    )
    at_model = SEDWrapper(sed_model=sed_model, config=htsat_config, dataset=None)

    ckpt = torch.load(config.sed_ckpt_path, map_location="cpu")
    at_model.load_state_dict(ckpt["state_dict"])

    at_wrapper = AutoTaggingWarpper(
        at_model=at_model, config=config, target_keys=[config.test_key]
    )

    asp_model = ZeroShotASP(channels=1, config=config, at_model=at_model, dataset=None)
    ckpt = torch.load(config.ckpt_path, map_location="cpu")
    asp_model.load_state_dict(ckpt["state_dict"], strict=False)

    return at_wrapper, asp_model

def get_dataloader_from_sound_file(sound_file_path, config):
    signal, sampling_rate = librosa.load(str(sound_file_path), sr=None)
    signal = prepprocess_audio(
        signal[:, None], sampling_rate, config.sample_rate, config.test_type
    )
    signal = np.array([signal, signal]) # Duplicate signal for later use
    dataset = MusdbDataset(tracks=[signal])
    data_loader = DataLoader(dataset, num_workers=config.num_workers, batch_size=1, shuffle=False)
    return data_loader


class Predictor(BasePredictor):
    def setup(self):
        self.config = get_inference_configs()
        os.makedirs(self.config.wave_output_path, exist_ok=True)
        self.at_wrapper, self.asp_model = load_models(self.config)

    def predict(
        self,
        mix_file: Path = Input(description="Reference sound to extract source from"),
        query_file: Path = Input(description="Query sound to be searched and extracted from mix"),
    ) -> Path:
        ref_loader = get_dataloader_from_sound_file(str(mix_file), self.config)

        query_loader = get_dataloader_from_sound_file(str(query_file), self.config)

        trainer = pl.Trainer(gpus=1)
        trainer.test(self.at_wrapper, test_dataloaders=query_loader)
        avg_at = self.at_wrapper.avg_at

        exp_model = SeparatorModel(
            model=self.asp_model,
            config=self.config,
            target_keys=[self.config.test_key],
            avg_at=avg_at,
            using_wiener=False,
            calc_sdr=False,
            output_wav=True,
        )
        trainer.test(exp_model, test_dataloaders=ref_loader)

        prediction_path = os.path.join(
            self.config.wave_output_path, f"0_{self.config.test_key}_pred_(0.0).wav"
        )
        return prediction_path