from transformers import BertConfig | |
from typing import List | |
class BertVAEConfig(BertConfig): | |
model_type = "bert_vae" | |
is_encoder_decoder = True | |
def __init__( | |
self, | |
num_hidden_layers=3, | |
position_num=4, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.num_hidden_layers = num_hidden_layers | |
self.position_num = position_num | |