File size: 3,575 Bytes
2dbfc9c
b106ffd
2dbfc9c
 
 
 
 
1f9348b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59907c6
 
 
 
 
 
 
1f9348b
59907c6
 
 
 
 
1f9348b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ead7fe5
 
 
 
1f9348b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pesq import pesq
from pystoi import stoi
import mir_eval
REFERENCE_CHANNEL = 0
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    # figure, axis = plt.subplots(1, 1)
    # img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
    # figure.suptitle(title)
    # plt.colorbar(img, ax=axis)
    # plt.show()


def plot_mask(mask, title="Mask", xlim=None):
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()

def si_snr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return si_snr.item()

# def generate_mixture(waveform_clean, waveform_noise, target_snr):
#     power_clean_signal = waveform_clean.pow(2).mean()
#     power_noise_signal = waveform_noise.pow(2).mean()
#     current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
#     waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
#     return waveform_clean + waveform_noise

def generate_mixture(waveform_clean, waveform_noise, target_snr):
    if waveform_clean.size(1) > waveform_noise.size(1):
        waveform_noise = F.pad(waveform_noise, (0, waveform_clean.size(1) - waveform_noise.size(1)))
    elif waveform_noise.size(1) > waveform_clean.size(1):
        waveform_clean = F.pad(waveform_clean, (0, waveform_noise.size(1) - waveform_clean.size(1)))

    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise

def evaluate(estimate, reference):
    si_snr_score = si_snr(estimate, reference)
    (
        sdr,
        _,
        _,
        _,
    ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
    pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
    stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
    print(f"SDR score: {sdr[0]}")
    print(f"Si-SNR score: {si_snr_score}")
    print(f"PESQ score: {pesq_mix}")
    print(f"STOI score: {stoi_mix}")

def get_irms(stft_clean, stft_noise):
    if stft_clean.size(1) > stft_noise.size(1):
       stft_noise = F.pad(stft_noise, (0, stft_clean.size(1) - stft_noise.size(1)))
    elif stft_noise.size(1) > stft_clean.size(1):
        stft_clean = F.pad(stft_clean, (0, stft_noise.size(1) - stft_clean.size(1)))
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]