Spaces:
Running
Running
import math | |
import numpy as np | |
import torch | |
from torch import nn | |
from .vits_config import VitsConfig | |
#............................................. | |
def _rational_quadratic_spline( | |
inputs, | |
unnormalized_widths, | |
unnormalized_heights, | |
unnormalized_derivatives, | |
reverse, | |
tail_bound, | |
min_bin_width, | |
min_bin_height, | |
min_derivative, | |
): | |
""" | |
This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the | |
function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`. | |
Args: | |
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Second half of the hidden-states input to the Vits convolutional flow module. | |
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
reverse (`bool`): | |
Whether the model is being run in reverse mode. | |
tail_bound (`float`): | |
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the | |
transform behaves as an identity function. | |
min_bin_width (`float`): | |
Minimum bin value across the width dimension for the piecewise rational quadratic function. | |
min_bin_height (`float`): | |
Minimum bin value across the height dimension for the piecewise rational quadratic function. | |
min_derivative (`float`): | |
Minimum bin value across the derivatives for the piecewise rational quadratic function. | |
Returns: | |
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Hidden-states as transformed by the piecewise rational quadratic function. | |
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Logarithm of the absolute value of the determinants corresponding to the `outputs`. | |
""" | |
upper_bound = tail_bound | |
lower_bound = -tail_bound | |
if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound: | |
raise ValueError("Input to a transform is not within its domain") | |
num_bins = unnormalized_widths.shape[-1] | |
if min_bin_width * num_bins > 1.0: | |
raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}") | |
if min_bin_height * num_bins > 1.0: | |
raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}") | |
widths = nn.functional.softmax(unnormalized_widths, dim=-1) | |
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths | |
cumwidths = torch.cumsum(widths, dim=-1) | |
cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) | |
cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound | |
cumwidths[..., 0] = lower_bound | |
cumwidths[..., -1] = upper_bound | |
widths = cumwidths[..., 1:] - cumwidths[..., :-1] | |
derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives) | |
heights = nn.functional.softmax(unnormalized_heights, dim=-1) | |
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights | |
cumheights = torch.cumsum(heights, dim=-1) | |
cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) | |
cumheights = (upper_bound - lower_bound) * cumheights + lower_bound | |
cumheights[..., 0] = lower_bound | |
cumheights[..., -1] = upper_bound | |
heights = cumheights[..., 1:] - cumheights[..., :-1] | |
bin_locations = cumheights if reverse else cumwidths | |
bin_locations[..., -1] += 1e-6 | |
bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 | |
bin_idx = bin_idx[..., None] | |
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] | |
input_bin_widths = widths.gather(-1, bin_idx)[..., 0] | |
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] | |
delta = heights / widths | |
input_delta = delta.gather(-1, bin_idx)[..., 0] | |
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] | |
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] | |
input_heights = heights.gather(-1, bin_idx)[..., 0] | |
intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta | |
if not reverse: | |
theta = (inputs - input_cumwidths) / input_bin_widths | |
theta_one_minus_theta = theta * (1 - theta) | |
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) | |
denominator = input_delta + intermediate1 * theta_one_minus_theta | |
outputs = input_cumheights + numerator / denominator | |
derivative_numerator = input_delta.pow(2) * ( | |
input_derivatives_plus_one * theta.pow(2) | |
+ 2 * input_delta * theta_one_minus_theta | |
+ input_derivatives * (1 - theta).pow(2) | |
) | |
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) | |
return outputs, log_abs_det | |
else: | |
# find the roots of a quadratic equation | |
intermediate2 = inputs - input_cumheights | |
intermediate3 = intermediate2 * intermediate1 | |
a = input_heights * (input_delta - input_derivatives) + intermediate3 | |
b = input_heights * input_derivatives - intermediate3 | |
c = -input_delta * intermediate2 | |
discriminant = b.pow(2) - 4 * a * c | |
if not (discriminant >= 0).all(): | |
raise RuntimeError(f"invalid discriminant {discriminant}") | |
root = (2 * c) / (-b - torch.sqrt(discriminant)) | |
outputs = root * input_bin_widths + input_cumwidths | |
theta_one_minus_theta = root * (1 - root) | |
denominator = input_delta + intermediate1 * theta_one_minus_theta | |
derivative_numerator = input_delta.pow(2) * ( | |
input_derivatives_plus_one * root.pow(2) | |
+ 2 * input_delta * theta_one_minus_theta | |
+ input_derivatives * (1 - root).pow(2) | |
) | |
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) | |
return outputs, -log_abs_det | |
#............................................. | |
def _unconstrained_rational_quadratic_spline( | |
inputs, | |
unnormalized_widths, | |
unnormalized_heights, | |
unnormalized_derivatives, | |
reverse=False, | |
tail_bound=5.0, | |
min_bin_width=1e-3, | |
min_bin_height=1e-3, | |
min_derivative=1e-3, | |
): | |
""" | |
This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the | |
`tail_bound`, the transform behaves as an identity function. | |
Args: | |
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Second half of the hidden-states input to the Vits convolutional flow module. | |
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): | |
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection | |
layer in the convolutional flow module | |
reverse (`bool`, *optional*, defaults to `False`): | |
Whether the model is being run in reverse mode. | |
tail_bound (`float`, *optional* defaults to 5): | |
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the | |
transform behaves as an identity function. | |
min_bin_width (`float`, *optional*, defaults to 1e-3): | |
Minimum bin value across the width dimension for the piecewise rational quadratic function. | |
min_bin_height (`float`, *optional*, defaults to 1e-3): | |
Minimum bin value across the height dimension for the piecewise rational quadratic function. | |
min_derivative (`float`, *optional*, defaults to 1e-3): | |
Minimum bin value across the derivatives for the piecewise rational quadratic function. | |
Returns: | |
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits | |
applied. | |
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: | |
Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound` | |
limits applied. | |
""" | |
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) | |
outside_interval_mask = ~inside_interval_mask | |
outputs = torch.zeros_like(inputs) | |
log_abs_det = torch.zeros_like(inputs) | |
constant = np.log(np.exp(1 - min_derivative) - 1) | |
unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) | |
unnormalized_derivatives[..., 0] = constant | |
unnormalized_derivatives[..., -1] = constant | |
outputs[outside_interval_mask] = inputs[outside_interval_mask] | |
log_abs_det[outside_interval_mask] = 0.0 | |
outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline( | |
inputs=inputs[inside_interval_mask], | |
unnormalized_widths=unnormalized_widths[inside_interval_mask, :], | |
unnormalized_heights=unnormalized_heights[inside_interval_mask, :], | |
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], | |
reverse=reverse, | |
tail_bound=tail_bound, | |
min_bin_width=min_bin_width, | |
min_bin_height=min_bin_height, | |
min_derivative=min_derivative, | |
) | |
return outputs, log_abs_det | |
#............................................................................................. | |
class VitsConvFlow(nn.Module): | |
def __init__(self, config: VitsConfig): | |
super().__init__() | |
self.filter_channels = config.hidden_size | |
self.half_channels = config.depth_separable_channels // 2 | |
self.num_bins = config.duration_predictor_flow_bins | |
self.tail_bound = config.duration_predictor_tail_bound | |
self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1) | |
self.conv_dds = VitsDilatedDepthSeparableConv(config) | |
self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1) | |
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) | |
hidden_states = self.conv_pre(first_half) | |
hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning) | |
hidden_states = self.conv_proj(hidden_states) * padding_mask | |
batch_size, channels, length = first_half.shape | |
hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2) | |
unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels) | |
unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) | |
unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :] | |
second_half, log_abs_det = _unconstrained_rational_quadratic_spline( | |
second_half, | |
unnormalized_widths, | |
unnormalized_heights, | |
unnormalized_derivatives, | |
reverse=reverse, | |
tail_bound=self.tail_bound, | |
) | |
outputs = torch.cat([first_half, second_half], dim=1) * padding_mask | |
if not reverse: | |
log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2]) | |
return outputs, log_determinant | |
else: | |
return outputs, None | |
#............................................................................................. | |
class VitsElementwiseAffine(nn.Module): | |
def __init__(self, config: VitsConfig): | |
super().__init__() | |
self.channels = config.depth_separable_channels | |
self.translate = nn.Parameter(torch.zeros(self.channels, 1)) | |
self.log_scale = nn.Parameter(torch.zeros(self.channels, 1)) | |
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
if not reverse: | |
outputs = self.translate + torch.exp(self.log_scale) * inputs | |
outputs = outputs * padding_mask | |
log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2]) | |
return outputs, log_determinant | |
else: | |
outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask | |
return outputs, None | |
#............................................................................................. | |
class VitsDilatedDepthSeparableConv(nn.Module): | |
def __init__(self, config: VitsConfig, dropout_rate=0.0): | |
super().__init__() | |
kernel_size = config.duration_predictor_kernel_size | |
channels = config.hidden_size | |
self.num_layers = config.depth_separable_num_layers | |
self.dropout = nn.Dropout(dropout_rate) | |
self.convs_dilated = nn.ModuleList() | |
self.convs_pointwise = nn.ModuleList() | |
self.norms_1 = nn.ModuleList() | |
self.norms_2 = nn.ModuleList() | |
for i in range(self.num_layers): | |
dilation = kernel_size**i | |
padding = (kernel_size * dilation - dilation) // 2 | |
self.convs_dilated.append( | |
nn.Conv1d( | |
in_channels=channels, | |
out_channels=channels, | |
kernel_size=kernel_size, | |
groups=channels, | |
dilation=dilation, | |
padding=padding, | |
) | |
) | |
self.convs_pointwise.append(nn.Conv1d(channels, channels, 1)) | |
self.norms_1.append(nn.LayerNorm(channels)) | |
self.norms_2.append(nn.LayerNorm(channels)) | |
def forward(self, inputs, padding_mask, global_conditioning=None): | |
if global_conditioning is not None: | |
inputs = inputs + global_conditioning | |
for i in range(self.num_layers): | |
hidden_states = self.convs_dilated[i](inputs * padding_mask) | |
hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1) | |
hidden_states = nn.functional.gelu(hidden_states) | |
hidden_states = self.convs_pointwise[i](hidden_states) | |
hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1) | |
hidden_states = nn.functional.gelu(hidden_states) | |
hidden_states = self.dropout(hidden_states) | |
inputs = inputs + hidden_states | |
return inputs * padding_mask | |
#............................................................................................. | |
class VitsStochasticDurationPredictor(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
embed_dim = config.speaker_embedding_size | |
filter_channels = config.hidden_size | |
self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1) | |
self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
self.conv_dds = VitsDilatedDepthSeparableConv( | |
config, | |
dropout_rate=config.duration_predictor_dropout, | |
) | |
if embed_dim != 0: | |
self.cond = nn.Conv1d(embed_dim, filter_channels, 1) | |
self.flows = nn.ModuleList() | |
self.flows.append(VitsElementwiseAffine(config)) | |
for _ in range(config.duration_predictor_num_flows): | |
self.flows.append(VitsConvFlow(config)) | |
self.post_conv_pre = nn.Conv1d(1, filter_channels, 1) | |
self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) | |
self.post_conv_dds = VitsDilatedDepthSeparableConv( | |
config, | |
dropout_rate=config.duration_predictor_dropout, | |
) | |
self.post_flows = nn.ModuleList() | |
self.post_flows.append(VitsElementwiseAffine(config)) | |
for _ in range(config.duration_predictor_num_flows): | |
self.post_flows.append(VitsConvFlow(config)) | |
self.filter_channels = filter_channels | |
def resize_speaker_embeddings(self, speaker_embedding_size): | |
self.cond = nn.Conv1d(speaker_embedding_size, self.filter_channels, 1) | |
def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0): | |
inputs = torch.detach(inputs) | |
inputs = self.conv_pre(inputs) | |
if global_conditioning is not None: | |
global_conditioning = torch.detach(global_conditioning) | |
inputs = inputs + self.cond(global_conditioning) | |
inputs = self.conv_dds(inputs, padding_mask) | |
inputs = self.conv_proj(inputs) * padding_mask | |
if not reverse: | |
hidden_states = self.post_conv_pre(durations) | |
hidden_states = self.post_conv_dds(hidden_states, padding_mask) | |
hidden_states = self.post_conv_proj(hidden_states) * padding_mask | |
random_posterior = ( | |
torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype) | |
* padding_mask | |
) | |
latents_posterior = random_posterior | |
latents_posterior, log_determinant = self.post_flows[0]( | |
latents_posterior, padding_mask, global_conditioning=inputs + hidden_states | |
) | |
log_determinant_posterior_sum = log_determinant | |
for flow in self.post_flows[1:]: | |
latents_posterior, log_determinant = flow( | |
latents_posterior, padding_mask, global_conditioning=inputs + hidden_states | |
) | |
latents_posterior = torch.flip(latents_posterior, [1]) | |
log_determinant_posterior_sum += log_determinant | |
first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1) | |
log_determinant_posterior_sum += torch.sum( | |
(nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2] | |
) | |
logq = ( | |
torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2]) | |
- log_determinant_posterior_sum | |
) | |
first_half = (durations - torch.sigmoid(first_half)) * padding_mask | |
first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask | |
log_determinant_sum = torch.sum(-first_half, [1, 2]) | |
latents = torch.cat([first_half, second_half], dim=1) | |
latents, log_determinant = self.flows[0](latents, padding_mask, global_conditioning=inputs) | |
log_determinant_sum += log_determinant | |
for flow in self.flows[1:]: | |
latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs) | |
latents = torch.flip(latents, [1]) | |
log_determinant_sum += log_determinant | |
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum | |
return nll + logq | |
else: | |
flows = list(reversed(self.flows)) | |
flows = flows[:-2] + [flows[-1]] # remove a useless vflow | |
latents = ( | |
torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype) | |
* noise_scale | |
) | |
for flow in flows: | |
latents = torch.flip(latents, [1]) | |
latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True) | |
log_duration, _ = torch.split(latents, [1, 1], dim=1) | |
return log_duration | |
#............................................................................................. | |
class VitsDurationPredictor(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
kernel_size = config.duration_predictor_kernel_size | |
filter_channels = config.duration_predictor_filter_channels | |
self.dropout = nn.Dropout(config.duration_predictor_dropout) | |
self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2) | |
self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) | |
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) | |
self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) | |
self.proj = nn.Conv1d(filter_channels, 1, 1) | |
if config.speaker_embedding_size != 0: | |
self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1) | |
self.hidden_size = config.hidden_size | |
def resize_speaker_embeddings(self, speaker_embedding_size): | |
self.cond = nn.Conv1d(speaker_embedding_size, self.hidden_size, 1) | |
def forward(self, inputs, padding_mask, global_conditioning=None): | |
inputs = torch.detach(inputs) | |
if global_conditioning is not None: | |
global_conditioning = torch.detach(global_conditioning) | |
inputs = inputs + self.cond(global_conditioning) | |
inputs = self.conv_1(inputs * padding_mask) | |
inputs = torch.relu(inputs) | |
inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1) | |
inputs = self.dropout(inputs) | |
inputs = self.conv_2(inputs * padding_mask) | |
inputs = torch.relu(inputs) | |
inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1) | |
inputs = self.dropout(inputs) | |
inputs = self.proj(inputs * padding_mask) | |
return inputs * padding_mask | |
#............................................................................................. |