File size: 3,538 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import shutil

import numpy as np
from tests import get_tests_path, get_tests_input_path, get_tests_output_path
from torch.utils.data import DataLoader

from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files

file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)

C = load_config(os.path.join(get_tests_input_path(),
                             "test_vocoder_wavernn_config.json"))

test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
test_mel_feat_path = os.path.join(test_data_path, "mel")
test_quant_feat_path = os.path.join(test_data_path, "quant")
ok_ljspeech = os.path.exists(test_data_path)


def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_workers):
    """ run dataloader with given parameters and check conditions """
    ap = AudioProcessor(**C.audio)

    C.batch_size = batch_size
    C.mode = mode
    C.seq_len = seq_len
    C.data_path = test_data_path

    preprocess_wav_files(test_data_path, C, ap)
    _, train_items = load_wav_feat_data(
        test_data_path, test_mel_feat_path, 5)

    dataset = WaveRNNDataset(ap=ap,
                             items=train_items,
                             seq_len=seq_len,
                             hop_len=hop_len,
                             pad=pad,
                             mode=mode,
                             mulaw=mulaw
                             )
    # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
    loader = DataLoader(dataset,
                        shuffle=True,
                        collate_fn=dataset.collate,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        pin_memory=True,
                        )

    max_iter = 10
    count_iter = 0

    try:
        for data in loader:
            x_input, mels, _ = data
            expected_feat_shape = (ap.num_mels,
                                   (x_input.shape[-1] // hop_len) + (pad * 2))
            assert np.all(
                mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}"

            assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1]
            count_iter += 1
            if count_iter == max_iter:
                break
    # except AssertionError:
    #     shutil.rmtree(test_mel_feat_path)
    #     shutil.rmtree(test_quant_feat_path)
    finally:
        shutil.rmtree(test_mel_feat_path)
        shutil.rmtree(test_quant_feat_path)


def test_parametrized_wavernn_dataset():
    ''' test dataloader with different parameters '''
    params = [
        [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0],
        [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4],
        [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0],
        [1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0],
        [1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0],
        [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2],
        [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0],
    ]
    for param in params:
        print(param)
        wavernn_dataset_case(*param)