|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import attrs |
|
|
|
from .discrete_video import DiscreteVideoFSQStateDictTokenizer |
|
from .ar_networks import CausalDiscreteVideoTokenizer |
|
from .lazy_config_init import LazyCall as L |
|
from .lazy_config_init import LazyDict |
|
|
|
|
|
def create_discrete_video_fsq_tokenizer_state_dict_config( |
|
ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16] |
|
) -> LazyDict: |
|
CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)( |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_resolutions=[32], |
|
channels=128, |
|
channels_mult=[2, 4, 4], |
|
dropout=0.0, |
|
in_channels=3, |
|
num_res_blocks=2, |
|
out_channels=3, |
|
resolution=1024, |
|
patch_size=4, |
|
patch_method="haar", |
|
z_channels=16, |
|
z_factor=1, |
|
num_groups=1, |
|
legacy_mode=False, |
|
spatial_compression=16, |
|
temporal_compression=8, |
|
embedding_dim=6, |
|
levels=[8, 8, 8, 5, 5, 5], |
|
name="CausalDiscreteFactorizedVideoTokenizer", |
|
) |
|
|
|
return L(DiscreteVideoFSQStateDictTokenizer)( |
|
enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"), |
|
dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"), |
|
tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig, |
|
name="discrete_video_fsq", |
|
latent_ch=6, |
|
is_bf16=True, |
|
pixel_chunk_duration=pixel_chunk_duration, |
|
latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0], |
|
max_enc_batch_size=8, |
|
max_dec_batch_size=4, |
|
levels=[8, 8, 8, 5, 5, 5], |
|
compression_ratio=compression_ratio, |
|
) |
|
|
|
|
|
@attrs.define(slots=False) |
|
class TextTokenizerConfig: |
|
""" |
|
Text tokenizer config |
|
|
|
Args: |
|
config: Config file to define the text tokenizer class. |
|
data_key (str): The input key from data_dict that will be passed to the text tokenizer. |
|
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. |
|
tokenizer_offset (int): Offset that is added to the tokens. |
|
vocab_size (int): Vocabulary size of the tokenizer. |
|
""" |
|
|
|
config: LazyDict |
|
data_key: str = "" |
|
tokenize_here: bool = False |
|
tokenizer_offset: int = 0 |
|
vocab_size: int = 0 |
|
|
|
|
|
@attrs.define(slots=False) |
|
class VideoTokenizerConfig: |
|
""" |
|
Video tokenizer config |
|
|
|
Args: |
|
config: Config file to define the video tokenizer class. |
|
data_key (str): The input key from data_dict that will be passed to the video tokenizer. |
|
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization. |
|
tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we |
|
add an offset to make sure that video tokens and text tokens don't overlap. |
|
vocab_size (int): Vocabulary size of the tokenizer. |
|
max_seq_len (int): Maximum token length for an input video. |
|
""" |
|
|
|
config: LazyDict |
|
data_key: str = "" |
|
tokenize_here: bool = True |
|
tokenizer_offset: int = 0 |
|
vocab_size: int = 0 |
|
max_seq_len: int = -1 |
|
|
|
|
|
@attrs.define(slots=False) |
|
class TokenizerConfig: |
|
""" |
|
Joint tokenizer config |
|
|
|
Args: |
|
text_tokenizer (TextTokenizerConfig): Text tokenizer config file |
|
class_tokenizer (ClassTokenizerConfig): Class tokenizer config file |
|
video_tokenizer (VideoTokenizerConfig): Video tokenizer config file |
|
image_tokenizer (ImageTokenizerConfig): Image tokenizer config file |
|
seq_len (int): Final token sequence length |
|
training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"] |
|
add_special_tokens (bool): Whether to add special tokens to the output tokens |
|
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64. |
|
""" |
|
|
|
text_tokenizer: Optional[TextTokenizerConfig] = None |
|
video_tokenizer: Optional[VideoTokenizerConfig] = None |
|
seq_len: int = 4096 |
|
training_type: str = None |
|
add_special_tokens: bool = True |
|
pad_to_multiple_of: Optional[int] = 64 |
|
|