from transformers import BertConfig from typing import List class BertVAEConfig(BertConfig): model_type = "bert_vae" is_encoder_decoder = True def __init__( self, num_hidden_layers=3, position_num=4, **kwargs, ): super().__init__(**kwargs) self.num_hidden_layers = num_hidden_layers self.position_num = position_num