PeechTTSv22050 / training /loss /delightful_tts_loss.py
nickovchinnikov's picture
Init
9d61c9b
from typing import Optional
from piq import SSIMLoss
import torch
from torch import Tensor, nn
from torch.nn import functional
from models.config import AcousticModelConfigType
from training.loss.utils import sample_wise_min_max
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length: Tensor, max_len: Optional[int] = None) -> Tensor:
"""Create a sequence mask for filtering padding in a sequence tensor.
Args:
sequence_length (torch.tensor): Sequence lengths.
max_len (int, Optional): Maximum sequence length. Defaults to None.
Shapes:
- mask: :math:`[B, T_max]`
"""
max_len_ = max_len if max_len is not None else sequence_length.max().item()
seq_range = torch.arange(max_len_, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
class ForwardSumLoss(nn.Module):
r"""A class used to compute the forward sum loss.
Attributes:
log_softmax (torch.nn.LogSoftmax): The log softmax function applied along dimension 3.
ctc_loss (torch.nn.CTCLoss): The CTC loss function with zero infinity set to True.
blank_logprob (int): The log probability of a blank, default is -1.
Methods:
forward(attn_logprob: Tensor, in_lens: Tensor, out_lens: Tensor)
Compute the forward sum loss.
"""
def __init__(self, blank_logprob: int = -1):
r"""Constructs all the necessary attributes for the ForwardSumLoss object.
Args:
blank_logprob (int, optional): The log probability of a blank (default is -1).
"""
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
def forward(self, attn_logprob: Tensor, in_lens: Tensor, out_lens: Tensor):
r"""Compute the forward sum loss.
Args:
attn_logprob (Tensor): The attention log probabilities.
in_lens (Tensor): The input lengths.
out_lens (Tensor): The output lengths.
Returns:
total_loss (float): The total loss computed.
"""
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid].item() + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
class DelightfulTTSLoss(nn.Module):
r"""A class used to compute the delightful TTS loss.
Attributes:
mse_loss (nn.MSELoss): The mean squared error loss function.
mae_loss (nn.L1Loss): The mean absolute error loss function.
forward_sum_loss (ForwardSumLoss): The forward sum loss function.
mel_loss_alpha (float): The weight for the mel loss.
aligner_loss_alpha (float): The weight for the aligner loss.
pitch_loss_alpha (float): The weight for the pitch loss.
energy_loss_alpha (float): The weight for the energy loss.
u_prosody_loss_alpha (float): The weight for the u prosody loss.
p_prosody_loss_alpha (float): The weight for the p prosody loss.
dur_loss_alpha (float): The weight for the duration loss.
binary_alignment_loss_alpha (float): The weight for the binary alignment loss.
Methods:
_binary_alignment_loss(alignment_hard: Tensor, alignment_soft: Tensor)
Compute the binary alignment loss.
forward(
mel_output: Tensor,
mel_target: Tensor,
mel_lens: Tensor,
dur_output: Tensor,
dur_target: Tensor,
pitch_output: Tensor,
pitch_target: Tensor,
energy_output: Tensor,
energy_target: Tensor,
src_lens: Tensor,
p_prosody_ref: Tensor,
p_prosody_pred: Tensor,
u_prosody_ref: Tensor,
u_prosody_pred: Tensor,
aligner_logprob: Tensor,
aligner_hard: Tensor,
aligner_soft: Tensor,
binary_loss_weight: Optional[Tensor] = None,
)
Compute the delightful TTS loss.
"""
def __init__(self, config: AcousticModelConfigType):
r"""Constructs all the necessary attributes for the DelightfulTTSLoss object.
Args:
config (AcousticModelConfigType): Configuration parameters for the loss function.
"""
super().__init__()
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
self.forward_sum_loss = ForwardSumLoss()
self.ssim_loss = SSIMLoss()
self.mel_loss_alpha = config.loss.mel_loss_alpha
self.ssim_loss_alpha = config.loss.ssim_loss_alpha
self.aligner_loss_alpha = config.loss.aligner_loss_alpha
self.pitch_loss_alpha = config.loss.pitch_loss_alpha
self.energy_loss_alpha = config.loss.energy_loss_alpha
self.u_prosody_loss_alpha = config.loss.u_prosody_loss_alpha
self.p_prosody_loss_alpha = config.loss.p_prosody_loss_alpha
self.dur_loss_alpha = config.loss.dur_loss_alpha
self.binary_alignment_loss_alpha = config.loss.binary_align_loss_alpha
@staticmethod
def _binary_alignment_loss(alignment_hard: Tensor, alignment_soft: Tensor) -> Tensor:
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
Args:
alignment_hard (Tensor): The hard alignment tensor.
alignment_soft (Tensor): The soft alignment tensor.
Returns:
loss (float): The computed binary alignment loss.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
def forward(
self,
mel_output: Tensor,
mel_target: Tensor,
mel_lens: Tensor,
dur_output: Tensor,
dur_target: Tensor,
pitch_output: Tensor,
pitch_target: Tensor,
energy_output: Tensor,
energy_target: Tensor,
src_lens: Tensor,
p_prosody_ref: Tensor,
p_prosody_pred: Tensor,
u_prosody_ref: Tensor,
u_prosody_pred: Tensor,
aligner_logprob: Tensor,
aligner_hard: Tensor,
aligner_soft: Tensor,
):
r"""Compute the delightful TTS loss.
Args:
mel_output (Tensor): The mel output tensor.
mel_target (Tensor): The mel target tensor.
mel_lens (Tensor): The mel lengths tensor.
dur_output (Tensor): The duration output tensor.
dur_target (Tensor): The duration target tensor.
pitch_output (Tensor): The pitch output tensor.
pitch_target (Tensor): The pitch target tensor.
energy_output (Tensor): The energy output tensor.
energy_target (Tensor): The energy target tensor.
src_lens (Tensor): The source lengths tensor.
p_prosody_ref (Tensor): The p prosody reference tensor.
p_prosody_pred (Tensor): The p prosody prediction tensor.
u_prosody_ref (Tensor): The u prosody reference tensor.
u_prosody_pred (Tensor): The u prosody prediction tensor.
aligner_logprob (Tensor): The aligner log probabilities tensor.
aligner_hard (Tensor): The hard aligner tensor.
aligner_soft (Tensor): The soft aligner tensor.
Returns:
loss_dict (Tupple): A dictionary containing all the loss values.
Shapes:
- mel_output: :math:`(B, C_mel, T_mel)`
- mel_target: :math:`(B, C_mel, T_mel)`
- mel_lens: :math:`(B)`
- dur_output: :math:`(B, T_src)`
- dur_target: :math:`(B, T_src)`
- pitch_output: :math:`(B, 1, T_src)`
- pitch_target: :math:`(B, 1, T_src)`
- energy_output: :math:`(B, 1, T_src)`
- energy_target: :math:`(B, 1, T_src)`
- src_lens: :math:`(B)`
- p_prosody_ref: :math:`(B, T_src, 4)`
- p_prosody_pred: :math:`(B, T_src, 4)`
- u_prosody_ref: :math:`(B, 1, 256)
- u_prosody_pred: :math:`(B, 1, 256)
- aligner_logprob: :math:`(B, 1, T_mel, T_src)`
- aligner_hard: :math:`(B, T_mel, T_src)`
- aligner_soft: :math:`(B, T_mel, T_src)`
"""
src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src)
mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel)
dur_target.requires_grad = False
mel_target.requires_grad = False
pitch_target.requires_grad = False
mel_predictions_normalized = sample_wise_min_max(mel_output).float().to(mel_output.device)
mel_targets_normalized = sample_wise_min_max(mel_target).float().to(mel_target.device)
masked_mel_predictions = mel_output.masked_select(mel_mask[:, None])
mel_targets = mel_target.masked_select(mel_mask[:, None])
mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) * self.mel_loss_alpha
ssim_loss: torch.Tensor = self.ssim_loss(
mel_predictions_normalized.unsqueeze(1), mel_targets_normalized.unsqueeze(1),
) * self.ssim_loss_alpha
if ssim_loss.item() > 1.0 or ssim_loss.item() < 0.0:
print(
f"Overflow in ssim loss detected, which was {ssim_loss.item()}, setting to 1.0",
)
ssim_loss = torch.tensor([1.0], device=mel_output.device)
p_prosody_ref = p_prosody_ref.detach()
p_prosody_loss = self.mae_loss(
p_prosody_ref.masked_select(src_mask.unsqueeze(-1)),
p_prosody_pred.masked_select(src_mask.unsqueeze(-1)),
) * self.p_prosody_loss_alpha
u_prosody_ref = u_prosody_ref.detach()
u_prosody_loss = self.mae_loss(u_prosody_ref, u_prosody_pred) * self.u_prosody_loss_alpha
duration_loss = self.mse_loss(dur_output, dur_target) * self.dur_loss_alpha
pitch_output = pitch_output.masked_select(src_mask[:, None])
pitch_target = pitch_target.masked_select(src_mask[:, None])
pitch_loss = self.mse_loss(pitch_output, pitch_target) * self.pitch_loss_alpha
energy_output = energy_output.masked_select(src_mask[:, None])
energy_target = energy_target.masked_select(src_mask[:, None])
energy_loss = self.mse_loss(energy_output, energy_target) * self.energy_loss_alpha
forward_sum_loss = self.forward_sum_loss(
aligner_logprob,
src_lens,
mel_lens,
) * self.aligner_loss_alpha
binary_alignment_loss = self._binary_alignment_loss(
aligner_hard,
aligner_soft,
) * self.binary_alignment_loss_alpha
total_loss = (
mel_loss
+ ssim_loss
+ duration_loss
+ u_prosody_loss
+ p_prosody_loss
+ pitch_loss
+ forward_sum_loss
+ binary_alignment_loss
+ energy_loss
)
return (
total_loss,
mel_loss,
ssim_loss,
duration_loss,
u_prosody_loss,
p_prosody_loss,
pitch_loss,
forward_sum_loss,
binary_alignment_loss,
energy_loss,
)