nevreal's picture
Upload Complited files
ecfa0da verified
# coding: utf-8
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
import argparse
import time
import librosa
from tqdm import tqdm
import sys
import os
import glob
import torch
import numpy as np
import soundfile as sf
import torch.nn as nn
# Using the embedded version of Python can also correctly import the utils module.
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from utils import demix, get_model_from_config
import warnings
warnings.filterwarnings("ignore")
class Args:
def __init__(
self,
input_file,
store_dir,
model_type,
extract_instrumental,
disable_detailed_pbar,
flac_file,
pcm_type,
use_tta,
):
self.input_file = input_file
self.model_type = model_type
self.store_dir = store_dir
self.extract_instrumental = extract_instrumental
self.disable_detailed_pbar = disable_detailed_pbar
self.flac_file = flac_file
self.pcm_type = pcm_type
self.use_tta = use_tta
def run_file(model, args, config, device, verbose=False):
start_time = time.time()
model.eval()
if not os.path.isfile(args.input_file):
print("File not found: {}".format(args.input_file))
return
instruments = config.training.instruments.copy()
if config.training.target_instrument is not None:
instruments = [config.training.target_instrument]
if not os.path.isdir(args.store_dir):
os.mkdir(args.store_dir)
print("Starting processing track: ", args.input_file)
try:
mix, sr = librosa.load(args.input_file, sr=44100, mono=False)
except Exception as e:
print("Cannot read track: {}".format(args.input_file))
print("Error message: {}".format(str(e)))
return
# Convert mono to stereo if needed
if len(mix.shape) == 1:
mix = np.stack([mix, mix], axis=0)
mix_orig = mix.copy()
if "normalize" in config.inference:
if config.inference["normalize"] is True:
mono = mix.mean(0)
mean = mono.mean()
std = mono.std()
mix = (mix - mean) / std
if args.use_tta:
# orig, channel inverse, polarity inverse
track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()]
else:
track_proc_list = [mix.copy()]
full_result = []
for mix in track_proc_list:
waveforms = demix(
config, model, mix, device, pbar=verbose, model_type=args.model_type
)
full_result.append(waveforms)
# Average all values in single dict
waveforms = full_result[0]
for i in range(1, len(full_result)):
d = full_result[i]
for el in d:
if i == 2:
waveforms[el] += -1.0 * d[el]
elif i == 1:
waveforms[el] += d[el][::-1].copy()
else:
waveforms[el] += d[el]
for el in waveforms:
waveforms[el] = waveforms[el] / len(full_result)
# Create a new `instr` in instruments list, 'instrumental'
if args.extract_instrumental:
instr = "vocals" if "vocals" in instruments else instruments[0]
instruments.append("instrumental")
# Output "instrumental", which is an inverse of 'vocals' or the first stem in list if 'vocals' absent
waveforms["instrumental"] = mix_orig - waveforms[instr]
for instr in instruments:
estimates = waveforms[instr].T
if "normalize" in config.inference:
if config.inference["normalize"] is True:
estimates = estimates * std + mean
file_name, _ = os.path.splitext(os.path.basename(args.input_file))
if args.flac_file:
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.flac")
subtype = "PCM_16" if args.pcm_type == "PCM_16" else "PCM_24"
sf.write(output_file, estimates, sr, subtype=subtype)
else:
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.wav")
sf.write(output_file, estimates, sr, subtype="FLOAT")
time.sleep(1)
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def proc_file(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
type=str,
default="mdx23c",
help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer, scnet, scnet_unofficial, segm_models, swin_upernet, torchseg",
)
parser.add_argument("--config_path", type=str, help="path to config file")
parser.add_argument(
"--start_check_point",
type=str,
default="",
help="Initial checkpoint to valid weights",
)
parser.add_argument(
"--input_file", type=str, help="folder with mixtures to process"
)
parser.add_argument(
"--store_dir", default="", type=str, help="path to store results as wav file"
)
parser.add_argument(
"--device_ids", nargs="+", type=int, default=0, help="list of gpu ids"
)
parser.add_argument(
"--extract_instrumental",
action="store_true",
help="invert vocals to get instrumental if provided",
)
parser.add_argument(
"--disable_detailed_pbar",
action="store_true",
help="disable detailed progress bar",
)
parser.add_argument(
"--force_cpu",
action="store_true",
help="Force the use of CPU even if CUDA is available",
)
parser.add_argument(
"--flac_file", action="store_true", help="Output flac file instead of wav"
)
parser.add_argument(
"--pcm_type",
type=str,
choices=["PCM_16", "PCM_24"],
default="PCM_24",
help="PCM type for FLAC files (PCM_16 or PCM_24)",
)
parser.add_argument(
"--use_tta",
action="store_true",
help="Flag adds test time augmentation during inference (polarity and channel inverse). While this triples the runtime, it reduces noise and slightly improves prediction quality.",
)
if args is None:
args = parser.parse_args()
else:
args = parser.parse_args(args)
device = "cpu"
if args.force_cpu:
device = "cpu"
elif torch.cuda.is_available():
print("CUDA is available, use --force_cpu to disable it.")
device = "cuda"
device = (
f"cuda:{args.device_ids[0]}"
if type(args.device_ids) == list
else f"cuda:{args.device_ids}"
)
elif torch.backends.mps.is_available():
device = "mps"
print("Using device: ", device)
model_load_start_time = time.time()
torch.backends.cudnn.benchmark = True
model, config = get_model_from_config(args.model_type, args.config_path)
if args.start_check_point != "":
print("Start from checkpoint: {}".format(args.start_check_point))
if args.model_type == "htdemucs":
state_dict = torch.load(
args.start_check_point, map_location=device, weights_only=False
)
# Fix for htdemucs pretrained models
if "state" in state_dict:
state_dict = state_dict["state"]
else:
state_dict = torch.load(
args.start_check_point, map_location=device, weights_only=True
)
model.load_state_dict(state_dict)
print("Instruments: {}".format(config.training.instruments))
# in case multiple CUDA GPUs are used and --device_ids arg is passed
if (
type(args.device_ids) == list
and len(args.device_ids) > 1
and not args.force_cpu
):
model = nn.DataParallel(model, device_ids=args.device_ids)
model = model.to(device)
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
run_file(model, args, config, device, verbose=True)
if __name__ == "__main__":
proc_file(None)