|
import numpy as np |
|
import torch |
|
import random |
|
from TTS.vocoder.models.wavernn import WaveRNN |
|
|
|
|
|
def test_wavernn(): |
|
model = WaveRNN( |
|
rnn_dims=512, |
|
fc_dims=512, |
|
mode=10, |
|
mulaw=False, |
|
pad=2, |
|
use_aux_net=True, |
|
use_upsample_net=True, |
|
upsample_factors=[4, 8, 8], |
|
feat_dims=80, |
|
compute_dims=128, |
|
res_out_dims=128, |
|
num_res_blocks=10, |
|
hop_length=256, |
|
sample_rate=22050, |
|
) |
|
dummy_x = torch.rand((2, 1280)) |
|
dummy_m = torch.rand((2, 80, 9)) |
|
y_size = random.randrange(20, 60) |
|
dummy_y = torch.rand((80, y_size)) |
|
output = model(dummy_x, dummy_m) |
|
assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape |
|
output = model.inference(dummy_y, True, 5500, 550) |
|
assert np.all(output.shape == (256 * (y_size - 1),)) |
|
|