Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import torch.nn as nn | |
import math | |
import torch | |
import torch.nn.functional as F | |
class AngularMargin(nn.Module): | |
""" | |
An implementation of Angular Margin (AM) proposed in the following | |
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network | |
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317) | |
Arguments | |
--------- | |
margin : float | |
The margin for cosine similiarity | |
scale : float | |
The scale for cosine similiarity | |
Return | |
--------- | |
predictions : torch.Tensor | |
Example | |
------- | |
>>> pred = AngularMargin() | |
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) | |
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) | |
>>> predictions = pred(outputs, targets) | |
>>> predictions[:,0] > predictions[:,1] | |
tensor([ True, False, True, False]) | |
""" | |
def __init__(self, margin=0.0, scale=1.0): | |
super(AngularMargin, self).__init__() | |
self.margin = margin | |
self.scale = scale | |
def forward(self, outputs, targets): | |
"""Compute AM between two tensors | |
Arguments | |
--------- | |
outputs : torch.Tensor | |
The outputs of shape [N, C], cosine similarity is required. | |
targets : torch.Tensor | |
The targets of shape [N, C], where the margin is applied for. | |
Return | |
--------- | |
predictions : torch.Tensor | |
""" | |
outputs = outputs - self.margin * targets | |
return self.scale * outputs | |
class AdditiveAngularMargin(AngularMargin): | |
""" | |
An implementation of Additive Angular Margin (AAM) proposed | |
in the following paper: '''Margin Matters: Towards More Discriminative Deep | |
Neural Network Embeddings for Speaker Recognition''' | |
(https://arxiv.org/abs/1906.07317) | |
Arguments | |
--------- | |
margin : float | |
The margin for cosine similiarity, usually 0.2. | |
scale: float | |
The scale for cosine similiarity, usually 30. | |
Returns | |
------- | |
predictions : torch.Tensor | |
Tensor. | |
Example | |
------- | |
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) | |
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) | |
>>> pred = AdditiveAngularMargin() | |
>>> predictions = pred(outputs, targets) | |
>>> predictions[:,0] > predictions[:,1] | |
tensor([ True, False, True, False]) | |
""" | |
def __init__(self, margin=0.0, scale=1.0, easy_margin=False): | |
super(AdditiveAngularMargin, self).__init__(margin, scale) | |
self.easy_margin = easy_margin | |
self.cos_m = math.cos(self.margin) | |
self.sin_m = math.sin(self.margin) | |
self.th = math.cos(math.pi - self.margin) | |
self.mm = math.sin(math.pi - self.margin) * self.margin | |
def forward(self, outputs, targets): | |
""" | |
Compute AAM between two tensors | |
Arguments | |
--------- | |
outputs : torch.Tensor | |
The outputs of shape [N, C], cosine similarity is required. | |
targets : torch.Tensor | |
The targets of shape [N, C], where the margin is applied for. | |
Return | |
--------- | |
predictions : torch.Tensor | |
""" | |
cosine = outputs.float() | |
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) | |
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) | |
if self.easy_margin: | |
phi = torch.where(cosine > 0, phi, cosine) | |
else: | |
phi = torch.where(cosine > self.th, phi, cosine - self.mm) | |
outputs = (targets * phi) + ((1.0 - targets) * cosine) | |
return self.scale * outputs | |
class SpeakerDecoderPostnet(nn.Module): | |
"""Speaker Identification Postnet. | |
Arguments | |
--------- | |
embed_dim : int | |
The size of embedding. | |
class_num: int | |
The number of classes. | |
args : Namespace | |
Return | |
--------- | |
embed : torch.Tensor | |
output : torch.Tensor | |
""" | |
def __init__(self, embed_dim, class_num, args): | |
super(SpeakerDecoderPostnet, self).__init__() | |
self.embed_dim = embed_dim | |
self.class_num = class_num | |
self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False) | |
self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False) | |
self.normalize_postnet = getattr(args, "sid_normalize_postnet", False) | |
self.softmax_head = getattr(args, "sid_softmax_type", "softmax") | |
if not self.no_pooling_bn: | |
self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim) | |
else: | |
self.bn_pooling = None | |
if not self.no_embed_postnet: | |
self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False) | |
self.bn_embedding = nn.BatchNorm1d(embed_dim) | |
else: | |
self.output_embedding = None | |
self.bn_embedding = None | |
self.embed_dim = args.decoder_output_dim | |
self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False) | |
if self.softmax_head == "amsoftmax": | |
self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale) | |
elif self.softmax_head == "aamsoftmax": | |
self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin) | |
else: | |
self.output_layer = None | |
if self.output_embedding is not None: | |
nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5) | |
nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5) | |
def forward(self, x, target=None): | |
""" | |
Parameters | |
---------- | |
x : torch.Tensor of shape [batch, channel] or [batch, time, channel] | |
target : torch.Tensor of shape [batch, channel] | |
""" | |
if self.bn_pooling is not None: | |
x = self.bn_pooling(x) | |
if self.output_embedding is not None and self.bn_embedding is not None: | |
embed = self.bn_embedding(self.output_embedding(x)) | |
else: | |
embed = x | |
if self.output_layer is not None or self.normalize_postnet: | |
x_norm = F.normalize(embed, p=2, dim=1) | |
w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim] | |
output = F.linear(x_norm, w_norm) | |
if self.training and target is not None and self.output_layer is not None: | |
output = self.output_layer(output, target) | |
else: | |
output = self.output_projection(embed) | |
return output, embed | |