# ---------------------------------------------------------------------------- # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 # # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # ---------------------------------------------------------------------------- import math from typing import Dict, List, Optional import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.distributed import fsdp_wrap from fairseq.models import FairseqEncoder from fairseq.modules import ( FairseqDropout, LayerDropModuleList, LayerNorm, SinusoidalPositionalEmbedding, ) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from torch import Tensor from fairseq.models.transformer import ( TransformerConfig, ) from speechlm.modules import transformer_layer, LearnedPositionalEmbedding from speechlm.modules.relative_pos_enc import RelativePositionalEncoding # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: if module_name == "TransformerEncoderBase": return "TransformerEncoder" else: return module_name class TransformerEncoderBase(FairseqEncoder): """ Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): encoding dictionary embed_tokens (torch.nn.Embedding): input embedding """ def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0): self.cfg = cfg super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) self.dropout_module = FairseqDropout( cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) ) self.encoder_layerdrop = cfg.encoder.layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( cfg.max_source_positions, embed_dim, self.padding_idx, learned=cfg.encoder.learned_pos, ) if not cfg.no_token_positional_embeddings else None ) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.use_rel_pos_enc = use_rel_pos_enc self.scaling_for_att = scaling_for_att self.layers.extend( [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] ) self.num_layers = len(self.layers) if cfg.encoder.normalize_before: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None if self.use_rel_pos_enc: self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160) def build_encoder_layer(self, cfg): layer = transformer_layer.TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( self, src_tokens, token_embedding: Optional[torch.Tensor] = None ): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def forward( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, uniformity_layers: Optional[List[int]] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ return self.forward_scriptable( src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers ) # TorchScript doesn't support super() method so that the scriptable Subclass # can't access the base class model in Torchscript. # Current workaround is to add a helper function with different name and # call the helper function from scriptable Subclass. def forward_scriptable( self, src_tokens, src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, uniformity_layers: Optional[List[int]] = None, ): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). token_embeddings (torch.Tensor, optional): precomputed embeddings default `None` will recompute embeddings Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` - **encoder_embedding** (Tensor): the (scaled) embedding lookup of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) # account for padding while computing the representation if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) # B x T x C -> T x B x C x = x.transpose(0, 1) if self.use_rel_pos_enc: x_len = x.shape[0] pos_seq = torch.arange(0, x_len).long().to(x.device) pos_seq = pos_seq[:, None] - pos_seq[None, :] pos_k, pos_v = self.pos_emb(pos_seq) else: pos_k = None encoder_states = [] uniformity_hiddens = [] if return_all_hiddens: encoder_states.append(x) if uniformity_layers is not None and 0 in uniformity_layers: x = F.normalize(x.float(), dim=-1).type_as(x) uniformity_hiddens.append(x) # encoder layers for i, layer in enumerate(self.layers): x = layer( x, encoder_padding_mask=encoder_padding_mask if has_pads else None, pos_bias=pos_k, ) if uniformity_layers is not None and i+1 in uniformity_layers: x = F.normalize(x.float(), dim=-1).type_as(x) uniformity_hiddens.append(x) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. src_lengths = ( src_tokens.ne(self.padding_idx) .sum(dim=1, dtype=torch.int32) .reshape(-1, 1) .contiguous() ) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "uniformity_hiddens": uniformity_hiddens, # List[T x B x C] "src_tokens": [], "src_lengths": [src_lengths], } @torch.jit.export def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. Args: encoder_out: output from the ``forward()`` method new_order (LongTensor): desired order Returns: *encoder_out* rearranged according to *new_order* """ if len(encoder_out["encoder_out"]) == 0: new_encoder_out = [] else: new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] if len(encoder_out["encoder_padding_mask"]) == 0: new_encoder_padding_mask = [] else: new_encoder_padding_mask = [ encoder_out["encoder_padding_mask"][0].index_select(0, new_order) ] if len(encoder_out["encoder_embedding"]) == 0: new_encoder_embedding = [] else: new_encoder_embedding = [ encoder_out["encoder_embedding"][0].index_select(0, new_order) ] if len(encoder_out["src_tokens"]) == 0: src_tokens = [] else: src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] if len(encoder_out["src_lengths"]) == 0: src_lengths = [] else: src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": src_tokens, # B x T "src_lengths": src_lengths, # B x 1 } def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = "{}.embed_positions.weights".format(name) if weights_key in state_dict: print("deleting {0}".format(weights_key)) del state_dict[weights_key] state_dict[ "{}.embed_positions._float_tensor".format(name) ] = torch.FloatTensor(1) for i in range(self.num_layers): # update layer norms self.layers[i].upgrade_state_dict_named( state_dict, "{}.layers.{}".format(name, i) ) version_key = "{}.version".format(name) if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: # earlier checkpoints did not normalize after the stack of layers self.layer_norm = None self.normalize = False state_dict[version_key] = torch.Tensor([1]) return state_dict class TransformerEncoder(TransformerEncoderBase): def __init__(self, args, dictionary, embed_tokens): self.args = args super().__init__( TransformerConfig.from_namespace(args), dictionary, embed_tokens, use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False), scaling_for_att=getattr(args, "scaling_for_att", 1.0), ) def build_encoder_layer(self, args): return super().build_encoder_layer( TransformerConfig.from_namespace(args), ) def PositionalEmbedding( num_embeddings: int, embedding_dim: int, padding_idx: int, learned: bool = False, ): if learned: # if padding_idx is specified then offset the embedding ids by # this index and adjust num_embeddings appropriately # TODO: The right place for this offset would be inside # LearnedPositionalEmbedding. Move this there for a cleaner implementation. if padding_idx is not None: num_embeddings = num_embeddings + padding_idx + 1 m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) if padding_idx is not None: nn.init.constant_(m.weight[padding_idx], 0) else: m = SinusoidalPositionalEmbedding( embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, ) return m