ONNXServies / VitsModelSplit /duration_predictor.py
wasmdashai's picture
model push
38f004a
import math
import numpy as np
import torch
from torch import nn
from .vits_config import VitsConfig
#.............................................
def _rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
reverse,
tail_bound,
min_bin_width,
min_bin_height,
min_derivative,
):
"""
This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
Args:
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Second half of the hidden-states input to the Vits convolutional flow module.
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
reverse (`bool`):
Whether the model is being run in reverse mode.
tail_bound (`float`):
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
transform behaves as an identity function.
min_bin_width (`float`):
Minimum bin value across the width dimension for the piecewise rational quadratic function.
min_bin_height (`float`):
Minimum bin value across the height dimension for the piecewise rational quadratic function.
min_derivative (`float`):
Minimum bin value across the derivatives for the piecewise rational quadratic function.
Returns:
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Hidden-states as transformed by the piecewise rational quadratic function.
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Logarithm of the absolute value of the determinants corresponding to the `outputs`.
"""
upper_bound = tail_bound
lower_bound = -tail_bound
if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
if min_bin_height * num_bins > 1.0:
raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
widths = nn.functional.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
cumwidths[..., 0] = lower_bound
cumwidths[..., -1] = upper_bound
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
heights = nn.functional.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
cumheights[..., 0] = lower_bound
cumheights[..., -1] = upper_bound
heights = cumheights[..., 1:] - cumheights[..., :-1]
bin_locations = cumheights if reverse else cumwidths
bin_locations[..., -1] += 1e-6
bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
bin_idx = bin_idx[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
if not reverse:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + intermediate1 * theta_one_minus_theta
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, log_abs_det
else:
# find the roots of a quadratic equation
intermediate2 = inputs - input_cumheights
intermediate3 = intermediate2 * intermediate1
a = input_heights * (input_delta - input_derivatives) + intermediate3
b = input_heights * input_derivatives - intermediate3
c = -input_delta * intermediate2
discriminant = b.pow(2) - 4 * a * c
if not (discriminant >= 0).all():
raise RuntimeError(f"invalid discriminant {discriminant}")
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + intermediate1 * theta_one_minus_theta
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -log_abs_det
#.............................................
def _unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
reverse=False,
tail_bound=5.0,
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3,
):
"""
This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
`tail_bound`, the transform behaves as an identity function.
Args:
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Second half of the hidden-states input to the Vits convolutional flow module.
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
layer in the convolutional flow module
reverse (`bool`, *optional*, defaults to `False`):
Whether the model is being run in reverse mode.
tail_bound (`float`, *optional* defaults to 5):
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
transform behaves as an identity function.
min_bin_width (`float`, *optional*, defaults to 1e-3):
Minimum bin value across the width dimension for the piecewise rational quadratic function.
min_bin_height (`float`, *optional*, defaults to 1e-3):
Minimum bin value across the height dimension for the piecewise rational quadratic function.
min_derivative (`float`, *optional*, defaults to 1e-3):
Minimum bin value across the derivatives for the piecewise rational quadratic function.
Returns:
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
applied.
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
limits applied.
"""
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
log_abs_det = torch.zeros_like(inputs)
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
log_abs_det[outside_interval_mask] = 0.0
outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
reverse=reverse,
tail_bound=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
return outputs, log_abs_det
#.............................................................................................
class VitsConvFlow(nn.Module):
def __init__(self, config: VitsConfig):
super().__init__()
self.filter_channels = config.hidden_size
self.half_channels = config.depth_separable_channels // 2
self.num_bins = config.duration_predictor_flow_bins
self.tail_bound = config.duration_predictor_tail_bound
self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
self.conv_dds = VitsDilatedDepthSeparableConv(config)
self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
hidden_states = self.conv_pre(first_half)
hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
hidden_states = self.conv_proj(hidden_states) * padding_mask
batch_size, channels, length = first_half.shape
hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
second_half,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
reverse=reverse,
tail_bound=self.tail_bound,
)
outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
if not reverse:
log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
return outputs, log_determinant
else:
return outputs, None
#.............................................................................................
class VitsElementwiseAffine(nn.Module):
def __init__(self, config: VitsConfig):
super().__init__()
self.channels = config.depth_separable_channels
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
if not reverse:
outputs = self.translate + torch.exp(self.log_scale) * inputs
outputs = outputs * padding_mask
log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
return outputs, log_determinant
else:
outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
return outputs, None
#.............................................................................................
class VitsDilatedDepthSeparableConv(nn.Module):
def __init__(self, config: VitsConfig, dropout_rate=0.0):
super().__init__()
kernel_size = config.duration_predictor_kernel_size
channels = config.hidden_size
self.num_layers = config.depth_separable_num_layers
self.dropout = nn.Dropout(dropout_rate)
self.convs_dilated = nn.ModuleList()
self.convs_pointwise = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(self.num_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_dilated.append(
nn.Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(nn.LayerNorm(channels))
self.norms_2.append(nn.LayerNorm(channels))
def forward(self, inputs, padding_mask, global_conditioning=None):
if global_conditioning is not None:
inputs = inputs + global_conditioning
for i in range(self.num_layers):
hidden_states = self.convs_dilated[i](inputs * padding_mask)
hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
hidden_states = nn.functional.gelu(hidden_states)
hidden_states = self.convs_pointwise[i](hidden_states)
hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
hidden_states = nn.functional.gelu(hidden_states)
hidden_states = self.dropout(hidden_states)
inputs = inputs + hidden_states
return inputs * padding_mask
#.............................................................................................
class VitsStochasticDurationPredictor(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.speaker_embedding_size
filter_channels = config.hidden_size
self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.conv_dds = VitsDilatedDepthSeparableConv(
config,
dropout_rate=config.duration_predictor_dropout,
)
if embed_dim != 0:
self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
self.flows = nn.ModuleList()
self.flows.append(VitsElementwiseAffine(config))
for _ in range(config.duration_predictor_num_flows):
self.flows.append(VitsConvFlow(config))
self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_conv_dds = VitsDilatedDepthSeparableConv(
config,
dropout_rate=config.duration_predictor_dropout,
)
self.post_flows = nn.ModuleList()
self.post_flows.append(VitsElementwiseAffine(config))
for _ in range(config.duration_predictor_num_flows):
self.post_flows.append(VitsConvFlow(config))
self.filter_channels = filter_channels
def resize_speaker_embeddings(self, speaker_embedding_size):
self.cond = nn.Conv1d(speaker_embedding_size, self.filter_channels, 1)
def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
inputs = torch.detach(inputs)
inputs = self.conv_pre(inputs)
if global_conditioning is not None:
global_conditioning = torch.detach(global_conditioning)
inputs = inputs + self.cond(global_conditioning)
inputs = self.conv_dds(inputs, padding_mask)
inputs = self.conv_proj(inputs) * padding_mask
if not reverse:
hidden_states = self.post_conv_pre(durations)
hidden_states = self.post_conv_dds(hidden_states, padding_mask)
hidden_states = self.post_conv_proj(hidden_states) * padding_mask
random_posterior = (
torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
* padding_mask
)
latents_posterior = random_posterior
latents_posterior, log_determinant = self.post_flows[0](
latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
)
log_determinant_posterior_sum = log_determinant
for flow in self.post_flows[1:]:
latents_posterior, log_determinant = flow(
latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
)
latents_posterior = torch.flip(latents_posterior, [1])
log_determinant_posterior_sum += log_determinant
first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
log_determinant_posterior_sum += torch.sum(
(nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
- log_determinant_posterior_sum
)
first_half = (durations - torch.sigmoid(first_half)) * padding_mask
first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
log_determinant_sum = torch.sum(-first_half, [1, 2])
latents = torch.cat([first_half, second_half], dim=1)
latents, log_determinant = self.flows[0](latents, padding_mask, global_conditioning=inputs)
log_determinant_sum += log_determinant
for flow in self.flows[1:]:
latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
latents = torch.flip(latents, [1])
log_determinant_sum += log_determinant
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
return nll + logq
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
latents = (
torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
* noise_scale
)
for flow in flows:
latents = torch.flip(latents, [1])
latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
log_duration, _ = torch.split(latents, [1, 1], dim=1)
return log_duration
#.............................................................................................
class VitsDurationPredictor(nn.Module):
def __init__(self, config):
super().__init__()
kernel_size = config.duration_predictor_kernel_size
filter_channels = config.duration_predictor_filter_channels
self.dropout = nn.Dropout(config.duration_predictor_dropout)
self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if config.speaker_embedding_size != 0:
self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
self.hidden_size = config.hidden_size
def resize_speaker_embeddings(self, speaker_embedding_size):
self.cond = nn.Conv1d(speaker_embedding_size, self.hidden_size, 1)
def forward(self, inputs, padding_mask, global_conditioning=None):
inputs = torch.detach(inputs)
if global_conditioning is not None:
global_conditioning = torch.detach(global_conditioning)
inputs = inputs + self.cond(global_conditioning)
inputs = self.conv_1(inputs * padding_mask)
inputs = torch.relu(inputs)
inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
inputs = self.dropout(inputs)
inputs = self.conv_2(inputs * padding_mask)
inputs = torch.relu(inputs)
inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
inputs = self.dropout(inputs)
inputs = self.proj(inputs * padding_mask)
return inputs * padding_mask
#.............................................................................................