Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Copyright 2019 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Dataset modules.""" | |
import logging | |
import os | |
from multiprocessing import Manager | |
import numpy as np | |
from torch.utils.data import Dataset | |
from parallel_wavegan.utils import find_files | |
from parallel_wavegan.utils import read_hdf5 | |
class AudioMelDataset(Dataset): | |
"""PyTorch compatible audio and mel dataset.""" | |
def __init__( | |
self, | |
root_dir, | |
audio_query="*.h5", | |
mel_query="*.h5", | |
audio_load_fn=lambda x: read_hdf5(x, "wave"), | |
mel_load_fn=lambda x: read_hdf5(x, "feats"), | |
audio_length_threshold=None, | |
mel_length_threshold=None, | |
return_utt_id=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
root_dir (str): Root directory including dumped files. | |
audio_query (str): Query to find audio files in root_dir. | |
mel_query (str): Query to find feature files in root_dir. | |
audio_load_fn (func): Function to load audio file. | |
mel_load_fn (func): Function to load feature file. | |
audio_length_threshold (int): Threshold to remove short audio files. | |
mel_length_threshold (int): Threshold to remove short feature files. | |
return_utt_id (bool): Whether to return the utterance id with arrays. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# find all of audio and mel files | |
audio_files = sorted(find_files(root_dir, audio_query)) | |
mel_files = sorted(find_files(root_dir, mel_query)) | |
# filter by threshold | |
if audio_length_threshold is not None: | |
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] | |
idxs = [ | |
idx | |
for idx in range(len(audio_files)) | |
if audio_lengths[idx] > audio_length_threshold | |
] | |
if len(audio_files) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by audio length threshold " | |
f"({len(audio_files)} -> {len(idxs)})." | |
) | |
audio_files = [audio_files[idx] for idx in idxs] | |
mel_files = [mel_files[idx] for idx in idxs] | |
if mel_length_threshold is not None: | |
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files] | |
idxs = [ | |
idx | |
for idx in range(len(mel_files)) | |
if mel_lengths[idx] > mel_length_threshold | |
] | |
if len(mel_files) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by mel length threshold " | |
f"({len(mel_files)} -> {len(idxs)})." | |
) | |
audio_files = [audio_files[idx] for idx in idxs] | |
mel_files = [mel_files[idx] for idx in idxs] | |
# assert the number of files | |
assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}." | |
assert len(audio_files) == len( | |
mel_files | |
), f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})." | |
self.audio_files = audio_files | |
self.audio_load_fn = audio_load_fn | |
self.mel_load_fn = mel_load_fn | |
self.mel_files = mel_files | |
if ".npy" in audio_query: | |
self.utt_ids = [ | |
os.path.basename(f).replace("-wave.npy", "") for f in audio_files | |
] | |
else: | |
self.utt_ids = [ | |
os.path.splitext(os.path.basename(f))[0] for f in audio_files | |
] | |
self.return_utt_id = return_utt_id | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(audio_files))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray: Audio signal (T,). | |
ndarray: Feature (T', C). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
audio = self.audio_load_fn(self.audio_files[idx]) | |
mel = self.mel_load_fn(self.mel_files[idx]) | |
if self.return_utt_id: | |
items = utt_id, audio, mel | |
else: | |
items = audio, mel | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.audio_files) | |
class AudioDataset(Dataset): | |
"""PyTorch compatible audio dataset.""" | |
def __init__( | |
self, | |
root_dir, | |
audio_query="*-wave.npy", | |
audio_length_threshold=None, | |
audio_load_fn=np.load, | |
return_utt_id=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
root_dir (str): Root directory including dumped files. | |
audio_query (str): Query to find audio files in root_dir. | |
audio_load_fn (func): Function to load audio file. | |
audio_length_threshold (int): Threshold to remove short audio files. | |
return_utt_id (bool): Whether to return the utterance id with arrays. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# find all of audio and mel files | |
audio_files = sorted(find_files(root_dir, audio_query)) | |
# filter by threshold | |
if audio_length_threshold is not None: | |
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] | |
idxs = [ | |
idx | |
for idx in range(len(audio_files)) | |
if audio_lengths[idx] > audio_length_threshold | |
] | |
if len(audio_files) != len(idxs): | |
logging.waning( | |
f"some files are filtered by audio length threshold " | |
f"({len(audio_files)} -> {len(idxs)})." | |
) | |
audio_files = [audio_files[idx] for idx in idxs] | |
# assert the number of files | |
assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}." | |
self.audio_files = audio_files | |
self.audio_load_fn = audio_load_fn | |
self.return_utt_id = return_utt_id | |
if ".npy" in audio_query: | |
self.utt_ids = [ | |
os.path.basename(f).replace("-wave.npy", "") for f in audio_files | |
] | |
else: | |
self.utt_ids = [ | |
os.path.splitext(os.path.basename(f))[0] for f in audio_files | |
] | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(audio_files))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray: Audio (T,). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
audio = self.audio_load_fn(self.audio_files[idx]) | |
if self.return_utt_id: | |
items = utt_id, audio | |
else: | |
items = audio | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.audio_files) | |
class MelDataset(Dataset): | |
"""PyTorch compatible mel dataset.""" | |
def __init__( | |
self, | |
root_dir, | |
mel_query="*-feats.npy", | |
mel_length_threshold=None, | |
mel_load_fn=np.load, | |
return_utt_id=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
root_dir (str): Root directory including dumped files. | |
mel_query (str): Query to find feature files in root_dir. | |
mel_load_fn (func): Function to load feature file. | |
mel_length_threshold (int): Threshold to remove short feature files. | |
return_utt_id (bool): Whether to return the utterance id with arrays. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# find all of the mel files | |
mel_files = sorted(find_files(root_dir, mel_query)) | |
# filter by threshold | |
if mel_length_threshold is not None: | |
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files] | |
idxs = [ | |
idx | |
for idx in range(len(mel_files)) | |
if mel_lengths[idx] > mel_length_threshold | |
] | |
if len(mel_files) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by mel length threshold " | |
f"({len(mel_files)} -> {len(idxs)})." | |
) | |
mel_files = [mel_files[idx] for idx in idxs] | |
# assert the number of files | |
assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}." | |
self.mel_files = mel_files | |
self.mel_load_fn = mel_load_fn | |
self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] | |
if ".npy" in mel_query: | |
self.utt_ids = [ | |
os.path.basename(f).replace("-feats.npy", "") for f in mel_files | |
] | |
else: | |
self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] | |
self.return_utt_id = return_utt_id | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(mel_files))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray: Feature (T', C). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
mel = self.mel_load_fn(self.mel_files[idx]) | |
if self.return_utt_id: | |
items = utt_id, mel | |
else: | |
items = mel | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.mel_files) | |