artst-demo-asr / artst /models /modules /speech_encoder_prenet.py
amupd's picture
initial commit
8b33290
raw
history blame
13.7 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
import math
import torch
import contextlib
from typing import List, Tuple
import torch.nn as nn
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.data.data_utils import compute_mask_indices
from fairseq.modules import (
PositionalEmbedding,
Fp32GroupNorm,
FairseqDropout,
SamePad,
GradMultiply,
LayerNorm,
Fp32LayerNorm,
TransposeLast,
)
import numpy as np
logger = logging.getLogger(__name__)
class LinearLayer(nn.Module):
def __init__(self, idim, odom, dropout=0):
super(LinearLayer, self).__init__()
self.linear = nn.Sequential(
nn.Linear(idim, odom),
nn.LayerNorm(odom),
nn.Dropout(dropout),
nn.ReLU(),
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
return out
def forward(self, src_tokens, src_lengths):
"""
src_tokens: [B, T, C]
src_lengths: [B]
"""
x = self.linear(src_tokens)
x = x.transpose(0, 1).contiguous() # -> T x B x C
return x, src_lengths
class SpeechEncoderPrenet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, args):
super(SpeechEncoderPrenet, self).__init__()
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_scale = math.sqrt(args.encoder_embed_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = 1
self.freeze_encoder_updates = args.freeze_encoder_updates
self.num_updates = 0
assert args.encoder_speech_prenet in ["conv", "linear"], args.encoder_speech_prenet
feature_enc_layers = eval(args.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=args.extractor_mode,
conv_bias=args.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = (
args.label_rates * feature_ds_rate / args.sample_rate
)
self.post_extract_proj = (
nn.Linear(self.embed, args.encoder_embed_dim)
if self.embed != args.encoder_embed_dim
else None
)
self.use_conv_pos = args.use_conv_pos
self.use_sinc_pos = args.use_sinc_pos
self.use_abs_pos = getattr(args, "use_abs_pos", False)
self.feature_grad_mult = args.feature_grad_mult
if self.use_conv_pos:
self.layer_norm = LayerNorm(self.embed)
self.pos_conv = nn.Conv1d(
args.encoder_embed_dim,
args.encoder_embed_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * args.encoder_embed_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
assert not (self.use_sinc_pos and self.use_abs_pos), f"sinc pos: {self.use_sinc_pos} abs pos: {self.use_abs_pos}"
if self.use_sinc_pos:
self.embed_positions = PositionalEmbedding(
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx
)
if self.use_abs_pos:
self.embed_positions = PositionalEmbedding(
args.max_speech_positions, args.encoder_embed_dim, self.padding_idx, learned=True
)
# Hubert
self.mask_prob = args.mask_prob
self.mask_selection = args.mask_selection
self.mask_other = args.mask_other
self.hubert_mask_length = args.hubert_mask_length
self.no_mask_overlap = args.no_mask_overlap
self.mask_min_space = args.mask_min_space
self.mask_channel_prob = args.mask_channel_prob
self.mask_channel_selection = args.mask_channel_selection
self.mask_channel_other = args.mask_channel_other
self.mask_channel_length = args.mask_channel_length
self.no_mask_channel_overlap = args.no_mask_channel_overlap
self.mask_channel_min_space = args.mask_channel_min_space
self.mask_emb = nn.Parameter(
torch.FloatTensor(args.encoder_embed_dim).uniform_()
)
def forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
ft = self.freeze_encoder_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
return self._forward(src_tokens, require_feat_pen, target_list, padding_mask, mask)
def _forward(self, src_tokens, require_feat_pen=False, target_list=None, padding_mask=None, mask=True):
if self.feature_grad_mult > 0:
x = self.feature_extractor(src_tokens)
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
if self.feature_grad_mult != 1.0:
x = GradMultiply.apply(x, self.feature_grad_mult)
else:
with torch.no_grad():
x = self.feature_extractor(src_tokens)
x = x.transpose(1, 2).transpose(0, 1) # [length, batch, hidden_size]
x = x.transpose(0, 1) # [batch, length, hidden_size]
encoder_padding_mask = padding_mask
x = x.transpose(1, 2) # [batch, hidden_size, length]
if target_list is not None:
x, target_list = self.forward_targets(x, target_list)
features_pen = x.float().pow(2).mean()
x = x.transpose(1, 2) # [batch, length, hidden_size]
x = self.layer_norm(x)
encoder_padding_mask = self.forward_padding_mask(x, encoder_padding_mask)
if self.post_extract_proj is not None:
x = self.post_extract_proj(x)
x = self.dropout_module(x)
if mask:
x, mask_indices = self.apply_hubert_mask(
x, encoder_padding_mask
)
else:
x = x
mask_indices = None
if self.use_conv_pos:
positions = self.pos_conv(x.transpose(1, 2))
positions = positions.transpose(1, 2)
#else:
# positions = self.embed_positions(encoder_padding_mask)
x = x + positions
if self.use_sinc_pos:
positions = self.embed_positions(encoder_padding_mask)
x = x + positions
# x = self.dropout_module(x)
if require_feat_pen:
return (x, features_pen, mask_indices, target_list), encoder_padding_mask
else:
# For consistence with encoder
return x, encoder_padding_mask
def forward_targets(
self, features: torch.Tensor, target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
targ_tsz = min([t.size(1) for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[..., :feat_tsz]
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
target_list = [t[:, target_inds.long()] for t in target_list]
return features, target_list
def forward_padding_mask(
self, features: torch.Tensor, padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1)
return padding_mask
def get_src_lengths(self, src_lengths):
return self.feature_extractor.get_out_seq_lens_tensor(src_lengths)
def apply_hubert_mask(self, x, padding_mask):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.hubert_mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self.num_updates = num_updates
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
in_d = 1
self.conv_layers = nn.ModuleList()
self.conv_layers_infos = conv_layers
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
def get_out_seq_lens_nonmask_after_a_layer(self, in_seq_lens_tensor, i):
"""Returns the out_seq_lens_nonmask 0/1 tensor after a layer.
Args:
in_seq_lens_tensor (LongTensor): length
Returns:
LongTensor: length
"""
out_lengths = in_seq_lens_tensor.clone()
out_lengths = ((out_lengths.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
out_nonmask = (~lengths_to_padding_mask(out_lengths)).float()
return out_nonmask, out_lengths
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for i in range(len(self.conv_layers)):
out = ((out.float() - (self.conv_layers_infos[i][1] - 1) - 1) / self.conv_layers_infos[i][-1] + 1).floor().long()
return out