Spaces:
Running
Running
from typing import Optional | |
from piq import SSIMLoss | |
import torch | |
from torch import Tensor, nn | |
from torch.nn import functional | |
from models.config import AcousticModelConfigType | |
from training.loss.utils import sample_wise_min_max | |
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 | |
def sequence_mask(sequence_length: Tensor, max_len: Optional[int] = None) -> Tensor: | |
"""Create a sequence mask for filtering padding in a sequence tensor. | |
Args: | |
sequence_length (torch.tensor): Sequence lengths. | |
max_len (int, Optional): Maximum sequence length. Defaults to None. | |
Shapes: | |
- mask: :math:`[B, T_max]` | |
""" | |
max_len_ = max_len if max_len is not None else sequence_length.max().item() | |
seq_range = torch.arange(max_len_, dtype=sequence_length.dtype, device=sequence_length.device) | |
# B x T_max | |
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) | |
class ForwardSumLoss(nn.Module): | |
r"""A class used to compute the forward sum loss. | |
Attributes: | |
log_softmax (torch.nn.LogSoftmax): The log softmax function applied along dimension 3. | |
ctc_loss (torch.nn.CTCLoss): The CTC loss function with zero infinity set to True. | |
blank_logprob (int): The log probability of a blank, default is -1. | |
Methods: | |
forward(attn_logprob: Tensor, in_lens: Tensor, out_lens: Tensor) | |
Compute the forward sum loss. | |
""" | |
def __init__(self, blank_logprob: int = -1): | |
r"""Constructs all the necessary attributes for the ForwardSumLoss object. | |
Args: | |
blank_logprob (int, optional): The log probability of a blank (default is -1). | |
""" | |
super().__init__() | |
self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) | |
self.blank_logprob = blank_logprob | |
def forward(self, attn_logprob: Tensor, in_lens: Tensor, out_lens: Tensor): | |
r"""Compute the forward sum loss. | |
Args: | |
attn_logprob (Tensor): The attention log probabilities. | |
in_lens (Tensor): The input lengths. | |
out_lens (Tensor): The output lengths. | |
Returns: | |
total_loss (float): The total loss computed. | |
""" | |
key_lens = in_lens | |
query_lens = out_lens | |
attn_logprob_padded = functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) | |
total_loss = 0.0 | |
for bid in range(attn_logprob.shape[0]): | |
target_seq = torch.arange(1, key_lens[bid].item() + 1).unsqueeze(0) | |
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] | |
curr_logprob = self.log_softmax(curr_logprob[None])[0] | |
loss = self.ctc_loss( | |
curr_logprob, | |
target_seq, | |
input_lengths=query_lens[bid : bid + 1], | |
target_lengths=key_lens[bid : bid + 1], | |
) | |
total_loss = total_loss + loss | |
total_loss = total_loss / attn_logprob.shape[0] | |
return total_loss | |
class DelightfulTTSLoss(nn.Module): | |
r"""A class used to compute the delightful TTS loss. | |
Attributes: | |
mse_loss (nn.MSELoss): The mean squared error loss function. | |
mae_loss (nn.L1Loss): The mean absolute error loss function. | |
forward_sum_loss (ForwardSumLoss): The forward sum loss function. | |
mel_loss_alpha (float): The weight for the mel loss. | |
aligner_loss_alpha (float): The weight for the aligner loss. | |
pitch_loss_alpha (float): The weight for the pitch loss. | |
energy_loss_alpha (float): The weight for the energy loss. | |
u_prosody_loss_alpha (float): The weight for the u prosody loss. | |
p_prosody_loss_alpha (float): The weight for the p prosody loss. | |
dur_loss_alpha (float): The weight for the duration loss. | |
binary_alignment_loss_alpha (float): The weight for the binary alignment loss. | |
Methods: | |
_binary_alignment_loss(alignment_hard: Tensor, alignment_soft: Tensor) | |
Compute the binary alignment loss. | |
forward( | |
mel_output: Tensor, | |
mel_target: Tensor, | |
mel_lens: Tensor, | |
dur_output: Tensor, | |
dur_target: Tensor, | |
pitch_output: Tensor, | |
pitch_target: Tensor, | |
energy_output: Tensor, | |
energy_target: Tensor, | |
src_lens: Tensor, | |
p_prosody_ref: Tensor, | |
p_prosody_pred: Tensor, | |
u_prosody_ref: Tensor, | |
u_prosody_pred: Tensor, | |
aligner_logprob: Tensor, | |
aligner_hard: Tensor, | |
aligner_soft: Tensor, | |
binary_loss_weight: Optional[Tensor] = None, | |
) | |
Compute the delightful TTS loss. | |
""" | |
def __init__(self, config: AcousticModelConfigType): | |
r"""Constructs all the necessary attributes for the DelightfulTTSLoss object. | |
Args: | |
config (AcousticModelConfigType): Configuration parameters for the loss function. | |
""" | |
super().__init__() | |
self.mse_loss = nn.MSELoss() | |
self.mae_loss = nn.L1Loss() | |
self.forward_sum_loss = ForwardSumLoss() | |
self.ssim_loss = SSIMLoss() | |
self.mel_loss_alpha = config.loss.mel_loss_alpha | |
self.ssim_loss_alpha = config.loss.ssim_loss_alpha | |
self.aligner_loss_alpha = config.loss.aligner_loss_alpha | |
self.pitch_loss_alpha = config.loss.pitch_loss_alpha | |
self.energy_loss_alpha = config.loss.energy_loss_alpha | |
self.u_prosody_loss_alpha = config.loss.u_prosody_loss_alpha | |
self.p_prosody_loss_alpha = config.loss.p_prosody_loss_alpha | |
self.dur_loss_alpha = config.loss.dur_loss_alpha | |
self.binary_alignment_loss_alpha = config.loss.binary_align_loss_alpha | |
def _binary_alignment_loss(alignment_hard: Tensor, alignment_soft: Tensor) -> Tensor: | |
"""Binary loss that forces soft alignments to match the hard alignments as | |
explained in `https://arxiv.org/pdf/2108.10447.pdf`. | |
Args: | |
alignment_hard (Tensor): The hard alignment tensor. | |
alignment_soft (Tensor): The soft alignment tensor. | |
Returns: | |
loss (float): The computed binary alignment loss. | |
""" | |
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() | |
return -log_sum / alignment_hard.sum() | |
def forward( | |
self, | |
mel_output: Tensor, | |
mel_target: Tensor, | |
mel_lens: Tensor, | |
dur_output: Tensor, | |
dur_target: Tensor, | |
pitch_output: Tensor, | |
pitch_target: Tensor, | |
energy_output: Tensor, | |
energy_target: Tensor, | |
src_lens: Tensor, | |
p_prosody_ref: Tensor, | |
p_prosody_pred: Tensor, | |
u_prosody_ref: Tensor, | |
u_prosody_pred: Tensor, | |
aligner_logprob: Tensor, | |
aligner_hard: Tensor, | |
aligner_soft: Tensor, | |
): | |
r"""Compute the delightful TTS loss. | |
Args: | |
mel_output (Tensor): The mel output tensor. | |
mel_target (Tensor): The mel target tensor. | |
mel_lens (Tensor): The mel lengths tensor. | |
dur_output (Tensor): The duration output tensor. | |
dur_target (Tensor): The duration target tensor. | |
pitch_output (Tensor): The pitch output tensor. | |
pitch_target (Tensor): The pitch target tensor. | |
energy_output (Tensor): The energy output tensor. | |
energy_target (Tensor): The energy target tensor. | |
src_lens (Tensor): The source lengths tensor. | |
p_prosody_ref (Tensor): The p prosody reference tensor. | |
p_prosody_pred (Tensor): The p prosody prediction tensor. | |
u_prosody_ref (Tensor): The u prosody reference tensor. | |
u_prosody_pred (Tensor): The u prosody prediction tensor. | |
aligner_logprob (Tensor): The aligner log probabilities tensor. | |
aligner_hard (Tensor): The hard aligner tensor. | |
aligner_soft (Tensor): The soft aligner tensor. | |
Returns: | |
loss_dict (Tupple): A dictionary containing all the loss values. | |
Shapes: | |
- mel_output: :math:`(B, C_mel, T_mel)` | |
- mel_target: :math:`(B, C_mel, T_mel)` | |
- mel_lens: :math:`(B)` | |
- dur_output: :math:`(B, T_src)` | |
- dur_target: :math:`(B, T_src)` | |
- pitch_output: :math:`(B, 1, T_src)` | |
- pitch_target: :math:`(B, 1, T_src)` | |
- energy_output: :math:`(B, 1, T_src)` | |
- energy_target: :math:`(B, 1, T_src)` | |
- src_lens: :math:`(B)` | |
- p_prosody_ref: :math:`(B, T_src, 4)` | |
- p_prosody_pred: :math:`(B, T_src, 4)` | |
- u_prosody_ref: :math:`(B, 1, 256) | |
- u_prosody_pred: :math:`(B, 1, 256) | |
- aligner_logprob: :math:`(B, 1, T_mel, T_src)` | |
- aligner_hard: :math:`(B, T_mel, T_src)` | |
- aligner_soft: :math:`(B, T_mel, T_src)` | |
""" | |
src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) | |
mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) | |
dur_target.requires_grad = False | |
mel_target.requires_grad = False | |
pitch_target.requires_grad = False | |
mel_predictions_normalized = sample_wise_min_max(mel_output).float().to(mel_output.device) | |
mel_targets_normalized = sample_wise_min_max(mel_target).float().to(mel_target.device) | |
masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) | |
mel_targets = mel_target.masked_select(mel_mask[:, None]) | |
mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) * self.mel_loss_alpha | |
ssim_loss: torch.Tensor = self.ssim_loss( | |
mel_predictions_normalized.unsqueeze(1), mel_targets_normalized.unsqueeze(1), | |
) * self.ssim_loss_alpha | |
if ssim_loss.item() > 1.0 or ssim_loss.item() < 0.0: | |
print( | |
f"Overflow in ssim loss detected, which was {ssim_loss.item()}, setting to 1.0", | |
) | |
ssim_loss = torch.tensor([1.0], device=mel_output.device) | |
p_prosody_ref = p_prosody_ref.detach() | |
p_prosody_loss = self.mae_loss( | |
p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), | |
p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), | |
) * self.p_prosody_loss_alpha | |
u_prosody_ref = u_prosody_ref.detach() | |
u_prosody_loss = self.mae_loss(u_prosody_ref, u_prosody_pred) * self.u_prosody_loss_alpha | |
duration_loss = self.mse_loss(dur_output, dur_target) * self.dur_loss_alpha | |
pitch_output = pitch_output.masked_select(src_mask[:, None]) | |
pitch_target = pitch_target.masked_select(src_mask[:, None]) | |
pitch_loss = self.mse_loss(pitch_output, pitch_target) * self.pitch_loss_alpha | |
energy_output = energy_output.masked_select(src_mask[:, None]) | |
energy_target = energy_target.masked_select(src_mask[:, None]) | |
energy_loss = self.mse_loss(energy_output, energy_target) * self.energy_loss_alpha | |
forward_sum_loss = self.forward_sum_loss( | |
aligner_logprob, | |
src_lens, | |
mel_lens, | |
) * self.aligner_loss_alpha | |
binary_alignment_loss = self._binary_alignment_loss( | |
aligner_hard, | |
aligner_soft, | |
) * self.binary_alignment_loss_alpha | |
total_loss = ( | |
mel_loss | |
+ ssim_loss | |
+ duration_loss | |
+ u_prosody_loss | |
+ p_prosody_loss | |
+ pitch_loss | |
+ forward_sum_loss | |
+ binary_alignment_loss | |
+ energy_loss | |
) | |
return ( | |
total_loss, | |
mel_loss, | |
ssim_loss, | |
duration_loss, | |
u_prosody_loss, | |
p_prosody_loss, | |
pitch_loss, | |
forward_sum_loss, | |
binary_alignment_loss, | |
energy_loss, | |
) | |