Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import logging | |
import math | |
import torch | |
import contextlib | |
from typing import List, Tuple | |
import torch.nn as nn | |
from fairseq.data.data_utils import lengths_to_padding_mask | |
from fairseq.data.data_utils import compute_mask_indices | |
from fairseq.modules import ( | |
PositionalEmbedding, | |
Fp32GroupNorm, | |
FairseqDropout, | |
SamePad, | |
GradMultiply, | |
LayerNorm, | |
Fp32LayerNorm, | |
TransposeLast, | |
) | |
import numpy as np | |
logger = logging.getLogger(__name__) | |
class LinearLayer(nn.Module): | |
def __init__(self, idim, odom, dropout=0): | |
super(LinearLayer, self).__init__() | |
self.linear = nn.Sequential( | |
nn.Linear(idim, odom), | |
nn.LayerNorm(odom), | |
nn.Dropout(dropout), | |
nn.ReLU(), | |
) | |
def get_out_seq_lens_tensor(self, in_seq_lens_tensor): | |
out = in_seq_lens_tensor.clone() | |
return out | |
def forward(self, src_tokens, src_lengths): | |
""" | |
src_tokens: [B, T, C] | |
src_lengths: [B] | |
""" | |
x = self.linear(src_tokens) | |
x = x.transpose(0, 1).contiguous() # -> T x B x C | |
return x, src_lengths | |
class SpeechEncoderPrenet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__(self, args): | |
super(SpeechEncoderPrenet, self).__init__() | |
self.dropout_module = FairseqDropout( | |
p=args.dropout, module_name=self.__class__.__name__ | |
) | |
self.embed_scale = math.sqrt(args.encoder_embed_dim) | |
if args.no_scale_embedding: | |
self.embed_scale = 1.0 | |
self.padding_idx = 1 | |
self.freeze_encoder_updates = args.freeze_encoder_updates | |
self.num_updates = 0 | |
assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet | |
feature_enc_layers = eval(args.conv_feature_layers) # noqa | |
self.embed = feature_enc_layers[-1][0] | |
self.feature_extractor = ConvFeatureExtractionModel( | |
conv_layers=feature_enc_layers, | |
dropout=0.0, | |
mode=args.extractor_mode, | |
conv_bias=args.conv_bias, | |
) | |
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) | |
self.feat2tar_ratio = ( | |
args.label_rates * feature_ds_rate / args.sample_rate | |
) | |
self.post_extract_proj = ( | |
nn.Linear(self.embed, args.encoder_embed_dim) | |
if self.embed != args.encoder_embed_dim | |
else None | |
) | |
self.use_conv_pos = args.use_conv_pos | |
self.use_sinc_pos = args.use_sinc_pos | |
self.use_abs_pos = getattr(args, "use_abs_pos", False) | |
self.feature_grad_mult = args.feature_grad_mult | |
if self.use_conv_pos: | |
self.layer_norm = LayerNorm(self.embed) | |
self.pos_conv = nn.Conv1d( | |
args.encoder_embed_dim, | |
args.encoder_embed_dim, | |
kernel_size=args.conv_pos, | |
padding=args.conv_pos // 2, | |
groups=args.conv_pos_groups, | |
) | |
dropout = 0 | |
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim)) | |
nn.init.normal_(self.pos_conv.weight, mean=0, std=std) | |
nn.init.constant_(self.pos_conv.bias, 0) | |
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) | |
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) | |
assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}" | |
if self.use_sinc_pos: | |
self.embed_positions = PositionalEmbedding( | |
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx | |
) | |
if self.use_abs_pos: | |
self.embed_positions = PositionalEmbedding( | |
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True | |
) | |
# Hubert | |
self.mask_prob = args.mask_prob | |
self.mask_selection = args.mask_selection | |
self.mask_other = args.mask_other | |
self.hubert_mask_length = args.hubert_mask_length | |
self.no_mask_overlap = args.no_mask_overlap | |
self.mask_min_space = args.mask_min_space | |
self.mask_channel_prob = args.mask_channel_prob | |
self.mask_channel_selection = args.mask_channel_selection | |
self.mask_channel_other = args.mask_channel_other | |
self.mask_channel_length = args.mask_channel_length | |
self.no_mask_channel_overlap = args.no_mask_channel_overlap | |
self.mask_channel_min_space = args.mask_channel_min_space | |
self.mask_emb = nn.Parameter( | |
torch.FloatTensor(args.encoder_embed_dim).uniform_() | |
) | |
def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): | |
ft = self.freeze_encoder_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask) | |
def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True): | |
if self.feature_grad_mult > 0: | |
x = self.feature_extractor(src_tokens) | |
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] | |
if self.feature_grad_mult != 1.0: | |
x = GradMultiply.apply(x, self.feature_grad_mult) | |
else: | |
with torch.no_grad(): | |
x = self.feature_extractor(src_tokens) | |
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size] | |
x = x.transpose(0, 1) # [batch, length, hidden_size] | |
encoder_padding_mask = padding_mask | |
x = x.transpose(1, 2) # [batch, hidden_size, length] | |
if target_list is not None: | |
x, target_list = self.forward_targets(x, target_list) | |
features_pen = x.float().pow(2).mean() | |
x = x.transpose(1, 2) # [batch, length, hidden_size] | |
x = self.layer_norm(x) | |
encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask) | |
if self.post_extract_proj is not None: | |
x = self.post_extract_proj(x) | |
x = self.dropout_module(x) | |
if mask: | |
x, mask_indices = self.apply_hubert_mask( | |
x, encoder_padding_mask | |
) | |
else: | |
x = x | |
mask_indices = None | |
if self.use_conv_pos: | |
positions = self.pos_conv(x.transpose(1, 2)) | |
positions = positions.transpose(1, 2) | |
#else: | |
# positions = self.embed_positions(encoder_padding_mask) | |
x = x + positions | |
if self.use_sinc_pos: | |
positions = self.embed_positions(encoder_padding_mask) | |
x = x + positions | |
# x = self.dropout_module(x) | |
if require_feat_pen: | |
return (x, features_pen, mask_indices, target_list), encoder_padding_mask | |
else: | |
# For consistence with encoder | |
return x, encoder_padding_mask | |
def forward_targets( | |
self, features: torch.Tensor, target_list: List[torch.Tensor], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Trim features to ensure labels exist and then get aligned labels | |
feat_tsz = features.size(2) | |
targ_tsz = min([t.size(1) for t in target_list]) | |
if self.feat2tar_ratio * feat_tsz > targ_tsz: | |
feat_tsz = int(targ_tsz / self.feat2tar_ratio) | |
features = features[..., :feat_tsz] | |
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio | |
target_list = [t[:, target_inds.long()] for t in target_list] | |
return features, target_list | |
def forward_padding_mask( | |
self, features: torch.Tensor, padding_mask: torch.Tensor, | |
) -> torch.Tensor: | |
extra = padding_mask.size(1) % features.size(1) | |
if extra > 0: | |
padding_mask = padding_mask[:, :-extra] | |
padding_mask = padding_mask.view( | |
padding_mask.size(0), features.size(1), -1 | |
) | |
padding_mask = padding_mask.all(-1) | |
return padding_mask | |
def get_src_lengths(self, src_lengths): | |
return self.feature_extractor.get_out_seq_lens_tensor(src_lengths) | |
def apply_hubert_mask(self, x, padding_mask): | |
B, T, C = x.shape | |
if self.mask_prob > 0: | |
mask_indices = compute_mask_indices( | |
(B, T), | |
padding_mask, | |
self.mask_prob, | |
self.hubert_mask_length, | |
self.mask_selection, | |
self.mask_other, | |
min_masks=2, | |
no_overlap=self.no_mask_overlap, | |
min_space=self.mask_min_space, | |
) | |
mask_indices = torch.from_numpy(mask_indices).to(x.device) | |
x[mask_indices] = self.mask_emb | |
else: | |
mask_indices = None | |
if self.mask_channel_prob > 0: | |
mask_channel_indices = compute_mask_indices( | |
(B, C), | |
None, | |
self.mask_channel_prob, | |
self.mask_channel_length, | |
self.mask_channel_selection, | |
self.mask_channel_other, | |
no_overlap=self.no_mask_channel_overlap, | |
min_space=self.mask_channel_min_space, | |
) | |
mask_channel_indices = ( | |
torch.from_numpy(mask_channel_indices) | |
.to(x.device) | |
.unsqueeze(1) | |
.expand(-1, T, -1) | |
) | |
x[mask_channel_indices] = 0 | |
return x, mask_indices | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
self.num_updates = num_updates | |
class ConvFeatureExtractionModel(nn.Module): | |
def __init__( | |
self, | |
conv_layers: List[Tuple[int, int, int]], | |
dropout: float = 0.0, | |
mode: str = "default", | |
conv_bias: bool = False, | |
): | |
super().__init__() | |
assert mode in {"default", "layer_norm"} | |
def block( | |
n_in, | |
n_out, | |
k, | |
stride, | |
is_layer_norm=False, | |
is_group_norm=False, | |
conv_bias=False, | |
): | |
def make_conv(): | |
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) | |
nn.init.kaiming_normal_(conv.weight) | |
return conv | |
assert ( | |
is_layer_norm and is_group_norm | |
) == False, "layer norm and group norm are exclusive" | |
if is_layer_norm: | |
return nn.Sequential( | |
make_conv(), | |
nn.Dropout(p=dropout), | |
nn.Sequential( | |
TransposeLast(), | |
Fp32LayerNorm(dim, elementwise_affine=True), | |
TransposeLast(), | |
), | |
nn.GELU(), | |
) | |
elif is_group_norm: | |
return nn.Sequential( | |
make_conv(), | |
nn.Dropout(p=dropout), | |
Fp32GroupNorm(dim, dim, affine=True), | |
nn.GELU(), | |
) | |
else: | |
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) | |
in_d = 1 | |
self.conv_layers = nn.ModuleList() | |
self.conv_layers_infos = conv_layers | |
for i, cl in enumerate(conv_layers): | |
assert len(cl) == 3, "invalid conv definition: " + str(cl) | |
(dim, k, stride) = cl | |
self.conv_layers.append( | |
block( | |
in_d, | |
dim, | |
k, | |
stride, | |
is_layer_norm=mode == "layer_norm", | |
is_group_norm=mode == "default" and i == 0, | |
conv_bias=conv_bias, | |
) | |
) | |
in_d = dim | |
def forward(self, x): | |
# BxT -> BxCxT | |
x = x.unsqueeze(1) | |
for conv in self.conv_layers: | |
x = conv(x) | |
return x | |
def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i): | |
"""Returns the out_seq_lens_nonmask 0/1 tensor after a layer. | |
Args: | |
in_seq_lens_tensor (LongTensor): length | |
Returns: | |
LongTensor: length | |
""" | |
out_lengths = in_seq_lens_tensor.clone() | |
out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() | |
out_nonmask = (~lengths_to_padding_mask(out_lengths)).float() | |
return out_nonmask, out_lengths | |
def get_out_seq_lens_tensor(self, in_seq_lens_tensor): | |
out = in_seq_lens_tensor.clone() | |
for i in range(len(self.conv_layers)): | |
out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long() | |
return out | |