Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
from typing import Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
import contextlib | |
from fairseq import utils | |
from fairseq.modules import LayerNorm | |
from .multihead_attention import MultiheadAttention | |
from fairseq.modules.fairseq_dropout import FairseqDropout | |
from fairseq.modules.quant_noise import quant_noise | |
from torch import Tensor | |
class TransformerSentenceEncoderLayer(nn.Module): | |
""" | |
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained | |
models. | |
""" | |
def __init__( | |
self, | |
embedding_dim: float = 768, | |
ffn_embedding_dim: float = 3072, | |
num_attention_heads: float = 8, | |
dropout: float = 0.1, | |
attention_dropout: float = 0.1, | |
activation_dropout: float = 0.1, | |
activation_fn: str = "relu", | |
layer_norm_first: bool = False, | |
has_relative_attention_bias: bool = False, | |
) -> None: | |
super().__init__() | |
# Initialize parameters | |
self.embedding_dim = embedding_dim | |
self.dropout = dropout | |
self.activation_dropout = activation_dropout | |
# Initialize blocks | |
self.activation_fn = utils.get_activation_fn(activation_fn) | |
self.self_attn = MultiheadAttention( | |
self.embedding_dim, | |
num_attention_heads, | |
dropout=attention_dropout, | |
self_attention=True, | |
has_relative_attention_bias=has_relative_attention_bias, | |
) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(self.activation_dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.layer_norm_first = layer_norm_first | |
# layer norm associated with the self attention layer | |
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) | |
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) | |
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) | |
# layer norm associated with the position wise feed-forward NN | |
self.final_layer_norm = LayerNorm(self.embedding_dim) | |
if has_relative_attention_bias: | |
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads) | |
def forward( | |
self, | |
x: torch.Tensor, | |
self_attn_mask: torch.Tensor = None, | |
self_attn_padding_mask: torch.Tensor = None, | |
need_weights: bool = False, | |
att_args=None, | |
pos_bias=None, | |
): | |
""" | |
LayerNorm is applied either before or after the self-attention/ffn | |
modules similar to the original Transformer imlementation. | |
""" | |
residual = x | |
if self.layer_norm_first: | |
x = self.self_attn_layer_norm(x) | |
if pos_bias is not None: | |
pos_bias = self.norm_k(pos_bias) | |
x, attn = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=self_attn_padding_mask, | |
attn_mask=self_attn_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout1(x) | |
x = residual + x | |
residual = x | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
x = self.dropout3(x) | |
x = residual + x | |
else: | |
x, attn = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=self_attn_padding_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout1(x) | |
x = residual + x | |
x = self.self_attn_layer_norm(x) | |
residual = x | |
x = self.activation_fn(self.fc1(x)) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
x = self.dropout3(x) | |
x = residual + x | |
x = self.final_layer_norm(x) | |
return x, attn | |
class TransformerDecoderLayer(nn.Module): | |
"""Decoder layer block. | |
In the original paper each operation (multi-head attention, encoder | |
attention or FFN) is postprocessed with: `dropout -> add residual -> | |
layernorm`. In the tensor2tensor code they suggest that learning is more | |
robust when preprocessing each layer with layernorm and postprocessing with: | |
`dropout -> add residual`. We default to the approach in the paper, but the | |
tensor2tensor approach can be enabled by setting | |
*args.decoder_normalize_before* to ``True``. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
(default: False). | |
""" | |
def __init__( | |
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, has_relative_attention_bias=False | |
): | |
super().__init__() | |
self.embed_dim = args.decoder_embed_dim | |
self.num_updates = 0 | |
self.dropout_module = FairseqDropout( | |
args.dropout, module_name=self.__class__.__name__ | |
) | |
self.quant_noise = getattr(args, "quant_noise_pq", 0) | |
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) | |
self.cross_self_attention = getattr(args, "cross_self_attention", False) | |
self.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0) | |
self.self_attn = self.build_self_attention( | |
self.embed_dim, | |
args, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=add_zero_attn, | |
) | |
self.activation_fn = utils.get_activation_fn( | |
activation=str(args.activation_fn) | |
if getattr(args, "activation_fn", None) is not None | |
else "relu" | |
) | |
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 | |
if activation_dropout_p == 0: | |
# for backwards compatibility with models that use args.relu_dropout | |
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 | |
self.activation_dropout_module = FairseqDropout( | |
float(activation_dropout_p), module_name=self.__class__.__name__ | |
) | |
self.normalize_before = args.decoder_normalize_before | |
export = getattr(args, "export", False) | |
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) | |
if no_encoder_attn: | |
self.encoder_attn = None | |
self.encoder_attn_layer_norm = None | |
else: | |
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) | |
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) | |
self.fc1 = self.build_fc1( | |
self.embed_dim, | |
args.decoder_ffn_embed_dim, | |
self.quant_noise, | |
self.quant_noise_block_size, | |
) | |
self.fc2 = self.build_fc2( | |
args.decoder_ffn_embed_dim, | |
self.embed_dim, | |
self.quant_noise, | |
self.quant_noise_block_size, | |
) | |
self.final_layer_norm = LayerNorm(self.embed_dim, export=export) | |
self.need_attn = True | |
self.onnx_trace = False | |
self.has_relative_attention_bias = has_relative_attention_bias | |
if self.has_relative_attention_bias: | |
self.norm_k = LayerNorm(self.embed_dim//args.decoder_attention_heads) | |
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): | |
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) | |
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): | |
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) | |
def build_self_attention( | |
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False | |
): | |
return MultiheadAttention( | |
embed_dim, | |
args.decoder_attention_heads, | |
dropout=args.attention_dropout, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=add_zero_attn, | |
self_attention=not getattr(args, "cross_self_attention", False), | |
q_noise=self.quant_noise, | |
qn_block_size=self.quant_noise_block_size, | |
#has_relative_attention_bias=args.has_relative_attention_bias, | |
) | |
def build_encoder_attention(self, embed_dim, args): | |
return MultiheadAttention( | |
embed_dim, | |
args.decoder_attention_heads, | |
kdim=getattr(args, "encoder_embed_dim", None), | |
vdim=getattr(args, "encoder_embed_dim", None), | |
dropout=args.attention_dropout, | |
encoder_decoder_attention=True, | |
q_noise=self.quant_noise, | |
qn_block_size=self.quant_noise_block_size, | |
) | |
def prepare_for_onnx_export_(self): | |
self.onnx_trace = True | |
def residual_connection(self, x, residual): | |
return residual + x | |
def forward( | |
self, | |
x, | |
encoder_out: Optional[torch.Tensor] = None, | |
encoder_padding_mask: Optional[torch.Tensor] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
prev_self_attn_state: Optional[List[torch.Tensor]] = None, | |
prev_attn_state: Optional[List[torch.Tensor]] = None, | |
self_attn_mask: Optional[torch.Tensor] = None, | |
self_attn_padding_mask: Optional[torch.Tensor] = None, | |
need_attn: bool = False, | |
need_head_weights: bool = False, | |
pos_bias=None, | |
): | |
""" | |
Args: | |
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
encoder_padding_mask (ByteTensor, optional): binary | |
ByteTensor of shape `(batch, src_len)` where padding | |
elements are indicated by ``1``. | |
need_attn (bool, optional): return attention weights | |
need_head_weights (bool, optional): return attention weights | |
for each head (default: return average over heads). | |
Returns: | |
encoded output of shape `(seq_len, batch, embed_dim)` | |
""" | |
ft = self.freeze_decoder_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
if need_head_weights: | |
need_attn = True | |
residual = x | |
if self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
if pos_bias is not None: | |
pos_bias = self.norm_k(pos_bias) | |
if prev_self_attn_state is not None: | |
prev_key, prev_value = prev_self_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_self_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] | |
assert incremental_state is not None | |
self.self_attn._set_input_buffer(incremental_state, saved_state) | |
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) | |
if self.cross_self_attention and not ( | |
incremental_state is not None | |
and _self_attn_input_buffer is not None | |
and "prev_key" in _self_attn_input_buffer | |
): | |
if self_attn_mask is not None: | |
assert encoder_out is not None | |
self_attn_mask = torch.cat( | |
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 | |
) | |
if self_attn_padding_mask is not None: | |
if encoder_padding_mask is None: | |
assert encoder_out is not None | |
encoder_padding_mask = self_attn_padding_mask.new_zeros( | |
encoder_out.size(1), encoder_out.size(0) | |
) | |
self_attn_padding_mask = torch.cat( | |
(encoder_padding_mask, self_attn_padding_mask), dim=1 | |
) | |
assert encoder_out is not None | |
y = torch.cat((encoder_out, x), dim=0) | |
else: | |
y = x | |
x, attn = self.self_attn( | |
query=x, | |
key=y, | |
value=y, | |
key_padding_mask=self_attn_padding_mask, | |
incremental_state=incremental_state, | |
need_weights=False, | |
attn_mask=self_attn_mask, | |
position_bias=pos_bias, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
if self.encoder_attn is not None and encoder_out is not None: | |
residual = x | |
if self.normalize_before: | |
x = self.encoder_attn_layer_norm(x) | |
if prev_attn_state is not None: | |
prev_key, prev_value = prev_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_attn_state[2] | |
assert incremental_state is not None | |
self.encoder_attn._set_input_buffer(incremental_state, saved_state) | |
x, attn = self.encoder_attn( | |
query=x, | |
key=encoder_out, | |
value=encoder_out, | |
key_padding_mask=encoder_padding_mask, | |
incremental_state=incremental_state, | |
static_kv=True, | |
need_weights=need_attn or (not self.training and self.need_attn), | |
need_head_weights=need_head_weights, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.encoder_attn_layer_norm(x) | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
residual = x | |
if self.normalize_before: | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.activation_dropout_module(x) | |
x = self.fc2(x) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.final_layer_norm(x) | |
if self.onnx_trace and incremental_state is not None: | |
saved_state = self.self_attn._get_input_buffer(incremental_state) | |
assert saved_state is not None | |
if self_attn_padding_mask is not None: | |
self_attn_state = [ | |
saved_state["prev_key"], | |
saved_state["prev_value"], | |
saved_state["prev_key_padding_mask"], | |
] | |
else: | |
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] | |
return x, attn, self_attn_state | |
return x, attn, None | |
def make_generation_fast_(self, need_attn: bool = False, **kwargs): | |
self.need_attn = need_attn | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
self.num_updates = num_updates | |