Awell00 commited on
Commit
d3a31f9
·
verified ·
1 Parent(s): b5c212a

feat!: add inference and utils files for model features and sound separation

Browse files
Files changed (2) hide show
  1. inference.py +187 -0
  2. utils.py +194 -0
inference.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import librosa
4
+ from tqdm import tqdm
5
+ import sys
6
+ import os
7
+ import glob
8
+ import torch
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torch.nn as nn
12
+
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ sys.path.append(current_dir)
15
+ from utils import demix_track, demix_track_demucs, get_model_from_config
16
+
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
+
21
+ def run_folder(model, args, config, device, verbose=False):
22
+ start_time = time.time()
23
+ model.eval()
24
+ all_mixtures_path = glob.glob(args.input_folder + '/*.*')
25
+ all_mixtures_path.sort()
26
+ print('Total files found: {}'.format(len(all_mixtures_path)))
27
+
28
+ instruments = config.training.instruments
29
+ if config.training.target_instrument is not None:
30
+ instruments = [config.training.target_instrument]
31
+
32
+ if not os.path.isdir(args.store_dir):
33
+ os.mkdir(args.store_dir)
34
+
35
+ if not verbose:
36
+ all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress")
37
+
38
+ if args.disable_detailed_pbar:
39
+ detailed_pbar = False
40
+ else:
41
+ detailed_pbar = True
42
+
43
+ for path in all_mixtures_path:
44
+ print("Starting processing track: ", path)
45
+ if not verbose:
46
+ all_mixtures_path.set_postfix({'track': os.path.basename(path)})
47
+ try:
48
+ mix, sr = librosa.load(path, sr=44100, mono=False)
49
+ except Exception as e:
50
+ print('Cannot read track: {}'.format(path))
51
+ print('Error message: {}'.format(str(e)))
52
+ continue
53
+
54
+ # Convert mono to stereo if needed
55
+ if len(mix.shape) == 1:
56
+ mix = np.stack([mix, mix], axis=0)
57
+
58
+ mix_orig = mix.copy()
59
+ if 'normalize' in config.inference:
60
+ if config.inference['normalize'] is True:
61
+ mono = mix.mean(0)
62
+ mean = mono.mean()
63
+ std = mono.std()
64
+ mix = (mix - mean) / std
65
+
66
+ if args.use_tta:
67
+ # orig, channel inverse, polarity inverse
68
+ track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()]
69
+ else:
70
+ track_proc_list = [mix.copy()]
71
+
72
+ full_result = []
73
+ for single_track in track_proc_list:
74
+ mixture = torch.tensor(single_track, dtype=torch.float32)
75
+ if args.model_type == 'htdemucs':
76
+ waveforms = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar)
77
+ else:
78
+ waveforms = demix_track(config, model, mixture, device, pbar=detailed_pbar)
79
+ full_result.append(waveforms)
80
+
81
+ # Average all values in single dict
82
+ waveforms = full_result[0]
83
+ for i in range(1, len(full_result)):
84
+ d = full_result[i]
85
+ for el in d:
86
+ if i == 2:
87
+ waveforms[el] += -1.0 * d[el]
88
+ elif i == 1:
89
+ waveforms[el] += d[el][::-1].copy()
90
+ else:
91
+ waveforms[el] += d[el]
92
+ for el in waveforms:
93
+ waveforms[el] = waveforms[el] / len(full_result)
94
+
95
+ file_name, _ = os.path.splitext(os.path.basename(path))
96
+ song_dir = os.path.join(args.store_dir, file_name)
97
+ if not os.path.exists(song_dir):
98
+ os.makedirs(song_dir)
99
+
100
+ model_dir = os.path.join(song_dir, args.model_type)
101
+ if not os.path.exists(model_dir):
102
+ os.makedirs(model_dir)
103
+
104
+ for instr in instruments:
105
+ estimates = waveforms[instr].T
106
+ if 'normalize' in config.inference:
107
+ if config.inference['normalize'] is True:
108
+ estimates = estimates * std + mean
109
+ if args.flac_file:
110
+ output_file = os.path.join(model_dir, f"{file_name}_{instr}.flac")
111
+ subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24'
112
+ sf.write(output_file, estimates, sr, subtype=subtype)
113
+ else:
114
+ output_file = os.path.join(model_dir, f"{file_name}_{instr}.wav")
115
+ sf.write(output_file, estimates, sr, subtype='FLOAT')
116
+
117
+ # Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent)
118
+ if args.extract_instrumental:
119
+ if 'vocals' in instruments:
120
+ estimates = waveforms['vocals'].T
121
+ else:
122
+ estimates = waveforms[instruments[0]].T
123
+ if 'normalize' in config.inference:
124
+ if config.inference['normalize'] is True:
125
+ estimates = estimates * std + mean
126
+ if args.flac_file:
127
+ instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.flac")
128
+ subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24'
129
+ sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype=subtype)
130
+ else:
131
+ instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.wav")
132
+ sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT')
133
+
134
+ time.sleep(1)
135
+ print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
136
+
137
+ def proc_folder_direct(model_type, config_path, start_check_point, input_folder, store_dir, device_ids=[0], extract_instrumental=False, disable_detailed_pbar=False, force_cpu=False, flac_file=False, pcm_type='PCM_24', use_tta=False):
138
+ device = "cpu"
139
+ if force_cpu:
140
+ device = "cpu"
141
+ elif torch.cuda.is_available():
142
+ print('CUDA is available, use --force_cpu to disable it.')
143
+ device = "cuda"
144
+ device = f'cuda:{device_ids}' if type(device_ids) == int else f'cuda:{device_ids[0]}'
145
+ elif torch.backends.mps.is_available():
146
+ device = "mps"
147
+
148
+ print("Using device: ", device)
149
+
150
+ model_load_start_time = time.time()
151
+ torch.backends.cudnn.benchmark = True
152
+
153
+ model, config = get_model_from_config(model_type, config_path)
154
+ if start_check_point != '':
155
+ print('Start from checkpoint: {}'.format(start_check_point))
156
+ if model_type == 'htdemucs':
157
+ state_dict = torch.load(start_check_point, map_location=device, weights_only=False)
158
+ if 'state' in state_dict:
159
+ state_dict = state_dict['state']
160
+ else:
161
+ state_dict = torch.load(start_check_point, map_location=device, weights_only=True)
162
+ model.load_state_dict(state_dict)
163
+ print("Instruments: {}".format(config.training.instruments))
164
+
165
+ if type(device_ids) != int:
166
+ model = nn.DataParallel(model, device_ids=device_ids)
167
+
168
+ model = model.to(device)
169
+
170
+ print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
171
+
172
+ args = argparse.Namespace(
173
+ model_type=model_type,
174
+ config_path=config_path,
175
+ start_check_point=start_check_point,
176
+ input_folder=input_folder,
177
+ store_dir=store_dir,
178
+ device_ids=device_ids,
179
+ extract_instrumental=extract_instrumental,
180
+ disable_detailed_pbar=disable_detailed_pbar,
181
+ force_cpu=force_cpu,
182
+ flac_file=flac_file,
183
+ pcm_type=pcm_type,
184
+ use_tta=use_tta
185
+ )
186
+
187
+ run_folder(model, args, config, device, verbose=True)
utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import yaml
6
+ from ml_collections import ConfigDict
7
+ from omegaconf import OmegaConf
8
+ from tqdm import tqdm
9
+
10
+ def get_model_from_config(model_type, config_path):
11
+ with open(config_path) as f:
12
+ if model_type == 'htdemucs':
13
+ config = OmegaConf.load(config_path)
14
+ else:
15
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
16
+
17
+ if model_type == 'htdemucs':
18
+ from models.demucs4ht import get_model
19
+ model = get_model(config)
20
+ elif model_type == 'mel_band_roformer':
21
+ from models.bs_roformer import MelBandRoformer
22
+ model = MelBandRoformer(
23
+ **dict(config.model)
24
+ )
25
+ elif model_type == 'bs_roformer':
26
+ from models.bs_roformer import BSRoformer
27
+ model = BSRoformer(
28
+ **dict(config.model)
29
+ )
30
+ elif model_type == 'scnet':
31
+ from models.scnet import SCNet
32
+ model = SCNet(
33
+ **dict(config.model)
34
+ )
35
+ else:
36
+ print('Unknown model: {}'.format(model_type))
37
+ model = None
38
+
39
+ return model, config
40
+
41
+ def _getWindowingArray(window_size, fade_size):
42
+ fadein = torch.linspace(0, 1, fade_size)
43
+ fadeout = torch.linspace(1, 0, fade_size)
44
+ window = torch.ones(window_size)
45
+ window[-fade_size:] *= fadeout
46
+ window[:fade_size] *= fadein
47
+ return window
48
+
49
+
50
+ def demix_track(config, model, mix, device, pbar=False):
51
+ C = config.audio.chunk_size
52
+ N = config.inference.num_overlap
53
+ fade_size = C // 10
54
+ step = int(C // N)
55
+ border = C - step
56
+ batch_size = config.inference.batch_size
57
+
58
+ length_init = mix.shape[-1]
59
+
60
+ # Do pad from the beginning and end to account floating window results better
61
+ if length_init > 2 * border and (border > 0):
62
+ mix = nn.functional.pad(mix, (border, border), mode='reflect')
63
+
64
+ # windowingArray crossfades at segment boundaries to mitigate clicking artifacts
65
+ windowingArray = _getWindowingArray(C, fade_size)
66
+
67
+ with torch.cuda.amp.autocast(enabled=config.training.use_amp):
68
+ with torch.inference_mode():
69
+ if config.training.target_instrument is not None:
70
+ req_shape = (1, ) + tuple(mix.shape)
71
+ else:
72
+ req_shape = (len(config.training.instruments),) + tuple(mix.shape)
73
+
74
+ result = torch.zeros(req_shape, dtype=torch.float32)
75
+ counter = torch.zeros(req_shape, dtype=torch.float32)
76
+ i = 0
77
+ batch_data = []
78
+ batch_locations = []
79
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
80
+
81
+ while i < mix.shape[1]:
82
+ # print(i, i + C, mix.shape[1])
83
+ part = mix[:, i:i + C].to(device)
84
+ length = part.shape[-1]
85
+ if length < C:
86
+ if length > C // 2 + 1:
87
+ part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
88
+ else:
89
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
90
+ batch_data.append(part)
91
+ batch_locations.append((i, length))
92
+ i += step
93
+
94
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
95
+ arr = torch.stack(batch_data, dim=0)
96
+ x = model(arr)
97
+
98
+ window = windowingArray
99
+ if i - step == 0: # First audio chunk, no fadein
100
+ window[:fade_size] = 1
101
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
102
+ window[-fade_size:] = 1
103
+
104
+ for j in range(len(batch_locations)):
105
+ start, l = batch_locations[j]
106
+ result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
107
+ counter[..., start:start+l] += window[..., :l]
108
+
109
+ batch_data = []
110
+ batch_locations = []
111
+
112
+ if progress_bar:
113
+ progress_bar.update(step)
114
+
115
+ if progress_bar:
116
+ progress_bar.close()
117
+
118
+ estimated_sources = result / counter
119
+ estimated_sources = estimated_sources.cpu().numpy()
120
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
121
+
122
+ if length_init > 2 * border and (border > 0):
123
+ # Remove pad
124
+ estimated_sources = estimated_sources[..., border:-border]
125
+
126
+ if config.training.target_instrument is None:
127
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
128
+ else:
129
+ return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
130
+
131
+
132
+ def demix_track_demucs(config, model, mix, device, pbar=False):
133
+ S = len(config.training.instruments)
134
+ C = config.training.samplerate * config.training.segment
135
+ N = config.inference.num_overlap
136
+ batch_size = config.inference.batch_size
137
+ step = C // N
138
+ # print(S, C, N, step, mix.shape, mix.device)
139
+
140
+ with torch.cuda.amp.autocast(enabled=config.training.use_amp):
141
+ with torch.inference_mode():
142
+ req_shape = (S, ) + tuple(mix.shape)
143
+ result = torch.zeros(req_shape, dtype=torch.float32)
144
+ counter = torch.zeros(req_shape, dtype=torch.float32)
145
+ i = 0
146
+ batch_data = []
147
+ batch_locations = []
148
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
149
+
150
+ while i < mix.shape[1]:
151
+ # print(i, i + C, mix.shape[1])
152
+ part = mix[:, i:i + C].to(device)
153
+ length = part.shape[-1]
154
+ if length < C:
155
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
156
+ batch_data.append(part)
157
+ batch_locations.append((i, length))
158
+ i += step
159
+
160
+
161
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
162
+ arr = torch.stack(batch_data, dim=0)
163
+ x = model(arr)
164
+ for j in range(len(batch_locations)):
165
+ start, l = batch_locations[j]
166
+ result[..., start:start+l] += x[j][..., :l].cpu()
167
+ counter[..., start:start+l] += 1.
168
+ batch_data = []
169
+ batch_locations = []
170
+
171
+ if progress_bar:
172
+ progress_bar.update(step)
173
+
174
+ if progress_bar:
175
+ progress_bar.close()
176
+
177
+ estimated_sources = result / counter
178
+ estimated_sources = estimated_sources.cpu().numpy()
179
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
180
+
181
+ if S > 1:
182
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
183
+ else:
184
+ return estimated_sources
185
+
186
+
187
+ def sdr(references, estimates):
188
+ # compute SDR for one song
189
+ delta = 1e-7 # avoid numerical errors
190
+ num = np.sum(np.square(references), axis=(1, 2))
191
+ den = np.sum(np.square(references - estimates), axis=(1, 2))
192
+ num += delta
193
+ den += delta
194
+ return 10 * np.log10(num / den)