Spaces:
Running
on
Zero
Running
on
Zero
feat!: add inference and utils files for model features and sound separation
Browse files- inference.py +187 -0
- utils.py +194 -0
inference.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
import librosa
|
4 |
+
from tqdm import tqdm
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import soundfile as sf
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
14 |
+
sys.path.append(current_dir)
|
15 |
+
from utils import demix_track, demix_track_demucs, get_model_from_config
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
|
21 |
+
def run_folder(model, args, config, device, verbose=False):
|
22 |
+
start_time = time.time()
|
23 |
+
model.eval()
|
24 |
+
all_mixtures_path = glob.glob(args.input_folder + '/*.*')
|
25 |
+
all_mixtures_path.sort()
|
26 |
+
print('Total files found: {}'.format(len(all_mixtures_path)))
|
27 |
+
|
28 |
+
instruments = config.training.instruments
|
29 |
+
if config.training.target_instrument is not None:
|
30 |
+
instruments = [config.training.target_instrument]
|
31 |
+
|
32 |
+
if not os.path.isdir(args.store_dir):
|
33 |
+
os.mkdir(args.store_dir)
|
34 |
+
|
35 |
+
if not verbose:
|
36 |
+
all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress")
|
37 |
+
|
38 |
+
if args.disable_detailed_pbar:
|
39 |
+
detailed_pbar = False
|
40 |
+
else:
|
41 |
+
detailed_pbar = True
|
42 |
+
|
43 |
+
for path in all_mixtures_path:
|
44 |
+
print("Starting processing track: ", path)
|
45 |
+
if not verbose:
|
46 |
+
all_mixtures_path.set_postfix({'track': os.path.basename(path)})
|
47 |
+
try:
|
48 |
+
mix, sr = librosa.load(path, sr=44100, mono=False)
|
49 |
+
except Exception as e:
|
50 |
+
print('Cannot read track: {}'.format(path))
|
51 |
+
print('Error message: {}'.format(str(e)))
|
52 |
+
continue
|
53 |
+
|
54 |
+
# Convert mono to stereo if needed
|
55 |
+
if len(mix.shape) == 1:
|
56 |
+
mix = np.stack([mix, mix], axis=0)
|
57 |
+
|
58 |
+
mix_orig = mix.copy()
|
59 |
+
if 'normalize' in config.inference:
|
60 |
+
if config.inference['normalize'] is True:
|
61 |
+
mono = mix.mean(0)
|
62 |
+
mean = mono.mean()
|
63 |
+
std = mono.std()
|
64 |
+
mix = (mix - mean) / std
|
65 |
+
|
66 |
+
if args.use_tta:
|
67 |
+
# orig, channel inverse, polarity inverse
|
68 |
+
track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()]
|
69 |
+
else:
|
70 |
+
track_proc_list = [mix.copy()]
|
71 |
+
|
72 |
+
full_result = []
|
73 |
+
for single_track in track_proc_list:
|
74 |
+
mixture = torch.tensor(single_track, dtype=torch.float32)
|
75 |
+
if args.model_type == 'htdemucs':
|
76 |
+
waveforms = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar)
|
77 |
+
else:
|
78 |
+
waveforms = demix_track(config, model, mixture, device, pbar=detailed_pbar)
|
79 |
+
full_result.append(waveforms)
|
80 |
+
|
81 |
+
# Average all values in single dict
|
82 |
+
waveforms = full_result[0]
|
83 |
+
for i in range(1, len(full_result)):
|
84 |
+
d = full_result[i]
|
85 |
+
for el in d:
|
86 |
+
if i == 2:
|
87 |
+
waveforms[el] += -1.0 * d[el]
|
88 |
+
elif i == 1:
|
89 |
+
waveforms[el] += d[el][::-1].copy()
|
90 |
+
else:
|
91 |
+
waveforms[el] += d[el]
|
92 |
+
for el in waveforms:
|
93 |
+
waveforms[el] = waveforms[el] / len(full_result)
|
94 |
+
|
95 |
+
file_name, _ = os.path.splitext(os.path.basename(path))
|
96 |
+
song_dir = os.path.join(args.store_dir, file_name)
|
97 |
+
if not os.path.exists(song_dir):
|
98 |
+
os.makedirs(song_dir)
|
99 |
+
|
100 |
+
model_dir = os.path.join(song_dir, args.model_type)
|
101 |
+
if not os.path.exists(model_dir):
|
102 |
+
os.makedirs(model_dir)
|
103 |
+
|
104 |
+
for instr in instruments:
|
105 |
+
estimates = waveforms[instr].T
|
106 |
+
if 'normalize' in config.inference:
|
107 |
+
if config.inference['normalize'] is True:
|
108 |
+
estimates = estimates * std + mean
|
109 |
+
if args.flac_file:
|
110 |
+
output_file = os.path.join(model_dir, f"{file_name}_{instr}.flac")
|
111 |
+
subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24'
|
112 |
+
sf.write(output_file, estimates, sr, subtype=subtype)
|
113 |
+
else:
|
114 |
+
output_file = os.path.join(model_dir, f"{file_name}_{instr}.wav")
|
115 |
+
sf.write(output_file, estimates, sr, subtype='FLOAT')
|
116 |
+
|
117 |
+
# Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent)
|
118 |
+
if args.extract_instrumental:
|
119 |
+
if 'vocals' in instruments:
|
120 |
+
estimates = waveforms['vocals'].T
|
121 |
+
else:
|
122 |
+
estimates = waveforms[instruments[0]].T
|
123 |
+
if 'normalize' in config.inference:
|
124 |
+
if config.inference['normalize'] is True:
|
125 |
+
estimates = estimates * std + mean
|
126 |
+
if args.flac_file:
|
127 |
+
instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.flac")
|
128 |
+
subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24'
|
129 |
+
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype=subtype)
|
130 |
+
else:
|
131 |
+
instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.wav")
|
132 |
+
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT')
|
133 |
+
|
134 |
+
time.sleep(1)
|
135 |
+
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
136 |
+
|
137 |
+
def proc_folder_direct(model_type, config_path, start_check_point, input_folder, store_dir, device_ids=[0], extract_instrumental=False, disable_detailed_pbar=False, force_cpu=False, flac_file=False, pcm_type='PCM_24', use_tta=False):
|
138 |
+
device = "cpu"
|
139 |
+
if force_cpu:
|
140 |
+
device = "cpu"
|
141 |
+
elif torch.cuda.is_available():
|
142 |
+
print('CUDA is available, use --force_cpu to disable it.')
|
143 |
+
device = "cuda"
|
144 |
+
device = f'cuda:{device_ids}' if type(device_ids) == int else f'cuda:{device_ids[0]}'
|
145 |
+
elif torch.backends.mps.is_available():
|
146 |
+
device = "mps"
|
147 |
+
|
148 |
+
print("Using device: ", device)
|
149 |
+
|
150 |
+
model_load_start_time = time.time()
|
151 |
+
torch.backends.cudnn.benchmark = True
|
152 |
+
|
153 |
+
model, config = get_model_from_config(model_type, config_path)
|
154 |
+
if start_check_point != '':
|
155 |
+
print('Start from checkpoint: {}'.format(start_check_point))
|
156 |
+
if model_type == 'htdemucs':
|
157 |
+
state_dict = torch.load(start_check_point, map_location=device, weights_only=False)
|
158 |
+
if 'state' in state_dict:
|
159 |
+
state_dict = state_dict['state']
|
160 |
+
else:
|
161 |
+
state_dict = torch.load(start_check_point, map_location=device, weights_only=True)
|
162 |
+
model.load_state_dict(state_dict)
|
163 |
+
print("Instruments: {}".format(config.training.instruments))
|
164 |
+
|
165 |
+
if type(device_ids) != int:
|
166 |
+
model = nn.DataParallel(model, device_ids=device_ids)
|
167 |
+
|
168 |
+
model = model.to(device)
|
169 |
+
|
170 |
+
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
|
171 |
+
|
172 |
+
args = argparse.Namespace(
|
173 |
+
model_type=model_type,
|
174 |
+
config_path=config_path,
|
175 |
+
start_check_point=start_check_point,
|
176 |
+
input_folder=input_folder,
|
177 |
+
store_dir=store_dir,
|
178 |
+
device_ids=device_ids,
|
179 |
+
extract_instrumental=extract_instrumental,
|
180 |
+
disable_detailed_pbar=disable_detailed_pbar,
|
181 |
+
force_cpu=force_cpu,
|
182 |
+
flac_file=flac_file,
|
183 |
+
pcm_type=pcm_type,
|
184 |
+
use_tta=use_tta
|
185 |
+
)
|
186 |
+
|
187 |
+
run_folder(model, args, config, device, verbose=True)
|
utils.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import yaml
|
6 |
+
from ml_collections import ConfigDict
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
def get_model_from_config(model_type, config_path):
|
11 |
+
with open(config_path) as f:
|
12 |
+
if model_type == 'htdemucs':
|
13 |
+
config = OmegaConf.load(config_path)
|
14 |
+
else:
|
15 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
16 |
+
|
17 |
+
if model_type == 'htdemucs':
|
18 |
+
from models.demucs4ht import get_model
|
19 |
+
model = get_model(config)
|
20 |
+
elif model_type == 'mel_band_roformer':
|
21 |
+
from models.bs_roformer import MelBandRoformer
|
22 |
+
model = MelBandRoformer(
|
23 |
+
**dict(config.model)
|
24 |
+
)
|
25 |
+
elif model_type == 'bs_roformer':
|
26 |
+
from models.bs_roformer import BSRoformer
|
27 |
+
model = BSRoformer(
|
28 |
+
**dict(config.model)
|
29 |
+
)
|
30 |
+
elif model_type == 'scnet':
|
31 |
+
from models.scnet import SCNet
|
32 |
+
model = SCNet(
|
33 |
+
**dict(config.model)
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
print('Unknown model: {}'.format(model_type))
|
37 |
+
model = None
|
38 |
+
|
39 |
+
return model, config
|
40 |
+
|
41 |
+
def _getWindowingArray(window_size, fade_size):
|
42 |
+
fadein = torch.linspace(0, 1, fade_size)
|
43 |
+
fadeout = torch.linspace(1, 0, fade_size)
|
44 |
+
window = torch.ones(window_size)
|
45 |
+
window[-fade_size:] *= fadeout
|
46 |
+
window[:fade_size] *= fadein
|
47 |
+
return window
|
48 |
+
|
49 |
+
|
50 |
+
def demix_track(config, model, mix, device, pbar=False):
|
51 |
+
C = config.audio.chunk_size
|
52 |
+
N = config.inference.num_overlap
|
53 |
+
fade_size = C // 10
|
54 |
+
step = int(C // N)
|
55 |
+
border = C - step
|
56 |
+
batch_size = config.inference.batch_size
|
57 |
+
|
58 |
+
length_init = mix.shape[-1]
|
59 |
+
|
60 |
+
# Do pad from the beginning and end to account floating window results better
|
61 |
+
if length_init > 2 * border and (border > 0):
|
62 |
+
mix = nn.functional.pad(mix, (border, border), mode='reflect')
|
63 |
+
|
64 |
+
# windowingArray crossfades at segment boundaries to mitigate clicking artifacts
|
65 |
+
windowingArray = _getWindowingArray(C, fade_size)
|
66 |
+
|
67 |
+
with torch.cuda.amp.autocast(enabled=config.training.use_amp):
|
68 |
+
with torch.inference_mode():
|
69 |
+
if config.training.target_instrument is not None:
|
70 |
+
req_shape = (1, ) + tuple(mix.shape)
|
71 |
+
else:
|
72 |
+
req_shape = (len(config.training.instruments),) + tuple(mix.shape)
|
73 |
+
|
74 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
75 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
76 |
+
i = 0
|
77 |
+
batch_data = []
|
78 |
+
batch_locations = []
|
79 |
+
progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
|
80 |
+
|
81 |
+
while i < mix.shape[1]:
|
82 |
+
# print(i, i + C, mix.shape[1])
|
83 |
+
part = mix[:, i:i + C].to(device)
|
84 |
+
length = part.shape[-1]
|
85 |
+
if length < C:
|
86 |
+
if length > C // 2 + 1:
|
87 |
+
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
88 |
+
else:
|
89 |
+
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
90 |
+
batch_data.append(part)
|
91 |
+
batch_locations.append((i, length))
|
92 |
+
i += step
|
93 |
+
|
94 |
+
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
95 |
+
arr = torch.stack(batch_data, dim=0)
|
96 |
+
x = model(arr)
|
97 |
+
|
98 |
+
window = windowingArray
|
99 |
+
if i - step == 0: # First audio chunk, no fadein
|
100 |
+
window[:fade_size] = 1
|
101 |
+
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
|
102 |
+
window[-fade_size:] = 1
|
103 |
+
|
104 |
+
for j in range(len(batch_locations)):
|
105 |
+
start, l = batch_locations[j]
|
106 |
+
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
|
107 |
+
counter[..., start:start+l] += window[..., :l]
|
108 |
+
|
109 |
+
batch_data = []
|
110 |
+
batch_locations = []
|
111 |
+
|
112 |
+
if progress_bar:
|
113 |
+
progress_bar.update(step)
|
114 |
+
|
115 |
+
if progress_bar:
|
116 |
+
progress_bar.close()
|
117 |
+
|
118 |
+
estimated_sources = result / counter
|
119 |
+
estimated_sources = estimated_sources.cpu().numpy()
|
120 |
+
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
121 |
+
|
122 |
+
if length_init > 2 * border and (border > 0):
|
123 |
+
# Remove pad
|
124 |
+
estimated_sources = estimated_sources[..., border:-border]
|
125 |
+
|
126 |
+
if config.training.target_instrument is None:
|
127 |
+
return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
|
128 |
+
else:
|
129 |
+
return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
|
130 |
+
|
131 |
+
|
132 |
+
def demix_track_demucs(config, model, mix, device, pbar=False):
|
133 |
+
S = len(config.training.instruments)
|
134 |
+
C = config.training.samplerate * config.training.segment
|
135 |
+
N = config.inference.num_overlap
|
136 |
+
batch_size = config.inference.batch_size
|
137 |
+
step = C // N
|
138 |
+
# print(S, C, N, step, mix.shape, mix.device)
|
139 |
+
|
140 |
+
with torch.cuda.amp.autocast(enabled=config.training.use_amp):
|
141 |
+
with torch.inference_mode():
|
142 |
+
req_shape = (S, ) + tuple(mix.shape)
|
143 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
144 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
145 |
+
i = 0
|
146 |
+
batch_data = []
|
147 |
+
batch_locations = []
|
148 |
+
progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
|
149 |
+
|
150 |
+
while i < mix.shape[1]:
|
151 |
+
# print(i, i + C, mix.shape[1])
|
152 |
+
part = mix[:, i:i + C].to(device)
|
153 |
+
length = part.shape[-1]
|
154 |
+
if length < C:
|
155 |
+
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
156 |
+
batch_data.append(part)
|
157 |
+
batch_locations.append((i, length))
|
158 |
+
i += step
|
159 |
+
|
160 |
+
|
161 |
+
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
162 |
+
arr = torch.stack(batch_data, dim=0)
|
163 |
+
x = model(arr)
|
164 |
+
for j in range(len(batch_locations)):
|
165 |
+
start, l = batch_locations[j]
|
166 |
+
result[..., start:start+l] += x[j][..., :l].cpu()
|
167 |
+
counter[..., start:start+l] += 1.
|
168 |
+
batch_data = []
|
169 |
+
batch_locations = []
|
170 |
+
|
171 |
+
if progress_bar:
|
172 |
+
progress_bar.update(step)
|
173 |
+
|
174 |
+
if progress_bar:
|
175 |
+
progress_bar.close()
|
176 |
+
|
177 |
+
estimated_sources = result / counter
|
178 |
+
estimated_sources = estimated_sources.cpu().numpy()
|
179 |
+
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
180 |
+
|
181 |
+
if S > 1:
|
182 |
+
return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
|
183 |
+
else:
|
184 |
+
return estimated_sources
|
185 |
+
|
186 |
+
|
187 |
+
def sdr(references, estimates):
|
188 |
+
# compute SDR for one song
|
189 |
+
delta = 1e-7 # avoid numerical errors
|
190 |
+
num = np.sum(np.square(references), axis=(1, 2))
|
191 |
+
den = np.sum(np.square(references - estimates), axis=(1, 2))
|
192 |
+
num += delta
|
193 |
+
den += delta
|
194 |
+
return 10 * np.log10(num / den)
|