VQMIVC / ParallelWaveGAN /test /test_parallel_wavegan.py
akhaliq3
spaces demo
2b7bf83
raw
history blame
10.9 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
import logging
import numpy as np
import pytest
import torch
from parallel_wavegan.losses import DiscriminatorAdversarialLoss
from parallel_wavegan.losses import GeneratorAdversarialLoss
from parallel_wavegan.losses import MultiResolutionSTFTLoss
from parallel_wavegan.models import ParallelWaveGANDiscriminator
from parallel_wavegan.models import ParallelWaveGANGenerator
from parallel_wavegan.models import ResidualParallelWaveGANDiscriminator
from parallel_wavegan.optimizers import RAdam
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
def make_generator_args(**kwargs):
defaults = dict(
in_channels=1,
out_channels=1,
kernel_size=3,
layers=6,
stacks=3,
residual_channels=8,
gate_channels=16,
skip_channels=8,
aux_channels=10,
aux_context_window=0,
dropout=1 - 0.95,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_net="ConvInUpsampleNetwork",
upsample_params={"upsample_scales": [4, 4]},
)
defaults.update(kwargs)
return defaults
def make_discriminator_args(**kwargs):
defaults = dict(
in_channels=1,
out_channels=1,
kernel_size=3,
layers=5,
conv_channels=16,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
use_weight_norm=True,
)
defaults.update(kwargs)
return defaults
def make_residual_discriminator_args(**kwargs):
defaults = dict(
in_channels=1,
out_channels=1,
kernel_size=3,
layers=10,
stacks=1,
residual_channels=8,
gate_channels=16,
skip_channels=8,
dropout=0.0,
use_weight_norm=True,
use_causal_conv=False,
nonlinear_activation_params={"negative_slope": 0.2},
)
defaults.update(kwargs)
return defaults
def make_mutli_reso_stft_loss_args(**kwargs):
defaults = dict(
fft_sizes=[64, 128, 256],
hop_sizes=[32, 64, 128],
win_lengths=[48, 96, 192],
window="hann_window",
)
defaults.update(kwargs)
return defaults
@pytest.mark.parametrize(
"dict_g, dict_d, dict_loss",
[
({}, {}, {}),
({"layers": 1, "stacks": 1}, {}, {}),
({}, {"layers": 1}, {}),
({"kernel_size": 5}, {}, {}),
({}, {"kernel_size": 5}, {}),
({"gate_channels": 8}, {}, {}),
({"stacks": 1}, {}, {}),
({"use_weight_norm": False}, {"use_weight_norm": False}, {}),
({"aux_context_window": 2}, {}, {}),
({"upsample_net": "UpsampleNetwork"}, {}, {}),
(
{"upsample_params": {"upsample_scales": [4], "freq_axis_kernel_size": 3}},
{},
{},
),
(
{
"upsample_params": {
"upsample_scales": [4],
"nonlinear_activation": "ReLU",
}
},
{},
{},
),
(
{
"upsample_conditional_features": False,
"upsample_params": {"upsample_scales": [1]},
},
{},
{},
),
({}, {"nonlinear_activation": "ReLU", "nonlinear_activation_params": {}}, {}),
({"use_causal_conv": True}, {}, {}),
({"use_causal_conv": True, "upsample_net": "UpsampleNetwork"}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 1}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 2}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 3}, {}, {}),
(
{
"aux_channels": 16,
"upsample_net": "MelGANGenerator",
"upsample_params": {
"upsample_scales": [4, 4],
"in_channels": 16,
"out_channels": 16,
},
},
{},
{},
),
],
)
def test_parallel_wavegan_trainable(dict_g, dict_d, dict_loss):
# setup
batch_size = 4
batch_length = 4096
args_g = make_generator_args(**dict_g)
args_d = make_discriminator_args(**dict_d)
args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
z = torch.randn(batch_size, 1, batch_length)
y = torch.randn(batch_size, 1, batch_length)
c = torch.randn(
batch_size,
args_g["aux_channels"],
batch_length // np.prod(args_g["upsample_params"]["upsample_scales"])
+ 2 * args_g["aux_context_window"],
)
model_g = ParallelWaveGANGenerator(**args_g)
model_d = ParallelWaveGANDiscriminator(**args_d)
aux_criterion = MultiResolutionSTFTLoss(**args_loss)
gen_adv_criterion = GeneratorAdversarialLoss()
dis_adv_criterion = DiscriminatorAdversarialLoss()
optimizer_g = RAdam(model_g.parameters())
optimizer_d = RAdam(model_d.parameters())
# check generator trainable
y_hat = model_g(z, c)
p_hat = model_d(y_hat)
adv_loss = gen_adv_criterion(p_hat)
sc_loss, mag_loss = aux_criterion(y_hat, y)
aux_loss = sc_loss + mag_loss
loss_g = adv_loss + aux_loss
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
# check discriminator trainable
p = model_d(y)
p_hat = model_d(y_hat.detach())
real_loss, fake_loss = dis_adv_criterion(p_hat, p)
loss_d = real_loss + fake_loss
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
@pytest.mark.parametrize(
"dict_g, dict_d, dict_loss",
[
({}, {}, {}),
({"layers": 1, "stacks": 1}, {}, {}),
({}, {"layers": 1}, {}),
({"kernel_size": 5}, {}, {}),
({}, {"kernel_size": 5}, {}),
({"gate_channels": 8}, {}, {}),
({"stacks": 1}, {}, {}),
({"use_weight_norm": False}, {"use_weight_norm": False}, {}),
({"aux_context_window": 2}, {}, {}),
({"upsample_net": "UpsampleNetwork"}, {}, {}),
(
{"upsample_params": {"upsample_scales": [4], "freq_axis_kernel_size": 3}},
{},
{},
),
(
{
"upsample_params": {
"upsample_scales": [4],
"nonlinear_activation": "ReLU",
}
},
{},
{},
),
(
{
"upsample_conditional_features": False,
"upsample_params": {"upsample_scales": [1]},
},
{},
{},
),
({}, {"nonlinear_activation": "ReLU", "nonlinear_activation_params": {}}, {}),
({"use_causal_conv": True}, {}, {}),
({"use_causal_conv": True, "upsample_net": "UpsampleNetwork"}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 1}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 2}, {}, {}),
({"use_causal_conv": True, "aux_context_window": 3}, {}, {}),
(
{
"aux_channels": 16,
"upsample_net": "MelGANGenerator",
"upsample_params": {
"upsample_scales": [4, 4],
"in_channels": 16,
"out_channels": 16,
},
},
{},
{},
),
],
)
def test_parallel_wavegan_with_residual_discriminator_trainable(
dict_g, dict_d, dict_loss
):
# setup
batch_size = 4
batch_length = 4096
args_g = make_generator_args(**dict_g)
args_d = make_residual_discriminator_args(**dict_d)
args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
z = torch.randn(batch_size, 1, batch_length)
y = torch.randn(batch_size, 1, batch_length)
c = torch.randn(
batch_size,
args_g["aux_channels"],
batch_length // np.prod(args_g["upsample_params"]["upsample_scales"])
+ 2 * args_g["aux_context_window"],
)
model_g = ParallelWaveGANGenerator(**args_g)
model_d = ResidualParallelWaveGANDiscriminator(**args_d)
aux_criterion = MultiResolutionSTFTLoss(**args_loss)
gen_adv_criterion = GeneratorAdversarialLoss()
dis_adv_criterion = DiscriminatorAdversarialLoss()
optimizer_g = RAdam(model_g.parameters())
optimizer_d = RAdam(model_d.parameters())
# check generator trainable
y_hat = model_g(z, c)
p_hat = model_d(y_hat)
adv_loss = gen_adv_criterion(p_hat)
sc_loss, mag_loss = aux_criterion(y_hat, y)
aux_loss = sc_loss + mag_loss
loss_g = adv_loss + aux_loss
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
# check discriminator trainable
p = model_d(y)
p_hat = model_d(y_hat.detach())
real_loss, fake_loss = dis_adv_criterion(p_hat, p)
loss_d = real_loss + fake_loss
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
@pytest.mark.parametrize(
"upsample_net, aux_context_window",
[
("ConvInUpsampleNetwork", 0),
("ConvInUpsampleNetwork", 1),
("ConvInUpsampleNetwork", 2),
("ConvInUpsampleNetwork", 3),
("UpsampleNetwork", 0),
],
)
def test_causal_parallel_wavegan(upsample_net, aux_context_window):
batch_size = 1
batch_length = 4096
args_g = make_generator_args(
use_causal_conv=True,
upsample_net=upsample_net,
aux_context_window=aux_context_window,
dropout=0.0,
)
model_g = ParallelWaveGANGenerator(**args_g)
z = torch.randn(batch_size, 1, batch_length)
c = torch.randn(
batch_size,
args_g["aux_channels"],
batch_length // np.prod(args_g["upsample_params"]["upsample_scales"]),
)
z_ = z.clone()
c_ = c.clone()
z_[..., z.size(-1) // 2 :] = torch.randn(z[..., z.size(-1) // 2 :].shape)
c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape)
c = torch.nn.ConstantPad1d(args_g["aux_context_window"], 0.0)(c)
c_ = torch.nn.ConstantPad1d(args_g["aux_context_window"], 0.0)(c_)
try:
# check not equal
np.testing.assert_array_equal(c.numpy(), c_.numpy())
except AssertionError:
pass
else:
raise AssertionError("Must be different.")
try:
# check not equal
np.testing.assert_array_equal(z.numpy(), z_.numpy())
except AssertionError:
pass
else:
raise AssertionError("Must be different.")
# check causality
y = model_g(z, c)
y_ = model_g(z_, c_)
np.testing.assert_array_equal(
y[..., : y.size(-1) // 2].detach().cpu().numpy(),
y_[..., : y_.size(-1) // 2].detach().cpu().numpy(),
)