File size: 399 Bytes
f4b9f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import BertConfig
from typing import List


class BertVAEConfig(BertConfig):
    model_type = "bert_vae"
    is_encoder_decoder = True

    def __init__(
        self,
        num_hidden_layers=3,
        position_num=4,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_hidden_layers = num_hidden_layers
        self.position_num = position_num