Sebas / data_processor.py
Mudrock's picture
Upload 18 files
530a7d1
# Ke Chen
# [email protected]
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
# The dataset classes
import numpy as np
import torch
import logging
import os
import sys
import h5py
import csv
import time
import random
import json
from datetime import datetime
from utils import int16_to_float32
from torch.utils.data import Dataset, Sampler
# output the dict["index"].key form to save the memory in multi-GPU training
def reverse_dict(data_path, sed_path, output_dir):
# filename
waveform_dir = os.path.join(output_dir, "audioset_eval_waveform_balanced.h5")
sed_dir = os.path.join(output_dir, "audioset_eval_sed_balanced.h5")
# load data
logging.info("Write Data...............")
h_data = h5py.File(data_path, "r")
h_sed = h5py.File(sed_path, "r")
audio_num = len(h_data["waveform"])
assert len(h_data["waveform"]) == len(h_sed["sed_vector"]), "waveform and sed should be in the same length"
with h5py.File(waveform_dir, 'w') as hw:
for i in range(audio_num):
hw.create_dataset(str(i), data=int16_to_float32(h_data['waveform'][i]), dtype=np.float32)
logging.info("Write Data Succeed...............")
logging.info("Write Sed...............")
with h5py.File(sed_dir, 'w') as hw:
for i in range(audio_num):
hw.create_dataset(str(i), data=h_sed['sed_vector'][i], dtype=np.float32)
logging.info("Write Sed Succeed...............")
# A dataset for handling musdb
class MusdbDataset(Dataset):
def __init__(self, tracks):
self.tracks = tracks
self.dataset_len = len(tracks)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return:
track: [mixture + n_sources, n_samples]
"""
return self.tracks[index]
def __len__(self):
return self.dataset_len
class InferDataset(Dataset):
def __init__(self, tracks):
self.tracks = tracks
self.dataset_len = len(tracks)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return:
track: [mixture + n_sources, n_samples]
"""
return self.tracks[index]
def __len__(self):
return self.dataset_len
# polished LGSPDataset, the main dataset for procssing the audioset files
class LGSPDataset(Dataset):
def __init__(self, index_path, idc, config, factor = 3, eval_mode = False):
self.index_path = index_path
self.fp = h5py.File(index_path, "r")
self.config = config
self.idc = idc
self.factor = factor
self.classes_num = self.config.classes_num
self.eval_mode = eval_mode
self.total_size = int(len(self.fp["audio_name"]) * self.factor)
self.generate_queue()
logging.info("total dataset size: %d" %(self.total_size))
logging.info("class num: %d" %(self.classes_num))
def generate_queue(self):
self.queue = []
self.class_queue = []
if self.config.debug:
self.total_size = 1000
if self.config.balanced_data:
while len(self.queue) < self.total_size * 2:
if self.eval_mode:
if len(self.config.eval_list) == 0:
class_set = [*range(self.classes_num)]
else:
class_set = self.config.eval_list[:]
else:
class_set = [*range(self.classes_num)]
class_set = list(set(class_set) - set(self.config.eval_list))
random.shuffle(class_set)
self.queue += [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in class_set]
self.class_queue += class_set[:]
self.queue = self.queue[:self.total_size * 2]
self.class_queue = self.class_queue[:self.total_size * 2]
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
assert len(self.queue) == self.total_size, "generate data error!!"
else:
if self.eval_mode:
if len(self.config.eval_list) == 0:
class_set = [*range(self.classes_num)]
else:
class_set = self.config.eval_list[:]
else:
class_set = [*range(self.classes_num)]
class_set = list(set(class_set) - set(self.config.eval_list))
self.class_queue = random.choices(class_set, k = self.total_size * 2)
self.queue = [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in self.class_queue]
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
assert len(self.queue) == self.total_size, "generate data error!!"
logging.info("queue regenerated:%s" %(self.queue[-5:]))
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return: {
"audio_name_1": str,
"waveform_1": (clip_samples,),
"class_id_1": int,
"audio_name_2": str,
"waveform_2": (clip_samples,),
"class_id_2": int,
...
"check_num": int
}
"""
# put the right index here!!!
data_dict = {}
for k in range(2):
s_index = self.queue[index][k]
target = self.class_queue[index][k]
audio_name = self.fp["audio_name"][s_index].decode()
hdf5_path = self.fp["hdf5_path"][s_index].decode().replace("/home/tiger/DB/knut/data/audioset", self.config.dataset_path)
r_idx = self.fp["index_in_hdf5"][s_index]
with h5py.File(hdf5_path, "r") as f:
waveform = int16_to_float32(f["waveform"][r_idx])
data_dict["audio_name_" + str(k+1)] = audio_name
data_dict["waveform_" + str(k+1)] = waveform
data_dict["class_id_" + str(k+1)] = target
data_dict["check_num"] = str(self.queue[-5:])
return data_dict
def __len__(self):
return self.total_size
# only for test
class TestDataset(Dataset):
def __init__(self, dataset_size):
print("init")
self.dataset_size = dataset_size
self.base_num = 100
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
def get_new_list(self):
self.base_num = random.randint(0,10)
print("base num changed:", self.base_num)
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
def __getitem__(self, index):
return self.dicts[index]
def __len__(self):
return self.dataset_size