RA-BART / custom_bart /config.py
MrVicente's picture
added demo base code
6cf191b
from transformers import BartConfig
class BartCustomConfig(BartConfig):
def __init__(
self,
model_type='bart',
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
init_std=0.02,
classifier_dropout=0.0,
classif_dropout=0.1,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
forced_bos_token_id=0,
no_repeat_ngram_size=3, # adding
num_hidden_layers=12,
normalize_before=False,
num_beams=4,
add_bias_logits=False,
add_final_layer_norm=False,
early_stopping=True,
gradient_checkpointing=False,
num_relation_kinds = 0,
use_same_relation_kv_emb = True,
is_simple_mask_commonsense = False,
should_embed_positions = False,
heads_mask = None,
**kwargs
):
super(BartCustomConfig, self).__init__(
model_type=model_type,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
encoder_layers=encoder_layers,
encoder_ffn_dim=encoder_ffn_dim,
encoder_attention_heads=encoder_attention_heads,
decoder_layers=decoder_layers,
decoder_ffn_dim=decoder_ffn_dim,
decoder_attention_heads=decoder_attention_heads,
encoder_layerdrop=encoder_layerdrop,
decoder_layerdrop=decoder_layerdrop,
activation_function=activation_function,
d_model=d_model,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
init_std=init_std,
classifier_dropout=classifier_dropout,
classif_dropout=classif_dropout,
scale_embedding=scale_embedding,
use_cache=use_cache,
num_labels=num_labels,
pad_token_id = pad_token_id,
bos_token_id = bos_token_id,
eos_token_id = eos_token_id,
is_encoder_decoder = is_encoder_decoder,
decoder_start_token_id = decoder_start_token_id,
forced_eos_token_id = forced_eos_token_id,
forced_bos_token_id=forced_bos_token_id,
no_repeat_ngram_size=no_repeat_ngram_size, # Adding
normalize_before=normalize_before,
num_hidden_layers=num_hidden_layers,
num_beams=num_beams,
add_bias_logits=add_bias_logits,
add_final_layer_norm=add_final_layer_norm,
early_stopping=early_stopping,
gradient_checkpointing=gradient_checkpointing,
num_relation_kinds = num_relation_kinds,
use_same_relation_kv_emb = use_same_relation_kv_emb,
is_simple_mask_commonsense = is_simple_mask_commonsense,
heads_mask = None,
should_embed_positions=False,
**kwargs
)
self.num_relation_kinds = num_relation_kinds
self.use_same_relation_kv_emb = use_same_relation_kv_emb
self.is_simple_mask_commonsense = is_simple_mask_commonsense
self.heads_mask = heads_mask
self.should_embed_positions = should_embed_positions
class BartSmallCustomConfig(BartConfig):
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=6,
encoder_ffn_dim=3072,
encoder_attention_heads=12,
decoder_layers=12,
decoder_ffn_dim=3072,
decoder_attention_heads=12,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=768,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
init_std=0.02,
classifier_dropout=0.0,
classif_dropout= 0.1,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
forced_bos_token_id=0,
no_repeat_ngram_size=3, #adding
num_hidden_layers=6,
normalize_before=False,
num_beams=4,
add_bias_logits=False,
add_final_layer_norm=False,
_name_or_path="bart-base",
early_stopping=True,
gradient_checkpointing=False,
num_relation_kinds = 0,
use_same_relation_kv_emb = True,
is_simple_mask_commonsense = False,
should_embed_positions = True,
heads_mask = None,
**kwargs
):
super(BartSmallCustomConfig, self).__init__(
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
encoder_layers=encoder_layers,
encoder_ffn_dim=encoder_ffn_dim,
encoder_attention_heads=encoder_attention_heads,
decoder_layers=decoder_layers,
decoder_ffn_dim=decoder_ffn_dim,
decoder_attention_heads=decoder_attention_heads,
encoder_layerdrop=encoder_layerdrop,
decoder_layerdrop=decoder_layerdrop,
activation_function=activation_function,
d_model=d_model,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
init_std=init_std,
classifier_dropout=classifier_dropout,
classif_dropout=classif_dropout,
scale_embedding=scale_embedding,
use_cache=use_cache,
num_labels=num_labels,
pad_token_id = pad_token_id,
bos_token_id = bos_token_id,
eos_token_id = eos_token_id,
is_encoder_decoder = is_encoder_decoder,
decoder_start_token_id = decoder_start_token_id,
forced_eos_token_id = forced_eos_token_id,
forced_bos_token_id=forced_bos_token_id,
no_repeat_ngram_size = no_repeat_ngram_size, #Adding
normalize_before = normalize_before,
num_hidden_layers=num_hidden_layers,
num_beams=num_beams,
add_bias_logits=add_bias_logits,
add_final_layer_norm=add_final_layer_norm,
_name_or_path=_name_or_path,
early_stopping=early_stopping,
gradient_checkpointing=gradient_checkpointing,
num_relation_kinds = num_relation_kinds,
use_same_relation_kv_emb = use_same_relation_kv_emb,
is_simple_mask_commonsense = is_simple_mask_commonsense,
heads_mask = heads_mask,
should_embed_positions=should_embed_positions,
**kwargs
)
self.num_relation_kinds = num_relation_kinds
self.use_same_relation_kv_emb = use_same_relation_kv_emb
self.is_simple_mask_commonsense = is_simple_mask_commonsense
self.heads_mask = heads_mask
self.should_embed_positions = should_embed_positions