Serhiy Stetskovych
New multispeaker model
3d2700d
raw
history blame
7.19 kB
import datetime as dt
import math
import random
import torch
import torch.nn.functional as F
from pflow import utils
from pflow.models.baselightningmodule import BaseLightningClass
from pflow.models.components.flow_matching import CFM
from pflow.models.components.speech_prompt_encoder import TextEncoder
from pflow.utils.model import (
denormalize,
duration_loss,
fix_len_compatibility,
generate_path,
sequence_mask,
)
from pflow.models.components import commons
from pflow.models.components.aligner import Aligner, ForwardSumLoss, BinLoss
log = utils.get_pylogger(__name__)
class pflowTTS(BaseLightningClass): #
def __init__(
self,
n_vocab,
n_feats,
encoder,
decoder,
cfm,
data_statistics,
prompt_size=264,
dur_p_use_log=False,
optimizer=None,
scheduler=None,
**kwargs,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.n_vocab = n_vocab
self.n_feats = n_feats
self.prompt_size = prompt_size
self.dur_p_use_log = dur_p_use_log
speech_in_channels = n_feats
self.encoder = TextEncoder(
encoder.encoder_type,
encoder.encoder_params,
encoder.duration_predictor_params,
n_vocab,
speech_in_channels,
)
# self.aligner = Aligner(
# dim_in=encoder.encoder_params.n_feats,
# dim_hidden=encoder.encoder_params.n_feats,
# attn_channels=encoder.encoder_params.n_feats,
# )
# self.aligner_loss = ForwardSumLoss()
# self.bin_loss = BinLoss()
# self.aligner_bin_loss_weight = 0.0
self.decoder = CFM(
in_channels=encoder.encoder_params.n_feats,
out_channel=encoder.encoder_params.n_feats,
cfm_params=cfm,
decoder_params=decoder,
)
self.proj_prompt = torch.nn.Conv1d(encoder.encoder_params.n_channels, self.n_feats, 1)
self.update_data_statistics(data_statistics)
@torch.inference_mode()
def synthesise(self, x, x_lengths, prompt, n_timesteps, temperature=1.0, length_scale=1.0, guidance_scale=0.0):
# For RTF computation
t = dt.datetime.now()
assert prompt is not None, "Prompt must be provided for synthesis"
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample tracing the probability flow
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, guidance_scale=guidance_scale)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
t = (dt.datetime.now() - t).total_seconds()
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
"mel_lengths": y_lengths,
"rtf": rtf,
}
def forward(self, x, x_lengths, y, y_lengths, prompt=None, cond=None, **kwargs):
if prompt is None:
prompt_slice, ids_slice = commons.rand_slice_segments(
y, y_lengths, self.prompt_size
)
else:
prompt_slice = prompt
mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt_slice)
y_max_length = y.shape[-1]
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
# s_p_sq_r = torch.exp(-2 * logx)
neg_cent1 = torch.sum(
-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
)
# neg_cent1 = torch.sum(
# -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
# ) # [b, 1, t_s]
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
neg_cent4 = torch.sum(
-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
)
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
from pflow.utils.monotonic_align import maximum_path
attn = (
maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
)
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths, use_log=self.dur_p_use_log)
# aln_hard, aln_soft, aln_log, aln_mask = self.aligner(
# mu_x.transpose(1,2), x_mask, y, y_mask
# )
# attn = aln_mask.transpose(1,2).unsqueeze(1)
# align_loss = self.aligner_loss(aln_log, x_lengths, y_lengths)
# if self.aligner_bin_loss_weight > 0.:
# align_bin_loss = self.bin_loss(aln_mask, aln_log, x_lengths) * self.aligner_bin_loss_weight
# align_loss = align_loss + align_bin_loss
# dur_loss = F.l1_loss(logw, attn.sum(2))
# dur_loss = dur_loss + align_loss
# Align encoded text with mel-spectrogram and get mu_y segment
attn = attn.squeeze(1).transpose(1,2)
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
y_loss_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
if prompt is None:
for i in range(y.size(0)):
y_loss_mask[i,:,ids_slice[i]:ids_slice[i] + self.prompt_size] = False
# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y.detach(), mask=y_mask, mu=mu_y, cond=cond, loss_mask=y_loss_mask)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_loss_mask)
prior_loss = prior_loss / (torch.sum(y_loss_mask) * self.n_feats)
return dur_loss, prior_loss, diff_loss, attn