File size: 5,170 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import glob
import os
import random
from multiprocessing import Manager
import numpy as np
import torch
from torch.utils.data import Dataset
class GANDataset(Dataset):
"""
GAN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
return_pairs=False,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
self.seq_len = seq_len
self.hop_len = hop_len
self.pad_short = pad_short
self.conv_pad = conv_pad
self.return_pairs = return_pairs
self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# map G and D instances
self.G_to_D_mappings = list(range(len(self.item_list)))
self.shuffle_mapping()
# cache acoustic features
if use_cache:
self.create_feature_cache()
def create_feature_cache(self):
self.manager = Manager()
self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))]
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
def __len__(self):
return len(self.item_list)
def __getitem__(self, idx):
"""Return different items for Generator and Discriminator and
cache acoustic features"""
# set the seed differently for each worker
if torch.utils.data.get_worker_info():
random.seed(torch.utils.data.get_worker_info().seed)
if self.return_segments:
item1 = self.load_item(idx)
if self.return_pairs:
idx2 = self.G_to_D_mappings[idx]
item2 = self.load_item(idx2)
return item1, item2
return item1
item1 = self.load_item(idx)
return item1
def _pad_short_samples(self, audio, mel=None):
"""Pad samples shorter than the output sequence length"""
if len(audio) < self.seq_len:
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
if mel is not None and mel.shape[1] < self.feat_frame_len:
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
mel = np.pad(
mel,
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
mode="constant",
constant_values=pad_value.mean(),
)
return audio, mel
def shuffle_mapping(self):
random.shuffle(self.G_to_D_mappings)
def load_item(self, idx):
"""load (audio, feat) couple"""
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
audio, mel = self._pad_short_samples(audio, mel)
else:
# load precomputed features
wavpath, feat_path = self.item_list[idx]
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
audio, mel = self._pad_short_samples(audio, mel)
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[: mel.shape[-1] * self.hop_len]
assert (
mel.shape[-1] * self.hop_len == audio.shape[-1]
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
if self.return_segments:
max_mel_start = mel.shape[1] - self.feat_frame_len
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.feat_frame_len
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hop_len
audio = audio[:, audio_start : audio_start + self.seq_len]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)
return (mel, audio)
|