Spaces:
Running
Running
import unittest | |
import torch | |
from torch.utils.data import DataLoader | |
from training.datasets import LibriTTSDatasetVocoder | |
class TestLibriTTSDatasetAcoustic(unittest.TestCase): | |
def setUp(self): | |
self.batch_size = 2 | |
self.lang = "en" | |
self.download = False | |
self.dataset = LibriTTSDatasetVocoder( | |
root="datasets_cache/LIBRITTS", | |
batch_size=self.batch_size, | |
download=self.download, | |
) | |
def test_len(self): | |
self.assertEqual(len(self.dataset), 33236) | |
def test_getitem(self): | |
sample = self.dataset[0] | |
self.assertEqual(sample["mel"].shape, torch.Size([100, 64])) | |
self.assertEqual(sample["audio"].shape, torch.Size([16384])) | |
self.assertEqual(sample["speaker_id"], 1034) | |
def test_collate_fn(self): | |
data = [ | |
self.dataset[0], | |
self.dataset[2], | |
] | |
# Call the collate_fn method | |
result = self.dataset.collate_fn(data) | |
# Check the output | |
self.assertEqual(len(result), 4) | |
# Check that all the batches are the same size | |
for batch in result: | |
self.assertEqual(len(batch), self.batch_size) | |
def test_dataloader(self): | |
# Create a DataLoader from the dataset | |
dataloader = DataLoader( | |
self.dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
collate_fn=self.dataset.collate_fn, | |
) | |
iter_dataloader = iter(dataloader) | |
# Iterate over the DataLoader and check the output | |
for _, items in enumerate([next(iter_dataloader), next(iter_dataloader)]): | |
# Check the batch size | |
self.assertEqual(len(items), 4) | |
for it in items: | |
self.assertEqual(len(it), self.batch_size) | |
if __name__ == "__main__": | |
unittest.main() | |