artst-demo-asr / artst /models /modules /text_decoder_postnet.py
amupd's picture
initial commit
8b33290
raw
history blame
3.57 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import torch.nn as nn
import torch
import contextlib
from fairseq import utils
from fairseq.modules import (
AdaptiveSoftmax,
)
class TextDecoderPostnet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, embed_tokens, dictionary, args, output_projection=None,):
super(TextDecoderPostnet, self).__init__()
self.output_embed_dim = args.decoder_output_dim
self.output_projection = output_projection
self.adaptive_softmax = None
self.share_input_output_embed = args.share_input_output_embed
if self.output_projection is None:
self.build_output_projection(args, dictionary, embed_tokens)
self.freeze_decoder_updates = args.freeze_decoder_updates
self.num_updates = 0
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def build_output_projection(self, args, dictionary, embed_tokens):
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
# num_base_layers = getattr(args, "base_layers", 0)
# for i in range(num_base_layers):
# self.layers.insert(
# ((i + 1) * args.decoder_layers) // (num_base_layers + 1),
# BaseLayer(args),
# )
def forward(self, x):
ft = self.freeze_decoder_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
return self._forward(x)
def _forward(self, x):
# embed positions
x = self.output_layer(x)
return x
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self.num_updates = num_updates