Spaces:
Build error
Build error
import os | |
import json | |
import glob | |
import torch | |
import random | |
from tqdm import tqdm | |
# from deepafx_st.plugins.channel import Channel | |
from deepafx_st.processors.processor import Processor | |
from deepafx_st.data.audio import AudioFile | |
import deepafx_st.utils as utils | |
class DSPProxyDataset(torch.utils.data.Dataset): | |
"""Class for generating input-output audio from Python DSP effects. | |
Args: | |
input_dir (List[str]): List of paths to the directories containing input audio files. | |
processor (Processor): Processor object to create proxy of. | |
processor_type (str): Processor name. | |
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" | |
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0 | |
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers | |
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000 | |
length (int, optional): Number of samples to load for each example. Default: 65536 | |
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 | |
ext (str, optional): Expected audio file extension. Default: "wav" | |
hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True | |
""" | |
def __init__( | |
self, | |
input_dir: str, | |
processor: Processor, | |
processor_type: str, | |
subset="train", | |
length=65536, | |
buffer_size_gb=1.0, | |
buffer_reload_rate=1000, | |
half=False, | |
num_examples_per_epoch=10000, | |
ext="wav", | |
soft_clip=True, | |
): | |
super().__init__() | |
self.input_dir = input_dir | |
self.processor = processor | |
self.processor_type = processor_type | |
self.subset = subset | |
self.length = length | |
self.buffer_size_gb = buffer_size_gb | |
self.buffer_reload_rate = buffer_reload_rate | |
self.half = half | |
self.num_examples_per_epoch = num_examples_per_epoch | |
self.ext = ext | |
self.soft_clip = soft_clip | |
search_path = os.path.join(input_dir, f"*.{ext}") | |
self.input_filepaths = glob.glob(search_path) | |
self.input_filepaths = sorted(self.input_filepaths) | |
if len(self.input_filepaths) < 1: | |
raise RuntimeError(f"No files found in {input_dir}.") | |
# get training split | |
self.input_filepaths = utils.split_dataset( | |
self.input_filepaths, self.subset, 0.9 | |
) | |
# get details about audio files | |
cnt = 0 | |
self.input_files = {} | |
for input_filepath in tqdm(self.input_filepaths, ncols=80): | |
file_id = os.path.basename(input_filepath) | |
audio_file = AudioFile( | |
input_filepath, | |
preload=False, | |
half=half, | |
) | |
if audio_file.num_frames < self.length: | |
continue | |
self.input_files[file_id] = audio_file | |
self.sample_rate = self.input_files[file_id].sample_rate | |
cnt += 1 | |
if cnt > 1000: | |
break | |
# some setup for iteratble loading of the dataset into RAM | |
self.items_since_load = self.buffer_reload_rate | |
def __len__(self): | |
return self.num_examples_per_epoch | |
def load_audio_buffer(self): | |
self.input_files_loaded = {} # clear audio buffer | |
self.items_since_load = 0 # reset iteration counter | |
nbytes_loaded = 0 # counter for data in RAM | |
# different subset in each | |
random.shuffle(self.input_filepaths) | |
# load files into RAM | |
for input_filepath in self.input_filepaths: | |
file_id = os.path.basename(input_filepath) | |
audio_file = AudioFile( | |
input_filepath, | |
preload=True, | |
half=self.half, | |
) | |
if audio_file.num_frames < self.length: | |
continue | |
self.input_files_loaded[file_id] = audio_file | |
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement() | |
nbytes_loaded += nbytes | |
if nbytes_loaded > self.buffer_size_gb * 1e9: | |
break | |
def __getitem__(self, _): | |
""" """ | |
# increment counter | |
self.items_since_load += 1 | |
# load next chunk into buffer if needed | |
if self.items_since_load > self.buffer_reload_rate: | |
self.load_audio_buffer() | |
rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys()) | |
# use this random key to retrieve an input file | |
input_file = self.input_files_loaded[rand_input_file_id] | |
# load the audio data if needed | |
if not input_file.loaded: | |
input_file.load() | |
# get a random patch of size `self.length` | |
# start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length) | |
start_idx, stop_idx = utils.get_random_patch(input_file, self.length) | |
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach() | |
# random scaling | |
input_audio /= input_audio.abs().max() | |
scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12 | |
input_audio *= 10 ** (-scale_dB / 20.0) | |
# generate random parameters (uniform) over 0 to 1 | |
params = torch.rand(self.processor.num_control_params) | |
# expects batch dim | |
# apply plugins with random parameters | |
if self.processor_type == "channel": | |
params[-1] = 0.5 # set makeup gain to 0dB | |
target_audio = self.processor( | |
input_audio.view(1, 1, -1), | |
params.view(1, -1), | |
) | |
target_audio = target_audio.view(1, -1) | |
elif self.processor_type == "peq": | |
target_audio = self.processor( | |
input_audio.view(1, 1, -1).numpy(), | |
params.view(1, -1).numpy(), | |
) | |
target_audio = torch.tensor(target_audio).view(1, -1) | |
elif self.processor_type == "comp": | |
params[-1] = 0.5 # set makeup gain to 0dB | |
target_audio = self.processor( | |
input_audio.view(1, 1, -1).numpy(), | |
params.view(1, -1).numpy(), | |
) | |
target_audio = torch.tensor(target_audio).view(1, -1) | |
# clip | |
if self.soft_clip: | |
# target_audio = target_audio.clamp(-2.0, 2.0) | |
target_audio = torch.tanh(target_audio / 2.0) * 2.0 | |
return input_audio, target_audio, params | |