Spaces:
Build error
Build error
import os | |
import glob | |
import torch | |
import torchaudio | |
from tqdm import tqdm | |
class StyleDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
audio_dir: str, | |
subset: str = "train", | |
sample_rate: int = 24000, | |
length: int = 131072, | |
) -> None: | |
super().__init__() | |
self.audio_dir = audio_dir | |
self.subset = subset | |
self.sample_rate = sample_rate | |
self.length = length | |
self.style_dirs = glob.glob(os.path.join(audio_dir, subset, "*")) | |
self.style_dirs = [sd for sd in self.style_dirs if os.path.isdir(sd)] | |
self.num_classes = len(self.style_dirs) | |
self.class_labels = {"broadcast" : 0, "telephone": 1, "neutral": 2, "bright": 3, "warm": 4} | |
self.examples = [] | |
for n, style_dir in enumerate(self.style_dirs): | |
# get all files in style dir | |
style_filepaths = glob.glob(os.path.join(style_dir, "*.wav")) | |
style_name = os.path.basename(style_dir) | |
for style_filepath in tqdm(style_filepaths, ncols=120): | |
# load audio file | |
x, sr = torchaudio.load(style_filepath) | |
# sum to mono if needed | |
if x.shape[0] > 1: | |
x = x.mean(dim=0, keepdim=True) | |
# resample | |
if sr != self.sample_rate: | |
x = torchaudio.transforms.Resample(sr, self.sample_rate)(x) | |
# crop length after resample | |
if x.shape[-1] >= self.length: | |
x = x[...,:self.length] | |
# store example | |
example = (x, self.class_labels[style_name]) | |
self.examples.append(example) | |
print(f"Loaded {len(self.examples)} examples for {subset} subset.") | |
def __len__(self): | |
return len(self.examples) | |
def __getitem__(self, idx): | |
example = self.examples[idx] | |
x = example[0] | |
y = example[1] | |
return x, y | |