Spaces:
Running
Running
File size: 5,508 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from pathlib import Path
import unittest
from torch import Tensor
import torchaudio
from voicefixer import Vocoder
from training.datasets.hifi_libri_dataset import HifiLibriDataset, HifiLibriItem
class TestHifiLibriDataset(unittest.TestCase):
def setUp(self):
self.cache_dir = "datasets_cache"
self.dataset = HifiLibriDataset(cache_dir=self.cache_dir, cache=True)
self.vocoder_vf = Vocoder(44100)
def test_init(self):
self.assertEqual(len(self.dataset.cutset), 129751)
def test_get_cache_subdir_path(self):
idx = 1234
expected_path = Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000"
self.assertEqual(self.dataset.get_cache_subdir_path(idx), expected_path)
def test_get_cache_file_path(self):
idx = 1234
expected_path = (
Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" / f"{idx}.pt"
)
self.assertEqual(self.dataset.get_cache_file_path(idx), expected_path)
def test_getitem(self):
# Take the hifi items from the beginning of the dataset
item = self.dataset[0]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Convert mel spectrogram to waveform and save it to a file
# NOTE: Vocoder expects the mel spectrogram to be prepared in a specific way
# wav = self.vocoder_vf.forward(item.mel.permute((1, 0)).unsqueeze(0))
# wav_path = Path(f"results/{item.id}.wav")
# torchaudio.save(str(wav_path), wav, 44100)
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(0)
self.assertTrue(cache_file.exists())
# Take the same id again to check if the cache is used
item = self.dataset[0]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
item = self.dataset[10]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(10)
self.assertTrue(cache_file.exists())
item = self.dataset[20]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "hifitts")
# Take the libri items from the end of the dataset
item = self.dataset[len(self.dataset) - 20]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
# Check that the cache file is created
cache_file = self.dataset.get_cache_file_path(len(self.dataset) - 20)
self.assertTrue(cache_file.exists())
item = self.dataset[len(self.dataset) - 10]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
item = self.dataset[len(self.dataset) - 5]
self.assertIsInstance(item, HifiLibriItem)
self.assertEqual(item.dataset_type, "libritts")
def test_collate_fn(self):
data = [self.dataset[0] for _ in range(10)]
collated = self.dataset.collate_fn(data)
self.assertIsInstance(collated, list)
self.assertIsInstance(collated[0], list) # ids
self.assertIsInstance(collated[1], list) # raw_texts
self.assertIsInstance(collated[2], Tensor) # speakers
self.assertIsInstance(collated[3], Tensor) # texts
self.assertIsInstance(collated[4], Tensor) # src_lens
self.assertIsInstance(collated[5], Tensor) # mels
self.assertIsInstance(collated[6], Tensor) # pitches
self.assertIsInstance(collated[7], list) # pitches_stat
self.assertIsInstance(collated[8], Tensor) # mel_lens
self.assertIsInstance(collated[9], Tensor) # langs
self.assertIsInstance(collated[10], Tensor) # attn_priors
self.assertIsInstance(collated[11], Tensor) # wavs
self.assertIsInstance(collated[12], Tensor) # energy
def test_include_libri(self):
dataset_with_libri = HifiLibriDataset(
cache_dir="datasets_cache",
include_libri=True,
)
dataset_without_libri = HifiLibriDataset(
cache_dir="datasets_cache",
include_libri=False,
)
# Check that the dataset with LibriTTS is larger than the dataset without LibriTTS
self.assertTrue(len(dataset_with_libri) > len(dataset_without_libri))
# Check that the dataset with LibriTTS includes items of type 'libritts'
libri_item = dataset_with_libri[len(dataset_with_libri) - 10]
self.assertIsInstance(libri_item, HifiLibriItem)
self.assertEqual(libri_item.dataset_type, "libritts")
# Check that the dataset without LibriTTS does not include items of type 'libritts'
hifi_item = dataset_without_libri[len(dataset_without_libri) - 10]
self.assertIsInstance(hifi_item, HifiLibriItem)
self.assertEqual(hifi_item.dataset_type, "hifitts")
def test_dur_filter(self):
# Test with a duration of 0.2
self.assertFalse(self.dataset.dur_filter(0.2))
# Test with a duration of 1.0
self.assertTrue(self.dataset.dur_filter(1.0))
# Test with a duration of 2.0
self.assertTrue(self.dataset.dur_filter(2.0))
# Test with a duration of 30.0
self.assertFalse(self.dataset.dur_filter(30.0))
if __name__ == "__main__":
unittest.main()
|