|
|
|
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" |
|
|
|
import argparse |
|
import time |
|
import librosa |
|
from tqdm import tqdm |
|
import sys |
|
import os |
|
import glob |
|
import torch |
|
import numpy as np |
|
import soundfile as sf |
|
import torch.nn as nn |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_dir) |
|
from utils import demix, get_model_from_config |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class Args: |
|
def __init__( |
|
self, |
|
input_file, |
|
store_dir, |
|
model_type, |
|
extract_instrumental, |
|
disable_detailed_pbar, |
|
flac_file, |
|
pcm_type, |
|
use_tta, |
|
): |
|
self.input_file = input_file |
|
self.model_type = model_type |
|
self.store_dir = store_dir |
|
self.extract_instrumental = extract_instrumental |
|
self.disable_detailed_pbar = disable_detailed_pbar |
|
self.flac_file = flac_file |
|
self.pcm_type = pcm_type |
|
self.use_tta = use_tta |
|
|
|
|
|
def run_file(model, args, config, device, verbose=False): |
|
start_time = time.time() |
|
model.eval() |
|
|
|
if not os.path.isfile(args.input_file): |
|
print("File not found: {}".format(args.input_file)) |
|
return |
|
|
|
instruments = config.training.instruments.copy() |
|
if config.training.target_instrument is not None: |
|
instruments = [config.training.target_instrument] |
|
|
|
if not os.path.isdir(args.store_dir): |
|
os.mkdir(args.store_dir) |
|
|
|
print("Starting processing track: ", args.input_file) |
|
try: |
|
mix, sr = librosa.load(args.input_file, sr=44100, mono=False) |
|
except Exception as e: |
|
print("Cannot read track: {}".format(args.input_file)) |
|
print("Error message: {}".format(str(e))) |
|
return |
|
|
|
|
|
if len(mix.shape) == 1: |
|
mix = np.stack([mix, mix], axis=0) |
|
|
|
mix_orig = mix.copy() |
|
if "normalize" in config.inference: |
|
if config.inference["normalize"] is True: |
|
mono = mix.mean(0) |
|
mean = mono.mean() |
|
std = mono.std() |
|
mix = (mix - mean) / std |
|
|
|
if args.use_tta: |
|
|
|
track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()] |
|
else: |
|
track_proc_list = [mix.copy()] |
|
|
|
full_result = [] |
|
for mix in track_proc_list: |
|
waveforms = demix( |
|
config, model, mix, device, pbar=verbose, model_type=args.model_type |
|
) |
|
full_result.append(waveforms) |
|
|
|
|
|
waveforms = full_result[0] |
|
for i in range(1, len(full_result)): |
|
d = full_result[i] |
|
for el in d: |
|
if i == 2: |
|
waveforms[el] += -1.0 * d[el] |
|
elif i == 1: |
|
waveforms[el] += d[el][::-1].copy() |
|
else: |
|
waveforms[el] += d[el] |
|
for el in waveforms: |
|
waveforms[el] = waveforms[el] / len(full_result) |
|
|
|
|
|
if args.extract_instrumental: |
|
instr = "vocals" if "vocals" in instruments else instruments[0] |
|
instruments.append("instrumental") |
|
|
|
waveforms["instrumental"] = mix_orig - waveforms[instr] |
|
|
|
for instr in instruments: |
|
estimates = waveforms[instr].T |
|
if "normalize" in config.inference: |
|
if config.inference["normalize"] is True: |
|
estimates = estimates * std + mean |
|
file_name, _ = os.path.splitext(os.path.basename(args.input_file)) |
|
if args.flac_file: |
|
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.flac") |
|
subtype = "PCM_16" if args.pcm_type == "PCM_16" else "PCM_24" |
|
sf.write(output_file, estimates, sr, subtype=subtype) |
|
else: |
|
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.wav") |
|
sf.write(output_file, estimates, sr, subtype="FLOAT") |
|
|
|
time.sleep(1) |
|
print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) |
|
|
|
|
|
def proc_file(args): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_type", |
|
type=str, |
|
default="mdx23c", |
|
help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer, scnet, scnet_unofficial, segm_models, swin_upernet, torchseg", |
|
) |
|
parser.add_argument("--config_path", type=str, help="path to config file") |
|
parser.add_argument( |
|
"--start_check_point", |
|
type=str, |
|
default="", |
|
help="Initial checkpoint to valid weights", |
|
) |
|
parser.add_argument( |
|
"--input_file", type=str, help="folder with mixtures to process" |
|
) |
|
parser.add_argument( |
|
"--store_dir", default="", type=str, help="path to store results as wav file" |
|
) |
|
parser.add_argument( |
|
"--device_ids", nargs="+", type=int, default=0, help="list of gpu ids" |
|
) |
|
parser.add_argument( |
|
"--extract_instrumental", |
|
action="store_true", |
|
help="invert vocals to get instrumental if provided", |
|
) |
|
parser.add_argument( |
|
"--disable_detailed_pbar", |
|
action="store_true", |
|
help="disable detailed progress bar", |
|
) |
|
parser.add_argument( |
|
"--force_cpu", |
|
action="store_true", |
|
help="Force the use of CPU even if CUDA is available", |
|
) |
|
parser.add_argument( |
|
"--flac_file", action="store_true", help="Output flac file instead of wav" |
|
) |
|
parser.add_argument( |
|
"--pcm_type", |
|
type=str, |
|
choices=["PCM_16", "PCM_24"], |
|
default="PCM_24", |
|
help="PCM type for FLAC files (PCM_16 or PCM_24)", |
|
) |
|
parser.add_argument( |
|
"--use_tta", |
|
action="store_true", |
|
help="Flag adds test time augmentation during inference (polarity and channel inverse). While this triples the runtime, it reduces noise and slightly improves prediction quality.", |
|
) |
|
if args is None: |
|
args = parser.parse_args() |
|
else: |
|
args = parser.parse_args(args) |
|
|
|
device = "cpu" |
|
if args.force_cpu: |
|
device = "cpu" |
|
elif torch.cuda.is_available(): |
|
print("CUDA is available, use --force_cpu to disable it.") |
|
device = "cuda" |
|
device = ( |
|
f"cuda:{args.device_ids[0]}" |
|
if type(args.device_ids) == list |
|
else f"cuda:{args.device_ids}" |
|
) |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
|
|
print("Using device: ", device) |
|
|
|
model_load_start_time = time.time() |
|
torch.backends.cudnn.benchmark = True |
|
|
|
model, config = get_model_from_config(args.model_type, args.config_path) |
|
if args.start_check_point != "": |
|
print("Start from checkpoint: {}".format(args.start_check_point)) |
|
if args.model_type == "htdemucs": |
|
state_dict = torch.load( |
|
args.start_check_point, map_location=device, weights_only=False |
|
) |
|
|
|
if "state" in state_dict: |
|
state_dict = state_dict["state"] |
|
else: |
|
state_dict = torch.load( |
|
args.start_check_point, map_location=device, weights_only=True |
|
) |
|
model.load_state_dict(state_dict) |
|
print("Instruments: {}".format(config.training.instruments)) |
|
|
|
|
|
if ( |
|
type(args.device_ids) == list |
|
and len(args.device_ids) > 1 |
|
and not args.force_cpu |
|
): |
|
model = nn.DataParallel(model, device_ids=args.device_ids) |
|
|
|
model = model.to(device) |
|
|
|
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) |
|
|
|
run_file(model, args, config, device, verbose=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
proc_file(None) |
|
|