from ..configuration_videobase import VideoBaseConfiguration from typing import Union, Tuple class VQVAEConfiguration(VideoBaseConfiguration): def __init__( self, embedding_dim: int = 256, n_codes: int = 2048, n_hiddens: int = 240, n_res_layers: int = 4, resolution: int = 128, sequence_length: int = 16, downsample: Union[Tuple[int, int, int], str] = (4, 4, 4), no_pos_embd: bool = True, **kwargs, ): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.n_codes = n_codes self.n_hiddens = n_hiddens self.n_res_layers = n_res_layers self.resolution = resolution self.sequence_length = sequence_length if isinstance(downsample, str): self.downsample = tuple(map(int, downsample.split(","))) else: self.downsample = downsample self.no_pos_embd = no_pos_embd self.hidden_size = n_hiddens