Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import torch.nn as nn | |
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding | |
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding | |
class TextEncoderPrenet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__( | |
self, | |
embed_tokens, | |
args, | |
): | |
super(TextEncoderPrenet, self).__init__() | |
self.padding_idx = embed_tokens.padding_idx | |
# define encoder prenet | |
# get positional encoding class | |
pos_enc_class = ( | |
ScaledPositionalEncoding if args.enc_use_scaled_pos_enc else PositionalEncoding | |
) | |
self.encoder_prenet = nn.Sequential( | |
embed_tokens, | |
pos_enc_class(args.encoder_embed_dim, args.transformer_enc_positional_dropout_rate, max_len=args.max_text_positions), | |
) | |
def forward(self, src_tokens): | |
return self.encoder_prenet(src_tokens), src_tokens.eq(self.padding_idx) | |