|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
class GE2ELoss(nn.Module): |
|
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): |
|
""" |
|
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] |
|
Accepts an input of size (N, M, D) |
|
where N is the number of speakers in the batch, |
|
M is the number of utterances per speaker, |
|
and D is the dimensionality of the embedding vector (e.g. d-vector) |
|
Args: |
|
- init_w (float): defines the initial value of w in Equation (5) of [1] |
|
- init_b (float): definies the initial value of b in Equation (5) of [1] |
|
""" |
|
super(GE2ELoss, self).__init__() |
|
|
|
self.w = nn.Parameter(torch.tensor(init_w)) |
|
|
|
self.b = nn.Parameter(torch.tensor(init_b)) |
|
self.loss_method = loss_method |
|
|
|
print(' > Initialised Generalized End-to-End loss') |
|
|
|
assert self.loss_method in ["softmax", "contrast"] |
|
|
|
if self.loss_method == "softmax": |
|
self.embed_loss = self.embed_loss_softmax |
|
if self.loss_method == "contrast": |
|
self.embed_loss = self.embed_loss_contrast |
|
|
|
|
|
def calc_new_centroids(self, dvecs, centroids, spkr, utt): |
|
""" |
|
Calculates the new centroids excluding the reference utterance |
|
""" |
|
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) |
|
excl = torch.mean(excl, 0) |
|
new_centroids = [] |
|
for i, centroid in enumerate(centroids): |
|
if i == spkr: |
|
new_centroids.append(excl) |
|
else: |
|
new_centroids.append(centroid) |
|
return torch.stack(new_centroids) |
|
|
|
def calc_cosine_sim(self, dvecs, centroids): |
|
""" |
|
Make the cosine similarity matrix with dims (N,M,N) |
|
""" |
|
cos_sim_matrix = [] |
|
for spkr_idx, speaker in enumerate(dvecs): |
|
cs_row = [] |
|
for utt_idx, utterance in enumerate(speaker): |
|
new_centroids = self.calc_new_centroids( |
|
dvecs, centroids, spkr_idx, utt_idx |
|
) |
|
|
|
cs_row.append( |
|
torch.clamp( |
|
torch.mm( |
|
utterance.unsqueeze(1).transpose(0, 1), |
|
new_centroids.transpose(0, 1), |
|
) |
|
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), |
|
1e-6, |
|
) |
|
) |
|
cs_row = torch.cat(cs_row, dim=0) |
|
cos_sim_matrix.append(cs_row) |
|
return torch.stack(cos_sim_matrix) |
|
|
|
|
|
def embed_loss_softmax(self, dvecs, cos_sim_matrix): |
|
""" |
|
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax |
|
""" |
|
N, M, _ = dvecs.shape |
|
L = [] |
|
for j in range(N): |
|
L_row = [] |
|
for i in range(M): |
|
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) |
|
L_row = torch.stack(L_row) |
|
L.append(L_row) |
|
return torch.stack(L) |
|
|
|
|
|
def embed_loss_contrast(self, dvecs, cos_sim_matrix): |
|
""" |
|
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid |
|
""" |
|
N, M, _ = dvecs.shape |
|
L = [] |
|
for j in range(N): |
|
L_row = [] |
|
for i in range(M): |
|
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) |
|
excl_centroids_sigmoids = torch.cat( |
|
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) |
|
) |
|
L_row.append( |
|
1.0 |
|
- torch.sigmoid(cos_sim_matrix[j, i, j]) |
|
+ torch.max(excl_centroids_sigmoids) |
|
) |
|
L_row = torch.stack(L_row) |
|
L.append(L_row) |
|
return torch.stack(L) |
|
|
|
def forward(self, dvecs): |
|
""" |
|
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) |
|
""" |
|
centroids = torch.mean(dvecs, 1) |
|
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids) |
|
torch.clamp(self.w, 1e-6) |
|
cos_sim_matrix = self.w * cos_sim_matrix + self.b |
|
L = self.embed_loss(dvecs, cos_sim_matrix) |
|
return L.mean() |
|
|
|
|
|
class AngleProtoLoss(nn.Module): |
|
""" |
|
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 |
|
Accepts an input of size (N, M, D) |
|
where N is the number of speakers in the batch, |
|
M is the number of utterances per speaker, |
|
and D is the dimensionality of the embedding vector |
|
Args: |
|
- init_w (float): defines the initial value of w |
|
- init_b (float): definies the initial value of b |
|
""" |
|
def __init__(self, init_w=10.0, init_b=-5.0): |
|
super(AngleProtoLoss, self).__init__() |
|
|
|
self.w = nn.Parameter(torch.tensor(init_w)) |
|
|
|
self.b = nn.Parameter(torch.tensor(init_b)) |
|
self.criterion = torch.nn.CrossEntropyLoss() |
|
|
|
print(' > Initialised Angular Prototypical loss') |
|
|
|
def forward(self, x): |
|
""" |
|
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) |
|
""" |
|
out_anchor = torch.mean(x[:, 1:, :], 1) |
|
out_positive = x[:, 0, :] |
|
num_speakers = out_anchor.size()[0] |
|
|
|
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2)) |
|
torch.clamp(self.w, 1e-6) |
|
cos_sim_matrix = cos_sim_matrix * self.w + self.b |
|
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device) |
|
L = self.criterion(cos_sim_matrix, label) |
|
return L |
|
|