|
import math |
|
import random |
|
|
|
import torch |
|
|
|
import monotonic_align |
|
from model.base import BaseModule |
|
from model.text_encoder import TextEncoder |
|
from model.diffusion import Diffusion |
|
from model.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility |
|
|
|
|
|
class GradTTSWithEmo(BaseModule): |
|
def __init__(self, n_vocab=148, n_spks=1,n_emos=5, spk_emb_dim=64, |
|
n_enc_channels=192, filter_channels=768, filter_channels_dp=256, |
|
n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4, |
|
n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000, |
|
use_classifier_free=False, dummy_spk_rate=0.5, |
|
**kwargs): |
|
super(GradTTSWithEmo, self).__init__() |
|
self.n_vocab = n_vocab |
|
self.n_spks = n_spks |
|
self.n_emos = n_emos |
|
self.spk_emb_dim = spk_emb_dim |
|
self.n_enc_channels = n_enc_channels |
|
self.filter_channels = filter_channels |
|
self.filter_channels_dp = filter_channels_dp |
|
self.n_heads = n_heads |
|
self.n_enc_layers = n_enc_layers |
|
self.enc_kernel = enc_kernel |
|
self.enc_dropout = enc_dropout |
|
self.window_size = window_size |
|
self.n_feats = n_feats |
|
self.dec_dim = dec_dim |
|
self.beta_min = beta_min |
|
self.beta_max = beta_max |
|
self.pe_scale = pe_scale |
|
self.use_classifier_free = use_classifier_free |
|
|
|
|
|
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) |
|
self.emo_emb = torch.nn.Embedding(n_emos, spk_emb_dim) |
|
self.merge_spk_emo = torch.nn.Sequential( |
|
torch.nn.Linear(spk_emb_dim*2, spk_emb_dim), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(spk_emb_dim, spk_emb_dim) |
|
) |
|
self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, |
|
filter_channels, filter_channels_dp, n_heads, |
|
n_enc_layers, enc_kernel, enc_dropout, window_size, |
|
spk_emb_dim=spk_emb_dim, n_spks=n_spks) |
|
self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale) |
|
|
|
if self.use_classifier_free: |
|
self.dummy_xv = torch.nn.Parameter(torch.randn(size=(spk_emb_dim, ))) |
|
self.dummy_rate = dummy_spk_rate |
|
print(f"Using classifier free with rate {self.dummy_rate}") |
|
|
|
@torch.no_grad() |
|
def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, |
|
length_scale=1.0, classifier_free_guidance=1., force_dur=None): |
|
""" |
|
Generates mel-spectrogram from text. Returns: |
|
1. encoder outputs |
|
2. decoder outputs |
|
3. generated alignment |
|
|
|
Args: |
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. |
|
x_lengths (torch.Tensor): lengths of texts in batch. |
|
n_timesteps (int): number of steps to use for reverse diffusion in decoder. |
|
temperature (float, optional): controls variance of terminal distribution. |
|
stoc (bool, optional): flag that adds stochastic term to the decoder sampler. |
|
Usually, does not provide synthesis improvements. |
|
length_scale (float, optional): controls speech pace. |
|
Increase value to slow down generated speech and vice versa. |
|
""" |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
|
|
spk = self.spk_emb(spk) |
|
emo = self.emo_emb(emo) |
|
|
|
if self.use_classifier_free: |
|
emo = emo / torch.sqrt(torch.sum(emo**2, dim=1, keepdim=True)) |
|
|
|
spk_merged = self.merge_spk_emo(torch.cat([spk, emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) |
|
|
|
w = torch.exp(logw) * x_mask |
|
w_ceil = torch.ceil(w) * length_scale |
|
if force_dur is not None: |
|
w_ceil = force_dur.unsqueeze(1) |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
|
|
|
|
unit_dummy_emo = self.dummy_xv / torch.sqrt(torch.sum(self.dummy_xv**2)) if self.use_classifier_free else None |
|
dummy_spk = self.merge_spk_emo(torch.cat([spk, unit_dummy_emo.unsqueeze(0).repeat(len(spk), 1)], dim=-1)) if self.use_classifier_free else None |
|
|
|
decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, |
|
use_classifier_free=self.use_classifier_free, |
|
classifier_free_guidance=classifier_free_guidance, |
|
dummy_spk=dummy_spk) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
|
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def classifier_guidance_decode(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, |
|
length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
|
|
spk = self.spk_emb(spk) |
|
dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) |
|
|
|
spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) |
|
|
|
w = torch.exp(logw) * x_mask |
|
|
|
w_ceil = torch.ceil(w) * length_scale |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : |
|
y_max_length = max(y_max_length, 180) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
|
|
decoder_outputs = self.decoder.classifier_decode(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, |
|
classifier_func, guidance, |
|
control_emo=emo, classifier_type=classifier_type) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def classifier_guidance_decode_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo=None, |
|
length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
|
|
spk = self.spk_emb(spk) |
|
dummy_emo = self.emo_emb(torch.zeros_like(emo).long()) |
|
|
|
spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) |
|
|
|
w = torch.exp(logw) * x_mask |
|
w_ceil = torch.ceil(w) * length_scale |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : |
|
y_max_length = max(y_max_length, 180) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
|
|
decoder_outputs = self.decoder.classifier_decode_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, |
|
classifier_func, guidance, |
|
control_emo=emo, classifier_type=classifier_type) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def classifier_guidance_decode_two_mixture(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None, |
|
length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
|
|
spk = self.spk_emb(spk) |
|
dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) |
|
|
|
spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) |
|
|
|
w = torch.exp(logw) * x_mask |
|
w_ceil = torch.ceil(w) * length_scale |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : |
|
y_max_length = max(y_max_length, 180) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
|
|
decoder_outputs = self.decoder.classifier_decode_mixture(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, |
|
classifier_func, guidance, |
|
control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def classifier_guidance_decode_two_mixture_DPS(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, emo1=None, emo2=None, emo1_weight=None, |
|
length_scale=1.0, classifier_func=None, guidance=1.0, classifier_type='conformer'): |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
|
|
spk = self.spk_emb(spk) |
|
dummy_emo = self.emo_emb(torch.zeros_like(emo1).long()) |
|
|
|
spk_merged = self.merge_spk_emo(torch.cat([spk, dummy_emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk_merged) |
|
|
|
w = torch.exp(logw) * x_mask |
|
w_ceil = torch.ceil(w) * length_scale |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
if classifier_type == 'CNN' or classifier_type == 'CNN-with-time' : |
|
y_max_length = max(y_max_length, 180) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
|
|
decoder_outputs = self.decoder.classifier_decode_mixture_DPS(z, y_mask, mu_y, n_timesteps, stoc, spk_merged, |
|
classifier_func, guidance, |
|
control_emo1=emo1, control_emo2=emo2, emo1_weight=emo1_weight, classifier_type=classifier_type) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, emo=None, out_size=None, use_gt_dur=False, durs=None): |
|
""" |
|
Computes 3 losses: |
|
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). |
|
2. prior loss: loss between mel-spectrogram and encoder outputs. |
|
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. |
|
|
|
Args: |
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. |
|
x_lengths (torch.Tensor): lengths of texts in batch. |
|
y (torch.Tensor): batch of corresponding mel-spectrograms. |
|
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. |
|
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. |
|
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. |
|
use_gt_dur: bool |
|
durs: gt duration |
|
""" |
|
x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths]) |
|
|
|
spk = self.spk_emb(spk) |
|
emo = self.emo_emb(emo) |
|
if self.use_classifier_free: |
|
emo = emo / torch.sqrt(torch.sum(emo ** 2, dim=1, keepdim=True)) |
|
use_dummy_per_sample = torch.distributions.Binomial(1, torch.tensor( |
|
[self.dummy_rate] * len(emo))).sample().bool() |
|
emo[use_dummy_per_sample] = (self.dummy_xv / torch.sqrt( |
|
torch.sum(self.dummy_xv ** 2))) |
|
|
|
spk = self.merge_spk_emo(torch.cat([spk, emo], dim=-1)) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) |
|
y_max_length = y.shape[-1] |
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
|
|
|
|
if use_gt_dur: |
|
attn = generate_path(durs, attn_mask.squeeze(1)).detach() |
|
else: |
|
with torch.no_grad(): |
|
const = -0.5 * math.log(2 * math.pi) * self.n_feats |
|
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) |
|
y_square = torch.matmul(factor.transpose(1, 2), y ** 2) |
|
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) |
|
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) |
|
log_prior = y_square - y_mu_double + mu_square + const |
|
|
|
|
|
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) |
|
attn = attn.detach() |
|
|
|
|
|
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask |
|
dur_loss = duration_loss(logw, logw_, x_lengths) |
|
|
|
|
|
|
|
if not isinstance(out_size, type(None)): |
|
clip_size = min(out_size, y_max_length) |
|
clip_size = -fix_len_compatibility(-clip_size) |
|
max_offset = (y_lengths - clip_size).clamp(0) |
|
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) |
|
out_offset = torch.LongTensor([ |
|
torch.tensor(random.choice(range(start, end)) if end > start else 0) |
|
for start, end in offset_ranges |
|
]).to(y_lengths) |
|
|
|
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], clip_size, dtype=attn.dtype, device=attn.device) |
|
y_cut = torch.zeros(y.shape[0], self.n_feats, clip_size, dtype=y.dtype, device=y.device) |
|
y_cut_lengths = [] |
|
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): |
|
y_cut_length = clip_size + (y_lengths[i] - clip_size).clamp(None, 0) |
|
y_cut_lengths.append(y_cut_length) |
|
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length |
|
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] |
|
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] |
|
y_cut_lengths = torch.LongTensor(y_cut_lengths) |
|
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) |
|
|
|
attn = attn_cut |
|
y = y_cut |
|
y_mask = y_cut_mask |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
|
|
|
|
|
|
diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk) |
|
|
|
|
|
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) |
|
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) |
|
|
|
return dur_loss, prior_loss, diff_loss |
|
|
|
|
|
class GradTTSXvector(BaseModule): |
|
def __init__(self, n_vocab=148, spk_emb_dim=64, |
|
n_enc_channels=192, filter_channels=768, filter_channels_dp=256, |
|
n_heads=2, n_enc_layers=6, enc_kernel=3, enc_dropout=0.1, window_size=4, |
|
n_feats=80, dec_dim=64, beta_min=0.05, beta_max=20.0, pe_scale=1000, xvector_dim=512, **kwargs): |
|
super(GradTTSXvector, self).__init__() |
|
self.n_vocab = n_vocab |
|
|
|
self.spk_emb_dim = spk_emb_dim |
|
self.n_enc_channels = n_enc_channels |
|
self.filter_channels = filter_channels |
|
self.filter_channels_dp = filter_channels_dp |
|
self.n_heads = n_heads |
|
self.n_enc_layers = n_enc_layers |
|
self.enc_kernel = enc_kernel |
|
self.enc_dropout = enc_dropout |
|
self.window_size = window_size |
|
self.n_feats = n_feats |
|
self.dec_dim = dec_dim |
|
self.beta_min = beta_min |
|
self.beta_max = beta_max |
|
self.pe_scale = pe_scale |
|
|
|
self.xvector_proj = torch.nn.Linear(xvector_dim, spk_emb_dim) |
|
self.encoder = TextEncoder(n_vocab, n_feats, n_enc_channels, |
|
filter_channels, filter_channels_dp, n_heads, |
|
n_enc_layers, enc_kernel, enc_dropout, window_size, |
|
spk_emb_dim=spk_emb_dim, n_spks=999) |
|
self.decoder = Diffusion(n_feats, dec_dim, spk_emb_dim, beta_min, beta_max, pe_scale) |
|
|
|
@torch.no_grad() |
|
def forward(self, x, x_lengths, n_timesteps, temperature=1.0, stoc=False, spk=None, length_scale=1.0): |
|
""" |
|
Generates mel-spectrogram from text. Returns: |
|
1. encoder outputs |
|
2. decoder outputs |
|
3. generated alignment |
|
|
|
Args: |
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. |
|
x_lengths (torch.Tensor): lengths of texts in batch. |
|
n_timesteps (int): number of steps to use for reverse diffusion in decoder. |
|
temperature (float, optional): controls variance of terminal distribution. |
|
stoc (bool, optional): flag that adds stochastic term to the decoder sampler. |
|
Usually, does not provide synthesis improvements. |
|
length_scale (float, optional): controls speech pace. |
|
Increase value to slow down generated speech and vice versa. |
|
spk: actually the xvectors |
|
""" |
|
x, x_lengths = self.relocate_input([x, x_lengths]) |
|
|
|
spk = self.xvector_proj(spk) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) |
|
|
|
w = torch.exp(logw) * x_mask |
|
w_ceil = torch.ceil(w) * length_scale |
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
|
y_max_length = int(y_lengths.max()) |
|
y_max_length_ = fix_len_compatibility(y_max_length) |
|
|
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
encoder_outputs = mu_y[:, :, :y_max_length] |
|
|
|
|
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature |
|
|
|
decoder_outputs = self.decoder(z, y_mask, mu_y, n_timesteps, stoc, spk) |
|
decoder_outputs = decoder_outputs[:, :, :y_max_length] |
|
|
|
return encoder_outputs, decoder_outputs, attn[:, :, :y_max_length] |
|
|
|
def compute_loss(self, x, x_lengths, y, y_lengths, spk=None, out_size=None, use_gt_dur=False, durs=None): |
|
""" |
|
Computes 3 losses: |
|
1. duration loss: loss between predicted token durations and those extracted by Monotonic Alignment Search (MAS). |
|
2. prior loss: loss between mel-spectrogram and encoder outputs. |
|
3. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. |
|
|
|
Args: |
|
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. |
|
x_lengths (torch.Tensor): lengths of texts in batch. |
|
y (torch.Tensor): batch of corresponding mel-spectrograms. |
|
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. |
|
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. |
|
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. |
|
spk: xvector |
|
use_gt_dur: bool |
|
durs: gt duration |
|
""" |
|
x, x_lengths, y, y_lengths = self.relocate_input([x, x_lengths, y, y_lengths]) |
|
|
|
spk = self.xvector_proj(spk) |
|
|
|
|
|
mu_x, logw, x_mask = self.encoder(x, x_lengths, spk) |
|
y_max_length = y.shape[-1] |
|
|
|
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) |
|
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) |
|
|
|
|
|
if not use_gt_dur: |
|
with torch.no_grad(): |
|
const = -0.5 * math.log(2 * math.pi) * self.n_feats |
|
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) |
|
y_square = torch.matmul(factor.transpose(1, 2), y ** 2) |
|
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) |
|
mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) |
|
log_prior = y_square - y_mu_double + mu_square + const |
|
|
|
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) |
|
attn = attn.detach() |
|
else: |
|
with torch.no_grad(): |
|
attn = generate_path(durs, attn_mask.squeeze(1)).detach() |
|
|
|
|
|
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask |
|
dur_loss = duration_loss(logw, logw_, x_lengths) |
|
|
|
|
|
|
|
|
|
if not isinstance(out_size, type(None)): |
|
max_offset = (y_lengths - out_size).clamp(0) |
|
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) |
|
out_offset = torch.LongTensor([ |
|
torch.tensor(random.choice(range(start, end)) if end > start else 0) |
|
for start, end in offset_ranges |
|
]).to(y_lengths) |
|
|
|
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) |
|
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) |
|
y_cut_lengths = [] |
|
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): |
|
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) |
|
y_cut_lengths.append(y_cut_length) |
|
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length |
|
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] |
|
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] |
|
y_cut_lengths = torch.LongTensor(y_cut_lengths) |
|
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) |
|
|
|
attn = attn_cut |
|
y = y_cut |
|
y_mask = y_cut_mask |
|
|
|
|
|
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) |
|
mu_y = mu_y.transpose(1, 2) |
|
|
|
|
|
diff_loss, xt = self.decoder.compute_loss(y, y_mask, mu_y, spk) |
|
|
|
|
|
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) |
|
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) |
|
|
|
return dur_loss, prior_loss, diff_loss |
|
|