Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
1.83 kB
import torch
from torch.nn.modules.loss import _Loss
class MultiSrcNegSDR(_Loss):
def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
super().__init__()
assert sdr_type in ["snr", "sisdr", "sdsdr"]
self.sdr_type = sdr_type
self.zero_mean = zero_mean
self.take_log = take_log
self.EPS = 1e-8
def forward(self, ests, targets):
if targets.size() != ests.size() or targets.ndim != 3:
raise TypeError(
f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
)
# Step 1. Zero-mean norm
if self.zero_mean:
mean_source = torch.mean(targets, dim=2, keepdim=True)
mean_est = torch.mean(ests, dim=2, keepdim=True)
targets = targets - mean_source
ests = ests - mean_est
# Step 2. Pair-wise SI-SDR.
if self.sdr_type in ["sisdr", "sdsdr"]:
# [batch, n_src]
pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True)
# [batch, n_src]
s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS
# [batch, n_src, time]
scaled_targets = pair_wise_dot * targets / s_target_energy
else:
# [batch, n_src, time]
scaled_targets = targets
if self.sdr_type in ["sdsdr", "snr"]:
e_noise = ests - targets
else:
e_noise = ests - scaled_targets
# [batch, n_src]
pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / (
torch.sum(e_noise ** 2, dim=2) + self.EPS
)
if self.take_log:
pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
return -torch.mean(pair_wise_sdr, dim=-1).mean(0)