""" Taken from ESPNet, but heavily modified """ import torch from Modules.GeneralLayers.Attention import RelPositionMultiHeadedAttention from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm from Modules.GeneralLayers.Convolution import ConvolutionModule from Modules.GeneralLayers.EncoderLayer import EncoderLayer from Modules.GeneralLayers.LayerNorm import LayerNorm from Modules.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d from Modules.GeneralLayers.MultiSequential import repeat from Modules.GeneralLayers.PositionalEncoding import RelPositionalEncoding from Modules.GeneralLayers.Swish import Swish from Utility.utils import integrate_with_utt_embed class Conformer(torch.nn.Module): """ Conformer encoder module. Args: idim (int): Input dimension. attention_dim (int): Dimension of attention. attention_heads (int): The number of heads of multi head attention. linear_units (int): The number of units of position-wise feed forward. num_blocks (int): The number of decoder blocks. dropout_rate (float): Dropout rate. positional_dropout_rate (float): Dropout rate after adding positional encoding. attention_dropout_rate (float): Dropout rate in attention. input_layer (Union[str, torch.nn.Module]): Input layer type. normalize_before (bool): Whether to use layer_norm before the first block. concat_after (bool): Whether to concat attention layer's input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. macaron_style (bool): Whether to use macaron style for positionwise layer. use_cnn_module (bool): Whether to use convolution module. cnn_module_kernel (int): Kernel size of convolution module. """ def __init__(self, conformer_type, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, lang_emb_size=16, use_output_norm=True, embedding_integration="AdaIN"): super(Conformer, self).__init__() activation = Swish() self.conv_subsampling_factor = 1 self.use_output_norm = use_output_norm if isinstance(input_layer, torch.nn.Module): self.embed = input_layer self.art_embed_norm = LayerNorm(attention_dim) self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate) elif input_layer is None: self.embed = None self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate)) else: raise ValueError("unknown input_layer: " + input_layer) if self.use_output_norm: self.output_norm = LayerNorm(attention_dim) self.utt_embed = utt_embed self.conformer_type = conformer_type self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"] if utt_embed is not None: if conformer_type == "encoder": # the encoder gets an additional conditioning signal added to its output if embedding_integration == "AdaIN": self.encoder_embedding_projection = AdaIN1d(style_dim=utt_embed, num_features=attention_dim) elif embedding_integration == "ConditionalLayerNorm": self.encoder_embedding_projection = ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim) else: self.encoder_embedding_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim) else: if embedding_integration == "AdaIN": self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: AdaIN1d(style_dim=utt_embed, num_features=attention_dim)) elif embedding_integration == "ConditionalLayerNorm": self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim)) else: self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim)) if lang_embs is not None: self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size) if lang_emb_size == attention_dim: self.language_embedding_projection = lambda x: x else: self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim) self.language_emb_norm = LayerNorm(attention_dim) # self-attention module definition encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) # feed-forward module definition positionwise_layer = MultiLayeredConv1d positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) # convolution module definition convolution_layer = ConvolutionModule convolution_layer_args = (attention_dim, cnn_module_kernel, activation) self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), positionwise_layer(*positionwise_layer_args), positionwise_layer(*positionwise_layer_args) if macaron_style else None, convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, normalize_before, concat_after)) def forward(self, xs, masks, utterance_embedding=None, lang_ids=None): """ Encode input sequence. Args: utterance_embedding: embedding containing lots of conditioning signals lang_ids: ids of the languages per sample in the batch xs (torch.Tensor): Input tensor (#batch, time, idim). masks (torch.Tensor): Mask tensor (#batch, time). Returns: torch.Tensor: Output tensor (#batch, time, attention_dim). torch.Tensor: Mask tensor (#batch, time). """ if self.embed is not None: xs = self.embed(xs) xs = self.art_embed_norm(xs) if lang_ids is not None: lang_embs = self.language_embedding(lang_ids) projected_lang_embs = self.language_embedding_projection(lang_embs).unsqueeze(-1).transpose(1, 2) projected_lang_embs = self.language_emb_norm(projected_lang_embs) xs = xs + projected_lang_embs # offset phoneme representation by language specific offset xs = self.pos_enc(xs) for encoder_index, encoder in enumerate(self.encoders): if self.utt_embed: if isinstance(xs, tuple): x, pos_emb = xs[0], xs[1] if self.conformer_type != "encoder": x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) xs = (x, pos_emb) else: if self.conformer_type != "encoder": xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) xs, masks = encoder(xs, masks) if isinstance(xs, tuple): xs = xs[0] if self.utt_embed and self.conformer_type == "encoder": xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration) elif self.use_output_norm: xs = self.output_norm(xs) return xs, masks