radtts-uk-bigvgan-demo / autoregressive_flow.py
Yehor's picture
Init
ea6a7ed
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# AR_Back_Step and AR_Step based on implementation from
# https://github.com/NVIDIA/flowtron/blob/master/flowtron.py
# Original license text:
###############################################################################
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################
# Original Author and Contact: Rafael Valle
# Modification by Rafael Valle
import torch
from torch import nn
from common import DenseLayer, SplineTransformationLayerAR
class AR_Back_Step(torch.nn.Module):
def __init__(
self,
n_attr_channels,
n_speaker_dim,
n_text_dim,
n_hidden,
n_lstm_layers,
scaling_fn,
spline_flow_params=None,
):
super(AR_Back_Step, self).__init__()
self.ar_step = AR_Step(
n_attr_channels,
n_speaker_dim,
n_text_dim,
n_hidden,
n_lstm_layers,
scaling_fn,
spline_flow_params,
)
def forward(self, mel, context, lens):
mel = torch.flip(mel, (0,))
context = torch.flip(context, (0,))
# backwards flow, send padded zeros back to end
for k in range(mel.size(1)):
mel[:, k] = mel[:, k].roll(lens[k].item(), dims=0)
context[:, k] = context[:, k].roll(lens[k].item(), dims=0)
mel, log_s = self.ar_step(mel, context, lens)
# move padded zeros back to beginning
for k in range(mel.size(1)):
mel[:, k] = mel[:, k].roll(-lens[k].item(), dims=0)
return torch.flip(mel, (0,)), log_s
def infer(self, residual, context):
residual = self.ar_step.infer(
torch.flip(residual, (0,)), torch.flip(context, (0,))
)
residual = torch.flip(residual, (0,))
return residual
class AR_Step(torch.nn.Module):
def __init__(
self,
n_attr_channels,
n_speaker_dim,
n_text_channels,
n_hidden,
n_lstm_layers,
scaling_fn,
spline_flow_params=None,
):
super(AR_Step, self).__init__()
if spline_flow_params is not None:
self.spline_flow = SplineTransformationLayerAR(**spline_flow_params)
else:
self.n_out_dims = n_attr_channels
self.conv = torch.nn.Conv1d(n_hidden, 2 * n_attr_channels, 1)
self.conv.weight.data = 0.0 * self.conv.weight.data
self.conv.bias.data = 0.0 * self.conv.bias.data
self.attr_lstm = torch.nn.LSTM(n_attr_channels, n_hidden)
self.lstm = torch.nn.LSTM(
n_hidden + n_text_channels + n_speaker_dim, n_hidden, n_lstm_layers
)
if spline_flow_params is None:
self.dense_layer = DenseLayer(in_dim=n_hidden, sizes=[n_hidden, n_hidden])
self.scaling_fn = scaling_fn
def run_padded_sequence(
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
):
"""Sorts input data by previded ordering (and un-ordering) and runs the
packed data through the recurrent model
Args:
sorted_idx (torch.tensor): 1D sorting index
unsort_idx (torch.tensor): 1D unsorting index (inverse sorted_idx)
lens: lengths of input data (sorted in descending order)
padded_data (torch.tensor): input sequences (padded)
recurrent_model (nn.Module): recurrent model to run data through
Returns:
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
unsorted, ordering
"""
# sort the data by decreasing length using provided index
# we assume batch index is in dim=1
padded_data = padded_data[:, sorted_idx]
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens.cpu())
hidden_vectors = recurrent_model(padded_data)[0]
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
# unsort the results at dim=1 and return
hidden_vectors = hidden_vectors[:, unsort_idx]
return hidden_vectors
def get_scaling_and_logs(self, scale_unconstrained):
if self.scaling_fn == "translate":
s = torch.exp(scale_unconstrained * 0)
log_s = scale_unconstrained * 0
elif self.scaling_fn == "exp":
s = torch.exp(scale_unconstrained)
log_s = scale_unconstrained # log(exp
elif self.scaling_fn == "tanh":
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
log_s = torch.log(s)
elif self.scaling_fn == "sigmoid":
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
log_s = torch.log(s)
else:
raise Exception("Scaling fn {} not supp.".format(self.scaling_fn))
return s, log_s
def forward(self, mel, context, lens):
dummy = torch.FloatTensor(1, mel.size(1), mel.size(2)).zero_()
dummy = dummy.type(mel.type())
# seq_len x batch x dim
mel0 = torch.cat([dummy, mel[:-1]], 0)
self.lstm.flatten_parameters()
self.attr_lstm.flatten_parameters()
if lens is not None:
# collect decreasing length indices
lens, ids = torch.sort(lens, descending=True)
original_ids = [0] * lens.size(0)
for i, ids_i in enumerate(ids):
original_ids[ids_i] = i
# mel_seq_len x batch x hidden_dim
mel_hidden = self.run_padded_sequence(
ids, original_ids, lens, mel0, self.attr_lstm
)
else:
mel_hidden = self.attr_lstm(mel0)[0]
decoder_input = torch.cat((mel_hidden, context), -1)
if lens is not None:
# reorder, run padded sequence and undo reordering
lstm_hidden = self.run_padded_sequence(
ids, original_ids, lens, decoder_input, self.lstm
)
else:
lstm_hidden = self.lstm(decoder_input)[0]
if hasattr(self, "spline_flow"):
# spline flow fn expects inputs to be batch, channel, time
lstm_hidden = lstm_hidden.permute(1, 2, 0)
mel = mel.permute(1, 2, 0)
mel, log_s = self.spline_flow(mel, lstm_hidden, inverse=False)
mel = mel.permute(2, 0, 1)
log_s = log_s.permute(2, 0, 1)
else:
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
scale, log_s = self.get_scaling_and_logs(
decoder_output[:, :, : self.n_out_dims]
)
bias = decoder_output[:, :, self.n_out_dims :]
mel = scale * mel + bias
return mel, log_s
def infer(self, residual, context):
total_output = [] # seems 10FPS faster than pre-allocation
output = None
dummy = torch.cuda.FloatTensor(1, residual.size(1), residual.size(2)).zero_()
self.attr_lstm.flatten_parameters()
for i in range(0, residual.size(0)):
if i == 0:
output = dummy
mel_hidden, (h, c) = self.attr_lstm(output)
else:
mel_hidden, (h, c) = self.attr_lstm(output, (h, c))
decoder_input = torch.cat((mel_hidden, context[i][None]), -1)
if i == 0:
lstm_hidden, (h1, c1) = self.lstm(decoder_input)
else:
lstm_hidden, (h1, c1) = self.lstm(decoder_input, (h1, c1))
if hasattr(self, "spline_flow"):
# expects inputs to be batch, channel, time
lstm_hidden = lstm_hidden.permute(1, 2, 0)
output = residual[i : i + 1].permute(1, 2, 0)
output = self.spline_flow(output, lstm_hidden, inverse=True)
output = output.permute(2, 0, 1)
else:
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
s, log_s = self.get_scaling_and_logs(
decoder_output[:, :, : decoder_output.size(2) // 2]
)
b = decoder_output[:, :, decoder_output.size(2) // 2 :]
output = (residual[i : i + 1] - b) / s
total_output.append(output)
total_output = torch.cat(total_output, 0)
return total_output