File size: 5,508 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pathlib import Path
import unittest

from torch import Tensor
import torchaudio
from voicefixer import Vocoder

from training.datasets.hifi_libri_dataset import HifiLibriDataset, HifiLibriItem


class TestHifiLibriDataset(unittest.TestCase):
    def setUp(self):
        self.cache_dir = "datasets_cache"
        self.dataset = HifiLibriDataset(cache_dir=self.cache_dir, cache=True)
        self.vocoder_vf = Vocoder(44100)

    def test_init(self):
        self.assertEqual(len(self.dataset.cutset), 129751)

    def test_get_cache_subdir_path(self):
        idx = 1234
        expected_path = Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000"
        self.assertEqual(self.dataset.get_cache_subdir_path(idx), expected_path)

    def test_get_cache_file_path(self):
        idx = 1234
        expected_path = (
            Path(self.cache_dir) / "cache-hifitts-librittsr" / "2000" / f"{idx}.pt"
        )
        self.assertEqual(self.dataset.get_cache_file_path(idx), expected_path)

    def test_getitem(self):
        # Take the hifi items from the beginning of the dataset
        item = self.dataset[0]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "hifitts")

        # Convert mel spectrogram to waveform and save it to a file
        # NOTE: Vocoder expects the mel spectrogram to be prepared in a specific way
        # wav = self.vocoder_vf.forward(item.mel.permute((1, 0)).unsqueeze(0))
        # wav_path = Path(f"results/{item.id}.wav")

        # torchaudio.save(str(wav_path), wav, 44100)

        # Check that the cache file is created
        cache_file = self.dataset.get_cache_file_path(0)
        self.assertTrue(cache_file.exists())
        # Take the same id again to check if the cache is used
        item = self.dataset[0]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "hifitts")

        item = self.dataset[10]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "hifitts")
        # Check that the cache file is created
        cache_file = self.dataset.get_cache_file_path(10)
        self.assertTrue(cache_file.exists())

        item = self.dataset[20]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "hifitts")

        # Take the libri items from the end of the dataset
        item = self.dataset[len(self.dataset) - 20]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "libritts")
        # Check that the cache file is created
        cache_file = self.dataset.get_cache_file_path(len(self.dataset) - 20)
        self.assertTrue(cache_file.exists())

        item = self.dataset[len(self.dataset) - 10]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "libritts")
        item = self.dataset[len(self.dataset) - 5]
        self.assertIsInstance(item, HifiLibriItem)
        self.assertEqual(item.dataset_type, "libritts")

    def test_collate_fn(self):
        data = [self.dataset[0] for _ in range(10)]
        collated = self.dataset.collate_fn(data)
        self.assertIsInstance(collated, list)
        self.assertIsInstance(collated[0], list)  # ids
        self.assertIsInstance(collated[1], list)  # raw_texts
        self.assertIsInstance(collated[2], Tensor)  # speakers
        self.assertIsInstance(collated[3], Tensor)  # texts
        self.assertIsInstance(collated[4], Tensor)  # src_lens
        self.assertIsInstance(collated[5], Tensor)  # mels
        self.assertIsInstance(collated[6], Tensor)  # pitches
        self.assertIsInstance(collated[7], list)  # pitches_stat
        self.assertIsInstance(collated[8], Tensor)  # mel_lens
        self.assertIsInstance(collated[9], Tensor)  # langs
        self.assertIsInstance(collated[10], Tensor)  # attn_priors
        self.assertIsInstance(collated[11], Tensor)  # wavs
        self.assertIsInstance(collated[12], Tensor)  # energy

    def test_include_libri(self):
        dataset_with_libri = HifiLibriDataset(
            cache_dir="datasets_cache",
            include_libri=True,
        )
        dataset_without_libri = HifiLibriDataset(
            cache_dir="datasets_cache",
            include_libri=False,
        )

        # Check that the dataset with LibriTTS is larger than the dataset without LibriTTS
        self.assertTrue(len(dataset_with_libri) > len(dataset_without_libri))

        # Check that the dataset with LibriTTS includes items of type 'libritts'
        libri_item = dataset_with_libri[len(dataset_with_libri) - 10]
        self.assertIsInstance(libri_item, HifiLibriItem)
        self.assertEqual(libri_item.dataset_type, "libritts")

        # Check that the dataset without LibriTTS does not include items of type 'libritts'
        hifi_item = dataset_without_libri[len(dataset_without_libri) - 10]
        self.assertIsInstance(hifi_item, HifiLibriItem)
        self.assertEqual(hifi_item.dataset_type, "hifitts")

    def test_dur_filter(self):
        # Test with a duration of 0.2
        self.assertFalse(self.dataset.dur_filter(0.2))
        # Test with a duration of 1.0
        self.assertTrue(self.dataset.dur_filter(1.0))
        # Test with a duration of 2.0
        self.assertTrue(self.dataset.dur_filter(2.0))
        # Test with a duration of 30.0
        self.assertFalse(self.dataset.dur_filter(30.0))


if __name__ == "__main__":
    unittest.main()