chaplinDemo / espnet /nets /pytorch_backend /e2e_asr_transformer.py
willwade's picture
First push
e2c1e0f
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer speech recognition model (pytorch)."""
from argparse import Namespace
from distutils.util import strtobool
import logging
import math
import numpy
import torch
from espnet.nets.ctc_prefix_score import CTCPrefixScore
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.ctc import CTC
from espnet.nets.pytorch_backend.nets_utils import get_subsample
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.decoder import Decoder
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.mask import target_mask
from espnet.nets.scorers.ctc import CTCPrefixScorer
class E2E(torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
group = parser.add_argument_group("transformer model setting")
group.add_argument(
"--transformer-init",
type=str,
default="pytorch",
choices=[
"pytorch",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
],
help="how to initialize transformer parameters",
)
group.add_argument(
"--transformer-input-layer",
type=str,
default="conv2d",
choices=["conv3d", "conv2d", "conv1d", "linear", "embed"],
help="transformer input layer type",
)
group.add_argument(
"--transformer-encoder-attn-layer-type",
type=str,
default="mha",
choices=["mha", "rel_mha", "legacy_rel_mha"],
help="transformer encoder attention layer type",
)
group.add_argument(
"--transformer-attn-dropout-rate",
default=None,
type=float,
help="dropout in transformer attention. use --dropout-rate if None is set",
)
group.add_argument(
"--transformer-lr",
default=10.0,
type=float,
help="Initial value of learning rate",
)
group.add_argument(
"--transformer-warmup-steps",
default=25000,
type=int,
help="optimizer warmup steps",
)
group.add_argument(
"--transformer-length-normalized-loss",
default=True,
type=strtobool,
help="normalize loss by length",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
group.add_argument(
"--macaron-style",
default=False,
type=strtobool,
help="Whether to use macaron style for positionwise layer",
)
# -- input
group.add_argument(
"--a-upsample-ratio",
default=1,
type=int,
help="Upsample rate for audio",
)
group.add_argument(
"--relu-type",
default="swish",
type=str,
help="the type of activation layer",
)
# Encoder
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers (for shared recognition part "
"in multi-speaker asr mode)",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
group.add_argument(
"--use-cnn-module",
default=False,
type=strtobool,
help="Use convolution module or not",
)
group.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
# Attention
group.add_argument(
"--adim",
default=320,
type=int,
help="Number of attention transformation dimensions",
)
group.add_argument(
"--aheads",
default=4,
type=int,
help="Number of heads for multi head attention",
)
group.add_argument(
"--zero-triu",
default=False,
type=strtobool,
help="If true, zero the uppper triangular part of attention matrix.",
)
# Relative positional encoding
group.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# Decoder
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
# -- pretrain
group.add_argument("--pretrain-dataset",
default="",
type=str,
help='pre-trained dataset for encoder'
)
# -- custom name
group.add_argument("--custom-pretrain-name",
default="",
type=str,
help='pre-trained model for encoder'
)
return parser
@property
def attention_plot_class(self):
"""Return PlotAttentionReport."""
return PlotAttentionReport
def __init__(self, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
torch.nn.Module.__init__(self)
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.dropout_rate
# Check the relative positional encoding type
self.rel_pos_type = getattr(args, "rel_pos_type", None)
if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha":
args.transformer_encoder_attn_layer_type = "legacy_rel_mha"
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
idim = 80
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate,
encoder_attn_layer_type=args.transformer_encoder_attn_layer_type,
macaron_style=args.macaron_style,
use_cnn_module=args.use_cnn_module,
cnn_module_kernel=args.cnn_module_kernel,
zero_triu=getattr(args, "zero_triu", False),
a_upsample_ratio=args.a_upsample_ratio,
relu_type=getattr(args, "relu_type", "swish"),
)
self.transformer_input_layer = args.transformer_input_layer
self.a_upsample_ratio = args.a_upsample_ratio
if args.mtlalpha < 1:
self.decoder = Decoder(
odim=odim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
self_attention_dropout_rate=args.transformer_attn_dropout_rate,
src_attention_dropout_rate=args.transformer_attn_dropout_rate,
)
else:
self.decoder = None
self.blank = 0
self.sos = odim - 1
self.eos = odim - 1
self.odim = odim
self.ignore_id = ignore_id
self.subsample = get_subsample(args, mode="asr", arch="transformer")
# self.lsm_weight = a
self.criterion = LabelSmoothingLoss(
self.odim,
self.ignore_id,
args.lsm_weight,
args.transformer_length_normalized_loss,
)
self.adim = args.adim
self.mtlalpha = args.mtlalpha
if args.mtlalpha > 0.0:
self.ctc = CTC(
odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
)
else:
self.ctc = None
if args.report_cer or args.report_wer:
self.error_calculator = ErrorCalculator(
args.char_list,
args.sym_space,
args.sym_blank,
args.report_cer,
args.report_wer,
)
else:
self.error_calculator = None
self.rnnlm = None
def scorers(self):
"""Scorers."""
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x, extract_resnet_feats=False):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self.eval()
x = torch.as_tensor(x).unsqueeze(0)
if extract_resnet_feats:
resnet_feats = self.encoder(
x,
None,
extract_resnet_feats=extract_resnet_feats,
)
return resnet_feats.squeeze(0)
else:
enc_output, _ = self.encoder(x, None)
return enc_output.squeeze(0)