import torch import numpy as np from torch.utils.data import Dataset class WaveRNNDataset(Dataset): """ WaveRNN Dataset searchs for all the wav files under root path and converts them to acoustic features on the fly. """ def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, ): self.ap = ap self.compute_feat = not isinstance(items[0], (tuple, list)) self.item_list = items self.seq_len = seq_len self.hop_len = hop_len self.mel_len = seq_len // hop_len self.pad = pad self.mode = mode self.mulaw = mulaw self.is_training = is_training self.verbose = verbose assert self.seq_len % self.hop_len == 0 def __len__(self): return len(self.item_list) def __getitem__(self, index): item = self.load_item(index) return item def load_item(self, index): """ load (audio, feat) couple if feature_path is set else compute it on the fly """ if self.compute_feat: wavpath = self.item_list[index] audio = self.ap.load_wav(wavpath) min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) if audio.shape[0] < min_audio_len: print(" [!] Instance is too short! : {}".format(wavpath)) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) mel = self.ap.melspectrogram(audio) if self.mode in ["gauss", "mold"]: x_input = audio elif isinstance(self.mode, int): x_input = (self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)) else: raise RuntimeError("Unknown dataset mode - ", self.mode) else: wavpath, feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) if mel.shape[-1] < self.mel_len + 2 * self.pad: print(" [!] Instance is too short! : {}".format(wavpath)) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) if self.mode in ["gauss", "mold"]: x_input = self.ap.load_wav(wavpath) elif isinstance(self.mode, int): x_input = np.load(feat_path.replace("/mel/", "/quant/")) else: raise RuntimeError("Unknown dataset mode - ", self.mode) return mel, x_input, wavpath def collate(self, batch): mel_win = self.seq_len // self.hop_len + 2 * self.pad max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] mels = [ x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win] for i, x in enumerate(batch) ] coarse = [ x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch) ] mels = np.stack(mels).astype(np.float32) if self.mode in ["gauss", "mold"]: coarse = np.stack(coarse).astype(np.float32) coarse = torch.FloatTensor(coarse) x_input = coarse[:, : self.seq_len] elif isinstance(self.mode, int): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) x_input = (2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0) y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) return x_input, mels, y_coarse