import argparse import time import librosa from tqdm import tqdm import sys import os import glob import torch import numpy as np import soundfile as sf import torch.nn as nn current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) from utils import demix_track, demix_track_demucs, get_model_from_config import warnings warnings.filterwarnings("ignore") def run_folder(model, args, config, device, verbose=False): start_time = time.time() model.eval() all_mixtures_path = glob.glob(args.input_folder + '/*.*') all_mixtures_path.sort() print('Total files found: {}'.format(len(all_mixtures_path))) instruments = config.training.instruments if config.training.target_instrument is not None: instruments = [config.training.target_instrument] if not os.path.isdir(args.store_dir): os.mkdir(args.store_dir) if not verbose: all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress") if args.disable_detailed_pbar: detailed_pbar = False else: detailed_pbar = True for path in all_mixtures_path: print("Starting processing track: ", path) if not verbose: all_mixtures_path.set_postfix({'track': os.path.basename(path)}) try: mix, sr = librosa.load(path, sr=44100, mono=False) except Exception as e: print('Cannot read track: {}'.format(path)) print('Error message: {}'.format(str(e))) continue # Convert mono to stereo if needed if len(mix.shape) == 1: mix = np.stack([mix, mix], axis=0) mix_orig = mix.copy() if 'normalize' in config.inference: if config.inference['normalize'] is True: mono = mix.mean(0) mean = mono.mean() std = mono.std() mix = (mix - mean) / std if args.use_tta: # orig, channel inverse, polarity inverse track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()] else: track_proc_list = [mix.copy()] full_result = [] for single_track in track_proc_list: mixture = torch.tensor(single_track, dtype=torch.float32) if args.model_type == 'htdemucs': waveforms = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar) else: waveforms = demix_track(config, model, mixture, device, pbar=detailed_pbar) full_result.append(waveforms) # Average all values in single dict waveforms = full_result[0] for i in range(1, len(full_result)): d = full_result[i] for el in d: if i == 2: waveforms[el] += -1.0 * d[el] elif i == 1: waveforms[el] += d[el][::-1].copy() else: waveforms[el] += d[el] for el in waveforms: waveforms[el] = waveforms[el] / len(full_result) file_name, _ = os.path.splitext(os.path.basename(path)) song_dir = os.path.join(args.store_dir, file_name) if not os.path.exists(song_dir): os.makedirs(song_dir) model_dir = os.path.join(song_dir, args.model_type) if not os.path.exists(model_dir): os.makedirs(model_dir) for instr in instruments: estimates = waveforms[instr].T if 'normalize' in config.inference: if config.inference['normalize'] is True: estimates = estimates * std + mean if args.flac_file: output_file = os.path.join(model_dir, f"{file_name}_{instr}.flac") subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' sf.write(output_file, estimates, sr, subtype=subtype) else: output_file = os.path.join(model_dir, f"{file_name}_{instr}.wav") sf.write(output_file, estimates, sr, subtype='FLOAT') # Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent) if args.extract_instrumental: if 'vocals' in instruments: estimates = waveforms['vocals'].T else: estimates = waveforms[instruments[0]].T if 'normalize' in config.inference: if config.inference['normalize'] is True: estimates = estimates * std + mean if args.flac_file: instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.flac") subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype=subtype) else: instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.wav") sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT') time.sleep(1) print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) def proc_folder_direct(model_type, config_path, start_check_point, input_folder, store_dir, device_ids=[0], extract_instrumental=False, disable_detailed_pbar=False, force_cpu=False, flac_file=False, pcm_type='PCM_24', use_tta=False): device = "cpu" if force_cpu: device = "cpu" elif torch.cuda.is_available(): print('CUDA is available, use --force_cpu to disable it.') device = "cuda" device = f'cuda:{device_ids}' if type(device_ids) == int else f'cuda:{device_ids[0]}' elif torch.backends.mps.is_available(): device = "mps" print("Using device: ", device) model_load_start_time = time.time() torch.backends.cudnn.benchmark = True model, config = get_model_from_config(model_type, config_path) if start_check_point != '': print('Start from checkpoint: {}'.format(start_check_point)) if model_type == 'htdemucs': state_dict = torch.load(start_check_point, map_location=device, weights_only=False) if 'state' in state_dict: state_dict = state_dict['state'] else: state_dict = torch.load(start_check_point, map_location=device, weights_only=True) model.load_state_dict(state_dict) print("Instruments: {}".format(config.training.instruments)) if type(device_ids) != int: model = nn.DataParallel(model, device_ids=device_ids) model = model.to(device) print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) args = argparse.Namespace( model_type=model_type, config_path=config_path, start_check_point=start_check_point, input_folder=input_folder, store_dir=store_dir, device_ids=device_ids, extract_instrumental=extract_instrumental, disable_detailed_pbar=disable_detailed_pbar, force_cpu=force_cpu, flac_file=flac_file, pcm_type=pcm_type, use_tta=use_tta ) run_folder(model, args, config, device, verbose=True)