|
import argparse |
|
import os |
|
|
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
|
|
from lib import dataset |
|
from lib import nets |
|
from lib import spec_utils |
|
|
|
import inference |
|
|
|
|
|
def main(): |
|
p = argparse.ArgumentParser() |
|
p.add_argument('--gpu', '-g', type=int, default=-1) |
|
p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') |
|
p.add_argument('--mixtures', '-m', required=True) |
|
p.add_argument('--instruments', '-i', required=True) |
|
p.add_argument('--sr', '-r', type=int, default=44100) |
|
p.add_argument('--n_fft', '-f', type=int, default=2048) |
|
p.add_argument('--hop_length', '-H', type=int, default=1024) |
|
p.add_argument('--batchsize', '-B', type=int, default=4) |
|
p.add_argument('--cropsize', '-c', type=int, default=256) |
|
p.add_argument('--postprocess', '-p', action='store_true') |
|
args = p.parse_args() |
|
|
|
print('loading model...', end=' ') |
|
device = torch.device('cpu') |
|
model = nets.CascadedNet(args.n_fft, args.hop_length) |
|
model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) |
|
if torch.cuda.is_available() and args.gpu >= 0: |
|
device = torch.device('cuda:{}'.format(args.gpu)) |
|
model.to(device) |
|
print('done') |
|
|
|
filelist = dataset.make_pair(args.mixtures, args.instruments) |
|
for mix_path, inst_path in filelist: |
|
|
|
|
|
|
|
|
|
|
|
basename = os.path.splitext(os.path.basename(mix_path))[0] |
|
print(basename) |
|
|
|
print('loading wave source...', end=' ') |
|
X, sr = librosa.load( |
|
mix_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') |
|
y, sr = librosa.load( |
|
inst_path, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') |
|
print('done') |
|
|
|
if X.ndim == 1: |
|
|
|
X = np.asarray([X, X]) |
|
|
|
print('stft of wave source...', end=' ') |
|
X, y = spec_utils.align_wave_head_and_tail(X, y, sr) |
|
X = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) |
|
y = spec_utils.wave_to_spectrogram(y, args.hop_length, args.n_fft) |
|
print('done') |
|
|
|
sp = inference.Separator(model, device, args.batchsize, args.cropsize, args.postprocess) |
|
a_spec, _ = sp.separate_tta(X - y) |
|
|
|
print('inverse stft of pseudo instruments...', end=' ') |
|
pseudo_inst = y + a_spec |
|
print('done') |
|
|
|
sf.write('pseudo/{}_PseudoInstruments.wav'.format(basename), [0], sr) |
|
np.save('pseudo/{}_PseudoInstruments.npy'.format(basename), pseudo_inst) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|