nguyenvulebinh commited on
Commit
fd06d88
·
1 Parent(s): 847e2fc

upload infer utils

Browse files
Files changed (2) hide show
  1. wav2filterbank.py +313 -0
  2. xvector_sincnet.py +223 -0
wav2filterbank.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import logging
6
+ import math
7
+ import random
8
+
9
+ CONSTANT = 1e-5
10
+
11
+ def normalize_batch(x, seq_len, normalize_type):
12
+ x_mean = None
13
+ x_std = None
14
+ if normalize_type == "per_feature":
15
+ x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
16
+ x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
17
+ for i in range(x.shape[0]):
18
+ if x[i, :, : seq_len[i]].shape[1] == 1:
19
+ raise ValueError(
20
+ "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
21
+ "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
22
+ "feature (ex. at least `hop_length` for Mel Spectrograms)."
23
+ )
24
+ x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
25
+ x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
26
+ # make sure x_std is not zero
27
+ x_std += CONSTANT
28
+ return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
29
+ elif normalize_type == "all_features":
30
+ x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
31
+ x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
32
+ for i in range(x.shape[0]):
33
+ x_mean[i] = x[i, :, : seq_len[i].item()].mean()
34
+ x_std[i] = x[i, :, : seq_len[i].item()].std()
35
+ # make sure x_std is not zero
36
+ x_std += CONSTANT
37
+ return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
38
+ elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
39
+ x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
40
+ x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
41
+ return (
42
+ (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
43
+ x_mean,
44
+ x_std,
45
+ )
46
+ else:
47
+ return x, x_mean, x_std
48
+
49
+ def splice_frames(x, frame_splicing):
50
+ """ Stacks frames together across feature dim
51
+
52
+ input is batch_size, feature_dim, num_frames
53
+ output is batch_size, feature_dim*frame_splicing, num_frames
54
+
55
+ """
56
+ seq = [x]
57
+ for n in range(1, frame_splicing):
58
+ seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
59
+ return torch.cat(seq, dim=1)
60
+
61
+ class FilterbankFeatures(nn.Module):
62
+ """Featurizer that converts wavs to Mel Spectrograms.
63
+ See AudioToMelSpectrogramPreprocessor for args.
64
+
65
+ "normalize": "per_feature",
66
+ "window_size": 0.025,
67
+ "sample_rate": 16000,
68
+ "window_stride": 0.01,
69
+ "window": "hann",
70
+ "features": 80,
71
+ "n_fft": 512,
72
+ "frame_splicing": 1,
73
+ "dither": 1e-05
74
+
75
+ n_window_size=window_size * sample_rate,
76
+ n_window_stride = window_stride * sample_rate,
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ sample_rate=16000,
82
+ n_window_size=400,
83
+ n_window_stride=160,
84
+ window="hann",
85
+ normalize="per_feature",
86
+ n_fft=512,
87
+ preemph=0.97,
88
+ nfilt=80,
89
+ lowfreq=0,
90
+ highfreq=None,
91
+ log=True,
92
+ log_zero_guard_type="add",
93
+ log_zero_guard_value=2 ** -24,
94
+ dither=CONSTANT,
95
+ pad_to=16,
96
+ max_duration=16.7,
97
+ frame_splicing=1,
98
+ exact_pad=False,
99
+ pad_value=0,
100
+ mag_power=2.0,
101
+ use_grads=False,
102
+ rng=None,
103
+ nb_augmentation_prob=0.0,
104
+ nb_max_freq=4000,
105
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
106
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
107
+ ):
108
+ super().__init__()
109
+ if stft_conv or stft_exact_pad:
110
+ logging.warning(
111
+ "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
112
+ "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
113
+ "as needed."
114
+ )
115
+ if exact_pad and n_window_stride % 2 == 1:
116
+ raise NotImplementedError(
117
+ f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
118
+ "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
119
+ )
120
+ self.log_zero_guard_value = log_zero_guard_value
121
+ if (
122
+ n_window_size is None
123
+ or n_window_stride is None
124
+ or not isinstance(n_window_size, int)
125
+ or not isinstance(n_window_stride, int)
126
+ or n_window_size <= 0
127
+ or n_window_stride <= 0
128
+ ):
129
+ raise ValueError(
130
+ f"{self} got an invalid value for either n_window_size or "
131
+ f"n_window_stride. Both must be positive ints."
132
+ )
133
+ logging.info(f"PADDING: {pad_to}")
134
+
135
+ self.win_length = n_window_size
136
+ self.hop_length = n_window_stride
137
+ self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
138
+ self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
139
+
140
+ if exact_pad:
141
+ logging.info("STFT using exact pad")
142
+ torch_windows = {
143
+ 'hann': torch.hann_window,
144
+ 'hamming': torch.hamming_window,
145
+ 'blackman': torch.blackman_window,
146
+ 'bartlett': torch.bartlett_window,
147
+ 'none': None,
148
+ }
149
+ window_fn = torch_windows.get(window, None)
150
+ window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
151
+ self.register_buffer("window", window_tensor)
152
+ self.stft = lambda x: torch.stft(
153
+ x,
154
+ n_fft=self.n_fft,
155
+ hop_length=self.hop_length,
156
+ win_length=self.win_length,
157
+ center=False if exact_pad else True,
158
+ window=self.window.to(dtype=torch.float),
159
+ return_complex=True,
160
+ )
161
+
162
+ self.normalize = normalize
163
+ self.log = log
164
+ self.dither = dither
165
+ self.frame_splicing = frame_splicing
166
+ self.nfilt = nfilt
167
+ self.preemph = preemph
168
+ self.pad_to = pad_to
169
+ highfreq = highfreq or sample_rate / 2
170
+
171
+ filterbanks = torch.tensor(
172
+ librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq),
173
+ dtype=torch.float,
174
+ ).unsqueeze(0)
175
+ self.register_buffer("fb", filterbanks)
176
+
177
+ # Calculate maximum sequence length
178
+ max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
179
+ max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
180
+ self.max_length = max_length + max_pad
181
+ self.pad_value = pad_value
182
+ self.mag_power = mag_power
183
+
184
+ # We want to avoid taking the log of zero
185
+ # There are two options: either adding or clamping to a small value
186
+ if log_zero_guard_type not in ["add", "clamp"]:
187
+ raise ValueError(
188
+ f"{self} received {log_zero_guard_type} for the "
189
+ f"log_zero_guard_type parameter. It must be either 'add' or "
190
+ f"'clamp'."
191
+ )
192
+
193
+ self.use_grads = use_grads
194
+ if not use_grads:
195
+ self.forward = torch.no_grad()(self.forward)
196
+ self._rng = random.Random() if rng is None else rng
197
+ self.nb_augmentation_prob = nb_augmentation_prob
198
+ if self.nb_augmentation_prob > 0.0:
199
+ if nb_max_freq >= sample_rate / 2:
200
+ self.nb_augmentation_prob = 0.0
201
+ else:
202
+ self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
203
+
204
+ # log_zero_guard_value is the the small we want to use, we support
205
+ # an actual number, or "tiny", or "eps"
206
+ self.log_zero_guard_type = log_zero_guard_type
207
+ logging.debug(f"sr: {sample_rate}")
208
+ logging.debug(f"n_fft: {self.n_fft}")
209
+ logging.debug(f"win_length: {self.win_length}")
210
+ logging.debug(f"hop_length: {self.hop_length}")
211
+ logging.debug(f"n_mels: {nfilt}")
212
+ logging.debug(f"fmin: {lowfreq}")
213
+ logging.debug(f"fmax: {highfreq}")
214
+ logging.debug(f"using grads: {use_grads}")
215
+ logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")
216
+
217
+ def log_zero_guard_value_fn(self, x):
218
+ if isinstance(self.log_zero_guard_value, str):
219
+ if self.log_zero_guard_value == "tiny":
220
+ return torch.finfo(x.dtype).tiny
221
+ elif self.log_zero_guard_value == "eps":
222
+ return torch.finfo(x.dtype).eps
223
+ else:
224
+ raise ValueError(
225
+ f"{self} received {self.log_zero_guard_value} for the "
226
+ f"log_zero_guard_type parameter. It must be either a "
227
+ f"number, 'tiny', or 'eps'"
228
+ )
229
+ else:
230
+ return self.log_zero_guard_value
231
+
232
+ def get_seq_len(self, seq_len):
233
+ # Assuming that center is True is stft_pad_amount = 0
234
+ pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
235
+ seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
236
+ return seq_len.to(dtype=torch.long)
237
+
238
+ @property
239
+ def filter_banks(self):
240
+ return self.fb
241
+
242
+ def forward(self, x, seq_len, linear_spec=False):
243
+ seq_len = self.get_seq_len(seq_len.float())
244
+
245
+ if self.stft_pad_amount is not None:
246
+ x = torch.nn.functional.pad(
247
+ x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
248
+ ).squeeze(1)
249
+
250
+ # dither (only in training mode for eval determinism)
251
+ if self.training and self.dither > 0:
252
+ x += self.dither * torch.randn_like(x)
253
+
254
+ # do preemphasis
255
+ if self.preemph is not None:
256
+ x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
257
+
258
+ # disable autocast to get full range of stft values
259
+ with torch.cuda.amp.autocast(enabled=False):
260
+ x = self.stft(x)
261
+
262
+ # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
263
+ # guard is needed for sqrt if grads are passed through
264
+ guard = 0 if not self.use_grads else CONSTANT
265
+ x = torch.view_as_real(x)
266
+ x = torch.sqrt(x.pow(2).sum(-1) + guard)
267
+
268
+ if self.training and self.nb_augmentation_prob > 0.0:
269
+ for idx in range(x.shape[0]):
270
+ if self._rng.random() < self.nb_augmentation_prob:
271
+ x[idx, self._nb_max_fft_bin :, :] = 0.0
272
+
273
+ # get power spectrum
274
+ if self.mag_power != 1.0:
275
+ x = x.pow(self.mag_power)
276
+
277
+ # return plain spectrogram if required
278
+ if linear_spec:
279
+ return x, seq_len
280
+
281
+ # dot with filterbank energies
282
+ x = torch.matmul(self.fb.to(x.dtype), x)
283
+ # log features if required
284
+ if self.log:
285
+ if self.log_zero_guard_type == "add":
286
+ x = torch.log(x + self.log_zero_guard_value_fn(x))
287
+ elif self.log_zero_guard_type == "clamp":
288
+ x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
289
+ else:
290
+ raise ValueError("log_zero_guard_type was not understood")
291
+
292
+ # frame splicing if required
293
+ if self.frame_splicing > 1:
294
+ x = splice_frames(x, self.frame_splicing)
295
+
296
+ # normalize if required
297
+ if self.normalize:
298
+ x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
299
+
300
+ # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
301
+ max_len = x.size(-1)
302
+ mask = torch.arange(max_len).to(x.device)
303
+ mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
304
+ x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
305
+ del mask
306
+ pad_to = self.pad_to
307
+ if pad_to == "max":
308
+ x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
309
+ elif pad_to > 0:
310
+ pad_amt = x.size(-1) % pad_to
311
+ if pad_amt != 0:
312
+ x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
313
+ return x, seq_len
xvector_sincnet.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import warnings
7
+ from asteroid_filterbanks import Encoder, ParamSincFB
8
+
9
+ def merge_dict(defaults: dict, custom: dict = None):
10
+ params = dict(defaults)
11
+ if custom is not None:
12
+ params.update(custom)
13
+ return params
14
+
15
+ class StatsPool(nn.Module):
16
+ """Statistics pooling
17
+ Compute temporal mean and (unbiased) standard deviation
18
+ and returns their concatenation.
19
+ Reference
20
+ ---------
21
+ https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
22
+ """
23
+
24
+ def forward(
25
+ self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
26
+ ) -> torch.Tensor:
27
+ """Forward pass
28
+ Parameters
29
+ ----------
30
+ sequences : (batch, channel, frames) torch.Tensor
31
+ Sequences.
32
+ weights : (batch, frames) torch.Tensor, optional
33
+ When provided, compute weighted mean and standard deviation.
34
+ Returns
35
+ -------
36
+ output : (batch, 2 * channel) torch.Tensor
37
+ Concatenation of mean and (unbiased) standard deviation.
38
+ """
39
+
40
+ if weights is None:
41
+ mean = sequences.mean(dim=2)
42
+ std = sequences.std(dim=2, unbiased=True)
43
+
44
+ else:
45
+ weights = weights.unsqueeze(dim=1)
46
+ # (batch, 1, frames)
47
+
48
+ num_frames = sequences.shape[2]
49
+ num_weights = weights.shape[2]
50
+ if num_frames != num_weights:
51
+ warnings.warn(
52
+ f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
53
+ )
54
+ weights = F.interpolate(
55
+ weights, size=num_frames, mode="linear", align_corners=False
56
+ )
57
+
58
+ v1 = weights.sum(dim=2)
59
+ mean = torch.sum(sequences * weights, dim=2) / v1
60
+
61
+ dx2 = torch.square(sequences - mean.unsqueeze(2))
62
+ v2 = torch.square(weights).sum(dim=2)
63
+
64
+ var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1)
65
+ std = torch.sqrt(var)
66
+
67
+ return torch.cat([mean, std], dim=1)
68
+
69
+ class SincNet(nn.Module):
70
+ def __init__(self, sample_rate: int = 16000, stride: int = 1):
71
+ super().__init__()
72
+
73
+ if sample_rate != 16000:
74
+ raise NotImplementedError("PyanNet only supports 16kHz audio for now.")
75
+ # TODO: add support for other sample rate. it should be enough to multiply
76
+ # kernel_size by (sample_rate / 16000). but this needs to be double-checked.
77
+
78
+ self.stride = stride
79
+
80
+ self.wav_norm1d = nn.InstanceNorm1d(1, affine=True)
81
+
82
+ self.conv1d = nn.ModuleList()
83
+ self.pool1d = nn.ModuleList()
84
+ self.norm1d = nn.ModuleList()
85
+
86
+ self.conv1d.append(
87
+ Encoder(
88
+ ParamSincFB(
89
+ 80,
90
+ 251,
91
+ stride=self.stride,
92
+ sample_rate=sample_rate,
93
+ min_low_hz=50,
94
+ min_band_hz=50,
95
+ )
96
+ )
97
+ )
98
+ self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
99
+ self.norm1d.append(nn.InstanceNorm1d(80, affine=True))
100
+
101
+ self.conv1d.append(nn.Conv1d(80, 60, 5, stride=1))
102
+ self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
103
+ self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
104
+
105
+ self.conv1d.append(nn.Conv1d(60, 60, 5, stride=1))
106
+ self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
107
+ self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
108
+
109
+ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
110
+ """Pass forward
111
+ Parameters
112
+ ----------
113
+ waveforms : (batch, channel, sample)
114
+ """
115
+
116
+ outputs = self.wav_norm1d(waveforms)
117
+
118
+ for c, (conv1d, pool1d, norm1d) in enumerate(
119
+ zip(self.conv1d, self.pool1d, self.norm1d)
120
+ ):
121
+
122
+ outputs = conv1d(outputs)
123
+
124
+ # https://github.com/mravanelli/SincNet/issues/4
125
+ if c == 0:
126
+ outputs = torch.abs(outputs)
127
+
128
+ outputs = F.leaky_relu(norm1d(pool1d(outputs)))
129
+
130
+ return outputs
131
+
132
+ class XVectorSincNet(nn.Module):
133
+
134
+ SINCNET_DEFAULTS = {"stride": 10}
135
+
136
+ def __init__(
137
+ self,
138
+ sample_rate: int = 16000,
139
+ # num_channels: int = 1,
140
+ sincnet: dict = dict(
141
+ stride=10,
142
+ sample_rate=16000
143
+ ),
144
+ dimension: int = 512,
145
+ # task: Optional[Task] = None,
146
+ ):
147
+ super(XVectorSincNet, self).__init__()
148
+
149
+ sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet)
150
+ sincnet["sample_rate"] = sample_rate
151
+
152
+ # self.save_hyperparameters("sincnet", "dimension")
153
+
154
+ self.sincnet = SincNet(**sincnet)
155
+ in_channel = 60
156
+
157
+ self.tdnns = nn.ModuleList()
158
+ out_channels = [512, 512, 512, 512, 1500]
159
+ kernel_sizes = [5, 3, 3, 1, 1]
160
+ dilations = [1, 2, 3, 1, 1]
161
+
162
+ for out_channel, kernel_size, dilation in zip(
163
+ out_channels, kernel_sizes, dilations
164
+ ):
165
+ self.tdnns.extend(
166
+ [
167
+ nn.Conv1d(
168
+ in_channels=in_channel,
169
+ out_channels=out_channel,
170
+ kernel_size=kernel_size,
171
+ dilation=dilation,
172
+ ),
173
+ nn.LeakyReLU(),
174
+ nn.BatchNorm1d(out_channel),
175
+ ]
176
+ )
177
+ in_channel = out_channel
178
+
179
+ self.stats_pool = StatsPool()
180
+
181
+ self.embedding = nn.Linear(in_channel * 2, dimension)
182
+
183
+ def forward(
184
+ self, waveforms: torch.Tensor, weights: torch.Tensor = None
185
+ ) -> torch.Tensor:
186
+ """
187
+ Parameters
188
+ ----------
189
+ waveforms : torch.Tensor
190
+ Batch of waveforms with shape (batch, channel, sample)
191
+ weights : torch.Tensor, optional
192
+ Batch of weights with shape (batch, frame).
193
+ """
194
+
195
+ outputs = self.sincnet(waveforms).squeeze(dim=1)
196
+ for tdnn in self.tdnns:
197
+ outputs = tdnn(outputs)
198
+ outputs = self.stats_pool(outputs, weights=weights)
199
+ return self.embedding(outputs)
200
+
201
+
202
+ """ Load model
203
+
204
+ def cal_xvector_sincnet_embedding(xvector_model, ref_wav, max_length=5, sr=16000):
205
+ wavs = []
206
+ for i in range(0, len(ref_wav), max_length*sr):
207
+ wav = ref_wav[i:i + max_length*sr]
208
+ wav = np.concatenate([wav, np.zeros(max(0, max_length * sr - len(wav)))])
209
+ wavs.append(wav)
210
+ wavs = torch.from_numpy(np.stack(wavs))
211
+ if use_gpu:
212
+ wavs = wavs.cuda()
213
+ embed = xvector_model(wavs.unsqueeze(1).float())
214
+ return torch.mean(embed, dim=0).detach().cpu()
215
+
216
+ xvector_model = XVectorSincNet()
217
+ model_file = "model-bin/speaker_embedding/xvector_sincnet.pt"
218
+ meta = torch.load(model_file, map_location='cpu')['state_dict']
219
+ print('load_xvector_sincnet_model', xvector_model.load_state_dict(meta, strict=False))
220
+ xvector_model = xvector_model.eval()
221
+ for param in xvector_model.parameters():
222
+ param.requires_grad = False
223
+ """