# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Authors: Yossi Adi (adiyoss) and Alexandre Défossez (adefossez) import json import logging import math from pathlib import Path import os import re import librosa import numpy as np import torch import torch.utils.data as data from .preprocess import preprocess_one_dir from .audio import Audioset logger = logging.getLogger(__name__) def sort(infos): return sorted( infos, key=lambda info: int(info[1]), reverse=True) class Trainset: def __init__(self, json_dir, sample_rate=16000, segment=4.0, stride=1.0, pad=True): mix_json = os.path.join(json_dir, 'mix.json') s_jsons = list() s_infos = list() sets_re = re.compile(r's[0-9]+.json') print(os.listdir(json_dir)) for s in os.listdir(json_dir): if sets_re.search(s): s_jsons.append(os.path.join(json_dir, s)) with open(mix_json, 'r') as f: mix_infos = json.load(f) for s_json in s_jsons: with open(s_json, 'r') as f: s_infos.append(json.load(f)) length = int(sample_rate * segment) stride = int(sample_rate * stride) kw = {'length': length, 'stride': stride, 'pad': pad} self.mix_set = Audioset(sort(mix_infos), **kw) self.sets = list() for s_info in s_infos: self.sets.append(Audioset(sort(s_info), **kw)) # verify all sets has the same size for s in self.sets: assert len(s) == len(self.mix_set) def __getitem__(self, index): mix_sig = self.mix_set[index] tgt_sig = [self.sets[i][index] for i in range(len(self.sets))] return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig) def __len__(self): return len(self.mix_set) class Validset: """ load entire wav. """ def __init__(self, json_dir): mix_json = os.path.join(json_dir, 'mix.json') s_jsons = list() s_infos = list() sets_re = re.compile(r's[0-9]+.json') for s in os.listdir(json_dir): if sets_re.search(s): s_jsons.append(os.path.join(json_dir, s)) with open(mix_json, 'r') as f: mix_infos = json.load(f) for s_json in s_jsons: with open(s_json, 'r') as f: s_infos.append(json.load(f)) self.mix_set = Audioset(sort(mix_infos)) self.sets = list() for s_info in s_infos: self.sets.append(Audioset(sort(s_info))) for s in self.sets: assert len(s) == len(self.mix_set) def __getitem__(self, index): mix_sig = self.mix_set[index] tgt_sig = [self.sets[i][index] for i in range(len(self.sets))] return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig) def __len__(self): return len(self.mix_set) # The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet # released under the MIT License. # Author: Kaituo XU # Created on 2018/12 class EvalDataset(data.Dataset): def __init__(self, mix_dir, mix_json, batch_size, sample_rate=8000): """ Args: mix_dir: directory including mixture wav files mix_json: json file including mixture wav files """ super(EvalDataset, self).__init__() assert mix_dir != None or mix_json != None if mix_dir is not None: # Generate mix.json given mix_dir preprocess_one_dir(mix_dir, mix_dir, 'mix', sample_rate=sample_rate) mix_json = os.path.join(mix_dir, 'mix.json') with open(mix_json, 'r') as f: mix_infos = json.load(f) # sort it by #samples (impl bucket) def sort(infos): return sorted( infos, key=lambda info: int(info[1]), reverse=True) sorted_mix_infos = sort(mix_infos) # generate minibach infomations minibatch = [] start = 0 while True: end = min(len(sorted_mix_infos), start + batch_size) minibatch.append([sorted_mix_infos[start:end], sample_rate]) if end == len(sorted_mix_infos): break start = end self.minibatch = minibatch def __getitem__(self, index): return self.minibatch[index] def __len__(self): return len(self.minibatch) class EvalDataLoader(data.DataLoader): """ NOTE: just use batchsize=1 here, so drop_last=True makes no sense here. """ def __init__(self, *args, **kwargs): super(EvalDataLoader, self).__init__(*args, **kwargs) self.collate_fn = _collate_fn_eval def _collate_fn_eval(batch): """ Args: batch: list, len(batch) = 1. See AudioDataset.__getitem__() Returns: mixtures_pad: B x T, torch.Tensor ilens : B, torch.Tentor filenames: a list contain B strings """ # batch should be located in list assert len(batch) == 1 mixtures, filenames = load_mixtures(batch[0]) # get batch of lengths of input sequences ilens = np.array([mix.shape[0] for mix in mixtures]) # perform padding and convert to tensor pad_value = 0 mixtures_pad = pad_list([torch.from_numpy(mix).float() for mix in mixtures], pad_value) ilens = torch.from_numpy(ilens) return mixtures_pad, ilens, filenames def load_mixtures(batch): """ Returns: mixtures: a list containing B items, each item is T np.ndarray filenames: a list containing B strings T varies from item to item. """ mixtures, filenames = [], [] mix_infos, sample_rate = batch # for each utterance for mix_info in mix_infos: mix_path = mix_info[0] # read wav file mix, _ = librosa.load(mix_path, sr=sample_rate) mixtures.append(mix) filenames.append(mix_path) return mixtures, filenames def pad_list(xs, pad_value): n_batch = len(xs) max_len = max(x.size(0) for x in xs) pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value) for i in range(n_batch): pad[i, :xs[i].size(0)] = xs[i] return pad