Spaces:
Running
Running
import os | |
import unittest | |
from training.datasets.libritts_r import LIBRITTS_R, load_libritts_item | |
class TestLibriTTS(unittest.TestCase): | |
def setUp(self): | |
# Set up any necessary values for the tests | |
self.fileid = "1061_146197_000015_000000" | |
self.path = "datasets_cache/LIBRITTS/LibriTTS/train-clean-360" | |
self.ext_audio = ".wav" | |
self.ext_original_txt = ".original.txt" | |
self.ext_normalized_txt = ".normalized.txt" | |
def test_load_libritts_item(self): | |
# Test the load_libritts_item function | |
waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id = load_libritts_item( | |
self.fileid, | |
self.path, | |
self.ext_audio, | |
self.ext_original_txt, | |
self.ext_normalized_txt, | |
) | |
base_path = os.path.join( | |
self.path, | |
f"{speaker_id}", | |
f"{chapter_id}", | |
) | |
# Check that the files were created | |
self.assertTrue( | |
os.path.exists( | |
os.path.join( | |
base_path, | |
self.fileid + self.ext_original_txt, | |
), | |
), | |
) | |
self.assertTrue( | |
os.path.exists( | |
os.path.join( | |
base_path, | |
self.fileid + self.ext_normalized_txt, | |
), | |
), | |
) | |
def test_selected_speaker_ids(self): | |
# Initialize the dataset with selected speaker IDs | |
dataset = LIBRITTS_R(root="datasets_cache/LIBRITTS", url="train-clean-100", selected_speaker_ids=[19, 26]) | |
# Iterate over the dataset and check the speaker IDs | |
for _, _, _, _, speaker_id, _, _ in dataset: | |
# Assert that the speaker ID is in the list of selected speaker IDs | |
self.assertIn(speaker_id, [19, 26]) | |
def test_max_audio_length(self): | |
# Initialize the dataset with a maximum audio length | |
dataset = LIBRITTS_R( | |
root="datasets_cache/LIBRITTS", | |
url="train-clean-100", | |
max_audio_length=3.0, | |
selected_speaker_ids=[19, 26], | |
) | |
# Iterate over the dataset and check the audio lengths | |
for waveform, sample_rate, _, _, speaker_id, _, _ in dataset: | |
# Get the duration of the waveform in seconds | |
duration = waveform.shape[1] / sample_rate | |
# Assert that the speaker ID is in the list of selected speaker IDs | |
self.assertIn(speaker_id, [19, 26]) | |
# Assert that the duration is less than or equal to the maximum length | |
self.assertLessEqual(duration, 3.0) | |
def test_min_audio_length(self): | |
# Initialize the dataset with a minimum audio length | |
dataset = LIBRITTS_R( | |
root="datasets_cache/LIBRITTS", | |
url="train-clean-100", | |
min_audio_length=30.0, | |
) | |
# Iterate over the dataset and check the audio lengths | |
for waveform, sample_rate, _, _, _, _, _ in dataset: | |
# Get the duration of the waveform in seconds | |
duration = waveform.shape[1] / sample_rate | |
# Assert that the duration is greater than or equal to the minimum length | |
self.assertGreaterEqual(duration, 30.0) | |
# Add any other assertions you want to make about the return values | |
def tearDown(self): | |
speaker_id, chapter_id, _, _ = self.fileid.split("_") | |
normalized_text_filename = self.fileid + self.ext_normalized_txt | |
normalized_text_path = os.path.join(self.path, speaker_id, chapter_id, normalized_text_filename) | |
original_text_filename = self.fileid + self.ext_original_txt | |
original_text_path = os.path.join(self.path, speaker_id, chapter_id, original_text_filename) | |
# Clean up any created files after tests are done | |
if os.path.exists(normalized_text_path): | |
os.remove(normalized_text_path) | |
if os.path.exists(original_text_path): | |
os.remove(original_text_path) | |
if __name__ == "__main__": | |
unittest.main() | |