from transformers import BartConfig class BartCustomConfig(BartConfig): def __init__( self, model_type='bart', vocab_size=50265, max_position_embeddings=1024, encoder_layers=12, encoder_ffn_dim=4096, encoder_attention_heads=16, decoder_layers=12, decoder_ffn_dim=4096, decoder_attention_heads=16, encoder_layerdrop=0.0, decoder_layerdrop=0.0, activation_function="gelu", d_model=1024, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, init_std=0.02, classifier_dropout=0.0, classif_dropout=0.1, scale_embedding=False, use_cache=True, num_labels=3, pad_token_id=1, bos_token_id=0, eos_token_id=2, is_encoder_decoder=True, decoder_start_token_id=2, forced_eos_token_id=2, forced_bos_token_id=0, no_repeat_ngram_size=3, # adding num_hidden_layers=12, normalize_before=False, num_beams=4, add_bias_logits=False, add_final_layer_norm=False, early_stopping=True, gradient_checkpointing=False, num_relation_kinds = 0, use_same_relation_kv_emb = True, is_simple_mask_commonsense = False, should_embed_positions = False, heads_mask = None, **kwargs ): super(BartCustomConfig, self).__init__( model_type=model_type, vocab_size=vocab_size, max_position_embeddings=max_position_embeddings, encoder_layers=encoder_layers, encoder_ffn_dim=encoder_ffn_dim, encoder_attention_heads=encoder_attention_heads, decoder_layers=decoder_layers, decoder_ffn_dim=decoder_ffn_dim, decoder_attention_heads=decoder_attention_heads, encoder_layerdrop=encoder_layerdrop, decoder_layerdrop=decoder_layerdrop, activation_function=activation_function, d_model=d_model, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, init_std=init_std, classifier_dropout=classifier_dropout, classif_dropout=classif_dropout, scale_embedding=scale_embedding, use_cache=use_cache, num_labels=num_labels, pad_token_id = pad_token_id, bos_token_id = bos_token_id, eos_token_id = eos_token_id, is_encoder_decoder = is_encoder_decoder, decoder_start_token_id = decoder_start_token_id, forced_eos_token_id = forced_eos_token_id, forced_bos_token_id=forced_bos_token_id, no_repeat_ngram_size=no_repeat_ngram_size, # Adding normalize_before=normalize_before, num_hidden_layers=num_hidden_layers, num_beams=num_beams, add_bias_logits=add_bias_logits, add_final_layer_norm=add_final_layer_norm, early_stopping=early_stopping, gradient_checkpointing=gradient_checkpointing, num_relation_kinds = num_relation_kinds, use_same_relation_kv_emb = use_same_relation_kv_emb, is_simple_mask_commonsense = is_simple_mask_commonsense, heads_mask = None, should_embed_positions=False, **kwargs ) self.num_relation_kinds = num_relation_kinds self.use_same_relation_kv_emb = use_same_relation_kv_emb self.is_simple_mask_commonsense = is_simple_mask_commonsense self.heads_mask = heads_mask self.should_embed_positions = should_embed_positions class BartSmallCustomConfig(BartConfig): def __init__( self, vocab_size=50265, max_position_embeddings=1024, encoder_layers=6, encoder_ffn_dim=3072, encoder_attention_heads=12, decoder_layers=12, decoder_ffn_dim=3072, decoder_attention_heads=12, encoder_layerdrop=0.0, decoder_layerdrop=0.0, activation_function="gelu", d_model=768, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, init_std=0.02, classifier_dropout=0.0, classif_dropout= 0.1, scale_embedding=False, use_cache=True, num_labels=3, pad_token_id=1, bos_token_id=0, eos_token_id=2, is_encoder_decoder=True, decoder_start_token_id=2, forced_eos_token_id=2, forced_bos_token_id=0, no_repeat_ngram_size=3, #adding num_hidden_layers=6, normalize_before=False, num_beams=4, add_bias_logits=False, add_final_layer_norm=False, _name_or_path="bart-base", early_stopping=True, gradient_checkpointing=False, num_relation_kinds = 0, use_same_relation_kv_emb = True, is_simple_mask_commonsense = False, should_embed_positions = True, heads_mask = None, **kwargs ): super(BartSmallCustomConfig, self).__init__( vocab_size=vocab_size, max_position_embeddings=max_position_embeddings, encoder_layers=encoder_layers, encoder_ffn_dim=encoder_ffn_dim, encoder_attention_heads=encoder_attention_heads, decoder_layers=decoder_layers, decoder_ffn_dim=decoder_ffn_dim, decoder_attention_heads=decoder_attention_heads, encoder_layerdrop=encoder_layerdrop, decoder_layerdrop=decoder_layerdrop, activation_function=activation_function, d_model=d_model, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, init_std=init_std, classifier_dropout=classifier_dropout, classif_dropout=classif_dropout, scale_embedding=scale_embedding, use_cache=use_cache, num_labels=num_labels, pad_token_id = pad_token_id, bos_token_id = bos_token_id, eos_token_id = eos_token_id, is_encoder_decoder = is_encoder_decoder, decoder_start_token_id = decoder_start_token_id, forced_eos_token_id = forced_eos_token_id, forced_bos_token_id=forced_bos_token_id, no_repeat_ngram_size = no_repeat_ngram_size, #Adding normalize_before = normalize_before, num_hidden_layers=num_hidden_layers, num_beams=num_beams, add_bias_logits=add_bias_logits, add_final_layer_norm=add_final_layer_norm, _name_or_path=_name_or_path, early_stopping=early_stopping, gradient_checkpointing=gradient_checkpointing, num_relation_kinds = num_relation_kinds, use_same_relation_kv_emb = use_same_relation_kv_emb, is_simple_mask_commonsense = is_simple_mask_commonsense, heads_mask = heads_mask, should_embed_positions=should_embed_positions, **kwargs ) self.num_relation_kinds = num_relation_kinds self.use_same_relation_kv_emb = use_same_relation_kv_emb self.is_simple_mask_commonsense = is_simple_mask_commonsense self.heads_mask = heads_mask self.should_embed_positions = should_embed_positions