Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from fairseq.modules import ( | |
FairseqDropout, | |
LayerDropModuleList, | |
LayerNorm, | |
MultiheadAttention, | |
PositionalEmbedding, | |
TransformerSentenceEncoderLayer, | |
) | |
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ | |
def init_bert_params(module): | |
""" | |
Initialize the weights specific to the BERT Model. | |
This overrides the default initializations depending on the specified arguments. | |
1. If normal_init_linear_weights is set then weights of linear | |
layer will be initialized using the normal distribution and | |
bais will be set to the specified value. | |
2. If normal_init_embed_weights is set then weights of embedding | |
layer will be initialized using the normal distribution. | |
3. If normal_init_proj_weights is set then weights of | |
in_project_weight for MultiHeadAttention initialized using | |
the normal distribution (to be validated). | |
""" | |
def normal_(data): | |
# with FSDP, module params will be on CUDA, so we cast them back to CPU | |
# so that the RNG is consistent with and without FSDP | |
data.copy_( | |
data.cpu().normal_(mean=0.0, std=0.02).to(data.device) | |
) | |
if isinstance(module, nn.Linear): | |
normal_(module.weight.data) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if isinstance(module, nn.Embedding): | |
normal_(module.weight.data) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
if isinstance(module, MultiheadAttention): | |
normal_(module.q_proj.weight.data) | |
normal_(module.k_proj.weight.data) | |
normal_(module.v_proj.weight.data) | |
class TransformerSentenceEncoder(nn.Module): | |
""" | |
Implementation for a Bi-directional Transformer based Sentence Encoder used | |
in BERT/XLM style pre-trained models. | |
This first computes the token embedding using the token embedding matrix, | |
position embeddings (if specified) and segment embeddings | |
(if specified). After applying the specified number of | |
TransformerEncoderLayers, it outputs all the internal states of the | |
encoder as well as the final representation associated with the first | |
token (usually CLS token). | |
Input: | |
- tokens: B x T matrix representing sentences | |
- segment_labels: B x T matrix representing segment label for tokens | |
Output: | |
- a tuple of the following: | |
- a list of internal model states used to compute the | |
predictions where each tensor has shape T x B x C | |
- sentence representation associated with first input token | |
in format B x C. | |
""" | |
def __init__( | |
self, | |
padding_idx: int, | |
vocab_size: int, | |
num_encoder_layers: int = 6, | |
embedding_dim: int = 768, | |
ffn_embedding_dim: int = 3072, | |
num_attention_heads: int = 8, | |
dropout: float = 0.1, | |
attention_dropout: float = 0.1, | |
activation_dropout: float = 0.1, | |
layerdrop: float = 0.0, | |
max_seq_len: int = 256, | |
num_segments: int = 2, | |
use_position_embeddings: bool = True, | |
offset_positions_by_padding: bool = True, | |
encoder_normalize_before: bool = False, | |
apply_bert_init: bool = False, | |
activation_fn: str = "relu", | |
learned_pos_embedding: bool = True, | |
embed_scale: float = None, | |
freeze_embeddings: bool = False, | |
n_trans_layers_to_freeze: int = 0, | |
export: bool = False, | |
traceable: bool = False, | |
q_noise: float = 0.0, | |
qn_block_size: int = 8, | |
) -> None: | |
super().__init__() | |
self.padding_idx = padding_idx | |
self.vocab_size = vocab_size | |
self.dropout_module = FairseqDropout( | |
dropout, module_name=self.__class__.__name__ | |
) | |
self.layerdrop = layerdrop | |
self.max_seq_len = max_seq_len | |
self.embedding_dim = embedding_dim | |
self.num_segments = num_segments | |
self.use_position_embeddings = use_position_embeddings | |
self.apply_bert_init = apply_bert_init | |
self.learned_pos_embedding = learned_pos_embedding | |
self.traceable = traceable | |
self.embed_tokens = self.build_embedding( | |
self.vocab_size, self.embedding_dim, self.padding_idx | |
) | |
self.embed_scale = embed_scale | |
if q_noise > 0: | |
self.quant_noise = apply_quant_noise_( | |
nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), | |
q_noise, | |
qn_block_size, | |
) | |
else: | |
self.quant_noise = None | |
self.segment_embeddings = ( | |
nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) | |
if self.num_segments > 0 | |
else None | |
) | |
self.embed_positions = ( | |
PositionalEmbedding( | |
self.max_seq_len, | |
self.embedding_dim, | |
padding_idx=(self.padding_idx if offset_positions_by_padding else None), | |
learned=self.learned_pos_embedding, | |
) | |
if self.use_position_embeddings | |
else None | |
) | |
if encoder_normalize_before: | |
self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) | |
else: | |
self.emb_layer_norm = None | |
if self.layerdrop > 0.0: | |
self.layers = LayerDropModuleList(p=self.layerdrop) | |
else: | |
self.layers = nn.ModuleList([]) | |
self.layers.extend( | |
[ | |
self.build_transformer_sentence_encoder_layer( | |
embedding_dim=self.embedding_dim, | |
ffn_embedding_dim=ffn_embedding_dim, | |
num_attention_heads=num_attention_heads, | |
dropout=self.dropout_module.p, | |
attention_dropout=attention_dropout, | |
activation_dropout=activation_dropout, | |
activation_fn=activation_fn, | |
export=export, | |
q_noise=q_noise, | |
qn_block_size=qn_block_size, | |
) | |
for _ in range(num_encoder_layers) | |
] | |
) | |
# Apply initialization of model params after building the model | |
if self.apply_bert_init: | |
self.apply(init_bert_params) | |
def freeze_module_params(m): | |
if m is not None: | |
for p in m.parameters(): | |
p.requires_grad = False | |
if freeze_embeddings: | |
freeze_module_params(self.embed_tokens) | |
freeze_module_params(self.segment_embeddings) | |
freeze_module_params(self.embed_positions) | |
freeze_module_params(self.emb_layer_norm) | |
for layer in range(n_trans_layers_to_freeze): | |
freeze_module_params(self.layers[layer]) | |
def build_embedding(self, vocab_size, embedding_dim, padding_idx): | |
return nn.Embedding(vocab_size, embedding_dim, padding_idx) | |
def build_transformer_sentence_encoder_layer( | |
self, | |
embedding_dim, | |
ffn_embedding_dim, | |
num_attention_heads, | |
dropout, | |
attention_dropout, | |
activation_dropout, | |
activation_fn, | |
export, | |
q_noise, | |
qn_block_size, | |
): | |
return TransformerSentenceEncoderLayer( | |
embedding_dim=embedding_dim, | |
ffn_embedding_dim=ffn_embedding_dim, | |
num_attention_heads=num_attention_heads, | |
dropout=dropout, | |
attention_dropout=attention_dropout, | |
activation_dropout=activation_dropout, | |
activation_fn=activation_fn, | |
export=export, | |
q_noise=q_noise, | |
qn_block_size=qn_block_size, | |
) | |
def forward( | |
self, | |
tokens: torch.Tensor, | |
segment_labels: torch.Tensor = None, | |
last_state_only: bool = False, | |
positions: Optional[torch.Tensor] = None, | |
token_embeddings: Optional[torch.Tensor] = None, | |
attn_mask: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
is_tpu = tokens.device.type == "xla" | |
# compute padding mask. This is needed for multi-head attention | |
padding_mask = tokens.eq(self.padding_idx) | |
if not self.traceable and not is_tpu and not padding_mask.any(): | |
padding_mask = None | |
if token_embeddings is not None: | |
x = token_embeddings | |
else: | |
x = self.embed_tokens(tokens) | |
if self.embed_scale is not None: | |
x = x * self.embed_scale | |
if self.embed_positions is not None: | |
x = x + self.embed_positions(tokens, positions=positions) | |
if self.segment_embeddings is not None and segment_labels is not None: | |
x = x + self.segment_embeddings(segment_labels) | |
if self.quant_noise is not None: | |
x = self.quant_noise(x) | |
if self.emb_layer_norm is not None: | |
x = self.emb_layer_norm(x) | |
x = self.dropout_module(x) | |
# account for padding while computing the representation | |
if padding_mask is not None: | |
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
# B x T x C -> T x B x C | |
x = x.transpose(0, 1) | |
inner_states = [] | |
if not last_state_only: | |
inner_states.append(x) | |
for layer in self.layers: | |
x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) | |
if not last_state_only: | |
inner_states.append(x) | |
sentence_rep = x[0, :, :] | |
if last_state_only: | |
inner_states = [x] | |
if self.traceable: | |
return torch.stack(inner_states), sentence_rep | |
else: | |
return inner_states, sentence_rep | |