|
import math |
|
import torch |
|
import typing as tp |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.utils import ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from .helpers_xvector import Fbank |
|
from .configuration_xvector import XvectorConfig |
|
|
|
|
|
class InputNormalization(nn.Module): |
|
|
|
spk_dict_mean: tp.Dict[int, torch.Tensor] |
|
spk_dict_std: tp.Dict[int, torch.Tensor] |
|
spk_dict_count: tp.Dict[int, int] |
|
|
|
def __init__( |
|
self, |
|
mean_norm=True, |
|
std_norm=True, |
|
norm_type="global", |
|
avg_factor=None, |
|
requires_grad=False, |
|
update_until_epoch=3, |
|
): |
|
super().__init__() |
|
self.mean_norm = mean_norm |
|
self.std_norm = std_norm |
|
self.norm_type = norm_type |
|
self.avg_factor = avg_factor |
|
self.requires_grad = requires_grad |
|
self.glob_mean = torch.tensor([0]) |
|
self.glob_std = torch.tensor([0]) |
|
self.spk_dict_mean = {} |
|
self.spk_dict_std = {} |
|
self.spk_dict_count = {} |
|
self.weight = 1.0 |
|
self.count = 0 |
|
self.eps = 1e-10 |
|
self.update_until_epoch = update_until_epoch |
|
|
|
def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): |
|
"""Returns the tensor with the surrounding context. |
|
|
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
lengths : tensor |
|
A batch of tensors containing the relative length of each |
|
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid |
|
computing stats on zero-padded steps. |
|
spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). |
|
It is used to perform per-speaker normalization when |
|
norm_type='speaker'. |
|
""" |
|
x = input_values |
|
N_batches = x.shape[0] |
|
|
|
current_means = [] |
|
current_stds = [] |
|
|
|
for snt_id in range(N_batches): |
|
|
|
|
|
|
|
|
|
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() |
|
|
|
|
|
current_mean, current_std = self._compute_current_stats( |
|
x[snt_id, 0:actual_size, ...] |
|
) |
|
|
|
current_means.append(current_mean) |
|
current_stds.append(current_std) |
|
|
|
if self.norm_type == "sentence": |
|
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data |
|
|
|
if self.norm_type == "speaker": |
|
spk_id = int(spk_ids[snt_id][0]) |
|
|
|
if self.training: |
|
if spk_id not in self.spk_dict_mean: |
|
|
|
self.spk_dict_mean[spk_id] = current_mean |
|
self.spk_dict_std[spk_id] = current_std |
|
self.spk_dict_count[spk_id] = 1 |
|
|
|
else: |
|
self.spk_dict_count[spk_id] = ( |
|
self.spk_dict_count[spk_id] + 1 |
|
) |
|
|
|
if self.avg_factor is None: |
|
self.weight = 1 / self.spk_dict_count[spk_id] |
|
else: |
|
self.weight = self.avg_factor |
|
|
|
self.spk_dict_mean[spk_id] = ( |
|
(1 - self.weight) * self.spk_dict_mean[spk_id] |
|
+ self.weight * current_mean |
|
) |
|
self.spk_dict_std[spk_id] = ( |
|
(1 - self.weight) * self.spk_dict_std[spk_id] |
|
+ self.weight * current_std |
|
) |
|
|
|
self.spk_dict_mean[spk_id].detach() |
|
self.spk_dict_std[spk_id].detach() |
|
|
|
speaker_mean = self.spk_dict_mean[spk_id].data |
|
speaker_std = self.spk_dict_std[spk_id].data |
|
else: |
|
if spk_id in self.spk_dict_mean: |
|
speaker_mean = self.spk_dict_mean[spk_id].data |
|
speaker_std = self.spk_dict_std[spk_id].data |
|
else: |
|
speaker_mean = current_mean.data |
|
speaker_std = current_std.data |
|
|
|
x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std |
|
|
|
if self.norm_type == "batch" or self.norm_type == "global": |
|
current_mean = torch.mean(torch.stack(current_means), dim=0) |
|
current_std = torch.mean(torch.stack(current_stds), dim=0) |
|
|
|
if self.norm_type == "batch": |
|
x = (x - current_mean.data) / (current_std.data) |
|
|
|
if self.norm_type == "global": |
|
if self.training: |
|
if self.count == 0: |
|
self.glob_mean = current_mean |
|
self.glob_std = current_std |
|
|
|
elif epoch < self.update_until_epoch: |
|
if self.avg_factor is None: |
|
self.weight = 1 / (self.count + 1) |
|
else: |
|
self.weight = self.avg_factor |
|
|
|
self.glob_mean = ( |
|
1 - self.weight |
|
) * self.glob_mean + self.weight * current_mean |
|
|
|
self.glob_std = ( |
|
1 - self.weight |
|
) * self.glob_std + self.weight * current_std |
|
|
|
self.glob_mean.detach() |
|
self.glob_std.detach() |
|
|
|
self.count = self.count + 1 |
|
|
|
x = (x - self.glob_mean.data) / (self.glob_std.data) |
|
|
|
return x |
|
|
|
def _compute_current_stats(self, x): |
|
"""Returns the tensor with the surrounding context. |
|
|
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
""" |
|
|
|
if self.mean_norm: |
|
current_mean = torch.mean(x, dim=0).detach().data |
|
else: |
|
current_mean = torch.tensor([0.0], device=x.device) |
|
|
|
|
|
if self.std_norm: |
|
current_std = torch.std(x, dim=0).detach().data |
|
else: |
|
current_std = torch.tensor([1.0], device=x.device) |
|
|
|
|
|
current_std = torch.max( |
|
current_std, self.eps * torch.ones_like(current_std) |
|
) |
|
|
|
return current_mean, current_std |
|
|
|
def _statistics_dict(self): |
|
"""Fills the dictionary containing the normalization statistics.""" |
|
state = {} |
|
state["count"] = self.count |
|
state["glob_mean"] = self.glob_mean |
|
state["glob_std"] = self.glob_std |
|
state["spk_dict_mean"] = self.spk_dict_mean |
|
state["spk_dict_std"] = self.spk_dict_std |
|
state["spk_dict_count"] = self.spk_dict_count |
|
|
|
return state |
|
|
|
def _load_statistics_dict(self, state): |
|
"""Loads the dictionary containing the statistics. |
|
|
|
Arguments |
|
--------- |
|
state : dict |
|
A dictionary containing the normalization statistics. |
|
""" |
|
self.count = state["count"] |
|
if isinstance(state["glob_mean"], int): |
|
self.glob_mean = state["glob_mean"] |
|
self.glob_std = state["glob_std"] |
|
else: |
|
self.glob_mean = state["glob_mean"] |
|
self.glob_std = state["glob_std"] |
|
|
|
|
|
self.spk_dict_mean = {} |
|
for spk in state["spk_dict_mean"]: |
|
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( |
|
self.device_inp |
|
) |
|
|
|
|
|
self.spk_dict_std = {} |
|
for spk in state["spk_dict_std"]: |
|
self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( |
|
self.device_inp |
|
) |
|
|
|
self.spk_dict_count = state["spk_dict_count"] |
|
|
|
return state |
|
|
|
def to(self, device): |
|
"""Puts the needed tensors in the right device.""" |
|
self = super(InputNormalization, self).to(device) |
|
self.glob_mean = self.glob_mean.to(device) |
|
self.glob_std = self.glob_std.to(device) |
|
for spk in self.spk_dict_mean: |
|
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) |
|
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) |
|
return self |
|
|
|
|
|
class TdnnLayer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
dilation=1, |
|
stride=1, |
|
padding=0, |
|
padding_mode="reflect", |
|
activation=torch.nn.LeakyReLU, |
|
): |
|
super(TdnnLayer, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
self.stride = stride |
|
self.padding = padding |
|
self.padding_mode = padding_mode |
|
self.activation = activation |
|
|
|
self.conv = nn.Conv1d( |
|
self.in_channels, |
|
self.out_channels, |
|
self.kernel_size, |
|
dilation=self.dilation, |
|
padding=self.padding |
|
) |
|
|
|
|
|
|
|
self.norm = nn.BatchNorm1d(out_channels, affine=False) |
|
|
|
def forward(self, x): |
|
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) |
|
out = self.conv(x) |
|
out = self.activation()(out) |
|
out = self.norm(out) |
|
return out |
|
|
|
def _manage_padding( |
|
self, x, kernel_size: int, dilation: int, stride: int, |
|
): |
|
|
|
L_in = self.in_channels |
|
|
|
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
|
|
|
x = F.pad(x, padding, mode=self.padding_mode) |
|
|
|
return x |
|
|
|
|
|
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
|
"""This function computes the number of elements to add for zero-padding. |
|
|
|
Arguments |
|
--------- |
|
L_in : int |
|
stride: int |
|
kernel_size : int |
|
dilation : int |
|
""" |
|
if stride > 1: |
|
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
|
else: |
|
L_out = ( |
|
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
|
) |
|
padding = [ |
|
math.floor((L_in - L_out) / 2), |
|
math.floor((L_in - L_out) / 2), |
|
] |
|
return padding |
|
|
|
|
|
class StatisticsPooling(nn.Module): |
|
|
|
def __init__(self, return_mean=True, return_std=True): |
|
super().__init__() |
|
|
|
|
|
self.eps = 1e-5 |
|
self.return_mean = return_mean |
|
self.return_std = return_std |
|
if not (self.return_mean or self.return_std): |
|
raise ValueError( |
|
"both of statistics are equal to False \n" |
|
"consider enabling mean and/or std statistic pooling" |
|
) |
|
|
|
def forward(self, input_values, lengths=None): |
|
"""Calculates mean and std for a batch (input tensor). |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
It represents a tensor for a mini-batch. |
|
""" |
|
x = input_values |
|
if lengths is None: |
|
if self.return_mean: |
|
mean = x.mean(dim=1) |
|
if self.return_std: |
|
std = x.std(dim=1) |
|
else: |
|
mean = [] |
|
std = [] |
|
for snt_id in range(x.shape[0]): |
|
|
|
|
|
|
|
|
|
actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) |
|
|
|
|
|
if self.return_mean: |
|
mean.append( |
|
torch.mean(x[snt_id, 0:actual_size, ...], dim=0) |
|
) |
|
if self.return_std: |
|
std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) |
|
if self.return_mean: |
|
mean = torch.stack(mean) |
|
if self.return_std: |
|
std = torch.stack(std) |
|
|
|
if self.return_mean: |
|
gnoise = self._get_gauss_noise(mean.size(), device=mean.device) |
|
gnoise = gnoise |
|
mean += gnoise |
|
if self.return_std: |
|
std = std + self.eps |
|
|
|
|
|
if self.return_mean and self.return_std: |
|
pooled_stats = torch.cat((mean, std), dim=1) |
|
pooled_stats = pooled_stats.unsqueeze(1) |
|
elif self.return_mean: |
|
pooled_stats = mean.unsqueeze(1) |
|
elif self.return_std: |
|
pooled_stats = std.unsqueeze(1) |
|
|
|
return pooled_stats |
|
|
|
def _get_gauss_noise(self, shape_of_tensor, device="cpu"): |
|
"""Returns a tensor of epsilon Gaussian noise. |
|
|
|
Arguments |
|
--------- |
|
shape_of_tensor : tensor |
|
It represents the size of tensor for generating Gaussian noise. |
|
""" |
|
gnoise = torch.randn(shape_of_tensor, device=device) |
|
gnoise -= torch.min(gnoise) |
|
gnoise /= torch.max(gnoise) |
|
gnoise = self.eps * ((1 - 9) * gnoise + 9) |
|
|
|
return gnoise |
|
|
|
|
|
class XvectorEmbedder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels=40, |
|
activation=torch.nn.LeakyReLU, |
|
tdnn_blocks=5, |
|
tdnn_channels=[512, 512, 512, 512, 1500], |
|
tdnn_kernel_sizes=[5, 3, 3, 1, 1], |
|
tdnn_dilations=[1, 2, 3, 1, 1], |
|
hidden_size=512, |
|
) -> None: |
|
super(XvectorEmbedder, self).__init__() |
|
self.activation = activation |
|
self.blocks = nn.ModuleList() |
|
for block_index in range(tdnn_blocks): |
|
out_channels = tdnn_channels[block_index] |
|
tdnn = TdnnLayer( |
|
in_channels, |
|
out_channels, |
|
kernel_size=tdnn_kernel_sizes[block_index], |
|
dilation=tdnn_dilations[block_index], |
|
activation=activation, |
|
) |
|
self.blocks.append(tdnn) |
|
in_channels = tdnn_channels[block_index] |
|
self.pooler = StatisticsPooling() |
|
self.fc = nn.Linear(2 * out_channels, hidden_size) |
|
|
|
def forward(self, input_values, lengths=None): |
|
x = input_values |
|
x = x.permute(0, 2, 1) |
|
for block in self.blocks: |
|
x = block(x) |
|
last_hidden_state = x.permute(0, 2, 1) |
|
pooler_output = self.pooler(last_hidden_state, lengths) |
|
pooler_output = self.fc(pooler_output.squeeze(1)) |
|
return ModelOutput( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooler_output |
|
) |
|
|
|
|
|
class CosineSimilarityHead(torch.nn.Module): |
|
""" |
|
This class implements the cosine similarity on the top of features. |
|
""" |
|
def __init__( |
|
self, |
|
in_channels, |
|
lin_blocks=0, |
|
hidden_size=192, |
|
num_classes=1211, |
|
): |
|
super().__init__() |
|
self.blocks = nn.ModuleList() |
|
|
|
for block_index in range(lin_blocks): |
|
self.blocks.extend( |
|
[ |
|
nn.BatchNorm1d(num_features=in_channels), |
|
nn.Linear(in_features=in_channels, out_features=hidden_size), |
|
] |
|
) |
|
in_channels = hidden_size |
|
|
|
|
|
self.weight = nn.Parameter( |
|
torch.FloatTensor(num_classes, in_channels) |
|
) |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
def forward(self, x): |
|
"""Returns the output probabilities over speakers. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Torch tensor. |
|
""" |
|
for layer in self.blocks: |
|
x = layer(x) |
|
|
|
|
|
x = F.linear(F.normalize(x), F.normalize(self.weight)) |
|
return x |
|
|
|
|
|
class XvectorPreTrainedModel(PreTrainedModel): |
|
|
|
config_class = XvectorConfig |
|
base_model_prefix = "xvector" |
|
main_input_name = "input_values" |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class XvectorModel(XvectorPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.compute_features = Fbank( |
|
n_mels=config.n_mels, |
|
sample_rate=config.sample_rate, |
|
win_length=config.win_length, |
|
hop_length=config.hop_length, |
|
) |
|
self.mean_var_norm = InputNormalization( |
|
mean_norm=config.mean_norm, |
|
std_norm=config.std_norm, |
|
norm_type=config.norm_type |
|
) |
|
self.embedding_model = XvectorEmbedder( |
|
in_channels=config.n_mels, |
|
activation=nn.LeakyReLU, |
|
tdnn_blocks=config.tdnn_blocks, |
|
tdnn_channels=config.tdnn_channels, |
|
tdnn_kernel_sizes=config.tdnn_kernel_sizes, |
|
tdnn_dilations=config.tdnn_dilations, |
|
hidden_size=config.hidden_size, |
|
) |
|
|
|
def forward(self, input_values, lengths=None): |
|
x = input_values |
|
|
|
|
|
x = self.compute_features(x) |
|
x = self.mean_var_norm(x, lengths) |
|
output = self.embedding_model(x, lengths) |
|
return output |