import torch import torch.nn.functional as F import matplotlib.pyplot as plt from pesq import pesq from pystoi import stoi import mir_eval REFERENCE_CHANNEL = 0 def plot_spectrogram(stft, title="Spectrogram", xlim=None): magnitude = stft.abs() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() # figure, axis = plt.subplots(1, 1) # img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") # figure.suptitle(title) # plt.colorbar(img, ax=axis) # plt.show() def plot_mask(mask, title="Mask", xlim=None): mask = mask.numpy() figure, axis = plt.subplots(1, 1) img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") figure.suptitle(title) plt.colorbar(img, ax=axis) plt.show() def si_snr(estimate, reference, epsilon=1e-8): estimate = estimate - estimate.mean() reference = reference - reference.mean() reference_pow = reference.pow(2).mean(axis=1, keepdim=True) mix_pow = (estimate * reference).mean(axis=1, keepdim=True) scale = mix_pow / (reference_pow + epsilon) reference = scale * reference error = estimate - reference reference_pow = reference.pow(2) error_pow = error.pow(2) reference_pow = reference_pow.mean(axis=1) error_pow = error_pow.mean(axis=1) si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) return si_snr.item() # def generate_mixture(waveform_clean, waveform_noise, target_snr): # power_clean_signal = waveform_clean.pow(2).mean() # power_noise_signal = waveform_noise.pow(2).mean() # current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal) # waveform_noise *= 10 ** (-(target_snr - current_snr) / 20) # return waveform_clean + waveform_noise def generate_mixture(waveform_clean, waveform_noise, target_snr): if waveform_clean.size(1) > waveform_noise.size(1): waveform_noise = F.pad(waveform_noise, (0, waveform_clean.size(1) - waveform_noise.size(1))) elif waveform_noise.size(1) > waveform_clean.size(1): waveform_clean = F.pad(waveform_clean, (0, waveform_noise.size(1) - waveform_clean.size(1))) power_clean_signal = waveform_clean.pow(2).mean() power_noise_signal = waveform_noise.pow(2).mean() current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal) waveform_noise *= 10 ** (-(target_snr - current_snr) / 20) return waveform_clean + waveform_noise def evaluate(estimate, reference): si_snr_score = si_snr(estimate, reference) ( sdr, _, _, _, ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False) pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb") stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False) print(f"SDR score: {sdr[0]}") print(f"Si-SNR score: {si_snr_score}") print(f"PESQ score: {pesq_mix}") print(f"STOI score: {stoi_mix}") def get_irms(stft_clean, stft_noise): if stft_clean.size(2) > stft_noise.size(2): stft_noise = F.pad(stft_noise, (0, stft_clean.size(2) - stft_noise.size(2))) elif stft_noise.size(2) > stft_clean.size(2): stft_clean = F.pad(stft_clean, (0, stft_noise.size(2) - stft_clean.size(2))) mag_clean = stft_clean.abs() ** 2 mag_noise = stft_noise.abs() ** 2 irm_speech = mag_clean / (mag_clean + mag_noise) irm_noise = mag_noise / (mag_clean + mag_noise) return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]