Spaces:
Sleeping
Sleeping
"""Building blocks for speech SSL models supporting pruning. | |
Originally from: | |
https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py | |
""" | |
from collections import defaultdict | |
from typing import List, Optional, Tuple | |
import math | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import Module, Parameter | |
from .hardconcrete import HardConcrete | |
from .pruning_utils import ( | |
prune_linear_layer, | |
prune_conv1d_layer, | |
prune_layer_norm, | |
) | |
def _init_transformer_params(module): | |
""" | |
Initialize the weights of Transformer module in Wav2Vec2/HuBERT. | |
If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. | |
If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. | |
If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. | |
If ``padding_idx`` is not None, set the weight of padding to 0. | |
Note: | |
Ths method corresponds to | |
`init_bert_params | |
<https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_sentence_encoder.py#L21>`__ | |
in the original ``fairseq`` implementation. | |
""" | |
def normal_(data): | |
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
if isinstance(module, nn.Linear): | |
normal_(module.weight.data) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if isinstance(module, nn.Embedding): | |
normal_(module.weight.data) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
class LayerNorm(nn.LayerNorm): | |
"""Layer norm with transpose""" | |
def forward(self, input: Tensor) -> Tensor: | |
x = input.transpose(-2, -1) | |
x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
x = x.transpose(-2, -1) | |
return x | |
class ConvLayerBlock(Module): | |
"""Convolution unit of FeatureExtractor""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int, | |
bias: bool, | |
layer_norm: Optional[Module], | |
prune_conv_channels: bool = False, | |
): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.layer_norm = layer_norm | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
bias=bias, | |
) | |
if prune_conv_channels: | |
self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) | |
else: | |
self.hard_concrete = None | |
def forward( | |
self, | |
x: Tensor, | |
length: Optional[Tensor], | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
""" | |
Args: | |
x (Tensor): Shape: ``[batch, in_channels, in_frame]``. | |
length (Tensor or None, optional): Shape ``[batch, ]``. | |
Returns: | |
Tensor: Shape ``[batch, out_channels, out_frames]``. | |
Optional[Tensor]: Shape ``[batch, ]``. | |
""" | |
x = self.conv(x) | |
if self.layer_norm is not None: | |
x = self.layer_norm(x) | |
x = nn.functional.gelu(x) | |
if self.hard_concrete is not None: | |
channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) | |
x = x * channel_mask.unsqueeze(-1) | |
if length is not None: | |
length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 | |
# When input length is 0, the resulting length can be negative. So fix it here. | |
length = torch.max(torch.zeros_like(length), length) | |
return x, length | |
def get_num_params_and_out_channels(self, in_channels): | |
if self.hard_concrete is not None: | |
out_channels = self.hard_concrete.l0_norm() | |
else: | |
out_channels = self.conv.out_channels | |
num_params = in_channels * out_channels * self.kernel_size | |
if self.conv.bias is not None: | |
num_params += out_channels | |
if self.layer_norm is not None: | |
num_params += out_channels * 2 | |
return num_params, out_channels | |
class FeatureExtractor(Module): | |
"""Extract features from audio | |
Args: | |
conv_layers (nn.ModuleList): | |
convolution layers | |
""" | |
def __init__( | |
self, | |
conv_layers: nn.ModuleList, | |
): | |
super().__init__() | |
self.conv_layers = conv_layers | |
# NOTE: a dummy weight used to save the soft mask of the last conv layer | |
self.dummy_weight = nn.Parameter( | |
torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), | |
requires_grad=False | |
) | |
def forward( | |
self, | |
x: Tensor, | |
length: Optional[Tensor], | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
""" | |
Args: | |
x (Tensor): | |
Input Tensor representing a batch of audio, | |
shape: ``[batch, time]``. | |
length (Tensor or None, optional): | |
Valid length of each input sample. shape: ``[batch, ]``. | |
Returns: | |
Tensor: | |
The resulting feature, shape: ``[batch, frame, feature]`` | |
Optional[Tensor]: | |
Valid length of each output sample. shape: ``[batch, ]``. | |
""" | |
if x.ndim != 2: | |
raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") | |
x = x.unsqueeze(1) # (batch, channel==1, frame) | |
for layer in self.conv_layers: | |
x, length = layer(x, length) # (batch, feature, frame) | |
x = x.transpose(1, 2) # (batch, frame, feature) | |
x = x * self.dummy_weight | |
return x, length | |
def get_num_params_and_final_out_channels(self): | |
in_channels = 1 | |
num_params = 0 | |
for layer in self.conv_layers: | |
layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) | |
num_params += layer_params | |
num_params += in_channels # dummy weight | |
return num_params, in_channels | |
def prune(self): | |
""""Prune conv layers and dummy weight based on hardconcrete parameters. | |
This is an in-place operation. | |
""" | |
new_config = [] # [(output_channel, kernel_size, stride), ...] | |
for idx, layer in enumerate(self.conv_layers): | |
if layer.hard_concrete is not None: | |
assert not layer.hard_concrete.training | |
mask = layer.hard_concrete() # (out_features,) | |
index = mask.nonzero().squeeze(-1) # 2D -> 1D | |
assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" | |
new_config.append( | |
(len(index), layer.kernel_size, layer.stride) | |
) | |
# prune the current layer | |
prune_conv1d_layer(layer.conv, index, "output") | |
if layer.layer_norm is not None: | |
prune_layer_norm(layer.layer_norm, index) | |
# prune the next layer | |
if idx == len(self.conv_layers) - 1: | |
self.dummy_weight.data *= mask | |
self.dummy_weight = nn.Parameter( | |
self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False | |
) | |
else: | |
self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) | |
prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") | |
layer.hard_concrete = None | |
else: | |
new_config.append( | |
(layer.conv.out_channels, layer.kernel_size, layer.stride) | |
) | |
index = torch.arange(layer.conv.out_channels, dtype=torch.long) | |
return new_config, index | |
class FeatureProjection(Module): | |
"""Layer that connects FeatureExtractor and Encoder | |
Projects features to encoder dimension. | |
Args: | |
in_features (int): Input feature dim. | |
out_features (int): Output feature dim. | |
dropout (float): Dropout probability. | |
""" | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
dropout: float, | |
): | |
super().__init__() | |
self.layer_norm = nn.LayerNorm(in_features) | |
self.projection = nn.Linear( | |
in_features, | |
out_features, | |
) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): | |
Feature Tensor. shape: ``[batch, frame, in_feature]`` | |
Returns: | |
Tensor: Projected features. ``[batch, frame, out_feature]``. | |
""" | |
x = self.layer_norm(x) | |
x = self.projection(x) | |
x = self.dropout(x) | |
return x | |
def get_num_params(self, in_features): | |
return in_features * 2 + (in_features + 1) * self.projection.out_features | |
class ConvolutionalPositionalEmbedding(Module): | |
"""Positional embedding which is placed at the beginning of Transformer. | |
Args: | |
embed_dim (int): Feature dimension of the input Tensor. | |
kernel_size (int): The number of frames to be use. | |
groups (int): The number of groups in feature dimensions. | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
kernel_size: int, | |
groups: int, | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.kernel_size = kernel_size | |
self.conv = nn.Conv1d( | |
in_channels=embed_dim, | |
out_channels=embed_dim, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
groups=groups, | |
) | |
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) | |
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 | |
def __prepare_scriptable__(self): | |
for hook in self.conv._forward_pre_hooks.values(): | |
# The hook we want to remove is an instance of WeightNorm class, so | |
# normally we would do `if isinstance(...)` but this class is not accessible | |
# because of shadowing, so we check the module name directly. | |
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 | |
if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": | |
torch.nn.utils.remove_weight_norm(self.conv) | |
return self | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): shape ``[batch, frame, feature]``. | |
Returns: | |
Tensor: The resulting feature. Shape ``[batch, frame, feature]``. | |
""" | |
x = x.transpose(-2, -1) | |
x = self.conv(x) | |
if self.num_remove > 0: | |
x = x[..., : -self.num_remove] | |
x = torch.nn.functional.gelu(x) | |
x = x.transpose(-2, -1) | |
return x | |
class SelfAttention(Module): | |
"""Multihead Self Attention module | |
Args: | |
embed_dim (int): Total dimension of the model. | |
num_heads (int): The number of heads. | |
dropout (float, optional): | |
Dropout probability on attn_output_weights. Default: ``0.0`` | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
head_dim: int, | |
dropout: float = 0.0, | |
prune_heads: bool = False, # whether to prune attention heads | |
prune_layer: bool = False, # whether to prune entire attention layers | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_dim = head_dim | |
self.dropout = torch.nn.Dropout(dropout) | |
self.scaling = self.head_dim**-0.5 | |
self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) | |
self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) | |
self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) | |
self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) | |
if prune_heads: | |
self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) | |
else: | |
self.hard_concrete_for_heads = None | |
if prune_layer: | |
self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) | |
else: | |
self.hard_concrete_for_layer = None | |
def forward( | |
self, | |
x: Tensor, | |
attention_mask: Optional[Tensor] = None, | |
position_bias: Optional[Tensor] = None, | |
key_padding_mask: Optional[Tensor] = None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
""" | |
Args: | |
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. | |
attention_mask (Tensor or ``None``, optional): | |
shape: ``[batch_size, 1, sequence_length, sequence_length]`` | |
position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. | |
key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with | |
:py:class:`WavLMSelfAttention`. | |
Returns: | |
(Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility | |
with :py:class:`WavLMSelAttention`). | |
Attention output shape: ``[batch, sequence_length, embed_dim]``. | |
""" | |
if x.ndim != 3 or x.shape[2] != self.embed_dim: | |
raise ValueError( | |
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." | |
) | |
batch_size, length, embed_dim = x.size() | |
shape = (batch_size, length, self.num_heads, self.head_dim) | |
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd | |
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L | |
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd | |
# scale down q to avoid value overflow. | |
weights = (self.scaling * q) @ k # B, nH, L, L | |
if attention_mask is not None: | |
weights += attention_mask | |
# subtracting a constant value from the tensor won't change the output of softmax. | |
# apply the subtraction to avoid value overflow in torch.nn.functional.softmax. | |
# for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 | |
weights = weights - weights.max(dim=-1, keepdim=True)[0] | |
weights = torch.nn.functional.softmax(weights, dim=-1) | |
weights = self.dropout(weights) | |
output = weights @ v # B, nH, L, Hd | |
if self.hard_concrete_for_heads is not None: | |
head_mask = self.hard_concrete_for_heads() # (nH,) | |
output = output * head_mask.unsqueeze(-1).unsqueeze(-1) | |
output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) | |
output = self.out_proj(output) | |
if self.hard_concrete_for_layer is not None: | |
layer_mask = self.hard_concrete_for_layer() # (1,) | |
output = output * layer_mask | |
return output, None # Necessary for compatibility with WavLMSelAttention | |
def get_num_params(self): | |
if self.hard_concrete_for_heads is not None: | |
num_heads = self.hard_concrete_for_heads.l0_norm() | |
else: | |
num_heads = self.num_heads | |
num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ | |
+ (num_heads * self.head_dim + 1) * self.embed_dim | |
if self.hard_concrete_for_layer is not None: | |
num_params *= self.hard_concrete_for_layer.l0_norm() | |
return num_params | |
def prune(self): | |
new_config = { | |
"use_attention": True, | |
"num_heads": self.num_heads, | |
} | |
if self.hard_concrete_for_layer is not None: | |
assert not self.hard_concrete_for_layer.training | |
layer_mask = self.hard_concrete_for_layer() # (1,) | |
self.out_proj.weight.data *= layer_mask | |
self.out_proj.bias.data *= layer_mask | |
if layer_mask == 0: | |
new_config["use_attention"] = False | |
self.hard_concrete_for_layer = None | |
if self.hard_concrete_for_heads is not None: | |
assert not self.hard_concrete_for_heads.training | |
head_mask = self.hard_concrete_for_heads() # (num_heads,) | |
new_config["num_heads"] = len(head_mask.nonzero()) | |
if new_config["num_heads"] == 0: | |
new_config["use_attention"] = False | |
else: | |
full_mask = head_mask.repeat_interleave(self.head_dim) | |
full_index = full_mask.nonzero().squeeze(-1) # 1D | |
prune_linear_layer(self.k_proj, full_index, "output") | |
prune_linear_layer(self.v_proj, full_index, "output") | |
prune_linear_layer(self.q_proj, full_index, "output") | |
self.out_proj.weight.data *= full_mask | |
prune_linear_layer(self.out_proj, full_index, "input") | |
self.hard_concrete_for_heads = None | |
return new_config | |
class WavLMSelfAttention(SelfAttention): | |
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. | |
Args: | |
embed_dim (int): Total dimension of the model. | |
num_heads (int): The number of heads. | |
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) | |
bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) | |
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. | |
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) | |
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) | |
max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) | |
gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
total_num_heads: int, | |
remaining_heads: Optional[List[int]] = None, | |
dropout: float = 0.0, | |
bias: bool = True, | |
has_relative_attention_bias: bool = False, | |
num_buckets: int = 32, | |
max_distance: int = 128, | |
gru_rel_pos: bool = True, | |
prune_heads: bool = False, | |
prune_layer: bool = False, | |
): | |
self.total_num_heads = total_num_heads | |
if remaining_heads is None: | |
self.remaining_heads = list(range(total_num_heads)) | |
else: | |
self.remaining_heads = remaining_heads # list of indices | |
self.head_dim = embed_dim // total_num_heads | |
super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) | |
self.has_relative_attention_bias = has_relative_attention_bias | |
self.num_buckets = num_buckets | |
self.max_distance = max_distance | |
if has_relative_attention_bias: | |
self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) | |
else: | |
self.rel_attn_embed = None | |
# override linear layers to customize bias | |
self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) | |
self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) | |
self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) | |
self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) | |
self.gru_rel_pos = gru_rel_pos | |
if self.gru_rel_pos: | |
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) | |
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) | |
self.has_position_bias = True | |
def compute_bias(self, query_length: int, key_length: int) -> Tensor: | |
"""Compute relative position embeddings for WavLM model. | |
Args: | |
query_length (int): Query position can take values between 0 and ``query_length - 1``. | |
key_length (int): Key position can take values between 0 and ``key_length - 1``. | |
Returns: | |
Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings | |
""" | |
context_position = torch.arange(query_length, dtype=torch.long)[:, None] | |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :] | |
relative_position = memory_position - context_position # Shape (query_length, key_length) | |
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) | |
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) | |
values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) | |
values = values.permute([2, 0, 1]) | |
return values | |
def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): | |
"""Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM | |
paper :cite:`chen2022wavlm`. | |
Args: | |
relative_positions (Tensor): Relative offsets between query and key positions, | |
of shape ``(query_length, key_length)``. | |
bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting | |
matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set | |
to zero. (Default ``True``) | |
Returns: | |
Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. | |
""" | |
num_buckets = self.num_buckets | |
max_distance = self.max_distance | |
# Shape (query_length, key_length) | |
relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) | |
if bidirectional: | |
num_buckets = num_buckets // 2 | |
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets | |
relative_positions = torch.abs(relative_positions) | |
else: | |
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) | |
max_exact = num_buckets // 2 | |
is_small = relative_positions < max_exact | |
relative_postion_if_large = max_exact + ( | |
torch.log(relative_positions.float() / max_exact) | |
/ math.log(max_distance / max_exact) | |
* (num_buckets - max_exact) | |
).to(torch.long) | |
relative_postion_if_large = torch.min( | |
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) | |
) | |
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) | |
return relative_buckets | |
def forward( | |
self, | |
query: Tensor, | |
attention_mask: Optional[Tensor] = None, | |
position_bias: Optional[Tensor] = None, | |
key_padding_mask: Optional[Tensor] = None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
""" | |
Args: | |
query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. | |
key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape | |
`(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) | |
attn_mask: Needs to be ``None``. The argument exists for compatibility with | |
``EncoderLayer``. (Default: ``None``) | |
position_bias (Tensor or None, optional): Position bias of shape | |
``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be | |
generated in the first layer and then passed from each encoder layer to the next one. | |
(Default: ``None``) | |
Returns: | |
attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. | |
position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. | |
""" | |
bsz, seq_len, embed_dim = query.size() | |
assert embed_dim == self.embed_dim | |
assert key_padding_mask is None | |
# only for the first layer | |
if self.rel_attn_embed is not None and position_bias is None: | |
position_bias = self.compute_bias(seq_len, seq_len) | |
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) | |
attn_mask_rel_pos: Optional[Tensor] = None | |
if position_bias is not None: | |
attn_mask_rel_pos = position_bias | |
if self.gru_rel_pos: # Apply gating on relative position bias | |
query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) | |
query_layer = query_layer.permute(0, 2, 1, 3) | |
gate_a, gate_b = torch.sigmoid( | |
self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) | |
).chunk(2, dim=-1) | |
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 | |
attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias | |
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) | |
attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] | |
attn_mask = attn_mask_rel_pos | |
if attention_mask is not None: | |
attn_mask = attn_mask + attention_mask | |
if key_padding_mask is not None: | |
attn_mask = attn_mask.masked_fill( | |
key_padding_mask.reshape(bsz, 1, 1, seq_len), | |
float("-inf") | |
) | |
attn_output, _ = super().forward(query, attention_mask=attn_mask) | |
return attn_output, position_bias | |
def prune(self): | |
new_config = { | |
"use_attention": True, | |
"remaining_heads": self.remaining_heads, | |
} | |
if self.hard_concrete_for_layer is not None: | |
assert not self.hard_concrete_for_layer.training | |
layer_mask = self.hard_concrete_for_layer() # (1,) | |
self.out_proj.weight.data *= layer_mask | |
self.out_proj.bias.data *= layer_mask | |
if layer_mask == 0: | |
new_config["use_attention"] = False | |
self.hard_concrete_for_layer = None | |
if self.hard_concrete_for_heads is not None: | |
assert not self.hard_concrete_for_heads.training | |
head_mask = self.hard_concrete_for_heads() # (num_heads,) | |
new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() | |
if len(new_config["remaining_heads"]) == 0: | |
new_config["use_attention"] = False | |
else: | |
full_mask = head_mask.repeat_interleave(self.head_dim) | |
full_index = full_mask.nonzero().squeeze(-1) # 1D | |
prune_linear_layer(self.k_proj, full_index, "output") | |
prune_linear_layer(self.v_proj, full_index, "output") | |
prune_linear_layer(self.q_proj, full_index, "output") | |
self.out_proj.weight.data *= full_mask | |
prune_linear_layer(self.out_proj, full_index, "input") | |
self.hard_concrete_for_heads = None | |
return new_config | |
class FeedForward(Module): | |
"""Layer that follows attention layer in encoder layer.""" | |
def __init__( | |
self, | |
io_features: int, | |
intermediate_features: int, | |
intermediate_dropout: float, | |
output_dropout: float, | |
prune_intermediate: bool = False, | |
prune_layer: bool = False, | |
): | |
super().__init__() | |
self.intermediate_dense = nn.Linear(io_features, intermediate_features) | |
self.intermediate_dropout = nn.Dropout(intermediate_dropout) | |
self.output_dense = nn.Linear(intermediate_features, io_features) | |
self.output_dropout = nn.Dropout(output_dropout) | |
if prune_intermediate: | |
self.hard_concrete_for_intermediate = HardConcrete( | |
n_in=intermediate_features, init_mean=0.5 | |
) | |
else: | |
self.hard_concrete_for_intermediate = None | |
if prune_layer: | |
self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) | |
else: | |
self.hard_concrete_for_layer = None | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): shape: `(batch, sequence_length, io_features)` | |
Returns: | |
x (Tensor): shape: `(batch, sequence_length, io_features)` | |
""" | |
x = self.intermediate_dense(x) | |
x = torch.nn.functional.gelu(x) | |
x = self.intermediate_dropout(x) | |
if self.hard_concrete_for_intermediate is not None: | |
intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) | |
x = x * intermediate_mask | |
x = self.output_dense(x) | |
x = self.output_dropout(x) | |
if self.hard_concrete_for_layer is not None: | |
layer_mask = self.hard_concrete_for_layer() # (1,) | |
x = x * layer_mask | |
return x | |
def get_num_params(self): | |
io_features = self.intermediate_dense.in_features | |
if self.hard_concrete_for_intermediate is not None: | |
intermediate_features = self.hard_concrete_for_intermediate.l0_norm() | |
else: | |
intermediate_features = self.intermediate_dense.out_features | |
num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features | |
if self.hard_concrete_for_layer is not None: | |
num_params *= self.hard_concrete_for_layer.l0_norm() | |
return num_params | |
def prune(self): | |
new_config = { | |
"use_feed_forward": True, | |
"ff_interm_features": self.intermediate_dense.out_features | |
} | |
if self.hard_concrete_for_layer is not None: | |
assert not self.hard_concrete_for_layer.training | |
layer_mask = self.hard_concrete_for_layer() | |
self.output_dense.weight.data *= layer_mask | |
self.output_dense.bias.data *= layer_mask | |
if layer_mask == 0: | |
new_config["use_feed_forward"] = False | |
self.hard_concrete_for_layer = None | |
if self.hard_concrete_for_intermediate is not None: | |
assert not self.hard_concrete_for_intermediate.training | |
interm_mask = self.hard_concrete_for_intermediate() | |
interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 | |
new_config["ff_interm_features"] = len(interm_index) | |
if new_config["ff_interm_features"] == 0: | |
new_config["use_feed_forward"] = False | |
else: | |
prune_linear_layer(self.intermediate_dense, interm_index, "output") | |
self.output_dense.weight.data *= interm_mask | |
prune_linear_layer(self.output_dense, interm_index, "input") | |
self.hard_concrete_for_intermediate = None | |
return new_config | |
class EncoderLayer(Module): | |
"""A layer unit in encoder. Combines multihead self attention and feed forward.""" | |
def __init__( | |
self, | |
attention: Optional[Module], # can be None if the entire layer is pruned | |
dropout: float, | |
layer_norm_first: bool, | |
feed_forward: Optional[Module], # can be None if the entire layer is pruned | |
embed_dim: int, | |
): | |
super().__init__() | |
self.attention = attention | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(embed_dim) | |
self.layer_norm_first = layer_norm_first | |
self.feed_forward = feed_forward | |
self.final_layer_norm = nn.LayerNorm(embed_dim) | |
self.embed_dim = embed_dim | |
def forward( | |
self, | |
x: Tensor, | |
attention_mask: Optional[Tensor] = None, | |
position_bias: Optional[Tensor] = None, | |
key_padding_mask: Optional[Tensor] = None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
""" | |
Args: | |
x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. | |
attention_mask (Tensor or ``None``, optional): attention mask | |
of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) | |
position_bias (Tensor or ``None``, optional): position bias of shape | |
``(batch_size * num_heads, src_len, src_len)``. | |
Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) | |
key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. | |
Only used for WavLM model, ignored otherwise. (Default: ``None``) | |
Returns: | |
(x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, | |
``None`` otherwise. | |
""" | |
if self.attention is not None: | |
residual = x | |
if self.layer_norm_first: | |
x = self.layer_norm(x) | |
x, position_bias = self.attention( | |
x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask | |
) | |
x = self.dropout(x) | |
x = residual + x | |
if self.layer_norm_first: | |
if self.feed_forward is not None: | |
x = x + self.feed_forward(self.final_layer_norm(x)) | |
else: | |
# NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. | |
x = self.layer_norm(x) | |
if self.feed_forward is not None: | |
x = x + self.feed_forward(x) | |
x = self.final_layer_norm(x) | |
return x, position_bias | |
def get_num_params(self): | |
num_params = self.embed_dim * 2 * 2 # two layer norms | |
if self.attention is not None: | |
num_params += self.attention.get_num_params() | |
if self.feed_forward is not None: | |
num_params += self.feed_forward.get_num_params() | |
return num_params | |
class Transformer(Module): | |
def __init__( | |
self, | |
pos_conv_embed: Module, | |
dropout: float, | |
layers: Module, | |
layer_norm_first: bool, | |
layer_drop: float, | |
): | |
super().__init__() | |
self.pos_conv_embed = pos_conv_embed | |
self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) | |
self.layer_norm_first = layer_norm_first | |
self.layer_drop = layer_drop | |
self.dropout = nn.Dropout(dropout) | |
self.layers = layers | |
def _preprocess(self, x: Tensor): | |
x = x + self.pos_conv_embed(x) | |
if self.layer_norm_first: | |
x = self.layer_norm(x) | |
x = self.dropout(x) | |
return x | |
def forward( | |
self, | |
x: Tensor, | |
attention_mask: Optional[Tensor] = None, | |
position_bias: Optional[Tensor] = None, | |
) -> Tensor: | |
x = self._preprocess(x) | |
for layer in self.layers: | |
if not (self.training and torch.rand(1).item() <= self.layer_drop): | |
x, position_bias = layer(x, attention_mask, position_bias=position_bias) | |
if not self.layer_norm_first: | |
x = self.layer_norm(x) | |
return x | |
def get_intermediate_outputs( | |
self, | |
x: Tensor, | |
attention_mask: Optional[Tensor] = None, | |
num_layers: Optional[int] = None, | |
position_bias: Optional[Tensor] = None, | |
) -> List[Tensor]: | |
if num_layers is not None: | |
if not 0 < num_layers <= len(self.layers): | |
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") | |
ret: List[Tensor] = [] | |
x = self._preprocess(x) | |
for layer in self.layers: | |
x, position_bias = layer(x, attention_mask, position_bias=position_bias) | |
ret.append(x) | |
if num_layers is not None and len(ret) >= num_layers: | |
return ret | |
return ret | |
def get_num_params(self): | |
# pos_conv_embed and layer_norm | |
num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 | |
for layer in self.layers: | |
num_params += layer.get_num_params() | |
return num_params | |
def prune(self): | |
new_config = defaultdict(list) | |
for layer in self.layers: | |
attention_config = layer.attention.prune() | |
new_config["use_attention"].append(attention_config["use_attention"]) | |
if "remaining_heads" in attention_config: | |
new_config["remaining_heads"].append(attention_config["remaining_heads"]) | |
else: | |
new_config["num_heads"].append(attention_config["num_heads"]) | |
if not attention_config["use_attention"]: | |
layer.attention = None | |
ff_config = layer.feed_forward.prune() | |
new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) | |
new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) | |
if not ff_config["use_feed_forward"]: | |
layer.feed_forward = None | |
return new_config | |
class Encoder(Module): | |
def __init__( | |
self, | |
feature_projection: Module, | |
transformer: Module, | |
): | |
super().__init__() | |
self.feature_projection = feature_projection | |
self.transformer = transformer | |
def _preprocess( | |
self, | |
features: Tensor, | |
lengths: Optional[Tensor] = None, | |
) -> Tuple[Tensor, Optional[Tensor]]: | |
x = self.feature_projection(features) | |
mask: Optional[Tensor] = None | |
if lengths is not None: | |
batch_size, max_len, _ = x.shape | |
# create mask for padded elements and zero-out them | |
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] | |
x[mask] = 0.0 | |
# extend the mask to attention shape and set weight | |
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) | |
mask = mask.expand(batch_size, 1, max_len, max_len) | |
return x, mask | |
def forward( | |
self, | |
features: Tensor, | |
lengths: Optional[Tensor] = None, | |
) -> Tensor: | |
x, mask = self._preprocess(features, lengths) | |
x = self.transformer(x, attention_mask=mask) | |
return x | |
def extract_features( | |
self, | |
features: Tensor, | |
lengths: Optional[Tensor] = None, | |
num_layers: Optional[int] = None, | |
) -> List[Tensor]: | |
x, masks = self._preprocess(features, lengths) | |
interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) | |
return [x] + interm | |
def get_num_params(self, in_features): | |
"""Calculate the current model size.""" | |
feature_projection_size = self.feature_projection.get_num_params(in_features) | |
transformer_size = self.transformer.get_num_params() | |
return feature_projection_size + transformer_size | |
def prune(self, conv_out_index): | |
"""In-place pruning of submodules.""" | |
prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) | |
prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") | |
transformer_config = self.transformer.prune() | |
return transformer_config | |
################################################################################ | |
def _get_feature_extractor( | |
norm_mode: str, | |
shapes: List[Tuple[int, int, int]], | |
bias: bool, | |
prune_conv_channels: bool = False, | |
) -> FeatureExtractor: | |
""" | |
Args: | |
norm_mode (str): | |
Either "group_norm" or "layer_norm". | |
If "group_norm", then a single normalization is applied | |
in the first convolution block. Otherwise, all the convolution | |
blocks will have layer normalization. | |
This option corresponds to "extractor_mode" from fairseq. | |
Expected values are "group_norm" for Base arch, and | |
"layer_norm" for Large arch. | |
shapes (list of tuple of int): | |
Configuration of convolution layers. List of convolution configuration, | |
i.e. ``[(output_channel, kernel_size, stride), ...]`` | |
This option corresponds to "conv_feature_layers" from fairseq. | |
Expected values are | |
``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` | |
for all the architectures. | |
bias (bool): | |
Whether to include bias term to each convolution operation. | |
This option corresponds to "conv_bias" from fairseq. | |
Expected values are False for Base arch, and True for Large arch. | |
See Also: | |
* Original implementation | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 | |
* "extractor_mode" | |
- Def and base: | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 | |
- Large: | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 | |
* "conv_feature_layers" | |
- Def, base and large: | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 | |
* "conv_bias" | |
- Def and base: | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 | |
- Large: | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 | |
""" | |
if norm_mode not in ["group_norm", "layer_norm"]: | |
raise ValueError("Invalid norm mode") | |
blocks = [] | |
in_channels = 1 | |
for i, (out_channels, kernel_size, stride) in enumerate(shapes): | |
normalization = None | |
if norm_mode == "group_norm" and i == 0: | |
normalization = nn.GroupNorm( | |
num_groups=out_channels, | |
num_channels=out_channels, | |
affine=True, | |
) | |
elif norm_mode == "layer_norm": | |
normalization = LayerNorm( | |
normalized_shape=out_channels, | |
elementwise_affine=True, | |
) | |
blocks.append( | |
ConvLayerBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
bias=bias, | |
layer_norm=normalization, | |
prune_conv_channels=prune_conv_channels, | |
) | |
) | |
in_channels = out_channels | |
return FeatureExtractor(nn.ModuleList(blocks)) | |
def _get_encoder( | |
in_features: int, | |
embed_dim: int, | |
dropout_input: float, | |
pos_conv_kernel: int, | |
pos_conv_groups: int, | |
num_layers: int, | |
use_attention: List[bool], | |
use_feed_forward: List[bool], | |
num_heads: List[int], | |
head_dim: int, | |
attention_dropout: float, | |
ff_interm_features: List[int], | |
ff_interm_dropout: float, | |
dropout: float, | |
layer_norm_first: bool, | |
layer_drop: float, | |
prune_attention_heads: bool = False, | |
prune_attention_layer: bool = False, | |
prune_feed_forward_intermediate: bool = False, | |
prune_feed_forward_layer: bool = False, | |
) -> Encoder: | |
""" | |
Args: | |
in_features (int): The number of input features. | |
embed_dim (int): | |
The dimension of embedding. | |
This option corresponds to "encoder_embed_dim" from fairseq. | |
Expected values are 768 for Base arch, and 1024 for Large arch. | |
dropout_input (float): | |
The dropout probability applied after the input feature is projected | |
to ``embed_dim``. | |
This option corresponds to "dropout_input" from fairseq. | |
Expected values are 0.1 for both Base and Large arch. | |
pos_conv_kernel (int): | |
The kernel size of convolutional positional embeddings. | |
This option corresponds to "conv_pos" from fairseq. | |
Expected values are 128 for both Base and Large arch. | |
pos_conv_groups (int): | |
The number of groups of convolutional positional embeddings. | |
This option corresponds to "conv_pos_groups" from fairseq. | |
Expected values are 16 for both Base and Large arch. | |
num_layers (int): | |
The number of self attention layers in transformer block. | |
This option corresponds to "encoder_layers" from fairseq. | |
Expected values are 12 for Base and 24 for Large arch. | |
num_heads (int): | |
The number of heads in self attention layers. | |
This option corresponds to "encoder_attention_heads" from fairseq. | |
Expected values are 12 for Base and 16 for Large arch. | |
attention_dropout (float): | |
The dropout probability applied after softmax in self-attention layer. | |
This option corresponds to "attention_dropout" from fairseq. | |
Expected values are 0.1 for Base and 0.0 for Large arch. | |
ff_interm_features (int): | |
The dimension of hidden features in feed forward layer. | |
This option corresponds to "encoder_ffn_embed_dim" from fairseq. | |
Expected values are 3072 for Base and 4096 for Large arch. | |
ff_interm_dropout (float): | |
The dropout probability applied in feedforward layer. | |
This option correspinds to "activation_dropout" from fairseq. | |
Expected values are 0.1 for both Base and Large arch. | |
dropout (float): | |
The dropout probability applied at the end of feed forward layer. | |
This option corresponds to "dropout" from fairseq. | |
Expected values are 0.1 for Base and 0.0 for Large arch. | |
layer_norm_first (bool): | |
Control the order of layer norm in transformer layer and each encoder layer. | |
If True, in transformer layer, layer norm is applied before features are fed | |
to encoder layers. In encoder layer, two layer norms are applied before and after | |
self attention. | |
If False, in transformer layer, layer norm is applied after features are fed | |
to encoder layers. In encoder layer, two layer norms are applied after self | |
attention, before and after feed forward. | |
This option corresponds to "layer_norm_first" from fairseq. | |
Expected values are False for Base and True for Large arch. | |
layer_drop (float): | |
Probability to drop each encoder layer during training. | |
This option corresponds to "layerdrop" from fairseq. | |
Expected values are 0.1 for both Base and Large arch. | |
See Also: | |
* "encoder_embed_dim" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 | |
* "dropout_input" | |
- Def, base and large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 | |
* "conv_pos" | |
- Def, base and large | |
NOTE: The description is wrong. | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 | |
- Usage | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 | |
* "conv_pos_groups" | |
- Def, base and large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 | |
* "encoder_layers" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 | |
* "encoder_attention_heads" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 | |
* "attention_dropout" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 | |
* "encoder_ffn_embed_dim" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 | |
* "activation_dropout" | |
- Def | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 | |
- Base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 | |
* "dropout" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 | |
* "layer_norm_first" | |
- Def and base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 | |
* "layerdrop" | |
- Def | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 | |
- Base | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 | |
- Large | |
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 | |
""" | |
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) | |
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) | |
# Original impl | |
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 | |
encoder_layers = nn.ModuleList() | |
for idx in range(num_layers): | |
if use_attention[idx]: | |
attention = SelfAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads[idx], | |
head_dim=head_dim, | |
dropout=attention_dropout, | |
prune_heads=prune_attention_heads, | |
prune_layer=prune_attention_layer, | |
) | |
else: | |
attention = None | |
if use_feed_forward[idx]: | |
feed_forward = FeedForward( | |
io_features=embed_dim, | |
intermediate_features=ff_interm_features[idx], | |
intermediate_dropout=ff_interm_dropout, | |
output_dropout=dropout, | |
prune_intermediate=prune_feed_forward_intermediate, | |
prune_layer=prune_feed_forward_layer, | |
) | |
else: | |
feed_forward = None | |
encoder_layers.append( | |
EncoderLayer( | |
attention=attention, | |
dropout=dropout, | |
layer_norm_first=layer_norm_first, | |
feed_forward=feed_forward, | |
embed_dim=embed_dim, | |
) | |
) | |
transformer = Transformer( | |
pos_conv_embed=pos_conv, | |
dropout=dropout, | |
layers=encoder_layers, | |
layer_norm_first=not layer_norm_first, | |
layer_drop=layer_drop, | |
) | |
return Encoder(feature_projection, transformer) | |
def _get_wavlm_encoder( | |
in_features: int, | |
embed_dim: int, | |
dropout_input: float, | |
pos_conv_kernel: int, | |
pos_conv_groups: int, | |
num_layers: int, | |
use_attention: List[bool], | |
use_feed_forward: List[bool], | |
total_num_heads: List[int], | |
remaining_heads: List[List[int]], | |
num_buckets: int, | |
max_distance: int, | |
attention_dropout: float, | |
ff_interm_features: List[int], | |
ff_interm_dropout: float, | |
dropout: float, | |
layer_norm_first: bool, | |
layer_drop: float, | |
prune_attention_heads: bool = False, | |
prune_attention_layer: bool = False, | |
prune_feed_forward_intermediate: bool = False, | |
prune_feed_forward_layer: bool = False, | |
) -> Encoder: | |
""" | |
Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are | |
the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder | |
is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and | |
`max_distance`. | |
Args: | |
in_features (int): See :py:func:`_get_encoder`. | |
embed_dim (int): See :py:func:`_get_encoder`. | |
dropout_input (float): See :py:func:`_get_encoder`. | |
pos_conv_kernel (int): See :py:func:`_get_encoder`. | |
pos_conv_groups (int): See :py:func:`_get_encoder`. | |
num_layers (int): See :py:func:`_get_encoder`. | |
num_heads (int): See :py:func:`_get_encoder`. | |
num_buckets (int): Number of buckets for relative position embedding. | |
max_distance (int): Maximum distance for relative position embedding. | |
attention_dropout (float): See :py:func:`_get_encoder`. | |
ff_interm_features (int): See :py:func:`_get_encoder`. | |
ff_interm_dropout (float): See :py:func:`_get_encoder`. | |
dropout (float): See :py:func:`_get_encoder`. | |
layer_norm_first (bool): See :py:func:`_get_encoder`. | |
layer_drop (float): See :py:func:`_get_encoder`. | |
""" | |
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) | |
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) | |
# Original impl | |
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 | |
encoder_layers = nn.ModuleList() | |
for i in range(num_layers): | |
if use_attention[i]: | |
attention = WavLMSelfAttention( | |
embed_dim=embed_dim, | |
total_num_heads=total_num_heads[i], | |
remaining_heads=remaining_heads[i], | |
dropout=attention_dropout, | |
has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. | |
num_buckets=num_buckets, | |
max_distance=max_distance, | |
prune_heads=prune_attention_heads, | |
prune_layer=prune_attention_layer, | |
) | |
else: | |
attention = None | |
if use_feed_forward[i]: | |
feed_forward = FeedForward( | |
io_features=embed_dim, | |
intermediate_features=ff_interm_features[i], | |
intermediate_dropout=ff_interm_dropout, | |
output_dropout=dropout, | |
prune_intermediate=prune_feed_forward_intermediate, | |
prune_layer=prune_feed_forward_layer, | |
) | |
else: | |
feed_forward = None | |
encoder_layers.append( | |
EncoderLayer( | |
attention=attention, | |
dropout=dropout, | |
layer_norm_first=layer_norm_first, | |
feed_forward=feed_forward, | |
embed_dim=embed_dim, | |
) | |
) | |
transformer = Transformer( | |
pos_conv_embed=pos_conv, | |
dropout=dropout, | |
layers=encoder_layers, | |
layer_norm_first=not layer_norm_first, | |
layer_drop=layer_drop, | |
) | |
return Encoder(feature_projection, transformer) | |
def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: | |
"""Generate the padding mask given the padded input and the lengths Tensors. | |
Args: | |
input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. | |
lengths (Tensor): The lengths Tensor of dimension `[batch,]`. | |
Returns: | |
(Tensor): The padding mask. | |
""" | |
batch_size, max_len, _ = input.shape | |
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] | |
return mask | |
class GradMultiply(torch.autograd.Function): | |
def forward(ctx, x, scale): | |
ctx.scale = scale | |
res = x.new(x) | |
return res | |
def backward(ctx, grad): | |
return grad * ctx.scale, None | |