# Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Transformer speech recognition model (pytorch).""" from argparse import Namespace from distutils.util import strtobool import logging import math import numpy import torch from espnet.nets.ctc_prefix_score import CTCPrefixScore from espnet.nets.e2e_asr_common import end_detect from espnet.nets.e2e_asr_common import ErrorCalculator from espnet.nets.pytorch_backend.ctc import CTC from espnet.nets.pytorch_backend.nets_utils import get_subsample from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask from espnet.nets.pytorch_backend.nets_utils import th_accuracy from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos from espnet.nets.pytorch_backend.transformer.attention import ( MultiHeadedAttention, # noqa: H301 RelPositionMultiHeadedAttention, # noqa: H301 ) from espnet.nets.pytorch_backend.transformer.decoder import Decoder from espnet.nets.pytorch_backend.transformer.encoder import Encoder from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask from espnet.nets.pytorch_backend.transformer.mask import target_mask from espnet.nets.scorers.ctc import CTCPrefixScorer class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="how to initialize transformer parameters", ) group.add_argument( "--transformer-input-layer", type=str, default="conv2d", choices=["conv3d", "conv2d", "conv1d", "linear", "embed"], help="transformer input layer type", ) group.add_argument( "--transformer-encoder-attn-layer-type", type=str, default="mha", choices=["mha", "rel_mha", "legacy_rel_mha"], help="transformer encoder attention layer type", ) group.add_argument( "--transformer-attn-dropout-rate", default=None, type=float, help="dropout in transformer attention. use --dropout-rate if None is set", ) group.add_argument( "--transformer-lr", default=10.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=25000, type=int, help="optimizer warmup steps", ) group.add_argument( "--transformer-length-normalized-loss", default=True, type=strtobool, help="normalize loss by length", ) group.add_argument( "--dropout-rate", default=0.0, type=float, help="Dropout rate for the encoder", ) group.add_argument( "--macaron-style", default=False, type=strtobool, help="Whether to use macaron style for positionwise layer", ) # -- input group.add_argument( "--a-upsample-ratio", default=1, type=int, help="Upsample rate for audio", ) group.add_argument( "--relu-type", default="swish", type=str, help="the type of activation layer", ) # Encoder group.add_argument( "--elayers", default=4, type=int, help="Number of encoder layers (for shared recognition part " "in multi-speaker asr mode)", ) group.add_argument( "--eunits", "-u", default=300, type=int, help="Number of encoder hidden units", ) group.add_argument( "--use-cnn-module", default=False, type=strtobool, help="Use convolution module or not", ) group.add_argument( "--cnn-module-kernel", default=31, type=int, help="Kernel size of convolution module.", ) # Attention group.add_argument( "--adim", default=320, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aheads", default=4, type=int, help="Number of heads for multi head attention", ) group.add_argument( "--zero-triu", default=False, type=strtobool, help="If true, zero the uppper triangular part of attention matrix.", ) # Relative positional encoding group.add_argument( "--rel-pos-type", type=str, default="legacy", choices=["legacy", "latest"], help="Whether to use the latest relative positional encoding or the legacy one." "The legacy relative positional encoding will be deprecated in the future." "More Details can be found in https://github.com/espnet/espnet/pull/2816.", ) # Decoder group.add_argument( "--dlayers", default=1, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=320, type=int, help="Number of decoder hidden units" ) # -- pretrain group.add_argument("--pretrain-dataset", default="", type=str, help='pre-trained dataset for encoder' ) # -- custom name group.add_argument("--custom-pretrain-name", default="", type=str, help='pre-trained model for encoder' ) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, odim, args, ignore_id=-1): """Construct an E2E object. :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate # Check the relative positional encoding type self.rel_pos_type = getattr(args, "rel_pos_type", None) if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha": args.transformer_encoder_attn_layer_type = "legacy_rel_mha" logging.warning( "Using legacy_rel_pos and it will be deprecated in the future." ) idim = 80 self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, encoder_attn_layer_type=args.transformer_encoder_attn_layer_type, macaron_style=args.macaron_style, use_cnn_module=args.use_cnn_module, cnn_module_kernel=args.cnn_module_kernel, zero_triu=getattr(args, "zero_triu", False), a_upsample_ratio=args.a_upsample_ratio, relu_type=getattr(args, "relu_type", "swish"), ) self.transformer_input_layer = args.transformer_input_layer self.a_upsample_ratio = args.a_upsample_ratio if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) else: self.decoder = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x, extract_resnet_feats=False): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) if extract_resnet_feats: resnet_feats = self.encoder( x, None, extract_resnet_feats=extract_resnet_feats, ) return resnet_feats.squeeze(0) else: enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0)