mikeross's picture
Duplicate from nateraw/deepafx-st
c983126
import os
import json
import glob
import torch
import random
from tqdm import tqdm
# from deepafx_st.plugins.channel import Channel
from deepafx_st.processors.processor import Processor
from deepafx_st.data.audio import AudioFile
import deepafx_st.utils as utils
class DSPProxyDataset(torch.utils.data.Dataset):
"""Class for generating input-output audio from Python DSP effects.
Args:
input_dir (List[str]): List of paths to the directories containing input audio files.
processor (Processor): Processor object to create proxy of.
processor_type (str): Processor name.
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
length (int, optional): Number of samples to load for each example. Default: 65536
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
ext (str, optional): Expected audio file extension. Default: "wav"
hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True
"""
def __init__(
self,
input_dir: str,
processor: Processor,
processor_type: str,
subset="train",
length=65536,
buffer_size_gb=1.0,
buffer_reload_rate=1000,
half=False,
num_examples_per_epoch=10000,
ext="wav",
soft_clip=True,
):
super().__init__()
self.input_dir = input_dir
self.processor = processor
self.processor_type = processor_type
self.subset = subset
self.length = length
self.buffer_size_gb = buffer_size_gb
self.buffer_reload_rate = buffer_reload_rate
self.half = half
self.num_examples_per_epoch = num_examples_per_epoch
self.ext = ext
self.soft_clip = soft_clip
search_path = os.path.join(input_dir, f"*.{ext}")
self.input_filepaths = glob.glob(search_path)
self.input_filepaths = sorted(self.input_filepaths)
if len(self.input_filepaths) < 1:
raise RuntimeError(f"No files found in {input_dir}.")
# get training split
self.input_filepaths = utils.split_dataset(
self.input_filepaths, self.subset, 0.9
)
# get details about audio files
cnt = 0
self.input_files = {}
for input_filepath in tqdm(self.input_filepaths, ncols=80):
file_id = os.path.basename(input_filepath)
audio_file = AudioFile(
input_filepath,
preload=False,
half=half,
)
if audio_file.num_frames < self.length:
continue
self.input_files[file_id] = audio_file
self.sample_rate = self.input_files[file_id].sample_rate
cnt += 1
if cnt > 1000:
break
# some setup for iteratble loading of the dataset into RAM
self.items_since_load = self.buffer_reload_rate
def __len__(self):
return self.num_examples_per_epoch
def load_audio_buffer(self):
self.input_files_loaded = {} # clear audio buffer
self.items_since_load = 0 # reset iteration counter
nbytes_loaded = 0 # counter for data in RAM
# different subset in each
random.shuffle(self.input_filepaths)
# load files into RAM
for input_filepath in self.input_filepaths:
file_id = os.path.basename(input_filepath)
audio_file = AudioFile(
input_filepath,
preload=True,
half=self.half,
)
if audio_file.num_frames < self.length:
continue
self.input_files_loaded[file_id] = audio_file
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
nbytes_loaded += nbytes
if nbytes_loaded > self.buffer_size_gb * 1e9:
break
def __getitem__(self, _):
""" """
# increment counter
self.items_since_load += 1
# load next chunk into buffer if needed
if self.items_since_load > self.buffer_reload_rate:
self.load_audio_buffer()
rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys())
# use this random key to retrieve an input file
input_file = self.input_files_loaded[rand_input_file_id]
# load the audio data if needed
if not input_file.loaded:
input_file.load()
# get a random patch of size `self.length`
# start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length)
start_idx, stop_idx = utils.get_random_patch(input_file, self.length)
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
# random scaling
input_audio /= input_audio.abs().max()
scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12
input_audio *= 10 ** (-scale_dB / 20.0)
# generate random parameters (uniform) over 0 to 1
params = torch.rand(self.processor.num_control_params)
# expects batch dim
# apply plugins with random parameters
if self.processor_type == "channel":
params[-1] = 0.5 # set makeup gain to 0dB
target_audio = self.processor(
input_audio.view(1, 1, -1),
params.view(1, -1),
)
target_audio = target_audio.view(1, -1)
elif self.processor_type == "peq":
target_audio = self.processor(
input_audio.view(1, 1, -1).numpy(),
params.view(1, -1).numpy(),
)
target_audio = torch.tensor(target_audio).view(1, -1)
elif self.processor_type == "comp":
params[-1] = 0.5 # set makeup gain to 0dB
target_audio = self.processor(
input_audio.view(1, 1, -1).numpy(),
params.view(1, -1).numpy(),
)
target_audio = torch.tensor(target_audio).view(1, -1)
# clip
if self.soft_clip:
# target_audio = target_audio.clamp(-2.0, 2.0)
target_audio = torch.tanh(target_audio / 2.0) * 2.0
return input_audio, target_audio, params