from typing import Tuple

import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.compliance import kaldi
from transformers import PretrainedConfig

from einops import rearrange

from timm.models.vision_transformer import VisionTransformer
from transformers import PreTrainedModel


# it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
class AudioMAEConfig(PretrainedConfig):
    model_type = "audiomae"

    def __init__(self,
                 img_size:Tuple[int,int]=(1024,128),
                 in_chans:int=1,
                 num_classes:int=0,
                 **kwargs,):
        super().__init__(**kwargs)
        self.img_size = img_size
        self.in_chans = in_chans
        self.num_classes = num_classes


class AudioMAEEncoder(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        """
        - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
        - AudoMAE accepts a mono-channel (i.e., in_chans=1)
        """
        self.MEAN = -4.2677393  # written on the paper
        self.STD = 4.5689974  # written on the paper

    def load_wav_file(self, file_path:str):
        """
        to use this, `torchaudio` and `ffmpeg` must be installed
        - `ffmpeg` version must be >=4.4 and <7.
        - `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
        """
        audio, sample_rate = torchaudio.load(file_path)  # audio: (n_channels, length); 

        # length clip
        audio_len = audio.shape[-1] / sample_rate
        if audio_len > 10.0:
            print('current audio length is:', audio_len)
            print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')

        # Check if the audio has multiple channels
        if audio.shape[0] > 1:
            # Convert stereo audio to mono by taking the mean across channels
            # AudioMAE accepts a mono channel.
            audio = torch.mean(audio, dim=0, keepdim=True)

        # resample the audio into 16khz
        # AudioMAE accepts 16khz
        if sample_rate != 16000:
            converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            audio = converter(audio)
        return audio
    
    def waveform_to_melspec(self, waveform:torch.FloatTensor):
        # Compute the Mel spectrogram using Kaldi-compatible features
        # the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
        mel_spectrogram = kaldi.fbank(
            waveform, 
            num_mel_bins=128, 
            frame_length=25.0, 
            frame_shift=10.0, 
            htk_compat=True, 
            use_energy=False,
            sample_frequency=16000, 
            window_type='hanning',
            dither=0.0
        )
        
        # Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
        expected_frames = 1024  # as described in the paper
        current_frames = mel_spectrogram.shape[0]
        if current_frames > expected_frames:
            mel_spectrogram = mel_spectrogram[:expected_frames, :]
        elif current_frames < expected_frames:
            padding = expected_frames - current_frames
            mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0,  # (left, right) for the 1st dim
                                                                        0, padding),  # (left, right) for the 2nd dim
                                                                        )
            
        # scale
        # as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
        mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2)  # (length, n_freq_bins) = (1024, 128)
        return mel_spectrogram

    @torch.no_grad()
    def encode(self, file_path:str, device):
        self.eval()

        waveform = self.load_wav_file(file_path)
        melspec = self.waveform_to_melspec(waveform)  # (length, n_freq_bins) = (1024, 128)
        melspec = melspec[None,None,:,:]  # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
        z = self.forward_features(melspec.to(device)).cpu()  # (b, 1+n, d); d=768
        z = z[:,1:,:]  # (b n d); remove [CLS], the class token

        b, c, w, h = melspec.shape  # w: temporal dim; h:freq dim
        wprime = round(w / self.patch_embed.patch_size[0])  # width in the latent space
        hprime = round(h / self.patch_embed.patch_size[1])  # height in the latent space

        # reconstruct the temporal and freq dims
        z = rearrange(z, 'b (w h) d -> b d h w', h=hprime)  # (b d h' w')

        # remove the batch dim
        z = z[0]  # (d h' w')
        return z  # (d h' w')



class PretrainedAudioMAEEncoder(PreTrainedModel):
    config_class = AudioMAEConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
    
    def forward(self, file_path:str):
        device = self.device
        return self.encoder.encode(file_path, device)  # (d h' w')