nguyenvulebinh
commited on
Commit
·
fd06d88
1
Parent(s):
847e2fc
upload infer utils
Browse files- wav2filterbank.py +313 -0
- xvector_sincnet.py +223 -0
wav2filterbank.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
CONSTANT = 1e-5
|
10 |
+
|
11 |
+
def normalize_batch(x, seq_len, normalize_type):
|
12 |
+
x_mean = None
|
13 |
+
x_std = None
|
14 |
+
if normalize_type == "per_feature":
|
15 |
+
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
16 |
+
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
17 |
+
for i in range(x.shape[0]):
|
18 |
+
if x[i, :, : seq_len[i]].shape[1] == 1:
|
19 |
+
raise ValueError(
|
20 |
+
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
|
21 |
+
"in torch.std() returning nan. Make sure your audio length has enough samples for a single "
|
22 |
+
"feature (ex. at least `hop_length` for Mel Spectrograms)."
|
23 |
+
)
|
24 |
+
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
|
25 |
+
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
|
26 |
+
# make sure x_std is not zero
|
27 |
+
x_std += CONSTANT
|
28 |
+
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
|
29 |
+
elif normalize_type == "all_features":
|
30 |
+
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
31 |
+
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
32 |
+
for i in range(x.shape[0]):
|
33 |
+
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
|
34 |
+
x_std[i] = x[i, :, : seq_len[i].item()].std()
|
35 |
+
# make sure x_std is not zero
|
36 |
+
x_std += CONSTANT
|
37 |
+
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
|
38 |
+
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
|
39 |
+
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
|
40 |
+
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
|
41 |
+
return (
|
42 |
+
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
|
43 |
+
x_mean,
|
44 |
+
x_std,
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
return x, x_mean, x_std
|
48 |
+
|
49 |
+
def splice_frames(x, frame_splicing):
|
50 |
+
""" Stacks frames together across feature dim
|
51 |
+
|
52 |
+
input is batch_size, feature_dim, num_frames
|
53 |
+
output is batch_size, feature_dim*frame_splicing, num_frames
|
54 |
+
|
55 |
+
"""
|
56 |
+
seq = [x]
|
57 |
+
for n in range(1, frame_splicing):
|
58 |
+
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
|
59 |
+
return torch.cat(seq, dim=1)
|
60 |
+
|
61 |
+
class FilterbankFeatures(nn.Module):
|
62 |
+
"""Featurizer that converts wavs to Mel Spectrograms.
|
63 |
+
See AudioToMelSpectrogramPreprocessor for args.
|
64 |
+
|
65 |
+
"normalize": "per_feature",
|
66 |
+
"window_size": 0.025,
|
67 |
+
"sample_rate": 16000,
|
68 |
+
"window_stride": 0.01,
|
69 |
+
"window": "hann",
|
70 |
+
"features": 80,
|
71 |
+
"n_fft": 512,
|
72 |
+
"frame_splicing": 1,
|
73 |
+
"dither": 1e-05
|
74 |
+
|
75 |
+
n_window_size=window_size * sample_rate,
|
76 |
+
n_window_stride = window_stride * sample_rate,
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
sample_rate=16000,
|
82 |
+
n_window_size=400,
|
83 |
+
n_window_stride=160,
|
84 |
+
window="hann",
|
85 |
+
normalize="per_feature",
|
86 |
+
n_fft=512,
|
87 |
+
preemph=0.97,
|
88 |
+
nfilt=80,
|
89 |
+
lowfreq=0,
|
90 |
+
highfreq=None,
|
91 |
+
log=True,
|
92 |
+
log_zero_guard_type="add",
|
93 |
+
log_zero_guard_value=2 ** -24,
|
94 |
+
dither=CONSTANT,
|
95 |
+
pad_to=16,
|
96 |
+
max_duration=16.7,
|
97 |
+
frame_splicing=1,
|
98 |
+
exact_pad=False,
|
99 |
+
pad_value=0,
|
100 |
+
mag_power=2.0,
|
101 |
+
use_grads=False,
|
102 |
+
rng=None,
|
103 |
+
nb_augmentation_prob=0.0,
|
104 |
+
nb_max_freq=4000,
|
105 |
+
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
|
106 |
+
stft_conv=False, # Deprecated arguments; kept for config compatibility
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
if stft_conv or stft_exact_pad:
|
110 |
+
logging.warning(
|
111 |
+
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
|
112 |
+
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
|
113 |
+
"as needed."
|
114 |
+
)
|
115 |
+
if exact_pad and n_window_stride % 2 == 1:
|
116 |
+
raise NotImplementedError(
|
117 |
+
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
|
118 |
+
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
|
119 |
+
)
|
120 |
+
self.log_zero_guard_value = log_zero_guard_value
|
121 |
+
if (
|
122 |
+
n_window_size is None
|
123 |
+
or n_window_stride is None
|
124 |
+
or not isinstance(n_window_size, int)
|
125 |
+
or not isinstance(n_window_stride, int)
|
126 |
+
or n_window_size <= 0
|
127 |
+
or n_window_stride <= 0
|
128 |
+
):
|
129 |
+
raise ValueError(
|
130 |
+
f"{self} got an invalid value for either n_window_size or "
|
131 |
+
f"n_window_stride. Both must be positive ints."
|
132 |
+
)
|
133 |
+
logging.info(f"PADDING: {pad_to}")
|
134 |
+
|
135 |
+
self.win_length = n_window_size
|
136 |
+
self.hop_length = n_window_stride
|
137 |
+
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
|
138 |
+
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
|
139 |
+
|
140 |
+
if exact_pad:
|
141 |
+
logging.info("STFT using exact pad")
|
142 |
+
torch_windows = {
|
143 |
+
'hann': torch.hann_window,
|
144 |
+
'hamming': torch.hamming_window,
|
145 |
+
'blackman': torch.blackman_window,
|
146 |
+
'bartlett': torch.bartlett_window,
|
147 |
+
'none': None,
|
148 |
+
}
|
149 |
+
window_fn = torch_windows.get(window, None)
|
150 |
+
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
151 |
+
self.register_buffer("window", window_tensor)
|
152 |
+
self.stft = lambda x: torch.stft(
|
153 |
+
x,
|
154 |
+
n_fft=self.n_fft,
|
155 |
+
hop_length=self.hop_length,
|
156 |
+
win_length=self.win_length,
|
157 |
+
center=False if exact_pad else True,
|
158 |
+
window=self.window.to(dtype=torch.float),
|
159 |
+
return_complex=True,
|
160 |
+
)
|
161 |
+
|
162 |
+
self.normalize = normalize
|
163 |
+
self.log = log
|
164 |
+
self.dither = dither
|
165 |
+
self.frame_splicing = frame_splicing
|
166 |
+
self.nfilt = nfilt
|
167 |
+
self.preemph = preemph
|
168 |
+
self.pad_to = pad_to
|
169 |
+
highfreq = highfreq or sample_rate / 2
|
170 |
+
|
171 |
+
filterbanks = torch.tensor(
|
172 |
+
librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq),
|
173 |
+
dtype=torch.float,
|
174 |
+
).unsqueeze(0)
|
175 |
+
self.register_buffer("fb", filterbanks)
|
176 |
+
|
177 |
+
# Calculate maximum sequence length
|
178 |
+
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
|
179 |
+
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
|
180 |
+
self.max_length = max_length + max_pad
|
181 |
+
self.pad_value = pad_value
|
182 |
+
self.mag_power = mag_power
|
183 |
+
|
184 |
+
# We want to avoid taking the log of zero
|
185 |
+
# There are two options: either adding or clamping to a small value
|
186 |
+
if log_zero_guard_type not in ["add", "clamp"]:
|
187 |
+
raise ValueError(
|
188 |
+
f"{self} received {log_zero_guard_type} for the "
|
189 |
+
f"log_zero_guard_type parameter. It must be either 'add' or "
|
190 |
+
f"'clamp'."
|
191 |
+
)
|
192 |
+
|
193 |
+
self.use_grads = use_grads
|
194 |
+
if not use_grads:
|
195 |
+
self.forward = torch.no_grad()(self.forward)
|
196 |
+
self._rng = random.Random() if rng is None else rng
|
197 |
+
self.nb_augmentation_prob = nb_augmentation_prob
|
198 |
+
if self.nb_augmentation_prob > 0.0:
|
199 |
+
if nb_max_freq >= sample_rate / 2:
|
200 |
+
self.nb_augmentation_prob = 0.0
|
201 |
+
else:
|
202 |
+
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
|
203 |
+
|
204 |
+
# log_zero_guard_value is the the small we want to use, we support
|
205 |
+
# an actual number, or "tiny", or "eps"
|
206 |
+
self.log_zero_guard_type = log_zero_guard_type
|
207 |
+
logging.debug(f"sr: {sample_rate}")
|
208 |
+
logging.debug(f"n_fft: {self.n_fft}")
|
209 |
+
logging.debug(f"win_length: {self.win_length}")
|
210 |
+
logging.debug(f"hop_length: {self.hop_length}")
|
211 |
+
logging.debug(f"n_mels: {nfilt}")
|
212 |
+
logging.debug(f"fmin: {lowfreq}")
|
213 |
+
logging.debug(f"fmax: {highfreq}")
|
214 |
+
logging.debug(f"using grads: {use_grads}")
|
215 |
+
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")
|
216 |
+
|
217 |
+
def log_zero_guard_value_fn(self, x):
|
218 |
+
if isinstance(self.log_zero_guard_value, str):
|
219 |
+
if self.log_zero_guard_value == "tiny":
|
220 |
+
return torch.finfo(x.dtype).tiny
|
221 |
+
elif self.log_zero_guard_value == "eps":
|
222 |
+
return torch.finfo(x.dtype).eps
|
223 |
+
else:
|
224 |
+
raise ValueError(
|
225 |
+
f"{self} received {self.log_zero_guard_value} for the "
|
226 |
+
f"log_zero_guard_type parameter. It must be either a "
|
227 |
+
f"number, 'tiny', or 'eps'"
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
return self.log_zero_guard_value
|
231 |
+
|
232 |
+
def get_seq_len(self, seq_len):
|
233 |
+
# Assuming that center is True is stft_pad_amount = 0
|
234 |
+
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
|
235 |
+
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
|
236 |
+
return seq_len.to(dtype=torch.long)
|
237 |
+
|
238 |
+
@property
|
239 |
+
def filter_banks(self):
|
240 |
+
return self.fb
|
241 |
+
|
242 |
+
def forward(self, x, seq_len, linear_spec=False):
|
243 |
+
seq_len = self.get_seq_len(seq_len.float())
|
244 |
+
|
245 |
+
if self.stft_pad_amount is not None:
|
246 |
+
x = torch.nn.functional.pad(
|
247 |
+
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
|
248 |
+
).squeeze(1)
|
249 |
+
|
250 |
+
# dither (only in training mode for eval determinism)
|
251 |
+
if self.training and self.dither > 0:
|
252 |
+
x += self.dither * torch.randn_like(x)
|
253 |
+
|
254 |
+
# do preemphasis
|
255 |
+
if self.preemph is not None:
|
256 |
+
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
|
257 |
+
|
258 |
+
# disable autocast to get full range of stft values
|
259 |
+
with torch.cuda.amp.autocast(enabled=False):
|
260 |
+
x = self.stft(x)
|
261 |
+
|
262 |
+
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
|
263 |
+
# guard is needed for sqrt if grads are passed through
|
264 |
+
guard = 0 if not self.use_grads else CONSTANT
|
265 |
+
x = torch.view_as_real(x)
|
266 |
+
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
267 |
+
|
268 |
+
if self.training and self.nb_augmentation_prob > 0.0:
|
269 |
+
for idx in range(x.shape[0]):
|
270 |
+
if self._rng.random() < self.nb_augmentation_prob:
|
271 |
+
x[idx, self._nb_max_fft_bin :, :] = 0.0
|
272 |
+
|
273 |
+
# get power spectrum
|
274 |
+
if self.mag_power != 1.0:
|
275 |
+
x = x.pow(self.mag_power)
|
276 |
+
|
277 |
+
# return plain spectrogram if required
|
278 |
+
if linear_spec:
|
279 |
+
return x, seq_len
|
280 |
+
|
281 |
+
# dot with filterbank energies
|
282 |
+
x = torch.matmul(self.fb.to(x.dtype), x)
|
283 |
+
# log features if required
|
284 |
+
if self.log:
|
285 |
+
if self.log_zero_guard_type == "add":
|
286 |
+
x = torch.log(x + self.log_zero_guard_value_fn(x))
|
287 |
+
elif self.log_zero_guard_type == "clamp":
|
288 |
+
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
|
289 |
+
else:
|
290 |
+
raise ValueError("log_zero_guard_type was not understood")
|
291 |
+
|
292 |
+
# frame splicing if required
|
293 |
+
if self.frame_splicing > 1:
|
294 |
+
x = splice_frames(x, self.frame_splicing)
|
295 |
+
|
296 |
+
# normalize if required
|
297 |
+
if self.normalize:
|
298 |
+
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
|
299 |
+
|
300 |
+
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
|
301 |
+
max_len = x.size(-1)
|
302 |
+
mask = torch.arange(max_len).to(x.device)
|
303 |
+
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
|
304 |
+
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
|
305 |
+
del mask
|
306 |
+
pad_to = self.pad_to
|
307 |
+
if pad_to == "max":
|
308 |
+
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
|
309 |
+
elif pad_to > 0:
|
310 |
+
pad_amt = x.size(-1) % pad_to
|
311 |
+
if pad_amt != 0:
|
312 |
+
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
|
313 |
+
return x, seq_len
|
xvector_sincnet.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import warnings
|
7 |
+
from asteroid_filterbanks import Encoder, ParamSincFB
|
8 |
+
|
9 |
+
def merge_dict(defaults: dict, custom: dict = None):
|
10 |
+
params = dict(defaults)
|
11 |
+
if custom is not None:
|
12 |
+
params.update(custom)
|
13 |
+
return params
|
14 |
+
|
15 |
+
class StatsPool(nn.Module):
|
16 |
+
"""Statistics pooling
|
17 |
+
Compute temporal mean and (unbiased) standard deviation
|
18 |
+
and returns their concatenation.
|
19 |
+
Reference
|
20 |
+
---------
|
21 |
+
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
|
22 |
+
"""
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
|
26 |
+
) -> torch.Tensor:
|
27 |
+
"""Forward pass
|
28 |
+
Parameters
|
29 |
+
----------
|
30 |
+
sequences : (batch, channel, frames) torch.Tensor
|
31 |
+
Sequences.
|
32 |
+
weights : (batch, frames) torch.Tensor, optional
|
33 |
+
When provided, compute weighted mean and standard deviation.
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
output : (batch, 2 * channel) torch.Tensor
|
37 |
+
Concatenation of mean and (unbiased) standard deviation.
|
38 |
+
"""
|
39 |
+
|
40 |
+
if weights is None:
|
41 |
+
mean = sequences.mean(dim=2)
|
42 |
+
std = sequences.std(dim=2, unbiased=True)
|
43 |
+
|
44 |
+
else:
|
45 |
+
weights = weights.unsqueeze(dim=1)
|
46 |
+
# (batch, 1, frames)
|
47 |
+
|
48 |
+
num_frames = sequences.shape[2]
|
49 |
+
num_weights = weights.shape[2]
|
50 |
+
if num_frames != num_weights:
|
51 |
+
warnings.warn(
|
52 |
+
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
|
53 |
+
)
|
54 |
+
weights = F.interpolate(
|
55 |
+
weights, size=num_frames, mode="linear", align_corners=False
|
56 |
+
)
|
57 |
+
|
58 |
+
v1 = weights.sum(dim=2)
|
59 |
+
mean = torch.sum(sequences * weights, dim=2) / v1
|
60 |
+
|
61 |
+
dx2 = torch.square(sequences - mean.unsqueeze(2))
|
62 |
+
v2 = torch.square(weights).sum(dim=2)
|
63 |
+
|
64 |
+
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1)
|
65 |
+
std = torch.sqrt(var)
|
66 |
+
|
67 |
+
return torch.cat([mean, std], dim=1)
|
68 |
+
|
69 |
+
class SincNet(nn.Module):
|
70 |
+
def __init__(self, sample_rate: int = 16000, stride: int = 1):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
if sample_rate != 16000:
|
74 |
+
raise NotImplementedError("PyanNet only supports 16kHz audio for now.")
|
75 |
+
# TODO: add support for other sample rate. it should be enough to multiply
|
76 |
+
# kernel_size by (sample_rate / 16000). but this needs to be double-checked.
|
77 |
+
|
78 |
+
self.stride = stride
|
79 |
+
|
80 |
+
self.wav_norm1d = nn.InstanceNorm1d(1, affine=True)
|
81 |
+
|
82 |
+
self.conv1d = nn.ModuleList()
|
83 |
+
self.pool1d = nn.ModuleList()
|
84 |
+
self.norm1d = nn.ModuleList()
|
85 |
+
|
86 |
+
self.conv1d.append(
|
87 |
+
Encoder(
|
88 |
+
ParamSincFB(
|
89 |
+
80,
|
90 |
+
251,
|
91 |
+
stride=self.stride,
|
92 |
+
sample_rate=sample_rate,
|
93 |
+
min_low_hz=50,
|
94 |
+
min_band_hz=50,
|
95 |
+
)
|
96 |
+
)
|
97 |
+
)
|
98 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
99 |
+
self.norm1d.append(nn.InstanceNorm1d(80, affine=True))
|
100 |
+
|
101 |
+
self.conv1d.append(nn.Conv1d(80, 60, 5, stride=1))
|
102 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
103 |
+
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
|
104 |
+
|
105 |
+
self.conv1d.append(nn.Conv1d(60, 60, 5, stride=1))
|
106 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
107 |
+
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
|
108 |
+
|
109 |
+
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
|
110 |
+
"""Pass forward
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
waveforms : (batch, channel, sample)
|
114 |
+
"""
|
115 |
+
|
116 |
+
outputs = self.wav_norm1d(waveforms)
|
117 |
+
|
118 |
+
for c, (conv1d, pool1d, norm1d) in enumerate(
|
119 |
+
zip(self.conv1d, self.pool1d, self.norm1d)
|
120 |
+
):
|
121 |
+
|
122 |
+
outputs = conv1d(outputs)
|
123 |
+
|
124 |
+
# https://github.com/mravanelli/SincNet/issues/4
|
125 |
+
if c == 0:
|
126 |
+
outputs = torch.abs(outputs)
|
127 |
+
|
128 |
+
outputs = F.leaky_relu(norm1d(pool1d(outputs)))
|
129 |
+
|
130 |
+
return outputs
|
131 |
+
|
132 |
+
class XVectorSincNet(nn.Module):
|
133 |
+
|
134 |
+
SINCNET_DEFAULTS = {"stride": 10}
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
sample_rate: int = 16000,
|
139 |
+
# num_channels: int = 1,
|
140 |
+
sincnet: dict = dict(
|
141 |
+
stride=10,
|
142 |
+
sample_rate=16000
|
143 |
+
),
|
144 |
+
dimension: int = 512,
|
145 |
+
# task: Optional[Task] = None,
|
146 |
+
):
|
147 |
+
super(XVectorSincNet, self).__init__()
|
148 |
+
|
149 |
+
sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet)
|
150 |
+
sincnet["sample_rate"] = sample_rate
|
151 |
+
|
152 |
+
# self.save_hyperparameters("sincnet", "dimension")
|
153 |
+
|
154 |
+
self.sincnet = SincNet(**sincnet)
|
155 |
+
in_channel = 60
|
156 |
+
|
157 |
+
self.tdnns = nn.ModuleList()
|
158 |
+
out_channels = [512, 512, 512, 512, 1500]
|
159 |
+
kernel_sizes = [5, 3, 3, 1, 1]
|
160 |
+
dilations = [1, 2, 3, 1, 1]
|
161 |
+
|
162 |
+
for out_channel, kernel_size, dilation in zip(
|
163 |
+
out_channels, kernel_sizes, dilations
|
164 |
+
):
|
165 |
+
self.tdnns.extend(
|
166 |
+
[
|
167 |
+
nn.Conv1d(
|
168 |
+
in_channels=in_channel,
|
169 |
+
out_channels=out_channel,
|
170 |
+
kernel_size=kernel_size,
|
171 |
+
dilation=dilation,
|
172 |
+
),
|
173 |
+
nn.LeakyReLU(),
|
174 |
+
nn.BatchNorm1d(out_channel),
|
175 |
+
]
|
176 |
+
)
|
177 |
+
in_channel = out_channel
|
178 |
+
|
179 |
+
self.stats_pool = StatsPool()
|
180 |
+
|
181 |
+
self.embedding = nn.Linear(in_channel * 2, dimension)
|
182 |
+
|
183 |
+
def forward(
|
184 |
+
self, waveforms: torch.Tensor, weights: torch.Tensor = None
|
185 |
+
) -> torch.Tensor:
|
186 |
+
"""
|
187 |
+
Parameters
|
188 |
+
----------
|
189 |
+
waveforms : torch.Tensor
|
190 |
+
Batch of waveforms with shape (batch, channel, sample)
|
191 |
+
weights : torch.Tensor, optional
|
192 |
+
Batch of weights with shape (batch, frame).
|
193 |
+
"""
|
194 |
+
|
195 |
+
outputs = self.sincnet(waveforms).squeeze(dim=1)
|
196 |
+
for tdnn in self.tdnns:
|
197 |
+
outputs = tdnn(outputs)
|
198 |
+
outputs = self.stats_pool(outputs, weights=weights)
|
199 |
+
return self.embedding(outputs)
|
200 |
+
|
201 |
+
|
202 |
+
""" Load model
|
203 |
+
|
204 |
+
def cal_xvector_sincnet_embedding(xvector_model, ref_wav, max_length=5, sr=16000):
|
205 |
+
wavs = []
|
206 |
+
for i in range(0, len(ref_wav), max_length*sr):
|
207 |
+
wav = ref_wav[i:i + max_length*sr]
|
208 |
+
wav = np.concatenate([wav, np.zeros(max(0, max_length * sr - len(wav)))])
|
209 |
+
wavs.append(wav)
|
210 |
+
wavs = torch.from_numpy(np.stack(wavs))
|
211 |
+
if use_gpu:
|
212 |
+
wavs = wavs.cuda()
|
213 |
+
embed = xvector_model(wavs.unsqueeze(1).float())
|
214 |
+
return torch.mean(embed, dim=0).detach().cpu()
|
215 |
+
|
216 |
+
xvector_model = XVectorSincNet()
|
217 |
+
model_file = "model-bin/speaker_embedding/xvector_sincnet.pt"
|
218 |
+
meta = torch.load(model_file, map_location='cpu')['state_dict']
|
219 |
+
print('load_xvector_sincnet_model', xvector_model.load_state_dict(meta, strict=False))
|
220 |
+
xvector_model = xvector_model.eval()
|
221 |
+
for param in xvector_model.parameters():
|
222 |
+
param.requires_grad = False
|
223 |
+
"""
|