voice-xtts2 / tests /test_vocoder_wavernn_datasets.py
antoniomae1234's picture
changes in flenema
2493d72 verified
import os
import shutil
import numpy as np
from tests import get_tests_path, get_tests_input_path, get_tests_output_path
from torch.utils.data import DataLoader
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files
file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/")
os.makedirs(OUTPATH, exist_ok=True)
C = load_config(os.path.join(get_tests_input_path(),
"test_vocoder_wavernn_config.json"))
test_data_path = os.path.join(get_tests_path(), "data/ljspeech/")
test_mel_feat_path = os.path.join(test_data_path, "mel")
test_quant_feat_path = os.path.join(test_data_path, "quant")
ok_ljspeech = os.path.exists(test_data_path)
def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_workers):
""" run dataloader with given parameters and check conditions """
ap = AudioProcessor(**C.audio)
C.batch_size = batch_size
C.mode = mode
C.seq_len = seq_len
C.data_path = test_data_path
preprocess_wav_files(test_data_path, C, ap)
_, train_items = load_wav_feat_data(
test_data_path, test_mel_feat_path, 5)
dataset = WaveRNNDataset(ap=ap,
items=train_items,
seq_len=seq_len,
hop_len=hop_len,
pad=pad,
mode=mode,
mulaw=mulaw
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
)
max_iter = 10
count_iter = 0
try:
for data in loader:
x_input, mels, _ = data
expected_feat_shape = (ap.num_mels,
(x_input.shape[-1] // hop_len) + (pad * 2))
assert np.all(
mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}"
assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1]
count_iter += 1
if count_iter == max_iter:
break
# except AssertionError:
# shutil.rmtree(test_mel_feat_path)
# shutil.rmtree(test_quant_feat_path)
finally:
shutil.rmtree(test_mel_feat_path)
shutil.rmtree(test_quant_feat_path)
def test_parametrized_wavernn_dataset():
''' test dataloader with different parameters '''
params = [
[16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0],
[16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4],
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0],
[1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0],
[1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0],
[1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2],
[1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0],
]
for param in params:
print(param)
wavernn_dataset_case(*param)