artst-demo-asr / artst /models /modules /text_encoder_prenet.py
amupd's picture
initial commit
8b33290
raw
history blame
1.65 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import torch.nn as nn
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
class TextEncoderPrenet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
embed_tokens,
args,
):
super(TextEncoderPrenet, self).__init__()
self.padding_idx = embed_tokens.padding_idx
# define encoder prenet
# get positional encoding class
pos_enc_class = (
ScaledPositionalEncoding if args.enc_use_scaled_pos_enc else PositionalEncoding
)
self.encoder_prenet = nn.Sequential(
embed_tokens,
pos_enc_class(args.encoder_embed_dim, args.transformer_enc_positional_dropout_rate, max_len=args.max_text_positions),
)
def forward(self, src_tokens):
return self.encoder_prenet(src_tokens), src_tokens.eq(self.padding_idx)