PeechTTSv22050 / training /datasets /tests /test_hifi_libri_dataset.py
nickovchinnikov's picture
Init
9d61c9b
from pathlib import Path
import unittest
from torch import Tensor
import torchaudio
from voicefixer import Vocoder
from training.datasets.hifi_libri_dataset import HifiLibriDataset, HifiLibriItem
class TestHifiLibriDataset(unittest.TestCase):
def setUp(self):
self.cache_dir = "datasets_cache"
self.dataset = HifiLibriDataset(cache_dir=self.cache_dir, cache=True)
self.vocoder_vf = Vocoder(44100)
def test_init(self):
self.assertEqual(len(self.dataset.cutset), 129751)
def test_get_cache_subdir_path(self):
idx = 1234
expected_path = Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000"
self.assertEqual(self.dataset.get_cache_subdir_path(idx), expected_path)
def test_get_cache_file_path(self):
idx = 1234
expected_path = (
Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" / f"{idx}.pt"
)
self.assertEqual(self.dataset.get_cache_file_path(idx), expected_path)
def test_getitem(self):
# Take the hifi items from the beginning of the dataset
item = self.dataset[0]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Convert mel spectrogram to waveform and save it to a file
# NOTE: Vocoder expects the mel spectrogram to be prepared in a specific way
# wav = self.vocoder_vf.forward(item.mel.permute((1, 0)).unsqueeze(0))
# wav_path = Path(f"results/{item.id}.wav")
# torchaudio.save(str(wav_path), wav, 44100)
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(0)
self.assertTrue(cache_file.exists())
# Take the same id again to check if the cache is used
item = self.dataset[0]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
item = self.dataset[10]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(10)
self.assertTrue(cache_file.exists())
item = self.dataset[20]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Take the libri items from the end of the dataset
item = self.dataset[len(self.dataset) - 20]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(len(self.dataset) - 20)
self.assertTrue(cache_file.exists())
item = self.dataset[len(self.dataset) - 10]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
item = self.dataset[len(self.dataset) - 5]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
def test_collate_fn(self):
data = [self.dataset[0] for _ in range(10)]
collated = self.dataset.collate_fn(data)
self.assertIsInstance(collated, list)
self.assertIsInstance(collated[0], list) # ids
self.assertIsInstance(collated[1], list) # raw_texts
self.assertIsInstance(collated[2], Tensor) # speakers
self.assertIsInstance(collated[3], Tensor) # texts
self.assertIsInstance(collated[4], Tensor) # src_lens
self.assertIsInstance(collated[5], Tensor) # mels
self.assertIsInstance(collated[6], Tensor) # pitches
self.assertIsInstance(collated[7], list) # pitches_stat
self.assertIsInstance(collated[8], Tensor) # mel_lens
self.assertIsInstance(collated[9], Tensor) # langs
self.assertIsInstance(collated[10], Tensor) # attn_priors
self.assertIsInstance(collated[11], Tensor) # wavs
self.assertIsInstance(collated[12], Tensor) # energy
def test_include_libri(self):
dataset_with_libri = HifiLibriDataset(
cache_dir="datasets_cache",
include_libri=True,
)
dataset_without_libri = HifiLibriDataset(
cache_dir="datasets_cache",
include_libri=False,
)
# Check that the dataset with LibriTTS is larger than the dataset without LibriTTS
self.assertTrue(len(dataset_with_libri) > len(dataset_without_libri))
# Check that the dataset with LibriTTS includes items of type 'libritts'
libri_item = dataset_with_libri[len(dataset_with_libri) - 10]
self.assertIsInstance(libri_item, HifiLibriItem)
self.assertEqual(libri_item.dataset_type, "libritts")
# Check that the dataset without LibriTTS does not include items of type 'libritts'
hifi_item = dataset_without_libri[len(dataset_without_libri) - 10]
self.assertIsInstance(hifi_item, HifiLibriItem)
self.assertEqual(hifi_item.dataset_type, "hifitts")
def test_dur_filter(self):
# Test with a duration of 0.2
self.assertFalse(self.dataset.dur_filter(0.2))
# Test with a duration of 1.0
self.assertTrue(self.dataset.dur_filter(1.0))
# Test with a duration of 2.0
self.assertTrue(self.dataset.dur_filter(2.0))
# Test with a duration of 30.0
self.assertFalse(self.dataset.dur_filter(30.0))
if __name__ == "__main__":
unittest.main()