Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch.nn as nn | |
from fairseq.modules import TransformerSentenceEncoder | |
from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( | |
SparseTransformerSentenceEncoderLayer, | |
) | |
class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): | |
""" | |
Sparse implementation of the TransformerSentenceEncoder | |
- see SparseMultiheadAttention | |
""" | |
def __init__( | |
self, | |
padding_idx: int, | |
vocab_size: int, | |
num_encoder_layers: int = 6, | |
embedding_dim: int = 768, | |
ffn_embedding_dim: int = 3072, | |
num_attention_heads: int = 8, | |
dropout: float = 0.1, | |
attention_dropout: float = 0.1, | |
activation_dropout: float = 0.1, | |
max_seq_len: int = 256, | |
num_segments: int = 2, | |
use_position_embeddings: bool = True, | |
offset_positions_by_padding: bool = True, | |
encoder_normalize_before: bool = False, | |
apply_bert_init: bool = False, | |
activation_fn: str = "relu", | |
learned_pos_embedding: bool = True, | |
embed_scale: float = None, | |
freeze_embeddings: bool = False, | |
n_trans_layers_to_freeze: int = 0, | |
export: bool = False, | |
is_bidirectional: bool = True, | |
stride: int = 32, | |
expressivity: int = 8, | |
) -> None: | |
super().__init__( | |
padding_idx, | |
vocab_size, | |
num_encoder_layers, | |
embedding_dim, | |
ffn_embedding_dim, | |
num_attention_heads, | |
dropout, | |
attention_dropout, | |
activation_dropout, | |
max_seq_len, | |
num_segments, | |
use_position_embeddings, | |
offset_positions_by_padding, | |
encoder_normalize_before, | |
apply_bert_init, | |
activation_fn, | |
learned_pos_embedding, | |
embed_scale, | |
freeze_embeddings, | |
n_trans_layers_to_freeze, | |
export, | |
) | |
self.layers = nn.ModuleList( | |
[ | |
SparseTransformerSentenceEncoderLayer( | |
embedding_dim=self.embedding_dim, | |
ffn_embedding_dim=ffn_embedding_dim, | |
num_attention_heads=num_attention_heads, | |
dropout=dropout, | |
attention_dropout=attention_dropout, | |
activation_dropout=activation_dropout, | |
activation_fn=activation_fn, | |
export=export, | |
is_bidirectional=is_bidirectional, | |
stride=stride, | |
expressivity=expressivity, | |
) | |
for _ in range(num_encoder_layers) | |
] | |
) | |
def freeze_module_params(m): | |
if m is not None: | |
for p in m.parameters(): | |
p.requires_grad = False | |
for layer in range(n_trans_layers_to_freeze): | |
freeze_module_params(self.layers[layer]) | |