|
from transformers import BartConfig |
|
|
|
class BartCustomConfig(BartConfig): |
|
def __init__( |
|
self, |
|
model_type='bart', |
|
vocab_size=50265, |
|
max_position_embeddings=1024, |
|
encoder_layers=12, |
|
encoder_ffn_dim=4096, |
|
encoder_attention_heads=16, |
|
decoder_layers=12, |
|
decoder_ffn_dim=4096, |
|
decoder_attention_heads=16, |
|
encoder_layerdrop=0.0, |
|
decoder_layerdrop=0.0, |
|
activation_function="gelu", |
|
d_model=1024, |
|
dropout=0.1, |
|
attention_dropout=0.1, |
|
activation_dropout=0.1, |
|
init_std=0.02, |
|
classifier_dropout=0.0, |
|
classif_dropout=0.1, |
|
scale_embedding=False, |
|
use_cache=True, |
|
num_labels=3, |
|
pad_token_id=1, |
|
bos_token_id=0, |
|
eos_token_id=2, |
|
is_encoder_decoder=True, |
|
decoder_start_token_id=2, |
|
forced_eos_token_id=2, |
|
forced_bos_token_id=0, |
|
no_repeat_ngram_size=3, |
|
num_hidden_layers=12, |
|
normalize_before=False, |
|
num_beams=4, |
|
add_bias_logits=False, |
|
add_final_layer_norm=False, |
|
early_stopping=True, |
|
gradient_checkpointing=False, |
|
num_relation_kinds = 0, |
|
use_same_relation_kv_emb = True, |
|
is_simple_mask_commonsense = False, |
|
should_embed_positions = False, |
|
heads_mask = None, |
|
**kwargs |
|
): |
|
super(BartCustomConfig, self).__init__( |
|
model_type=model_type, |
|
vocab_size=vocab_size, |
|
max_position_embeddings=max_position_embeddings, |
|
encoder_layers=encoder_layers, |
|
encoder_ffn_dim=encoder_ffn_dim, |
|
encoder_attention_heads=encoder_attention_heads, |
|
decoder_layers=decoder_layers, |
|
decoder_ffn_dim=decoder_ffn_dim, |
|
decoder_attention_heads=decoder_attention_heads, |
|
encoder_layerdrop=encoder_layerdrop, |
|
decoder_layerdrop=decoder_layerdrop, |
|
activation_function=activation_function, |
|
d_model=d_model, |
|
dropout=dropout, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
init_std=init_std, |
|
classifier_dropout=classifier_dropout, |
|
classif_dropout=classif_dropout, |
|
scale_embedding=scale_embedding, |
|
use_cache=use_cache, |
|
num_labels=num_labels, |
|
pad_token_id = pad_token_id, |
|
bos_token_id = bos_token_id, |
|
eos_token_id = eos_token_id, |
|
is_encoder_decoder = is_encoder_decoder, |
|
decoder_start_token_id = decoder_start_token_id, |
|
forced_eos_token_id = forced_eos_token_id, |
|
forced_bos_token_id=forced_bos_token_id, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
normalize_before=normalize_before, |
|
num_hidden_layers=num_hidden_layers, |
|
num_beams=num_beams, |
|
add_bias_logits=add_bias_logits, |
|
add_final_layer_norm=add_final_layer_norm, |
|
early_stopping=early_stopping, |
|
gradient_checkpointing=gradient_checkpointing, |
|
num_relation_kinds = num_relation_kinds, |
|
use_same_relation_kv_emb = use_same_relation_kv_emb, |
|
is_simple_mask_commonsense = is_simple_mask_commonsense, |
|
heads_mask = None, |
|
should_embed_positions=False, |
|
**kwargs |
|
) |
|
self.num_relation_kinds = num_relation_kinds |
|
self.use_same_relation_kv_emb = use_same_relation_kv_emb |
|
self.is_simple_mask_commonsense = is_simple_mask_commonsense |
|
self.heads_mask = heads_mask |
|
self.should_embed_positions = should_embed_positions |
|
|
|
class BartSmallCustomConfig(BartConfig): |
|
def __init__( |
|
self, |
|
vocab_size=50265, |
|
max_position_embeddings=1024, |
|
encoder_layers=6, |
|
encoder_ffn_dim=3072, |
|
encoder_attention_heads=12, |
|
decoder_layers=12, |
|
decoder_ffn_dim=3072, |
|
decoder_attention_heads=12, |
|
encoder_layerdrop=0.0, |
|
decoder_layerdrop=0.0, |
|
activation_function="gelu", |
|
d_model=768, |
|
dropout=0.1, |
|
attention_dropout=0.1, |
|
activation_dropout=0.1, |
|
init_std=0.02, |
|
classifier_dropout=0.0, |
|
classif_dropout= 0.1, |
|
scale_embedding=False, |
|
use_cache=True, |
|
num_labels=3, |
|
pad_token_id=1, |
|
bos_token_id=0, |
|
eos_token_id=2, |
|
is_encoder_decoder=True, |
|
decoder_start_token_id=2, |
|
forced_eos_token_id=2, |
|
forced_bos_token_id=0, |
|
no_repeat_ngram_size=3, |
|
num_hidden_layers=6, |
|
normalize_before=False, |
|
num_beams=4, |
|
add_bias_logits=False, |
|
add_final_layer_norm=False, |
|
_name_or_path="bart-base", |
|
early_stopping=True, |
|
gradient_checkpointing=False, |
|
num_relation_kinds = 0, |
|
use_same_relation_kv_emb = True, |
|
is_simple_mask_commonsense = False, |
|
should_embed_positions = True, |
|
heads_mask = None, |
|
**kwargs |
|
): |
|
super(BartSmallCustomConfig, self).__init__( |
|
vocab_size=vocab_size, |
|
max_position_embeddings=max_position_embeddings, |
|
encoder_layers=encoder_layers, |
|
encoder_ffn_dim=encoder_ffn_dim, |
|
encoder_attention_heads=encoder_attention_heads, |
|
decoder_layers=decoder_layers, |
|
decoder_ffn_dim=decoder_ffn_dim, |
|
decoder_attention_heads=decoder_attention_heads, |
|
encoder_layerdrop=encoder_layerdrop, |
|
decoder_layerdrop=decoder_layerdrop, |
|
activation_function=activation_function, |
|
d_model=d_model, |
|
dropout=dropout, |
|
attention_dropout=attention_dropout, |
|
activation_dropout=activation_dropout, |
|
init_std=init_std, |
|
classifier_dropout=classifier_dropout, |
|
classif_dropout=classif_dropout, |
|
scale_embedding=scale_embedding, |
|
use_cache=use_cache, |
|
num_labels=num_labels, |
|
pad_token_id = pad_token_id, |
|
bos_token_id = bos_token_id, |
|
eos_token_id = eos_token_id, |
|
is_encoder_decoder = is_encoder_decoder, |
|
decoder_start_token_id = decoder_start_token_id, |
|
forced_eos_token_id = forced_eos_token_id, |
|
forced_bos_token_id=forced_bos_token_id, |
|
no_repeat_ngram_size = no_repeat_ngram_size, |
|
normalize_before = normalize_before, |
|
num_hidden_layers=num_hidden_layers, |
|
num_beams=num_beams, |
|
add_bias_logits=add_bias_logits, |
|
add_final_layer_norm=add_final_layer_norm, |
|
_name_or_path=_name_or_path, |
|
early_stopping=early_stopping, |
|
gradient_checkpointing=gradient_checkpointing, |
|
num_relation_kinds = num_relation_kinds, |
|
use_same_relation_kv_emb = use_same_relation_kv_emb, |
|
is_simple_mask_commonsense = is_simple_mask_commonsense, |
|
heads_mask = heads_mask, |
|
should_embed_positions=should_embed_positions, |
|
**kwargs |
|
) |
|
self.num_relation_kinds = num_relation_kinds |
|
self.use_same_relation_kv_emb = use_same_relation_kv_emb |
|
self.is_simple_mask_commonsense = is_simple_mask_commonsense |
|
self.heads_mask = heads_mask |
|
self.should_embed_positions = should_embed_positions |
|
|