File size: 399 Bytes
f4b9f63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import BertConfig
from typing import List
class BertVAEConfig(BertConfig):
model_type = "bert_vae"
is_encoder_decoder = True
def __init__(
self,
num_hidden_layers=3,
position_num=4,
**kwargs,
):
super().__init__(**kwargs)
self.num_hidden_layers = num_hidden_layers
self.position_num = position_num
|