nickovchinnikov's picture
Init
9d61c9b
import os
import unittest
from training.datasets.libritts_r import LIBRITTS_R, load_libritts_item
class TestLibriTTS(unittest.TestCase):
def setUp(self):
# Set up any necessary values for the tests
self.fileid = "1061_146197_000015_000000"
self.path = "datasets_cache/LIBRITTS/LibriTTS/train-clean-360"
self.ext_audio = ".wav"
self.ext_original_txt = ".original.txt"
self.ext_normalized_txt = ".normalized.txt"
def test_load_libritts_item(self):
# Test the load_libritts_item function
waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id = load_libritts_item(
self.fileid,
self.path,
self.ext_audio,
self.ext_original_txt,
self.ext_normalized_txt,
)
base_path = os.path.join(
self.path,
f"{speaker_id}",
f"{chapter_id}",
)
# Check that the files were created
self.assertTrue(
os.path.exists(
os.path.join(
base_path,
self.fileid + self.ext_original_txt,
),
),
)
self.assertTrue(
os.path.exists(
os.path.join(
base_path,
self.fileid + self.ext_normalized_txt,
),
),
)
def test_selected_speaker_ids(self):
# Initialize the dataset with selected speaker IDs
dataset = LIBRITTS_R(root="datasets_cache/LIBRITTS", url="train-clean-100", selected_speaker_ids=[19, 26])
# Iterate over the dataset and check the speaker IDs
for _, _, _, _, speaker_id, _, _ in dataset:
# Assert that the speaker ID is in the list of selected speaker IDs
self.assertIn(speaker_id, [19, 26])
def test_max_audio_length(self):
# Initialize the dataset with a maximum audio length
dataset = LIBRITTS_R(
root="datasets_cache/LIBRITTS",
url="train-clean-100",
max_audio_length=3.0,
selected_speaker_ids=[19, 26],
)
# Iterate over the dataset and check the audio lengths
for waveform, sample_rate, _, _, speaker_id, _, _ in dataset:
# Get the duration of the waveform in seconds
duration = waveform.shape[1] / sample_rate
# Assert that the speaker ID is in the list of selected speaker IDs
self.assertIn(speaker_id, [19, 26])
# Assert that the duration is less than or equal to the maximum length
self.assertLessEqual(duration, 3.0)
def test_min_audio_length(self):
# Initialize the dataset with a minimum audio length
dataset = LIBRITTS_R(
root="datasets_cache/LIBRITTS",
url="train-clean-100",
min_audio_length=30.0,
)
# Iterate over the dataset and check the audio lengths
for waveform, sample_rate, _, _, _, _, _ in dataset:
# Get the duration of the waveform in seconds
duration = waveform.shape[1] / sample_rate
# Assert that the duration is greater than or equal to the minimum length
self.assertGreaterEqual(duration, 30.0)
# Add any other assertions you want to make about the return values
def tearDown(self):
speaker_id, chapter_id, _, _ = self.fileid.split("_")
normalized_text_filename = self.fileid + self.ext_normalized_txt
normalized_text_path = os.path.join(self.path, speaker_id, chapter_id, normalized_text_filename)
original_text_filename = self.fileid + self.ext_original_txt
original_text_path = os.path.join(self.path, speaker_id, chapter_id, original_text_filename)
# Clean up any created files after tests are done
if os.path.exists(normalized_text_path):
os.remove(normalized_text_path)
if os.path.exists(original_text_path):
os.remove(original_text_path)
if __name__ == "__main__":
unittest.main()