Spaces:
Runtime error
Runtime error
# Copyright 2019 Shigeki Karita | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""Transformer speech recognition model (pytorch).""" | |
from argparse import Namespace | |
from distutils.util import strtobool | |
import logging | |
import math | |
import numpy | |
import torch | |
from espnet.nets.ctc_prefix_score import CTCPrefixScore | |
from espnet.nets.e2e_asr_common import end_detect | |
from espnet.nets.e2e_asr_common import ErrorCalculator | |
from espnet.nets.pytorch_backend.ctc import CTC | |
from espnet.nets.pytorch_backend.nets_utils import get_subsample | |
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask | |
from espnet.nets.pytorch_backend.nets_utils import th_accuracy | |
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos | |
from espnet.nets.pytorch_backend.transformer.attention import ( | |
MultiHeadedAttention, # noqa: H301 | |
RelPositionMultiHeadedAttention, # noqa: H301 | |
) | |
from espnet.nets.pytorch_backend.transformer.decoder import Decoder | |
from espnet.nets.pytorch_backend.transformer.encoder import Encoder | |
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( | |
LabelSmoothingLoss, # noqa: H301 | |
) | |
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask | |
from espnet.nets.pytorch_backend.transformer.mask import target_mask | |
from espnet.nets.scorers.ctc import CTCPrefixScorer | |
class E2E(torch.nn.Module): | |
"""E2E module. | |
:param int idim: dimension of inputs | |
:param int odim: dimension of outputs | |
:param Namespace args: argument Namespace containing options | |
""" | |
def add_arguments(parser): | |
"""Add arguments.""" | |
group = parser.add_argument_group("transformer model setting") | |
group.add_argument( | |
"--transformer-init", | |
type=str, | |
default="pytorch", | |
choices=[ | |
"pytorch", | |
"xavier_uniform", | |
"xavier_normal", | |
"kaiming_uniform", | |
"kaiming_normal", | |
], | |
help="how to initialize transformer parameters", | |
) | |
group.add_argument( | |
"--transformer-input-layer", | |
type=str, | |
default="conv2d", | |
choices=["conv3d", "conv2d", "conv1d", "linear", "embed"], | |
help="transformer input layer type", | |
) | |
group.add_argument( | |
"--transformer-encoder-attn-layer-type", | |
type=str, | |
default="mha", | |
choices=["mha", "rel_mha", "legacy_rel_mha"], | |
help="transformer encoder attention layer type", | |
) | |
group.add_argument( | |
"--transformer-attn-dropout-rate", | |
default=None, | |
type=float, | |
help="dropout in transformer attention. use --dropout-rate if None is set", | |
) | |
group.add_argument( | |
"--transformer-lr", | |
default=10.0, | |
type=float, | |
help="Initial value of learning rate", | |
) | |
group.add_argument( | |
"--transformer-warmup-steps", | |
default=25000, | |
type=int, | |
help="optimizer warmup steps", | |
) | |
group.add_argument( | |
"--transformer-length-normalized-loss", | |
default=True, | |
type=strtobool, | |
help="normalize loss by length", | |
) | |
group.add_argument( | |
"--dropout-rate", | |
default=0.0, | |
type=float, | |
help="Dropout rate for the encoder", | |
) | |
group.add_argument( | |
"--macaron-style", | |
default=False, | |
type=strtobool, | |
help="Whether to use macaron style for positionwise layer", | |
) | |
# -- input | |
group.add_argument( | |
"--a-upsample-ratio", | |
default=1, | |
type=int, | |
help="Upsample rate for audio", | |
) | |
group.add_argument( | |
"--relu-type", | |
default="swish", | |
type=str, | |
help="the type of activation layer", | |
) | |
# Encoder | |
group.add_argument( | |
"--elayers", | |
default=4, | |
type=int, | |
help="Number of encoder layers (for shared recognition part " | |
"in multi-speaker asr mode)", | |
) | |
group.add_argument( | |
"--eunits", | |
"-u", | |
default=300, | |
type=int, | |
help="Number of encoder hidden units", | |
) | |
group.add_argument( | |
"--use-cnn-module", | |
default=False, | |
type=strtobool, | |
help="Use convolution module or not", | |
) | |
group.add_argument( | |
"--cnn-module-kernel", | |
default=31, | |
type=int, | |
help="Kernel size of convolution module.", | |
) | |
# Attention | |
group.add_argument( | |
"--adim", | |
default=320, | |
type=int, | |
help="Number of attention transformation dimensions", | |
) | |
group.add_argument( | |
"--aheads", | |
default=4, | |
type=int, | |
help="Number of heads for multi head attention", | |
) | |
group.add_argument( | |
"--zero-triu", | |
default=False, | |
type=strtobool, | |
help="If true, zero the uppper triangular part of attention matrix.", | |
) | |
# Relative positional encoding | |
group.add_argument( | |
"--rel-pos-type", | |
type=str, | |
default="legacy", | |
choices=["legacy", "latest"], | |
help="Whether to use the latest relative positional encoding or the legacy one." | |
"The legacy relative positional encoding will be deprecated in the future." | |
"More Details can be found in https://github.com/espnet/espnet/pull/2816.", | |
) | |
# Decoder | |
group.add_argument( | |
"--dlayers", default=1, type=int, help="Number of decoder layers" | |
) | |
group.add_argument( | |
"--dunits", default=320, type=int, help="Number of decoder hidden units" | |
) | |
# -- pretrain | |
group.add_argument("--pretrain-dataset", | |
default="", | |
type=str, | |
help='pre-trained dataset for encoder' | |
) | |
# -- custom name | |
group.add_argument("--custom-pretrain-name", | |
default="", | |
type=str, | |
help='pre-trained model for encoder' | |
) | |
return parser | |
def attention_plot_class(self): | |
"""Return PlotAttentionReport.""" | |
return PlotAttentionReport | |
def __init__(self, odim, args, ignore_id=-1): | |
"""Construct an E2E object. | |
:param int odim: dimension of outputs | |
:param Namespace args: argument Namespace containing options | |
""" | |
torch.nn.Module.__init__(self) | |
if args.transformer_attn_dropout_rate is None: | |
args.transformer_attn_dropout_rate = args.dropout_rate | |
# Check the relative positional encoding type | |
self.rel_pos_type = getattr(args, "rel_pos_type", None) | |
if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha": | |
args.transformer_encoder_attn_layer_type = "legacy_rel_mha" | |
logging.warning( | |
"Using legacy_rel_pos and it will be deprecated in the future." | |
) | |
idim = 80 | |
self.encoder = Encoder( | |
idim=idim, | |
attention_dim=args.adim, | |
attention_heads=args.aheads, | |
linear_units=args.eunits, | |
num_blocks=args.elayers, | |
input_layer=args.transformer_input_layer, | |
dropout_rate=args.dropout_rate, | |
positional_dropout_rate=args.dropout_rate, | |
attention_dropout_rate=args.transformer_attn_dropout_rate, | |
encoder_attn_layer_type=args.transformer_encoder_attn_layer_type, | |
macaron_style=args.macaron_style, | |
use_cnn_module=args.use_cnn_module, | |
cnn_module_kernel=args.cnn_module_kernel, | |
zero_triu=getattr(args, "zero_triu", False), | |
a_upsample_ratio=args.a_upsample_ratio, | |
relu_type=getattr(args, "relu_type", "swish"), | |
) | |
self.transformer_input_layer = args.transformer_input_layer | |
self.a_upsample_ratio = args.a_upsample_ratio | |
if args.mtlalpha < 1: | |
self.decoder = Decoder( | |
odim=odim, | |
attention_dim=args.adim, | |
attention_heads=args.aheads, | |
linear_units=args.dunits, | |
num_blocks=args.dlayers, | |
dropout_rate=args.dropout_rate, | |
positional_dropout_rate=args.dropout_rate, | |
self_attention_dropout_rate=args.transformer_attn_dropout_rate, | |
src_attention_dropout_rate=args.transformer_attn_dropout_rate, | |
) | |
else: | |
self.decoder = None | |
self.blank = 0 | |
self.sos = odim - 1 | |
self.eos = odim - 1 | |
self.odim = odim | |
self.ignore_id = ignore_id | |
self.subsample = get_subsample(args, mode="asr", arch="transformer") | |
# self.lsm_weight = a | |
self.criterion = LabelSmoothingLoss( | |
self.odim, | |
self.ignore_id, | |
args.lsm_weight, | |
args.transformer_length_normalized_loss, | |
) | |
self.adim = args.adim | |
self.mtlalpha = args.mtlalpha | |
if args.mtlalpha > 0.0: | |
self.ctc = CTC( | |
odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True | |
) | |
else: | |
self.ctc = None | |
if args.report_cer or args.report_wer: | |
self.error_calculator = ErrorCalculator( | |
args.char_list, | |
args.sym_space, | |
args.sym_blank, | |
args.report_cer, | |
args.report_wer, | |
) | |
else: | |
self.error_calculator = None | |
self.rnnlm = None | |
def scorers(self): | |
"""Scorers.""" | |
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) | |
def encode(self, x, extract_resnet_feats=False): | |
"""Encode acoustic features. | |
:param ndarray x: source acoustic feature (T, D) | |
:return: encoder outputs | |
:rtype: torch.Tensor | |
""" | |
self.eval() | |
x = torch.as_tensor(x).unsqueeze(0) | |
if extract_resnet_feats: | |
resnet_feats = self.encoder( | |
x, | |
None, | |
extract_resnet_feats=extract_resnet_feats, | |
) | |
return resnet_feats.squeeze(0) | |
else: | |
enc_output, _ = self.encoder(x, None) | |
return enc_output.squeeze(0) | |