# Ke Chen # knutchen@ucsd.edu # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data # The Main Script import os # this is to avoid the sdr calculation from occupying all cpus os.environ["OMP_NUM_THREADS"] = "4" os.environ["OPENBLAS_NUM_THREADS"] = "4" os.environ["MKL_NUM_THREADS"] = "6" os.environ["VECLIB_MAXIMUM_THREADS"] = "4" os.environ["NUMEXPR_NUM_THREADS"] = "6" import sys import librosa import numpy as np import argparse import logging import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from utils import collect_fn, dump_config, create_folder, prepprocess_audio import musdb from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper from data_processor import LGSPDataset, MusdbDataset import config import htsat_config from models.htsat import HTSAT_Swin_Transformer from sed_model import SEDWrapper import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from htsat_utils import process_idc import warnings warnings.filterwarnings("ignore") class data_prep(pl.LightningDataModule): def __init__(self, train_dataset, eval_dataset, device_num, config): super().__init__() self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.device_num = device_num self.config = config def train_dataloader(self): train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None train_loader = DataLoader( dataset = self.train_dataset, num_workers = config.num_workers, batch_size = config.batch_size // self.device_num, shuffle = False, sampler = train_sampler, collate_fn = collect_fn ) return train_loader def val_dataloader(self): eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None eval_loader = DataLoader( dataset = self.eval_dataset, num_workers = config.num_workers, batch_size = config.batch_size // self.device_num, shuffle = False, sampler = eval_sampler, collate_fn = collect_fn ) return eval_loader def test_dataloader(self): test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None test_loader = DataLoader( dataset = self.eval_dataset, num_workers = config.num_workers, batch_size = config.batch_size // self.device_num, shuffle = False, sampler = test_sampler, collate_fn = collect_fn ) return test_loader def save_idc(): train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5") eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy") process_idc(eval_index_path, config.classes_num, "eval_idc.npy") # Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz def process_musdb(): # use musdb as testset test_data = musdb.DB( root = config.musdb_path, download = False, subsets = "test", is_wav = True ) print(len(test_data.tracks)) mus_tracks = [] # in musdb, all fs is the same (44100) orig_fs = test_data.tracks[0].rate print(orig_fs) for track in test_data.tracks: temp = {} mixture = prepprocess_audio( track.audio, orig_fs, config.sample_rate, config.test_type ) temp["mixture" ]= mixture for dickey in config.test_key: source = prepprocess_audio( track.targets[dickey].audio, orig_fs, config.sample_rate, config.test_type ) temp[dickey] = source print(track.audio.shape, len(temp.keys()), temp["mixture"].shape) mus_tracks.append(temp) print(len(mus_tracks)) # save the file to npy np.save("musdb-32000fs.npy", mus_tracks) # weight average will perform in the given folder # It will output one model checkpoint, which avergas the weight of all models in the folder def weight_average(): model_ckpt = [] model_files = os.listdir(config.wa_model_folder) wa_ckpt = { "state_dict": {} } for model_file in model_files: model_file = os.path.join(config.esm_model_folder, model_file) model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"]) keys = model_ckpt[0].keys() for key in keys: model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt]) model_ckpt_key = torch.mean(model_ckpt_key, dim = 0) assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape wa_ckpt["state_dict"][key] = model_ckpt_key torch.save(wa_ckpt, config.wa_model_path) # use the model to quickly separate a track given a query # it requires four variables in config.py: # inference_file: the track you want to separate # inference_query: a **folder** containing all samples from the same source # test_key: ["name"] indicate the source name (just a name for final output, no other functions) # wave_output_path: the output folder # make sure the query folder contain the samples from the same source # each time, the model is able to separate one source from the track # if you want to separate multiple sources, you need to change the query folder or write a script to help you do that def inference(): # set exp settings device_name = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device("cuda") assert config.test_key is not None, "there should be a separate key" create_folder(config.wave_output_path) test_track, fs = librosa.load(config.inference_file, sr = None) test_track = test_track[:,None] print(test_track.shape) print(fs) # convert the track into 32000 Hz sample rate test_track = prepprocess_audio( test_track, fs, config.sample_rate, config.test_type ) test_tracks = [] temp = [test_track] for dickey in config.test_key: temp.append(test_track) temp = np.array(temp) test_tracks.append(temp) dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it loader = DataLoader( dataset = dataset, num_workers = 1, batch_size = 1, shuffle = False ) # obtain the samples for query queries = [] for query_file in os.listdir(config.inference_query): f_path = os.path.join(config.inference_query, query_file) if query_file.endswith(".wav"): temp_q, fs = librosa.load(f_path, sr = None) temp_q = temp_q[:, None] temp_q = prepprocess_audio( temp_q, fs, config.sample_rate, config.test_type ) temp = [temp_q] for dickey in config.test_key: temp.append(temp_q) temp = np.array(temp) queries.append(temp) assert config.resume_checkpoint is not None, "there should be a saved model when inferring" sed_model = HTSAT_Swin_Transformer( spec_size=htsat_config.htsat_spec_size, patch_size=htsat_config.htsat_patch_size, in_chans=1, num_classes=htsat_config.classes_num, window_size=htsat_config.htsat_window_size, config = htsat_config, depths = htsat_config.htsat_depth, embed_dim = htsat_config.htsat_dim, patch_stride=htsat_config.htsat_stride, num_heads=htsat_config.htsat_num_head ) at_model = SEDWrapper( sed_model = sed_model, config = htsat_config, dataset = None ) ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") at_model.load_state_dict(ckpt["state_dict"]) trainer = pl.Trainer( gpus = 1 ) avg_at = None # obtain the latent embedding as query if config.infer_type == "mean": avg_dataset = MusdbDataset(tracks = queries) avg_loader = DataLoader( dataset = avg_dataset, num_workers = 1, batch_size = 1, shuffle = False ) at_wrapper = AutoTaggingWarpper( at_model = at_model, config = config, target_keys = config.test_key ) trainer.test(at_wrapper, test_dataloaders = avg_loader) avg_at = at_wrapper.avg_at # import seapration model model = ZeroShotASP( channels = 1, config = config, at_model = at_model, dataset = dataset ) # resume checkpoint ckpt = torch.load(config.resume_checkpoint, map_location="cpu") model.load_state_dict(ckpt["state_dict"], strict= False) exp_model = SeparatorModel( model = model, config = config, target_keys = config.test_key, avg_at = avg_at, using_wiener = False, calc_sdr = False, output_wav = True ) trainer.test(exp_model, test_dataloaders = loader) # test the separation model, mainly in musdb def test(): # set exp settings device_name = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device("cuda") assert config.test_key is not None, "there should be a separate key" create_folder(config.wave_output_path) # use musdb as testset test_data = np.load(config.testset_path, allow_pickle = True) print(len(test_data)) mus_tracks = [] # in musdb, all fs is the same (44100) # load the dataset for track in test_data: temp = [] mixture = track["mixture"] temp.append(mixture) for dickey in config.test_key: source = track[dickey] temp.append(source) temp = np.array(temp) print(temp.shape) mus_tracks.append(temp) print(len(mus_tracks)) dataset = MusdbDataset(tracks = mus_tracks) loader = DataLoader( dataset = dataset, num_workers = 1, batch_size = 1, shuffle = False ) assert config.resume_checkpoint is not None, "there should be a saved model when inferring" sed_model = HTSAT_Swin_Transformer( spec_size=htsat_config.htsat_spec_size, patch_size=htsat_config.htsat_patch_size, in_chans=1, num_classes=htsat_config.classes_num, window_size=htsat_config.htsat_window_size, config = htsat_config, depths = htsat_config.htsat_depth, embed_dim = htsat_config.htsat_dim, patch_stride=htsat_config.htsat_stride, num_heads=htsat_config.htsat_num_head ) at_model = SEDWrapper( sed_model = sed_model, config = htsat_config, dataset = None ) ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") at_model.load_state_dict(ckpt["state_dict"]) trainer = pl.Trainer( gpus = 1 ) avg_at = None # obtain the query of four stems from the training set if config.infer_type == "mean": avg_data = np.load(config.testavg_path, allow_pickle = True)[:90] print(len(avg_data)) avgmus_tracks = [] # in musdb, all fs is the same (44100) # load the dataset for track in avg_data: temp = [] mixture = track["mixture"] temp.append(mixture) for dickey in config.test_key: source = track[dickey] temp.append(source) temp = np.array(temp) print(temp.shape) avgmus_tracks.append(temp) print(len(avgmus_tracks)) avg_dataset = MusdbDataset(tracks = avgmus_tracks) avg_loader = DataLoader( dataset = avg_dataset, num_workers = 1, batch_size = 1, shuffle = False ) at_wrapper = AutoTaggingWarpper( at_model = at_model, config = config, target_keys = config.test_key ) trainer.test(at_wrapper, test_dataloaders = avg_loader) avg_at = at_wrapper.avg_at model = ZeroShotASP( channels = 1, config = config, at_model = at_model, dataset = dataset ) ckpt = torch.load(config.resume_checkpoint, map_location="cpu") model.load_state_dict(ckpt["state_dict"], strict= False) exp_model = SeparatorModel( model = model, config = config, target_keys = config.test_key, avg_at = avg_at, using_wiener = config.using_wiener ) trainer.test(exp_model, test_dataloaders = loader) def train(): # set exp settings # device_name = "cuda" if torch.cuda.is_available() else "cpu" # device = torch.device("cuda") device_num = torch.cuda.device_count() print("each batch size:", config.batch_size // device_num) train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5") train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True) eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True) # set exp folder exp_dir = os.path.join(config.workspace, "results", config.exp_name) checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint") if not config.debug: create_folder(os.path.join(config.workspace, "results")) create_folder(exp_dir) create_folder(checkpoint_dir) dump_config(config, os.path.join(exp_dir, config.exp_name), False) # load data # import dataset LGSPDataset (latent general source separation) and sampler dataset = LGSPDataset( index_path = train_index_path, idc = train_idc, config = config, factor = 0.05, eval_mode = False ) eval_dataset = LGSPDataset( index_path = eval_index_path, idc = eval_idc, config = config, factor = 0.05, eval_mode = True ) audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config) checkpoint_callback = ModelCheckpoint( monitor = "mixture_sdr", filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}', save_top_k = 10, mode = "max" ) # infer at model sed_model = HTSAT_Swin_Transformer( spec_size=htsat_config.htsat_spec_size, patch_size=htsat_config.htsat_patch_size, in_chans=1, num_classes=htsat_config.classes_num, window_size=htsat_config.htsat_window_size, config = htsat_config, depths = htsat_config.htsat_depth, embed_dim = htsat_config.htsat_dim, patch_stride=htsat_config.htsat_stride, num_heads=htsat_config.htsat_num_head ) at_model = SEDWrapper( sed_model = sed_model, config = htsat_config, dataset = None ) # load the checkpoint ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") at_model.load_state_dict(ckpt["state_dict"]) trainer = pl.Trainer( deterministic=True, default_root_dir = checkpoint_dir, gpus = device_num, val_check_interval = 0.2, # check_val_every_n_epoch = 1, max_epochs = config.max_epoch, auto_lr_find = True, sync_batchnorm = True, callbacks = [checkpoint_callback], accelerator = "ddp" if device_num > 1 else None, resume_from_checkpoint = None, #config.resume_checkpoint, replace_sampler_ddp = False, gradient_clip_val=1.0, num_sanity_val_steps = 0, ) model = ZeroShotASP( channels = 1, config = config, at_model = at_model, dataset = dataset ) if config.resume_checkpoint is not None: ckpt = torch.load(config.resume_checkpoint, map_location="cpu") model.load_state_dict(ckpt["state_dict"]) # trainer.test(model, datamodule = audioset_data) trainer.fit(model, audioset_data) def main(): parser = argparse.ArgumentParser(description="latent genreal source separation parser") subparsers = parser.add_subparsers(dest = "mode") parser_train = subparsers.add_parser("train") parser_test = subparsers.add_parser("test") parser_musdb = subparsers.add_parser("musdb_process") parser_saveidc = subparsers.add_parser("save_idc") parser_wa = subparsers.add_parser("weight_average") parser_infer = subparsers.add_parser("inference") args = parser.parse_args() # default settings logging.basicConfig(level=logging.INFO) pl.utilities.seed.seed_everything(seed = config.random_seed) if args.mode == "train": train() elif args.mode == "test": test() elif args.mode == "musdb_process": process_musdb() elif args.mode == "weight_average": weight_average() elif args.mode == "save_idc": save_idc() elif args.mode == "inference": inference() else: raise Exception("Error Mode!") if __name__ == '__main__': main()