Awell00 commited on
Commit
0672c99
·
1 Parent(s): d9d3a0c

feat!: add configuration for inference process

Browse files
models/bs_roformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.bs_roformer.bs_roformer import BSRoformer
2
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
models/bs_roformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (347 Bytes). View file
 
models/bs_roformer/__pycache__/attend.cpython-311.pyc ADDED
Binary file (6.14 kB). View file
 
models/bs_roformer/__pycache__/bs_roformer.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
models/bs_roformer/__pycache__/mel_band_roformer.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
models/bs_roformer/attend.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import os
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, reduce
11
+
12
+ # constants
13
+
14
+ FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
15
+
16
+ # helpers
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+ def default(v, d):
22
+ return v if exists(v) else d
23
+
24
+ def once(fn):
25
+ called = False
26
+ @wraps(fn)
27
+ def inner(x):
28
+ nonlocal called
29
+ if called:
30
+ return
31
+ called = True
32
+ return fn(x)
33
+ return inner
34
+
35
+ print_once = once(print)
36
+
37
+ # main class
38
+
39
+ class Attend(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dropout = 0.,
43
+ flash = False,
44
+ scale = None
45
+ ):
46
+ super().__init__()
47
+ self.scale = scale
48
+ self.dropout = dropout
49
+ self.attn_dropout = nn.Dropout(dropout)
50
+
51
+ self.flash = flash
52
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
53
+
54
+ # determine efficient attention configs for cuda and cpu
55
+
56
+ self.cpu_config = FlashAttentionConfig(True, True, True)
57
+ self.cuda_config = None
58
+
59
+ if not torch.cuda.is_available() or not flash:
60
+ return
61
+
62
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
63
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
64
+
65
+ if device_version >= version.parse('8.0'):
66
+ if os.name == 'nt':
67
+ print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
68
+ self.cuda_config = FlashAttentionConfig(False, True, True)
69
+ else:
70
+ print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
71
+ self.cuda_config = FlashAttentionConfig(True, False, False)
72
+ else:
73
+ print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
74
+ self.cuda_config = FlashAttentionConfig(False, True, True)
75
+
76
+ def flash_attn(self, q, k, v):
77
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
78
+
79
+ if exists(self.scale):
80
+ default_scale = q.shape[-1] ** -0.5
81
+ q = q * (self.scale / default_scale)
82
+
83
+ # Check if there is a compatible device for flash attention
84
+
85
+ config = self.cuda_config if is_cuda else self.cpu_config
86
+
87
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
88
+
89
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
90
+ out = F.scaled_dot_product_attention(
91
+ q, k, v,
92
+ dropout_p = self.dropout if self.training else 0.
93
+ )
94
+
95
+ return out
96
+
97
+ def forward(self, q, k, v):
98
+ """
99
+ einstein notation
100
+ b - batch
101
+ h - heads
102
+ n, i, j - sequence length (base sequence length, source, target)
103
+ d - feature dimension
104
+ """
105
+
106
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
107
+
108
+ scale = default(self.scale, q.shape[-1] ** -0.5)
109
+
110
+ if self.flash:
111
+ return self.flash_attn(q, k, v)
112
+
113
+ # similarity
114
+
115
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
116
+
117
+ # attention
118
+
119
+ attn = sim.softmax(dim=-1)
120
+ attn = self.attn_dropout(attn)
121
+
122
+ # aggregate values
123
+
124
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
125
+
126
+ return out
models/bs_roformer/bs_roformer.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack
16
+ from einops.layers.torch import Rearrange
17
+
18
+ # helper functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+
28
+ def pack_one(t, pattern):
29
+ return pack([t], pattern)
30
+
31
+
32
+ def unpack_one(t, ps, pattern):
33
+ return unpack(t, ps, pattern)[0]
34
+
35
+
36
+ # norm
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim = -1, p = 2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim ** 0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ # attention
53
+
54
+ class FeedForward(Module):
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ mult=4,
59
+ dropout=0.
60
+ ):
61
+ super().__init__()
62
+ dim_inner = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ RMSNorm(dim),
65
+ nn.Linear(dim, dim_inner),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(dim_inner, dim),
69
+ nn.Dropout(dropout)
70
+ )
71
+
72
+ def forward(self, x):
73
+ return self.net(x)
74
+
75
+
76
+ class Attention(Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.,
83
+ rotary_embed=None,
84
+ flash=True
85
+ ):
86
+ super().__init__()
87
+ self.heads = heads
88
+ self.scale = dim_head ** -0.5
89
+ dim_inner = heads * dim_head
90
+
91
+ self.rotary_embed = rotary_embed
92
+
93
+ self.attend = Attend(flash=flash, dropout=dropout)
94
+
95
+ self.norm = RMSNorm(dim)
96
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
+
98
+ self.to_gates = nn.Linear(dim, heads)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(dim_inner, dim, bias=False),
102
+ nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
109
+
110
+ if exists(self.rotary_embed):
111
+ q = self.rotary_embed.rotate_queries_or_keys(q)
112
+ k = self.rotary_embed.rotate_queries_or_keys(k)
113
+
114
+ out = self.attend(q, k, v)
115
+
116
+ gates = self.to_gates(x)
117
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
118
+
119
+ out = rearrange(out, 'b h n d -> b n (h d)')
120
+ return self.to_out(out)
121
+
122
+
123
+ class LinearAttention(Module):
124
+ """
125
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
126
+ """
127
+
128
+ @beartype
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim,
133
+ dim_head=32,
134
+ heads=8,
135
+ scale=8,
136
+ flash=False,
137
+ dropout=0.
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
155
+
156
+ self.to_out = nn.Sequential(
157
+ Rearrange('b h d n -> b n (h d)'),
158
+ nn.Linear(dim_inner, dim, bias=False)
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ x
164
+ ):
165
+ x = self.norm(x)
166
+
167
+ q, k, v = self.to_qkv(x)
168
+
169
+ q, k = map(l2norm, (q, k))
170
+ q = q * self.temperature.exp()
171
+
172
+ out = self.attend(q, k, v)
173
+
174
+ return self.to_out(out)
175
+
176
+
177
+ class Transformer(Module):
178
+ def __init__(
179
+ self,
180
+ *,
181
+ dim,
182
+ depth,
183
+ dim_head=64,
184
+ heads=8,
185
+ attn_dropout=0.,
186
+ ff_dropout=0.,
187
+ ff_mult=4,
188
+ norm_output=True,
189
+ rotary_embed=None,
190
+ flash_attn=True,
191
+ linear_attn=False
192
+ ):
193
+ super().__init__()
194
+ self.layers = ModuleList([])
195
+
196
+ for _ in range(depth):
197
+ if linear_attn:
198
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
199
+ else:
200
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
201
+ rotary_embed=rotary_embed, flash=flash_attn)
202
+
203
+ self.layers.append(ModuleList([
204
+ attn,
205
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
206
+ ]))
207
+
208
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
209
+
210
+ def forward(self, x):
211
+
212
+ for attn, ff in self.layers:
213
+ x = attn(x) + x
214
+ x = ff(x) + x
215
+
216
+ return self.norm(x)
217
+
218
+
219
+ # bandsplit module
220
+
221
+ class BandSplit(Module):
222
+ @beartype
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ dim_inputs: Tuple[int, ...]
227
+ ):
228
+ super().__init__()
229
+ self.dim_inputs = dim_inputs
230
+ self.to_features = ModuleList([])
231
+
232
+ for dim_in in dim_inputs:
233
+ net = nn.Sequential(
234
+ RMSNorm(dim_in),
235
+ nn.Linear(dim_in, dim)
236
+ )
237
+
238
+ self.to_features.append(net)
239
+
240
+ def forward(self, x):
241
+ x = x.split(self.dim_inputs, dim=-1)
242
+
243
+ outs = []
244
+ for split_input, to_feature in zip(x, self.to_features):
245
+ split_output = to_feature(split_input)
246
+ outs.append(split_output)
247
+
248
+ return torch.stack(outs, dim=-2)
249
+
250
+
251
+ def MLP(
252
+ dim_in,
253
+ dim_out,
254
+ dim_hidden=None,
255
+ depth=1,
256
+ activation=nn.Tanh
257
+ ):
258
+ dim_hidden = default(dim_hidden, dim_in)
259
+
260
+ net = []
261
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
262
+
263
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
264
+ is_last = ind == (len(dims) - 2)
265
+
266
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
267
+
268
+ if is_last:
269
+ continue
270
+
271
+ net.append(activation())
272
+
273
+ return nn.Sequential(*net)
274
+
275
+
276
+ class MaskEstimator(Module):
277
+ @beartype
278
+ def __init__(
279
+ self,
280
+ dim,
281
+ dim_inputs: Tuple[int, ...],
282
+ depth,
283
+ mlp_expansion_factor=4
284
+ ):
285
+ super().__init__()
286
+ self.dim_inputs = dim_inputs
287
+ self.to_freqs = ModuleList([])
288
+ dim_hidden = dim * mlp_expansion_factor
289
+
290
+ for dim_in in dim_inputs:
291
+ net = []
292
+
293
+ mlp = nn.Sequential(
294
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
295
+ nn.GLU(dim=-1)
296
+ )
297
+
298
+ self.to_freqs.append(mlp)
299
+
300
+ def forward(self, x):
301
+ x = x.unbind(dim=-2)
302
+
303
+ outs = []
304
+
305
+ for band_features, mlp in zip(x, self.to_freqs):
306
+ freq_out = mlp(band_features)
307
+ outs.append(freq_out)
308
+
309
+ return torch.cat(outs, dim=-1)
310
+
311
+
312
+ # main class
313
+
314
+ DEFAULT_FREQS_PER_BANDS = (
315
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
316
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
317
+ 2, 2, 2, 2,
318
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
319
+ 12, 12, 12, 12, 12, 12, 12, 12,
320
+ 24, 24, 24, 24, 24, 24, 24, 24,
321
+ 48, 48, 48, 48, 48, 48, 48, 48,
322
+ 128, 129,
323
+ )
324
+
325
+
326
+ class BSRoformer(Module):
327
+
328
+ @beartype
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ *,
333
+ depth,
334
+ stereo=False,
335
+ num_stems=1,
336
+ time_transformer_depth=2,
337
+ freq_transformer_depth=2,
338
+ linear_transformer_depth=0,
339
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
340
+ # in the paper, they divide into ~60 bands, test with 1 for starters
341
+ dim_head=64,
342
+ heads=8,
343
+ attn_dropout=0.,
344
+ ff_dropout=0.,
345
+ flash_attn=True,
346
+ dim_freqs_in=1025,
347
+ stft_n_fft=2048,
348
+ stft_hop_length=512,
349
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
350
+ stft_win_length=2048,
351
+ stft_normalized=False,
352
+ stft_window_fn: Optional[Callable] = None,
353
+ mask_estimator_depth=2,
354
+ multi_stft_resolution_loss_weight=1.,
355
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
356
+ multi_stft_hop_size=147,
357
+ multi_stft_normalized=False,
358
+ multi_stft_window_fn: Callable = torch.hann_window
359
+ ):
360
+ super().__init__()
361
+
362
+ self.stereo = stereo
363
+ self.audio_channels = 2 if stereo else 1
364
+ self.num_stems = num_stems
365
+
366
+ self.layers = ModuleList([])
367
+
368
+ transformer_kwargs = dict(
369
+ dim=dim,
370
+ heads=heads,
371
+ dim_head=dim_head,
372
+ attn_dropout=attn_dropout,
373
+ ff_dropout=ff_dropout,
374
+ flash_attn=flash_attn,
375
+ norm_output=False
376
+ )
377
+
378
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
379
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
380
+
381
+ for _ in range(depth):
382
+ tran_modules = []
383
+ if linear_transformer_depth > 0:
384
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
385
+ tran_modules.append(
386
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
387
+ )
388
+ tran_modules.append(
389
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
390
+ )
391
+ self.layers.append(nn.ModuleList(tran_modules))
392
+
393
+ self.final_norm = RMSNorm(dim)
394
+
395
+ self.stft_kwargs = dict(
396
+ n_fft=stft_n_fft,
397
+ hop_length=stft_hop_length,
398
+ win_length=stft_win_length,
399
+ normalized=stft_normalized
400
+ )
401
+
402
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
403
+
404
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
405
+
406
+ assert len(freqs_per_bands) > 1
407
+ assert sum(
408
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
409
+
410
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
411
+
412
+ self.band_split = BandSplit(
413
+ dim=dim,
414
+ dim_inputs=freqs_per_bands_with_complex
415
+ )
416
+
417
+ self.mask_estimators = nn.ModuleList([])
418
+
419
+ for _ in range(num_stems):
420
+ mask_estimator = MaskEstimator(
421
+ dim=dim,
422
+ dim_inputs=freqs_per_bands_with_complex,
423
+ depth=mask_estimator_depth
424
+ )
425
+
426
+ self.mask_estimators.append(mask_estimator)
427
+
428
+ # for the multi-resolution stft loss
429
+
430
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
431
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
432
+ self.multi_stft_n_fft = stft_n_fft
433
+ self.multi_stft_window_fn = multi_stft_window_fn
434
+
435
+ self.multi_stft_kwargs = dict(
436
+ hop_length=multi_stft_hop_size,
437
+ normalized=multi_stft_normalized
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ raw_audio,
443
+ target=None,
444
+ return_loss_breakdown=False
445
+ ):
446
+ """
447
+ einops
448
+
449
+ b - batch
450
+ f - freq
451
+ t - time
452
+ s - audio channel (1 for mono, 2 for stereo)
453
+ n - number of 'stems'
454
+ c - complex (2)
455
+ d - feature dimension
456
+ """
457
+
458
+ device = raw_audio.device
459
+
460
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
461
+ x_is_mps = True if device.type == "mps" else False
462
+
463
+ if raw_audio.ndim == 2:
464
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
465
+
466
+ channels = raw_audio.shape[1]
467
+ assert (not self.stereo and channels == 1) or (
468
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
469
+
470
+ # to stft
471
+
472
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
473
+
474
+ stft_window = self.stft_window_fn(device=device)
475
+
476
+ # RuntimeError: FFT operations are only supported on MacOS 14+
477
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
478
+ try:
479
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
480
+ except:
481
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
482
+
483
+ stft_repr = torch.view_as_real(stft_repr)
484
+
485
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
486
+ stft_repr = rearrange(stft_repr,
487
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
488
+
489
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
490
+
491
+ x = self.band_split(x)
492
+
493
+ # axial / hierarchical attention
494
+
495
+ for transformer_block in self.layers:
496
+
497
+ if len(transformer_block) == 3:
498
+ linear_transformer, time_transformer, freq_transformer = transformer_block
499
+
500
+ x, ft_ps = pack([x], 'b * d')
501
+ x = linear_transformer(x)
502
+ x, = unpack(x, ft_ps, 'b * d')
503
+ else:
504
+ time_transformer, freq_transformer = transformer_block
505
+
506
+ x = rearrange(x, 'b t f d -> b f t d')
507
+ x, ps = pack([x], '* t d')
508
+
509
+ x = time_transformer(x)
510
+
511
+ x, = unpack(x, ps, '* t d')
512
+ x = rearrange(x, 'b f t d -> b t f d')
513
+ x, ps = pack([x], '* f d')
514
+
515
+ x = freq_transformer(x)
516
+
517
+ x, = unpack(x, ps, '* f d')
518
+
519
+ x = self.final_norm(x)
520
+
521
+ num_stems = len(self.mask_estimators)
522
+
523
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
524
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
525
+
526
+ # modulate frequency representation
527
+
528
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
529
+
530
+ # complex number multiplication
531
+
532
+ stft_repr = torch.view_as_complex(stft_repr)
533
+ mask = torch.view_as_complex(mask)
534
+
535
+ stft_repr = stft_repr * mask
536
+
537
+ # istft
538
+
539
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
540
+
541
+ # same as torch.stft() fix for MacOS MPS above
542
+ try:
543
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
544
+ except:
545
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device)
546
+
547
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
548
+
549
+ if num_stems == 1:
550
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
551
+
552
+ # if a target is passed in, calculate loss for learning
553
+
554
+ if not exists(target):
555
+ return recon_audio
556
+
557
+ if self.num_stems > 1:
558
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
559
+
560
+ if target.ndim == 2:
561
+ target = rearrange(target, '... t -> ... 1 t')
562
+
563
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
564
+
565
+ loss = F.l1_loss(recon_audio, target)
566
+
567
+ multi_stft_resolution_loss = 0.
568
+
569
+ for window_size in self.multi_stft_resolutions_window_sizes:
570
+ res_stft_kwargs = dict(
571
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
572
+ win_length=window_size,
573
+ return_complex=True,
574
+ window=self.multi_stft_window_fn(window_size, device=device),
575
+ **self.multi_stft_kwargs,
576
+ )
577
+
578
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
579
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
580
+
581
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
582
+
583
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
584
+
585
+ total_loss = loss + weighted_multi_resolution_loss
586
+
587
+ if not return_loss_breakdown:
588
+ return total_loss
589
+
590
+ return total_loss, (loss, multi_stft_resolution_loss)
models/bs_roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from librosa import filters
19
+
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+
31
+ def pack_one(t, pattern):
32
+ return pack([t], pattern)
33
+
34
+
35
+ def unpack_one(t, ps, pattern):
36
+ return unpack(t, ps, pattern)[0]
37
+
38
+
39
+ def pad_at_dim(t, pad, dim=-1, value=0.):
40
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
+ zeros = ((0, 0) * dims_from_right)
42
+ return F.pad(t, (*zeros, *pad), value=value)
43
+
44
+
45
+ def l2norm(t):
46
+ return F.normalize(t, dim=-1, p=2)
47
+
48
+
49
+ # norm
50
+
51
+ class RMSNorm(Module):
52
+ def __init__(self, dim):
53
+ super().__init__()
54
+ self.scale = dim ** 0.5
55
+ self.gamma = nn.Parameter(torch.ones(dim))
56
+
57
+ def forward(self, x):
58
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
59
+
60
+
61
+ # attention
62
+
63
+ class FeedForward(Module):
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ mult=4,
68
+ dropout=0.
69
+ ):
70
+ super().__init__()
71
+ dim_inner = int(dim * mult)
72
+ self.net = nn.Sequential(
73
+ RMSNorm(dim),
74
+ nn.Linear(dim, dim_inner),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout),
77
+ nn.Linear(dim_inner, dim),
78
+ nn.Dropout(dropout)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+
85
+ class Attention(Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ heads=8,
90
+ dim_head=64,
91
+ dropout=0.,
92
+ rotary_embed=None,
93
+ flash=True
94
+ ):
95
+ super().__init__()
96
+ self.heads = heads
97
+ self.scale = dim_head ** -0.5
98
+ dim_inner = heads * dim_head
99
+
100
+ self.rotary_embed = rotary_embed
101
+
102
+ self.attend = Attend(flash=flash, dropout=dropout)
103
+
104
+ self.norm = RMSNorm(dim)
105
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
+
107
+ self.to_gates = nn.Linear(dim, heads)
108
+
109
+ self.to_out = nn.Sequential(
110
+ nn.Linear(dim_inner, dim, bias=False),
111
+ nn.Dropout(dropout)
112
+ )
113
+
114
+ def forward(self, x):
115
+ x = self.norm(x)
116
+
117
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
+
119
+ if exists(self.rotary_embed):
120
+ q = self.rotary_embed.rotate_queries_or_keys(q)
121
+ k = self.rotary_embed.rotate_queries_or_keys(k)
122
+
123
+ out = self.attend(q, k, v)
124
+
125
+ gates = self.to_gates(x)
126
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
+
128
+ out = rearrange(out, 'b h n d -> b n (h d)')
129
+ return self.to_out(out)
130
+
131
+
132
+ class LinearAttention(Module):
133
+ """
134
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
+ """
136
+
137
+ @beartype
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim,
142
+ dim_head=32,
143
+ heads=8,
144
+ scale=8,
145
+ flash=False,
146
+ dropout=0.
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ self.attend = Attend(
160
+ scale=scale,
161
+ dropout=dropout,
162
+ flash=flash
163
+ )
164
+
165
+ self.to_out = nn.Sequential(
166
+ Rearrange('b h d n -> b n (h d)'),
167
+ nn.Linear(dim_inner, dim, bias=False)
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ x
173
+ ):
174
+ x = self.norm(x)
175
+
176
+ q, k, v = self.to_qkv(x)
177
+
178
+ q, k = map(l2norm, (q, k))
179
+ q = q * self.temperature.exp()
180
+
181
+ out = self.attend(q, k, v)
182
+
183
+ return self.to_out(out)
184
+
185
+
186
+ class Transformer(Module):
187
+ def __init__(
188
+ self,
189
+ *,
190
+ dim,
191
+ depth,
192
+ dim_head=64,
193
+ heads=8,
194
+ attn_dropout=0.,
195
+ ff_dropout=0.,
196
+ ff_mult=4,
197
+ norm_output=True,
198
+ rotary_embed=None,
199
+ flash_attn=True,
200
+ linear_attn=False
201
+ ):
202
+ super().__init__()
203
+ self.layers = ModuleList([])
204
+
205
+ for _ in range(depth):
206
+ if linear_attn:
207
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
+ else:
209
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
+ rotary_embed=rotary_embed, flash=flash_attn)
211
+
212
+ self.layers.append(ModuleList([
213
+ attn,
214
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
+ ]))
216
+
217
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
+
219
+ def forward(self, x):
220
+
221
+ for attn, ff in self.layers:
222
+ x = attn(x) + x
223
+ x = ff(x) + x
224
+
225
+ return self.norm(x)
226
+
227
+
228
+ # bandsplit module
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ dim_inputs: Tuple[int, ...]
236
+ ):
237
+ super().__init__()
238
+ self.dim_inputs = dim_inputs
239
+ self.to_features = ModuleList([])
240
+
241
+ for dim_in in dim_inputs:
242
+ net = nn.Sequential(
243
+ RMSNorm(dim_in),
244
+ nn.Linear(dim_in, dim)
245
+ )
246
+
247
+ self.to_features.append(net)
248
+
249
+ def forward(self, x):
250
+ x = x.split(self.dim_inputs, dim=-1)
251
+
252
+ outs = []
253
+ for split_input, to_feature in zip(x, self.to_features):
254
+ split_output = to_feature(split_input)
255
+ outs.append(split_output)
256
+
257
+ return torch.stack(outs, dim=-2)
258
+
259
+
260
+ def MLP(
261
+ dim_in,
262
+ dim_out,
263
+ dim_hidden=None,
264
+ depth=1,
265
+ activation=nn.Tanh
266
+ ):
267
+ dim_hidden = default(dim_hidden, dim_in)
268
+
269
+ net = []
270
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
+
272
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
+ is_last = ind == (len(dims) - 2)
274
+
275
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
+
277
+ if is_last:
278
+ continue
279
+
280
+ net.append(activation())
281
+
282
+ return nn.Sequential(*net)
283
+
284
+
285
+ class MaskEstimator(Module):
286
+ @beartype
287
+ def __init__(
288
+ self,
289
+ dim,
290
+ dim_inputs: Tuple[int, ...],
291
+ depth,
292
+ mlp_expansion_factor=4
293
+ ):
294
+ super().__init__()
295
+ self.dim_inputs = dim_inputs
296
+ self.to_freqs = ModuleList([])
297
+ dim_hidden = dim * mlp_expansion_factor
298
+
299
+ for dim_in in dim_inputs:
300
+ net = []
301
+
302
+ mlp = nn.Sequential(
303
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
+ nn.GLU(dim=-1)
305
+ )
306
+
307
+ self.to_freqs.append(mlp)
308
+
309
+ def forward(self, x):
310
+ x = x.unbind(dim=-2)
311
+
312
+ outs = []
313
+
314
+ for band_features, mlp in zip(x, self.to_freqs):
315
+ freq_out = mlp(band_features)
316
+ outs.append(freq_out)
317
+
318
+ return torch.cat(outs, dim=-1)
319
+
320
+
321
+ # main class
322
+
323
+ class MelBandRoformer(Module):
324
+
325
+ @beartype
326
+ def __init__(
327
+ self,
328
+ dim,
329
+ *,
330
+ depth,
331
+ stereo=False,
332
+ num_stems=1,
333
+ time_transformer_depth=2,
334
+ freq_transformer_depth=2,
335
+ linear_transformer_depth=0,
336
+ num_bands=60,
337
+ dim_head=64,
338
+ heads=8,
339
+ attn_dropout=0.1,
340
+ ff_dropout=0.1,
341
+ flash_attn=True,
342
+ dim_freqs_in=1025,
343
+ sample_rate=44100, # needed for mel filter bank from librosa
344
+ stft_n_fft=2048,
345
+ stft_hop_length=512,
346
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
+ stft_win_length=2048,
348
+ stft_normalized=False,
349
+ stft_window_fn: Optional[Callable] = None,
350
+ mask_estimator_depth=1,
351
+ multi_stft_resolution_loss_weight=1.,
352
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
+ multi_stft_hop_size=147,
354
+ multi_stft_normalized=False,
355
+ multi_stft_window_fn: Callable = torch.hann_window,
356
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
+ ):
358
+ super().__init__()
359
+
360
+ self.stereo = stereo
361
+ self.audio_channels = 2 if stereo else 1
362
+ self.num_stems = num_stems
363
+
364
+ self.layers = ModuleList([])
365
+
366
+ transformer_kwargs = dict(
367
+ dim=dim,
368
+ heads=heads,
369
+ dim_head=dim_head,
370
+ attn_dropout=attn_dropout,
371
+ ff_dropout=ff_dropout,
372
+ flash_attn=flash_attn
373
+ )
374
+
375
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
+
378
+ for _ in range(depth):
379
+ tran_modules = []
380
+ if linear_transformer_depth > 0:
381
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
+ tran_modules.append(
383
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
+ )
385
+ tran_modules.append(
386
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
+ )
388
+ self.layers.append(nn.ModuleList(tran_modules))
389
+
390
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
+
392
+ self.stft_kwargs = dict(
393
+ n_fft=stft_n_fft,
394
+ hop_length=stft_hop_length,
395
+ win_length=stft_win_length,
396
+ normalized=stft_normalized
397
+ )
398
+
399
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
400
+
401
+ # create mel filter bank
402
+ # with librosa.filters.mel as in section 2 of paper
403
+
404
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
+
406
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
+
408
+ # for some reason, it doesn't include the first freq? just force a value for now
409
+
410
+ mel_filter_bank[0][0] = 1.
411
+
412
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
+ # so let's force a positive value
414
+
415
+ mel_filter_bank[-1, -1] = 1.
416
+
417
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
418
+
419
+ freqs_per_band = mel_filter_bank > 0
420
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
+
422
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
+ freq_indices = repeated_freq_indices[freqs_per_band]
424
+
425
+ if stereo:
426
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
+ freq_indices = freq_indices * 2 + torch.arange(2)
428
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
+
430
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
431
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
+
433
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
+
436
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
+
439
+ # band split and mask estimator
440
+
441
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
+
443
+ self.band_split = BandSplit(
444
+ dim=dim,
445
+ dim_inputs=freqs_per_bands_with_complex
446
+ )
447
+
448
+ self.mask_estimators = nn.ModuleList([])
449
+
450
+ for _ in range(num_stems):
451
+ mask_estimator = MaskEstimator(
452
+ dim=dim,
453
+ dim_inputs=freqs_per_bands_with_complex,
454
+ depth=mask_estimator_depth
455
+ )
456
+
457
+ self.mask_estimators.append(mask_estimator)
458
+
459
+ # for the multi-resolution stft loss
460
+
461
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
+ self.multi_stft_n_fft = stft_n_fft
464
+ self.multi_stft_window_fn = multi_stft_window_fn
465
+
466
+ self.multi_stft_kwargs = dict(
467
+ hop_length=multi_stft_hop_size,
468
+ normalized=multi_stft_normalized
469
+ )
470
+
471
+ self.match_input_audio_length = match_input_audio_length
472
+
473
+ def forward(
474
+ self,
475
+ raw_audio,
476
+ target=None,
477
+ return_loss_breakdown=False
478
+ ):
479
+ """
480
+ einops
481
+
482
+ b - batch
483
+ f - freq
484
+ t - time
485
+ s - audio channel (1 for mono, 2 for stereo)
486
+ n - number of 'stems'
487
+ c - complex (2)
488
+ d - feature dimension
489
+ """
490
+
491
+ device = raw_audio.device
492
+
493
+ if raw_audio.ndim == 2:
494
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
+
496
+ batch, channels, raw_audio_length = raw_audio.shape
497
+
498
+ istft_length = raw_audio_length if self.match_input_audio_length else None
499
+
500
+ assert (not self.stereo and channels == 1) or (
501
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
+
503
+ # to stft
504
+
505
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
+
507
+ stft_window = self.stft_window_fn(device=device)
508
+
509
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
+ stft_repr = torch.view_as_real(stft_repr)
511
+
512
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
+ stft_repr = rearrange(stft_repr,
514
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
+
516
+ # index out all frequencies for all frequency ranges across bands ascending in one go
517
+
518
+ batch_arange = torch.arange(batch, device=device)[..., None]
519
+
520
+ # account for stereo
521
+
522
+ x = stft_repr[batch_arange, self.freq_indices]
523
+
524
+ # fold the complex (real and imag) into the frequencies dimension
525
+
526
+ x = rearrange(x, 'b f t c -> b t (f c)')
527
+
528
+ x = self.band_split(x)
529
+
530
+ # axial / hierarchical attention
531
+
532
+ for transformer_block in self.layers:
533
+
534
+ if len(transformer_block) == 3:
535
+ linear_transformer, time_transformer, freq_transformer = transformer_block
536
+
537
+ x, ft_ps = pack([x], 'b * d')
538
+ x = linear_transformer(x)
539
+ x, = unpack(x, ft_ps, 'b * d')
540
+ else:
541
+ time_transformer, freq_transformer = transformer_block
542
+
543
+ x = rearrange(x, 'b t f d -> b f t d')
544
+ x, ps = pack([x], '* t d')
545
+
546
+ x = time_transformer(x)
547
+
548
+ x, = unpack(x, ps, '* t d')
549
+ x = rearrange(x, 'b f t d -> b t f d')
550
+ x, ps = pack([x], '* f d')
551
+
552
+ x = freq_transformer(x)
553
+
554
+ x, = unpack(x, ps, '* f d')
555
+
556
+ num_stems = len(self.mask_estimators)
557
+
558
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
+
561
+ # modulate frequency representation
562
+
563
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
+
565
+ # complex number multiplication
566
+
567
+ stft_repr = torch.view_as_complex(stft_repr)
568
+ masks = torch.view_as_complex(masks)
569
+
570
+ masks = masks.type(stft_repr.dtype)
571
+
572
+ # need to average the estimated mask for the overlapped frequencies
573
+
574
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
+
576
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
+
579
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
+
581
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
+
583
+ # modulate stft repr with estimated mask
584
+
585
+ stft_repr = stft_repr * masks_averaged
586
+
587
+ # istft
588
+
589
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
+
591
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
+ length=istft_length)
593
+
594
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
+
596
+ if num_stems == 1:
597
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
+
599
+ # if a target is passed in, calculate loss for learning
600
+
601
+ if not exists(target):
602
+ return recon_audio
603
+
604
+ if self.num_stems > 1:
605
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
606
+
607
+ if target.ndim == 2:
608
+ target = rearrange(target, '... t -> ... 1 t')
609
+
610
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
+
612
+ loss = F.l1_loss(recon_audio, target)
613
+
614
+ multi_stft_resolution_loss = 0.
615
+
616
+ for window_size in self.multi_stft_resolutions_window_sizes:
617
+ res_stft_kwargs = dict(
618
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
+ win_length=window_size,
620
+ return_complex=True,
621
+ window=self.multi_stft_window_fn(window_size, device=device),
622
+ **self.multi_stft_kwargs,
623
+ )
624
+
625
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
+
628
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
+
630
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
+
632
+ total_loss = loss + weighted_multi_resolution_loss
633
+
634
+ if not return_loss_breakdown:
635
+ return total_loss
636
+
637
+ return total_loss, (loss, multi_stft_resolution_loss)
models/demucs4ht.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ from omegaconf import OmegaConf
10
+ from demucs.demucs import Demucs
11
+ from demucs.hdemucs import HDemucs
12
+
13
+ import math
14
+ from openunmix.filtering import wiener
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from fractions import Fraction
18
+ from einops import rearrange
19
+
20
+ from demucs.transformer import CrossTransformerEncoder
21
+
22
+ from demucs.demucs import rescale_module
23
+ from demucs.states import capture_init
24
+ from demucs.spec import spectro, ispectro
25
+ from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
26
+
27
+
28
+ class HTDemucs(nn.Module):
29
+ """
30
+ Spectrogram and hybrid Demucs model.
31
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
32
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
33
+ Frequency layers can still access information across time steps thanks to the DConv residual.
34
+
35
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
36
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
37
+
38
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
39
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
40
+ Open Unmix implementation [Stoter et al. 2019].
41
+
42
+ The loss is always on the temporal domain, by backpropagating through the above
43
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
44
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
45
+ contribution, without changing the one from the waveform, which will lead to worse performance.
46
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
47
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
48
+ hybrid models.
49
+
50
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
51
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
52
+
53
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
54
+ """
55
+
56
+ @capture_init
57
+ def __init__(
58
+ self,
59
+ sources,
60
+ # Channels
61
+ audio_channels=2,
62
+ channels=48,
63
+ channels_time=None,
64
+ growth=2,
65
+ # STFT
66
+ nfft=4096,
67
+ num_subbands=1,
68
+ wiener_iters=0,
69
+ end_iters=0,
70
+ wiener_residual=False,
71
+ cac=True,
72
+ # Main structure
73
+ depth=4,
74
+ rewrite=True,
75
+ # Frequency branch
76
+ multi_freqs=None,
77
+ multi_freqs_depth=3,
78
+ freq_emb=0.2,
79
+ emb_scale=10,
80
+ emb_smooth=True,
81
+ # Convolutions
82
+ kernel_size=8,
83
+ time_stride=2,
84
+ stride=4,
85
+ context=1,
86
+ context_enc=0,
87
+ # Normalization
88
+ norm_starts=4,
89
+ norm_groups=4,
90
+ # DConv residual branch
91
+ dconv_mode=1,
92
+ dconv_depth=2,
93
+ dconv_comp=8,
94
+ dconv_init=1e-3,
95
+ # Before the Transformer
96
+ bottom_channels=0,
97
+ # Transformer
98
+ t_layers=5,
99
+ t_emb="sin",
100
+ t_hidden_scale=4.0,
101
+ t_heads=8,
102
+ t_dropout=0.0,
103
+ t_max_positions=10000,
104
+ t_norm_in=True,
105
+ t_norm_in_group=False,
106
+ t_group_norm=False,
107
+ t_norm_first=True,
108
+ t_norm_out=True,
109
+ t_max_period=10000.0,
110
+ t_weight_decay=0.0,
111
+ t_lr=None,
112
+ t_layer_scale=True,
113
+ t_gelu=True,
114
+ t_weight_pos_embed=1.0,
115
+ t_sin_random_shift=0,
116
+ t_cape_mean_normalize=True,
117
+ t_cape_augment=True,
118
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
119
+ t_sparse_self_attn=False,
120
+ t_sparse_cross_attn=False,
121
+ t_mask_type="diag",
122
+ t_mask_random_seed=42,
123
+ t_sparse_attn_window=500,
124
+ t_global_window=100,
125
+ t_sparsity=0.95,
126
+ t_auto_sparsity=False,
127
+ # ------ Particuliar parameters
128
+ t_cross_first=False,
129
+ # Weight init
130
+ rescale=0.1,
131
+ # Metadata
132
+ samplerate=44100,
133
+ segment=10,
134
+ use_train_segment=False,
135
+ ):
136
+ """
137
+ Args:
138
+ sources (list[str]): list of source names.
139
+ audio_channels (int): input/output audio channels.
140
+ channels (int): initial number of hidden channels.
141
+ channels_time: if not None, use a different `channels` value for the time branch.
142
+ growth: increase the number of hidden channels by this factor at each layer.
143
+ nfft: number of fft bins. Note that changing this require careful computation of
144
+ various shape parameters and will not work out of the box for hybrid models.
145
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
146
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
147
+ wiener_residual: add residual source before wiener filtering.
148
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
149
+ in input and output. no further processing is done before ISTFT.
150
+ depth (int): number of layers in the encoder and in the decoder.
151
+ rewrite (bool): add 1x1 convolution to each layer.
152
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
153
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
154
+ layers will be wrapped.
155
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
156
+ the actual value controls the weight of the embedding.
157
+ emb_scale: equivalent to scaling the embedding learning rate
158
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
159
+ kernel_size: kernel_size for encoder and decoder layers.
160
+ stride: stride for encoder and decoder layers.
161
+ time_stride: stride for the final time layer, after the merge.
162
+ context: context for 1x1 conv in the decoder.
163
+ context_enc: context for 1x1 conv in the encoder.
164
+ norm_starts: layer at which group norm starts being used.
165
+ decoder layers are numbered in reverse order.
166
+ norm_groups: number of groups for group norm.
167
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
168
+ dconv_depth: depth of residual DConv branch.
169
+ dconv_comp: compression of DConv branch.
170
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
171
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
172
+ dconv_init: initial scale for the DConv branch LayerScale.
173
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
174
+ transformer in order to change the number of channels
175
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
176
+ t_emb: "sin", "cape" or "scaled"
177
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
178
+ for instance if C = 384 (the number of channels in the transformer) and
179
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
180
+ 384 * 4 = 1536
181
+ t_heads: number of heads for the transformer
182
+ t_dropout: dropout in the transformer
183
+ t_max_positions: max_positions for the "scaled" positional embedding, only
184
+ useful if t_emb="scaled"
185
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
186
+ transformer layers
187
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
190
+ timesteps (GroupNorm with group=1)
191
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
192
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
193
+ t_max_period: (float) denominator in the sinusoidal embedding expression
194
+ t_weight_decay: (float) weight decay for the transformer
195
+ t_lr: (float) specific learning rate for the transformer
196
+ t_layer_scale: (bool) Layer Scale for the transformer
197
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
198
+ t_weight_pos_embed: (float) weighting of the positional embedding
199
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
200
+ see: https://arxiv.org/abs/2106.03143
201
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
202
+ during the inference, see: https://arxiv.org/abs/2106.03143
203
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
204
+ see: https://arxiv.org/abs/2106.03143
205
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
206
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
207
+ unless you designed really specific masks)
208
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
209
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
210
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
211
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
212
+ that generated the random part of the mask
213
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
214
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
215
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
216
+ and mask[:, :t_global_window] will be True
217
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
218
+ level of the random part of the mask.
219
+ t_cross_first: (bool) if True cross attention is the first layer of the
220
+ transformer (False seems to be better)
221
+ rescale: weight rescaling trick
222
+ use_train_segment: (bool) if True, the actual size that is used during the
223
+ training is used during inference.
224
+ """
225
+ super().__init__()
226
+ self.num_subbands = num_subbands
227
+ self.cac = cac
228
+ self.wiener_residual = wiener_residual
229
+ self.audio_channels = audio_channels
230
+ self.sources = sources
231
+ self.kernel_size = kernel_size
232
+ self.context = context
233
+ self.stride = stride
234
+ self.depth = depth
235
+ self.bottom_channels = bottom_channels
236
+ self.channels = channels
237
+ self.samplerate = samplerate
238
+ self.segment = segment
239
+ self.use_train_segment = use_train_segment
240
+ self.nfft = nfft
241
+ self.hop_length = nfft // 4
242
+ self.wiener_iters = wiener_iters
243
+ self.end_iters = end_iters
244
+ self.freq_emb = None
245
+ assert wiener_iters == end_iters
246
+
247
+ self.encoder = nn.ModuleList()
248
+ self.decoder = nn.ModuleList()
249
+
250
+ self.tencoder = nn.ModuleList()
251
+ self.tdecoder = nn.ModuleList()
252
+
253
+ chin = audio_channels
254
+ chin_z = chin # number of channels for the freq branch
255
+ if self.cac:
256
+ chin_z *= 2
257
+ if self.num_subbands > 1:
258
+ chin_z *= self.num_subbands
259
+ chout = channels_time or channels
260
+ chout_z = channels
261
+ freqs = nfft // 2
262
+
263
+ for index in range(depth):
264
+ norm = index >= norm_starts
265
+ freq = freqs > 1
266
+ stri = stride
267
+ ker = kernel_size
268
+ if not freq:
269
+ assert freqs == 1
270
+ ker = time_stride * 2
271
+ stri = time_stride
272
+
273
+ pad = True
274
+ last_freq = False
275
+ if freq and freqs <= kernel_size:
276
+ ker = freqs
277
+ pad = False
278
+ last_freq = True
279
+
280
+ kw = {
281
+ "kernel_size": ker,
282
+ "stride": stri,
283
+ "freq": freq,
284
+ "pad": pad,
285
+ "norm": norm,
286
+ "rewrite": rewrite,
287
+ "norm_groups": norm_groups,
288
+ "dconv_kw": {
289
+ "depth": dconv_depth,
290
+ "compress": dconv_comp,
291
+ "init": dconv_init,
292
+ "gelu": True,
293
+ },
294
+ }
295
+ kwt = dict(kw)
296
+ kwt["freq"] = 0
297
+ kwt["kernel_size"] = kernel_size
298
+ kwt["stride"] = stride
299
+ kwt["pad"] = True
300
+ kw_dec = dict(kw)
301
+ multi = False
302
+ if multi_freqs and index < multi_freqs_depth:
303
+ multi = True
304
+ kw_dec["context_freq"] = False
305
+
306
+ if last_freq:
307
+ chout_z = max(chout, chout_z)
308
+ chout = chout_z
309
+
310
+ enc = HEncLayer(
311
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
312
+ )
313
+ if freq:
314
+ tenc = HEncLayer(
315
+ chin,
316
+ chout,
317
+ dconv=dconv_mode & 1,
318
+ context=context_enc,
319
+ empty=last_freq,
320
+ **kwt
321
+ )
322
+ self.tencoder.append(tenc)
323
+
324
+ if multi:
325
+ enc = MultiWrap(enc, multi_freqs)
326
+ self.encoder.append(enc)
327
+ if index == 0:
328
+ chin = self.audio_channels * len(self.sources)
329
+ chin_z = chin
330
+ if self.cac:
331
+ chin_z *= 2
332
+ if self.num_subbands > 1:
333
+ chin_z *= self.num_subbands
334
+ dec = HDecLayer(
335
+ chout_z,
336
+ chin_z,
337
+ dconv=dconv_mode & 2,
338
+ last=index == 0,
339
+ context=context,
340
+ **kw_dec
341
+ )
342
+ if multi:
343
+ dec = MultiWrap(dec, multi_freqs)
344
+ if freq:
345
+ tdec = HDecLayer(
346
+ chout,
347
+ chin,
348
+ dconv=dconv_mode & 2,
349
+ empty=last_freq,
350
+ last=index == 0,
351
+ context=context,
352
+ **kwt
353
+ )
354
+ self.tdecoder.insert(0, tdec)
355
+ self.decoder.insert(0, dec)
356
+
357
+ chin = chout
358
+ chin_z = chout_z
359
+ chout = int(growth * chout)
360
+ chout_z = int(growth * chout_z)
361
+ if freq:
362
+ if freqs <= kernel_size:
363
+ freqs = 1
364
+ else:
365
+ freqs //= stride
366
+ if index == 0 and freq_emb:
367
+ self.freq_emb = ScaledEmbedding(
368
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
369
+ )
370
+ self.freq_emb_scale = freq_emb
371
+
372
+ if rescale:
373
+ rescale_module(self, reference=rescale)
374
+
375
+ transformer_channels = channels * growth ** (depth - 1)
376
+ if bottom_channels:
377
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
378
+ self.channel_downsampler = nn.Conv1d(
379
+ bottom_channels, transformer_channels, 1
380
+ )
381
+ self.channel_upsampler_t = nn.Conv1d(
382
+ transformer_channels, bottom_channels, 1
383
+ )
384
+ self.channel_downsampler_t = nn.Conv1d(
385
+ bottom_channels, transformer_channels, 1
386
+ )
387
+
388
+ transformer_channels = bottom_channels
389
+
390
+ if t_layers > 0:
391
+ self.crosstransformer = CrossTransformerEncoder(
392
+ dim=transformer_channels,
393
+ emb=t_emb,
394
+ hidden_scale=t_hidden_scale,
395
+ num_heads=t_heads,
396
+ num_layers=t_layers,
397
+ cross_first=t_cross_first,
398
+ dropout=t_dropout,
399
+ max_positions=t_max_positions,
400
+ norm_in=t_norm_in,
401
+ norm_in_group=t_norm_in_group,
402
+ group_norm=t_group_norm,
403
+ norm_first=t_norm_first,
404
+ norm_out=t_norm_out,
405
+ max_period=t_max_period,
406
+ weight_decay=t_weight_decay,
407
+ lr=t_lr,
408
+ layer_scale=t_layer_scale,
409
+ gelu=t_gelu,
410
+ sin_random_shift=t_sin_random_shift,
411
+ weight_pos_embed=t_weight_pos_embed,
412
+ cape_mean_normalize=t_cape_mean_normalize,
413
+ cape_augment=t_cape_augment,
414
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
415
+ sparse_self_attn=t_sparse_self_attn,
416
+ sparse_cross_attn=t_sparse_cross_attn,
417
+ mask_type=t_mask_type,
418
+ mask_random_seed=t_mask_random_seed,
419
+ sparse_attn_window=t_sparse_attn_window,
420
+ global_window=t_global_window,
421
+ sparsity=t_sparsity,
422
+ auto_sparsity=t_auto_sparsity,
423
+ )
424
+ else:
425
+ self.crosstransformer = None
426
+
427
+ def _spec(self, x):
428
+ hl = self.hop_length
429
+ nfft = self.nfft
430
+ x0 = x # noqa
431
+
432
+ # We re-pad the signal in order to keep the property
433
+ # that the size of the output is exactly the size of the input
434
+ # divided by the stride (here hop_length), when divisible.
435
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
436
+ # which is not supported by torch.stft.
437
+ # Having all convolution operations follow this convention allow to easily
438
+ # align the time and frequency branches later on.
439
+ assert hl == nfft // 4
440
+ le = int(math.ceil(x.shape[-1] / hl))
441
+ pad = hl // 2 * 3
442
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
443
+
444
+ z = spectro(x, nfft, hl)[..., :-1, :]
445
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
446
+ z = z[..., 2: 2 + le]
447
+ return z
448
+
449
+ def _ispec(self, z, length=None, scale=0):
450
+ hl = self.hop_length // (4**scale)
451
+ z = F.pad(z, (0, 0, 0, 1))
452
+ z = F.pad(z, (2, 2))
453
+ pad = hl // 2 * 3
454
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
455
+ x = ispectro(z, hl, length=le)
456
+ x = x[..., pad: pad + length]
457
+ return x
458
+
459
+ def _magnitude(self, z):
460
+ # return the magnitude of the spectrogram, except when cac is True,
461
+ # in which case we just move the complex dimension to the channel one.
462
+ if self.cac:
463
+ B, C, Fr, T = z.shape
464
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
465
+ m = m.reshape(B, C * 2, Fr, T)
466
+ else:
467
+ m = z.abs()
468
+ return m
469
+
470
+ def _mask(self, z, m):
471
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
472
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
473
+ niters = self.wiener_iters
474
+ if self.cac:
475
+ B, S, C, Fr, T = m.shape
476
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
477
+ out = torch.view_as_complex(out.contiguous())
478
+ return out
479
+ if self.training:
480
+ niters = self.end_iters
481
+ if niters < 0:
482
+ z = z[:, None]
483
+ return z / (1e-8 + z.abs()) * m
484
+ else:
485
+ return self._wiener(m, z, niters)
486
+
487
+ def _wiener(self, mag_out, mix_stft, niters):
488
+ # apply wiener filtering from OpenUnmix.
489
+ init = mix_stft.dtype
490
+ wiener_win_len = 300
491
+ residual = self.wiener_residual
492
+
493
+ B, S, C, Fq, T = mag_out.shape
494
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
495
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
496
+
497
+ outs = []
498
+ for sample in range(B):
499
+ pos = 0
500
+ out = []
501
+ for pos in range(0, T, wiener_win_len):
502
+ frame = slice(pos, pos + wiener_win_len)
503
+ z_out = wiener(
504
+ mag_out[sample, frame],
505
+ mix_stft[sample, frame],
506
+ niters,
507
+ residual=residual,
508
+ )
509
+ out.append(z_out.transpose(-1, -2))
510
+ outs.append(torch.cat(out, dim=0))
511
+ out = torch.view_as_complex(torch.stack(outs, 0))
512
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
513
+ if residual:
514
+ out = out[:, :-1]
515
+ assert list(out.shape) == [B, S, C, Fq, T]
516
+ return out.to(init)
517
+
518
+ def valid_length(self, length: int):
519
+ """
520
+ Return a length that is appropriate for evaluation.
521
+ In our case, always return the training length, unless
522
+ it is smaller than the given length, in which case this
523
+ raises an error.
524
+ """
525
+ if not self.use_train_segment:
526
+ return length
527
+ training_length = int(self.segment * self.samplerate)
528
+ if training_length < length:
529
+ raise ValueError(
530
+ f"Given length {length} is longer than "
531
+ f"training length {training_length}")
532
+ return training_length
533
+
534
+ def cac2cws(self, x):
535
+ k = self.num_subbands
536
+ b, c, f, t = x.shape
537
+ x = x.reshape(b, c, k, f // k, t)
538
+ x = x.reshape(b, c * k, f // k, t)
539
+ return x
540
+
541
+ def cws2cac(self, x):
542
+ k = self.num_subbands
543
+ b, c, f, t = x.shape
544
+ x = x.reshape(b, c // k, k, f, t)
545
+ x = x.reshape(b, c // k, f * k, t)
546
+ return x
547
+
548
+ def forward(self, mix):
549
+ length = mix.shape[-1]
550
+ length_pre_pad = None
551
+ if self.use_train_segment:
552
+ if self.training:
553
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
554
+ else:
555
+ training_length = int(self.segment * self.samplerate)
556
+ # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate))
557
+ if mix.shape[-1] < training_length:
558
+ length_pre_pad = mix.shape[-1]
559
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
560
+ # print("Mix: {}".format(mix.shape))
561
+ # print("Length: {}".format(length))
562
+ z = self._spec(mix)
563
+ # print("Z: {} Type: {}".format(z.shape, z.dtype))
564
+ mag = self._magnitude(z)
565
+ x = mag
566
+ # print("MAG: {} Type: {}".format(x.shape, x.dtype))
567
+
568
+ if self.num_subbands > 1:
569
+ x = self.cac2cws(x)
570
+ # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype))
571
+
572
+ B, C, Fq, T = x.shape
573
+
574
+ # unlike previous Demucs, we always normalize because it is easier.
575
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
576
+ std = x.std(dim=(1, 2, 3), keepdim=True)
577
+ x = (x - mean) / (1e-5 + std)
578
+ # x will be the freq. branch input.
579
+
580
+ # Prepare the time branch input.
581
+ xt = mix
582
+ meant = xt.mean(dim=(1, 2), keepdim=True)
583
+ stdt = xt.std(dim=(1, 2), keepdim=True)
584
+ xt = (xt - meant) / (1e-5 + stdt)
585
+
586
+ # print("XT: {}".format(xt.shape))
587
+
588
+ # okay, this is a giant mess I know...
589
+ saved = [] # skip connections, freq.
590
+ saved_t = [] # skip connections, time.
591
+ lengths = [] # saved lengths to properly remove padding, freq branch.
592
+ lengths_t = [] # saved lengths for time branch.
593
+ for idx, encode in enumerate(self.encoder):
594
+ lengths.append(x.shape[-1])
595
+ inject = None
596
+ if idx < len(self.tencoder):
597
+ # we have not yet merged branches.
598
+ lengths_t.append(xt.shape[-1])
599
+ tenc = self.tencoder[idx]
600
+ xt = tenc(xt)
601
+ # print("Encode XT {}: {}".format(idx, xt.shape))
602
+ if not tenc.empty:
603
+ # save for skip connection
604
+ saved_t.append(xt)
605
+ else:
606
+ # tenc contains just the first conv., so that now time and freq.
607
+ # branches have the same shape and can be merged.
608
+ inject = xt
609
+ x = encode(x, inject)
610
+ # print("Encode X {}: {}".format(idx, x.shape))
611
+ if idx == 0 and self.freq_emb is not None:
612
+ # add frequency embedding to allow for non equivariant convolutions
613
+ # over the frequency axis.
614
+ frs = torch.arange(x.shape[-2], device=x.device)
615
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
616
+ x = x + self.freq_emb_scale * emb
617
+
618
+ saved.append(x)
619
+ if self.crosstransformer:
620
+ if self.bottom_channels:
621
+ b, c, f, t = x.shape
622
+ x = rearrange(x, "b c f t-> b c (f t)")
623
+ x = self.channel_upsampler(x)
624
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
625
+ xt = self.channel_upsampler_t(xt)
626
+
627
+ x, xt = self.crosstransformer(x, xt)
628
+ # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape))
629
+
630
+ if self.bottom_channels:
631
+ x = rearrange(x, "b c f t-> b c (f t)")
632
+ x = self.channel_downsampler(x)
633
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
634
+ xt = self.channel_downsampler_t(xt)
635
+
636
+ for idx, decode in enumerate(self.decoder):
637
+ skip = saved.pop(-1)
638
+ x, pre = decode(x, skip, lengths.pop(-1))
639
+ # print('Decode {} X: {}'.format(idx, x.shape))
640
+ # `pre` contains the output just before final transposed convolution,
641
+ # which is used when the freq. and time branch separate.
642
+
643
+ offset = self.depth - len(self.tdecoder)
644
+ if idx >= offset:
645
+ tdec = self.tdecoder[idx - offset]
646
+ length_t = lengths_t.pop(-1)
647
+ if tdec.empty:
648
+ assert pre.shape[2] == 1, pre.shape
649
+ pre = pre[:, :, 0]
650
+ xt, _ = tdec(pre, None, length_t)
651
+ else:
652
+ skip = saved_t.pop(-1)
653
+ xt, _ = tdec(xt, skip, length_t)
654
+ # print('Decode {} XT: {}'.format(idx, xt.shape))
655
+
656
+ # Let's make sure we used all stored skip connections.
657
+ assert len(saved) == 0
658
+ assert len(lengths_t) == 0
659
+ assert len(saved_t) == 0
660
+
661
+ S = len(self.sources)
662
+
663
+ if self.num_subbands > 1:
664
+ x = x.view(B, -1, Fq, T)
665
+ # print("X view 1: {}".format(x.shape))
666
+ x = self.cws2cac(x)
667
+ # print("X view 2: {}".format(x.shape))
668
+
669
+ x = x.view(B, S, -1, Fq * self.num_subbands, T)
670
+ x = x * std[:, None] + mean[:, None]
671
+ # print("X returned: {}".format(x.shape))
672
+
673
+ zout = self._mask(z, x)
674
+ if self.use_train_segment:
675
+ if self.training:
676
+ x = self._ispec(zout, length)
677
+ else:
678
+ x = self._ispec(zout, training_length)
679
+ else:
680
+ x = self._ispec(zout, length)
681
+
682
+ if self.use_train_segment:
683
+ if self.training:
684
+ xt = xt.view(B, S, -1, length)
685
+ else:
686
+ xt = xt.view(B, S, -1, training_length)
687
+ else:
688
+ xt = xt.view(B, S, -1, length)
689
+ xt = xt * stdt[:, None] + meant[:, None]
690
+ x = xt + x
691
+ if length_pre_pad:
692
+ x = x[..., :length_pre_pad]
693
+ return x
694
+
695
+
696
+ def get_model(args):
697
+ extra = {
698
+ 'sources': list(args.training.instruments),
699
+ 'audio_channels': args.training.channels,
700
+ 'samplerate': args.training.samplerate,
701
+ # 'segment': args.model_segment or 4 * args.dset.segment,
702
+ 'segment': args.training.segment,
703
+ }
704
+ klass = {
705
+ 'demucs': Demucs,
706
+ 'hdemucs': HDemucs,
707
+ 'htdemucs': HTDemucs,
708
+ }[args.model]
709
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
710
+ model = klass(**extra, **kw)
711
+ return model
712
+
713
+
models/scnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scnet import SCNet
models/scnet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (219 Bytes). View file
 
models/scnet/__pycache__/scnet.cpython-311.pyc ADDED
Binary file (20.7 kB). View file
 
models/scnet/__pycache__/separation.cpython-311.pyc ADDED
Binary file (8.43 kB). View file
 
models/scnet/scnet.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import deque
5
+ from .separation import SeparationNet
6
+ import typing as tp
7
+ import math
8
+
9
+
10
+ class Swish(nn.Module):
11
+ def forward(self, x):
12
+ return x * x.sigmoid()
13
+
14
+
15
+ class ConvolutionModule(nn.Module):
16
+ """
17
+ Convolution Module in SD block.
18
+
19
+ Args:
20
+ channels (int): input/output channels.
21
+ depth (int): number of layers in the residual branch. Each layer has its own
22
+ compress (float): amount of channel compression.
23
+ kernel (int): kernel size for the convolutions.
24
+ """
25
+
26
+ def __init__(self, channels, depth=2, compress=4, kernel=3):
27
+ super().__init__()
28
+ assert kernel % 2 == 1
29
+ self.depth = abs(depth)
30
+ hidden_size = int(channels / compress)
31
+ norm = lambda d: nn.GroupNorm(1, d)
32
+ self.layers = nn.ModuleList([])
33
+ for _ in range(self.depth):
34
+ padding = (kernel // 2)
35
+ mods = [
36
+ norm(channels),
37
+ nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding),
38
+ nn.GLU(1),
39
+ nn.Conv1d(hidden_size, hidden_size, kernel, padding=padding, groups=hidden_size),
40
+ norm(hidden_size),
41
+ Swish(),
42
+ nn.Conv1d(hidden_size, channels, 1),
43
+ ]
44
+ layer = nn.Sequential(*mods)
45
+ self.layers.append(layer)
46
+
47
+ def forward(self, x):
48
+ for layer in self.layers:
49
+ x = x + layer(x)
50
+ return x
51
+
52
+
53
+ class FusionLayer(nn.Module):
54
+ """
55
+ A FusionLayer within the decoder.
56
+
57
+ Args:
58
+ - channels (int): Number of input channels.
59
+ - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3.
60
+ - stride (int, optional): Stride for the convolutional layer, defaults to 1.
61
+ - padding (int, optional): Padding for the convolutional layer, defaults to 1.
62
+ """
63
+
64
+ def __init__(self, channels, kernel_size=3, stride=1, padding=1):
65
+ super(FusionLayer, self).__init__()
66
+ self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding)
67
+
68
+ def forward(self, x, skip=None):
69
+ if skip is not None:
70
+ x += skip
71
+ x = x.repeat(1, 2, 1, 1)
72
+ x = self.conv(x)
73
+ x = F.glu(x, dim=1)
74
+ return x
75
+
76
+
77
+ class SDlayer(nn.Module):
78
+ """
79
+ Implements a Sparse Down-sample Layer for processing different frequency bands separately.
80
+
81
+ Args:
82
+ - channels_in (int): Input channel count.
83
+ - channels_out (int): Output channel count.
84
+ - band_configs (dict): A dictionary containing configuration for each frequency band.
85
+ Keys are 'low', 'mid', 'high' for each band, and values are
86
+ dictionaries with keys 'SR', 'stride', and 'kernel' for proportion,
87
+ stride, and kernel size, respectively.
88
+ """
89
+
90
+ def __init__(self, channels_in, channels_out, band_configs):
91
+ super(SDlayer, self).__init__()
92
+
93
+ # Initializing convolutional layers for each band
94
+ self.convs = nn.ModuleList()
95
+ self.strides = []
96
+ self.kernels = []
97
+ for config in band_configs.values():
98
+ self.convs.append(
99
+ nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0)))
100
+ self.strides.append(config['stride'])
101
+ self.kernels.append(config['kernel'])
102
+
103
+ # Saving rate proportions for determining splits
104
+ self.SR_low = band_configs['low']['SR']
105
+ self.SR_mid = band_configs['mid']['SR']
106
+
107
+ def forward(self, x):
108
+ B, C, Fr, T = x.shape
109
+ # Define splitting points based on sampling rates
110
+ splits = [
111
+ (0, math.ceil(Fr * self.SR_low)),
112
+ (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
113
+ (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr)
114
+ ]
115
+
116
+ # Processing each band with the corresponding convolution
117
+ outputs = []
118
+ original_lengths = []
119
+ for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits):
120
+ extracted = x[:, :, start:end, :]
121
+ original_lengths.append(end - start)
122
+ current_length = extracted.shape[2]
123
+
124
+ # padding
125
+ if stride == 1:
126
+ total_padding = kernel - stride
127
+ else:
128
+ total_padding = (stride - current_length % stride) % stride
129
+ pad_left = total_padding // 2
130
+ pad_right = total_padding - pad_left
131
+
132
+ padded = F.pad(extracted, (0, 0, pad_left, pad_right))
133
+
134
+ output = conv(padded)
135
+ outputs.append(output)
136
+
137
+ return outputs, original_lengths
138
+
139
+
140
+ class SUlayer(nn.Module):
141
+ """
142
+ Implements a Sparse Up-sample Layer in decoder.
143
+
144
+ Args:
145
+ - channels_in: The number of input channels.
146
+ - channels_out: The number of output channels.
147
+ - convtr_configs: Dictionary containing the configurations for transposed convolutions.
148
+ """
149
+
150
+ def __init__(self, channels_in, channels_out, band_configs):
151
+ super(SUlayer, self).__init__()
152
+
153
+ # Initializing convolutional layers for each band
154
+ self.convtrs = nn.ModuleList([
155
+ nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1])
156
+ for _, config in band_configs.items()
157
+ ])
158
+
159
+ def forward(self, x, lengths, origin_lengths):
160
+ B, C, Fr, T = x.shape
161
+ # Define splitting points based on input lengths
162
+ splits = [
163
+ (0, lengths[0]),
164
+ (lengths[0], lengths[0] + lengths[1]),
165
+ (lengths[0] + lengths[1], None)
166
+ ]
167
+ # Processing each band with the corresponding convolution
168
+ outputs = []
169
+ for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)):
170
+ out = convtr(x[:, :, start:end, :])
171
+ # Calculate the distance to trim the output symmetrically to original length
172
+ current_Fr_length = out.shape[2]
173
+ dist = abs(origin_lengths[idx] - current_Fr_length) // 2
174
+
175
+ # Trim the output to the original length symmetrically
176
+ trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :]
177
+
178
+ outputs.append(trimmed_out)
179
+
180
+ # Concatenate trimmed outputs along the frequency dimension to return the final tensor
181
+ x = torch.cat(outputs, dim=2)
182
+
183
+ return x
184
+
185
+
186
+ class SDblock(nn.Module):
187
+ """
188
+ Implements a simplified Sparse Down-sample block in encoder.
189
+
190
+ Args:
191
+ - channels_in (int): Number of input channels.
192
+ - channels_out (int): Number of output channels.
193
+ - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions.
194
+ - conv_config (dict): Configuration for convolution modules applied to each band.
195
+ - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
196
+ """
197
+
198
+ def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3):
199
+ super(SDblock, self).__init__()
200
+ self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
201
+
202
+ # Dynamically create convolution modules for each band based on depths
203
+ self.conv_modules = nn.ModuleList([
204
+ ConvolutionModule(channels_out, depth, **conv_config) for depth in depths
205
+ ])
206
+ # Set the kernel_size to an odd number.
207
+ self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2)
208
+
209
+ def forward(self, x):
210
+ bands, original_lengths = self.SDlayer(x)
211
+ # B, C, f, T = band.shape
212
+ bands = [
213
+ F.gelu(
214
+ conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3]))
215
+ .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3])
216
+ .permute(0, 2, 1, 3)
217
+ )
218
+ for conv, band in zip(self.conv_modules, bands)
219
+
220
+ ]
221
+ lengths = [band.size(-2) for band in bands]
222
+ full_band = torch.cat(bands, dim=2)
223
+ skip = full_band
224
+
225
+ output = self.globalconv(full_band)
226
+
227
+ return output, skip, lengths, original_lengths
228
+
229
+
230
+ class SCNet(nn.Module):
231
+ """
232
+ The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf
233
+
234
+ Args:
235
+ - sources (List[str]): List of sources to be separated.
236
+ - audio_channels (int): Number of audio channels.
237
+ - nfft (int): Number of FFTs to determine the frequency dimension of the input.
238
+ - hop_size (int): Hop size for the STFT.
239
+ - win_size (int): Window size for STFT.
240
+ - normalized (bool): Whether to normalize the STFT.
241
+ - dims (List[int]): List of channel dimensions for each block.
242
+ - band_SR (List[float]): The proportion of each frequency band.
243
+ - band_stride (List[int]): The down-sampling ratio of each frequency band.
244
+ - band_kernel (List[int]): The kernel sizes for down-sampling convolution in each frequency band
245
+ - conv_depths (List[int]): List specifying the number of convolution modules in each SD block.
246
+ - compress (int): Compression factor for convolution module.
247
+ - conv_kernel (int): Kernel size for convolution layer in convolution module.
248
+ - num_dplayer (int): Number of dual-path layers.
249
+ - expand (int): Expansion factor in the dual-path RNN, default is 1.
250
+
251
+ """
252
+
253
+ def __init__(self,
254
+ sources=['drums', 'bass', 'other', 'vocals'],
255
+ audio_channels=2,
256
+ # Main structure
257
+ dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
258
+ # STFT
259
+ nfft=4096,
260
+ hop_size=1024,
261
+ win_size=4096,
262
+ normalized=True,
263
+ # SD/SU layer
264
+ band_SR=[0.175, 0.392, 0.433],
265
+ band_stride=[1, 4, 16],
266
+ band_kernel=[3, 4, 16],
267
+ # Convolution Module
268
+ conv_depths=[3, 2, 1],
269
+ compress=4,
270
+ conv_kernel=3,
271
+ # Dual-path RNN
272
+ num_dplayer=6,
273
+ expand=1,
274
+ ):
275
+ super().__init__()
276
+ self.sources = sources
277
+ self.audio_channels = audio_channels
278
+ self.dims = dims
279
+ band_keys = ['low', 'mid', 'high']
280
+ self.band_configs = {band_keys[i]: {'SR': band_SR[i], 'stride': band_stride[i], 'kernel': band_kernel[i]} for i
281
+ in range(len(band_keys))}
282
+ self.hop_length = hop_size
283
+ self.conv_config = {
284
+ 'compress': compress,
285
+ 'kernel': conv_kernel,
286
+ }
287
+
288
+ self.stft_config = {
289
+ 'n_fft': nfft,
290
+ 'hop_length': hop_size,
291
+ 'win_length': win_size,
292
+ 'center': True,
293
+ 'normalized': normalized
294
+ }
295
+
296
+ self.encoder = nn.ModuleList()
297
+ self.decoder = nn.ModuleList()
298
+
299
+ for index in range(len(dims) - 1):
300
+ enc = SDblock(
301
+ channels_in=dims[index],
302
+ channels_out=dims[index + 1],
303
+ band_configs=self.band_configs,
304
+ conv_config=self.conv_config,
305
+ depths=conv_depths
306
+ )
307
+ self.encoder.append(enc)
308
+
309
+ dec = nn.Sequential(
310
+ FusionLayer(channels=dims[index + 1]),
311
+ SUlayer(
312
+ channels_in=dims[index + 1],
313
+ channels_out=dims[index] if index != 0 else dims[index] * len(sources),
314
+ band_configs=self.band_configs,
315
+ )
316
+ )
317
+ self.decoder.insert(0, dec)
318
+
319
+ self.separation_net = SeparationNet(
320
+ channels=dims[-1],
321
+ expand=expand,
322
+ num_layers=num_dplayer,
323
+ )
324
+
325
+ def forward(self, x):
326
+ # B, C, L = x.shape
327
+ B = x.shape[0]
328
+ # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even,
329
+ # so that the RFFT operation can be used in the separation network.
330
+ padding = self.hop_length - x.shape[-1] % self.hop_length
331
+ if (x.shape[-1] + padding) // self.hop_length % 2 == 0:
332
+ padding += self.hop_length
333
+ x = F.pad(x, (0, padding))
334
+
335
+ # STFT
336
+ L = x.shape[-1]
337
+ x = x.reshape(-1, L)
338
+ x = torch.stft(x, **self.stft_config, return_complex=True)
339
+ x = torch.view_as_real(x)
340
+ x = x.permute(0, 3, 1, 2).reshape(x.shape[0] // self.audio_channels, x.shape[3] * self.audio_channels,
341
+ x.shape[1], x.shape[2])
342
+
343
+ B, C, Fr, T = x.shape
344
+
345
+ save_skip = deque()
346
+ save_lengths = deque()
347
+ save_original_lengths = deque()
348
+ # encoder
349
+ for sd_layer in self.encoder:
350
+ x, skip, lengths, original_lengths = sd_layer(x)
351
+ save_skip.append(skip)
352
+ save_lengths.append(lengths)
353
+ save_original_lengths.append(original_lengths)
354
+
355
+ # separation
356
+ x = self.separation_net(x)
357
+
358
+ # decoder
359
+ for fusion_layer, su_layer in self.decoder:
360
+ x = fusion_layer(x, save_skip.pop())
361
+ x = su_layer(x, save_lengths.pop(), save_original_lengths.pop())
362
+
363
+ # output
364
+ n = self.dims[0]
365
+ x = x.view(B, n, -1, Fr, T)
366
+ x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1)
367
+ x = torch.view_as_complex(x.contiguous())
368
+ x = torch.istft(x, **self.stft_config)
369
+ x = x.reshape(B, len(self.sources), self.audio_channels, -1)
370
+
371
+ x = x[:, :, :, :-padding]
372
+
373
+ return x
models/scnet/separation.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.rnn import LSTM
4
+
5
+
6
+ class FeatureConversion(nn.Module):
7
+ """
8
+ Integrates into the adjacent Dual-Path layer.
9
+
10
+ Args:
11
+ channels (int): Number of input channels.
12
+ inverse (bool): If True, uses ifft; otherwise, uses rfft.
13
+ """
14
+
15
+ def __init__(self, channels, inverse):
16
+ super().__init__()
17
+ self.inverse = inverse
18
+ self.channels = channels
19
+
20
+ def forward(self, x):
21
+ # B, C, F, T = x.shape
22
+ if self.inverse:
23
+ x = x.float()
24
+ x_r = x[:, :self.channels // 2, :, :]
25
+ x_i = x[:, self.channels // 2:, :, :]
26
+ x = torch.complex(x_r, x_i)
27
+ x = torch.fft.irfft(x, dim=3, norm="ortho")
28
+ else:
29
+ x = x.float()
30
+ x = torch.fft.rfft(x, dim=3, norm="ortho")
31
+ x_real = x.real
32
+ x_imag = x.imag
33
+ x = torch.cat([x_real, x_imag], dim=1)
34
+ return x
35
+
36
+
37
+ class DualPathRNN(nn.Module):
38
+ """
39
+ Dual-Path RNN in Separation Network.
40
+
41
+ Args:
42
+ d_model (int): The number of expected features in the input (input_size).
43
+ expand (int): Expansion factor used to calculate the hidden_size of LSTM.
44
+ bidirectional (bool): If True, becomes a bidirectional LSTM.
45
+ """
46
+
47
+ def __init__(self, d_model, expand, bidirectional=True):
48
+ super(DualPathRNN, self).__init__()
49
+
50
+ self.d_model = d_model
51
+ self.hidden_size = d_model * expand
52
+ self.bidirectional = bidirectional
53
+ # Initialize LSTM layers and normalization layers
54
+ self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)])
55
+ self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)])
56
+ self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
57
+
58
+ def _init_lstm_layer(self, d_model, hidden_size):
59
+ return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True)
60
+
61
+ def forward(self, x):
62
+ B, C, F, T = x.shape
63
+
64
+ # Process dual-path rnn
65
+ original_x = x
66
+ # Frequency-path
67
+ x = self.norm_layers[0](x)
68
+ x = x.transpose(1, 3).contiguous().view(B * T, F, C)
69
+ x, _ = self.lstm_layers[0](x)
70
+ x = self.linear_layers[0](x)
71
+ x = x.view(B, T, F, C).transpose(1, 3)
72
+ x = x + original_x
73
+
74
+ original_x = x
75
+ # Time-path
76
+ x = self.norm_layers[1](x)
77
+ x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
78
+ x, _ = self.lstm_layers[1](x)
79
+ x = self.linear_layers[1](x)
80
+ x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
81
+ x = x + original_x
82
+
83
+ return x
84
+
85
+
86
+ class SeparationNet(nn.Module):
87
+ """
88
+ Implements a simplified Sparse Down-sample block in an encoder architecture.
89
+
90
+ Args:
91
+ - channels (int): Number input channels.
92
+ - expand (int): Expansion factor used to calculate the hidden_size of LSTM.
93
+ - num_layers (int): Number of dual-path layers.
94
+ """
95
+
96
+ def __init__(self, channels, expand=1, num_layers=6):
97
+ super(SeparationNet, self).__init__()
98
+
99
+ self.num_layers = num_layers
100
+
101
+ self.dp_modules = nn.ModuleList([
102
+ DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers)
103
+ ])
104
+
105
+ self.feature_conversion = nn.ModuleList([
106
+ FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers)
107
+ ])
108
+
109
+ def forward(self, x):
110
+ for i in range(self.num_layers):
111
+ x = self.dp_modules[i](x)
112
+ x = self.feature_conversion[i](x)
113
+ return x