akhaliq3
spaces demo
2b7bf83
raw
history blame
11.3 kB
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Dataset modules."""
import logging
import os
from multiprocessing import Manager
import numpy as np
from torch.utils.data import Dataset
from parallel_wavegan.utils import find_files
from parallel_wavegan.utils import read_hdf5
class AudioMelDataset(Dataset):
"""PyTorch compatible audio and mel dataset."""
def __init__(
self,
root_dir,
audio_query="*.h5",
mel_query="*.h5",
audio_load_fn=lambda x: read_hdf5(x, "wave"),
mel_load_fn=lambda x: read_hdf5(x, "feats"),
audio_length_threshold=None,
mel_length_threshold=None,
return_utt_id=False,
allow_cache=False,
):
"""Initialize dataset.
Args:
root_dir (str): Root directory including dumped files.
audio_query (str): Query to find audio files in root_dir.
mel_query (str): Query to find feature files in root_dir.
audio_load_fn (func): Function to load audio file.
mel_load_fn (func): Function to load feature file.
audio_length_threshold (int): Threshold to remove short audio files.
mel_length_threshold (int): Threshold to remove short feature files.
return_utt_id (bool): Whether to return the utterance id with arrays.
allow_cache (bool): Whether to allow cache of the loaded files.
"""
# find all of audio and mel files
audio_files = sorted(find_files(root_dir, audio_query))
mel_files = sorted(find_files(root_dir, mel_query))
# filter by threshold
if audio_length_threshold is not None:
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
idxs = [
idx
for idx in range(len(audio_files))
if audio_lengths[idx] > audio_length_threshold
]
if len(audio_files) != len(idxs):
logging.warning(
f"Some files are filtered by audio length threshold "
f"({len(audio_files)} -> {len(idxs)})."
)
audio_files = [audio_files[idx] for idx in idxs]
mel_files = [mel_files[idx] for idx in idxs]
if mel_length_threshold is not None:
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
idxs = [
idx
for idx in range(len(mel_files))
if mel_lengths[idx] > mel_length_threshold
]
if len(mel_files) != len(idxs):
logging.warning(
f"Some files are filtered by mel length threshold "
f"({len(mel_files)} -> {len(idxs)})."
)
audio_files = [audio_files[idx] for idx in idxs]
mel_files = [mel_files[idx] for idx in idxs]
# assert the number of files
assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}."
assert len(audio_files) == len(
mel_files
), f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})."
self.audio_files = audio_files
self.audio_load_fn = audio_load_fn
self.mel_load_fn = mel_load_fn
self.mel_files = mel_files
if ".npy" in audio_query:
self.utt_ids = [
os.path.basename(f).replace("-wave.npy", "") for f in audio_files
]
else:
self.utt_ids = [
os.path.splitext(os.path.basename(f))[0] for f in audio_files
]
self.return_utt_id = return_utt_id
self.allow_cache = allow_cache
if allow_cache:
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(audio_files))]
def __getitem__(self, idx):
"""Get specified idx items.
Args:
idx (int): Index of the item.
Returns:
str: Utterance id (only in return_utt_id = True).
ndarray: Audio signal (T,).
ndarray: Feature (T', C).
"""
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
utt_id = self.utt_ids[idx]
audio = self.audio_load_fn(self.audio_files[idx])
mel = self.mel_load_fn(self.mel_files[idx])
if self.return_utt_id:
items = utt_id, audio, mel
else:
items = audio, mel
if self.allow_cache:
self.caches[idx] = items
return items
def __len__(self):
"""Return dataset length.
Returns:
int: The length of dataset.
"""
return len(self.audio_files)
class AudioDataset(Dataset):
"""PyTorch compatible audio dataset."""
def __init__(
self,
root_dir,
audio_query="*-wave.npy",
audio_length_threshold=None,
audio_load_fn=np.load,
return_utt_id=False,
allow_cache=False,
):
"""Initialize dataset.
Args:
root_dir (str): Root directory including dumped files.
audio_query (str): Query to find audio files in root_dir.
audio_load_fn (func): Function to load audio file.
audio_length_threshold (int): Threshold to remove short audio files.
return_utt_id (bool): Whether to return the utterance id with arrays.
allow_cache (bool): Whether to allow cache of the loaded files.
"""
# find all of audio and mel files
audio_files = sorted(find_files(root_dir, audio_query))
# filter by threshold
if audio_length_threshold is not None:
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
idxs = [
idx
for idx in range(len(audio_files))
if audio_lengths[idx] > audio_length_threshold
]
if len(audio_files) != len(idxs):
logging.waning(
f"some files are filtered by audio length threshold "
f"({len(audio_files)} -> {len(idxs)})."
)
audio_files = [audio_files[idx] for idx in idxs]
# assert the number of files
assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}."
self.audio_files = audio_files
self.audio_load_fn = audio_load_fn
self.return_utt_id = return_utt_id
if ".npy" in audio_query:
self.utt_ids = [
os.path.basename(f).replace("-wave.npy", "") for f in audio_files
]
else:
self.utt_ids = [
os.path.splitext(os.path.basename(f))[0] for f in audio_files
]
self.allow_cache = allow_cache
if allow_cache:
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(audio_files))]
def __getitem__(self, idx):
"""Get specified idx items.
Args:
idx (int): Index of the item.
Returns:
str: Utterance id (only in return_utt_id = True).
ndarray: Audio (T,).
"""
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
utt_id = self.utt_ids[idx]
audio = self.audio_load_fn(self.audio_files[idx])
if self.return_utt_id:
items = utt_id, audio
else:
items = audio
if self.allow_cache:
self.caches[idx] = items
return items
def __len__(self):
"""Return dataset length.
Returns:
int: The length of dataset.
"""
return len(self.audio_files)
class MelDataset(Dataset):
"""PyTorch compatible mel dataset."""
def __init__(
self,
root_dir,
mel_query="*-feats.npy",
mel_length_threshold=None,
mel_load_fn=np.load,
return_utt_id=False,
allow_cache=False,
):
"""Initialize dataset.
Args:
root_dir (str): Root directory including dumped files.
mel_query (str): Query to find feature files in root_dir.
mel_load_fn (func): Function to load feature file.
mel_length_threshold (int): Threshold to remove short feature files.
return_utt_id (bool): Whether to return the utterance id with arrays.
allow_cache (bool): Whether to allow cache of the loaded files.
"""
# find all of the mel files
mel_files = sorted(find_files(root_dir, mel_query))
# filter by threshold
if mel_length_threshold is not None:
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
idxs = [
idx
for idx in range(len(mel_files))
if mel_lengths[idx] > mel_length_threshold
]
if len(mel_files) != len(idxs):
logging.warning(
f"Some files are filtered by mel length threshold "
f"({len(mel_files)} -> {len(idxs)})."
)
mel_files = [mel_files[idx] for idx in idxs]
# assert the number of files
assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}."
self.mel_files = mel_files
self.mel_load_fn = mel_load_fn
self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files]
if ".npy" in mel_query:
self.utt_ids = [
os.path.basename(f).replace("-feats.npy", "") for f in mel_files
]
else:
self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files]
self.return_utt_id = return_utt_id
self.allow_cache = allow_cache
if allow_cache:
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(mel_files))]
def __getitem__(self, idx):
"""Get specified idx items.
Args:
idx (int): Index of the item.
Returns:
str: Utterance id (only in return_utt_id = True).
ndarray: Feature (T', C).
"""
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
utt_id = self.utt_ids[idx]
mel = self.mel_load_fn(self.mel_files[idx])
if self.return_utt_id:
items = utt_id, mel
else:
items = mel
if self.allow_cache:
self.caches[idx] = items
return items
def __len__(self):
"""Return dataset length.
Returns:
int: The length of dataset.
"""
return len(self.mel_files)