Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import time | |
import librosa | |
from tqdm import tqdm | |
import sys | |
import os | |
import glob | |
import torch | |
import numpy as np | |
import soundfile as sf | |
import torch.nn as nn | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(current_dir) | |
from utils import demix_track, demix_track_demucs, get_model_from_config | |
import warnings | |
warnings.filterwarnings("ignore") | |
def run_folder(model, args, config, device, verbose=False): | |
start_time = time.time() | |
model.eval() | |
all_mixtures_path = glob.glob(args.input_folder + '/*.*') | |
all_mixtures_path.sort() | |
print('Total files found: {}'.format(len(all_mixtures_path))) | |
instruments = config.training.instruments | |
if config.training.target_instrument is not None: | |
instruments = [config.training.target_instrument] | |
if not os.path.isdir(args.store_dir): | |
os.mkdir(args.store_dir) | |
if not verbose: | |
all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress") | |
if args.disable_detailed_pbar: | |
detailed_pbar = False | |
else: | |
detailed_pbar = True | |
for path in all_mixtures_path: | |
print("Starting processing track: ", path) | |
if not verbose: | |
all_mixtures_path.set_postfix({'track': os.path.basename(path)}) | |
try: | |
mix, sr = librosa.load(path, sr=44100, mono=False) | |
except Exception as e: | |
print('Cannot read track: {}'.format(path)) | |
print('Error message: {}'.format(str(e))) | |
continue | |
# Convert mono to stereo if needed | |
if len(mix.shape) == 1: | |
mix = np.stack([mix, mix], axis=0) | |
mix_orig = mix.copy() | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
mono = mix.mean(0) | |
mean = mono.mean() | |
std = mono.std() | |
mix = (mix - mean) / std | |
if args.use_tta: | |
# orig, channel inverse, polarity inverse | |
track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()] | |
else: | |
track_proc_list = [mix.copy()] | |
full_result = [] | |
for single_track in track_proc_list: | |
mixture = torch.tensor(single_track, dtype=torch.float32) | |
if args.model_type == 'htdemucs': | |
waveforms = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar) | |
else: | |
waveforms = demix_track(config, model, mixture, device, pbar=detailed_pbar) | |
full_result.append(waveforms) | |
# Average all values in single dict | |
waveforms = full_result[0] | |
for i in range(1, len(full_result)): | |
d = full_result[i] | |
for el in d: | |
if i == 2: | |
waveforms[el] += -1.0 * d[el] | |
elif i == 1: | |
waveforms[el] += d[el][::-1].copy() | |
else: | |
waveforms[el] += d[el] | |
for el in waveforms: | |
waveforms[el] = waveforms[el] / len(full_result) | |
file_name, _ = os.path.splitext(os.path.basename(path)) | |
song_dir = os.path.join(args.store_dir, file_name) | |
if not os.path.exists(song_dir): | |
os.makedirs(song_dir) | |
model_dir = os.path.join(song_dir, args.model_type) | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
for instr in instruments: | |
estimates = waveforms[instr].T | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
estimates = estimates * std + mean | |
if args.flac_file: | |
output_file = os.path.join(model_dir, f"{file_name}_{instr}.flac") | |
subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' | |
sf.write(output_file, estimates, sr, subtype=subtype) | |
else: | |
output_file = os.path.join(model_dir, f"{file_name}_{instr}.wav") | |
sf.write(output_file, estimates, sr, subtype='FLOAT') | |
# Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent) | |
if args.extract_instrumental: | |
if 'vocals' in instruments: | |
estimates = waveforms['vocals'].T | |
else: | |
estimates = waveforms[instruments[0]].T | |
if 'normalize' in config.inference: | |
if config.inference['normalize'] is True: | |
estimates = estimates * std + mean | |
if args.flac_file: | |
instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.flac") | |
subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' | |
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype=subtype) | |
else: | |
instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.wav") | |
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT') | |
time.sleep(1) | |
print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) | |
def proc_folder_direct(model_type, config_path, start_check_point, input_folder, store_dir, device_ids=[0], extract_instrumental=False, disable_detailed_pbar=False, force_cpu=False, flac_file=False, pcm_type='PCM_24', use_tta=False): | |
device = "cpu" | |
if force_cpu: | |
device = "cpu" | |
elif torch.cuda.is_available(): | |
print('CUDA is available, use --force_cpu to disable it.') | |
device = "cuda" | |
device = f'cuda:{device_ids}' if type(device_ids) == int else f'cuda:{device_ids[0]}' | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
print("Using device: ", device) | |
model_load_start_time = time.time() | |
torch.backends.cudnn.benchmark = True | |
model, config = get_model_from_config(model_type, config_path) | |
if start_check_point != '': | |
print('Start from checkpoint: {}'.format(start_check_point)) | |
if model_type == 'htdemucs': | |
state_dict = torch.load(start_check_point, map_location=device, weights_only=False) | |
if 'state' in state_dict: | |
state_dict = state_dict['state'] | |
else: | |
state_dict = torch.load(start_check_point, map_location=device, weights_only=True) | |
model.load_state_dict(state_dict) | |
print("Instruments: {}".format(config.training.instruments)) | |
if type(device_ids) != int: | |
model = nn.DataParallel(model, device_ids=device_ids) | |
model = model.to(device) | |
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) | |
args = argparse.Namespace( | |
model_type=model_type, | |
config_path=config_path, | |
start_check_point=start_check_point, | |
input_folder=input_folder, | |
store_dir=store_dir, | |
device_ids=device_ids, | |
extract_instrumental=extract_instrumental, | |
disable_detailed_pbar=disable_detailed_pbar, | |
force_cpu=force_cpu, | |
flac_file=flac_file, | |
pcm_type=pcm_type, | |
use_tta=use_tta | |
) | |
run_folder(model, args, config, device, verbose=True) | |