File size: 850 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import torch
import random
from TTS.vocoder.models.wavernn import WaveRNN


def test_wavernn():
    model = WaveRNN(
        rnn_dims=512,
        fc_dims=512,
        mode=10,
        mulaw=False,
        pad=2,
        use_aux_net=True,
        use_upsample_net=True,
        upsample_factors=[4, 8, 8],
        feat_dims=80,
        compute_dims=128,
        res_out_dims=128,
        num_res_blocks=10,
        hop_length=256,
        sample_rate=22050,
    )
    dummy_x = torch.rand((2, 1280))
    dummy_m = torch.rand((2, 80, 9))
    y_size = random.randrange(20, 60)
    dummy_y = torch.rand((80, y_size))
    output = model(dummy_x, dummy_m)
    assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape
    output = model.inference(dummy_y, True, 5500, 550)
    assert np.all(output.shape == (256 * (y_size - 1),))