|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
from .ar_tokenizer_quantizers import FSQuantizer |
|
|
|
|
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
|
|
|
def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: |
|
"""Loads a torch.jit.ScriptModule from a filepath. |
|
|
|
Args: |
|
jit_filepath: The filepath to the JIT-compiled model. |
|
device: The device to load the model onto, default=cuda. |
|
Returns: |
|
The JIT compiled model loaded to device and on eval mode. |
|
""" |
|
|
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
|
model = torch.jit.load(jit_filepath) |
|
return model.eval().to(device) |
|
|
|
|
|
class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): |
|
""" |
|
A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization |
|
using provided mean and standard deviation values for latent space representation. |
|
Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. |
|
|
|
Attributes: |
|
encoder (Module | Callable): Encoder loaded from storage. |
|
decoder (Module | Callable): Decoder loaded from storage. |
|
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
|
|
|
Args: |
|
name (str): Name of the model, used for differentiating cache file paths. |
|
latent_ch (int, optional): Number of latent channels (default is 6). |
|
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
|
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
|
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
|
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
|
level (list[int]): The level defined in FSQ quantizer. |
|
compression_ratio (list[int]): The compression factor for (T, H, W). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
name: str, |
|
latent_ch: int = 6, |
|
is_bf16: bool = True, |
|
pixel_chunk_duration: int = 25, |
|
latent_chunk_duration: int = 4, |
|
max_enc_batch_size: int = 8, |
|
max_dec_batch_size: int = 4, |
|
levels: list[int] = [8, 8, 8, 5, 5, 5], |
|
compression_ratio: list[int] = [8, 16, 16], |
|
): |
|
super().__init__() |
|
self.channel = latent_ch |
|
self.name = name |
|
dtype = torch.bfloat16 if is_bf16 else torch.float32 |
|
self.dtype = dtype |
|
self.pixel_chunk_duration = pixel_chunk_duration |
|
self.latent_chunk_duration = latent_chunk_duration |
|
self.max_enc_batch_size = max_enc_batch_size |
|
self.max_dec_batch_size = max_dec_batch_size |
|
self.levels = levels |
|
self.compress_ratio = compression_ratio |
|
self.fsq_quantizer = FSQuantizer(levels) |
|
|
|
@property |
|
def latent_ch(self) -> int: |
|
""" |
|
Returns the number of latent channels in the tokenizer. |
|
""" |
|
return self.channel |
|
|
|
@torch.no_grad() |
|
def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: |
|
B, C, T, H, W = state.shape |
|
if pixel_chunk_duration is None: |
|
|
|
pixel_chunk_duration = self.pixel_chunk_duration |
|
latent_chunk_duration = self.latent_chunk_duration |
|
else: |
|
|
|
latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] |
|
|
|
assert ( |
|
T % pixel_chunk_duration == 0 |
|
), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" |
|
state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) |
|
|
|
|
|
if state.shape[0] > self.max_enc_batch_size: |
|
quantized_out_list = [] |
|
indices_list = [] |
|
for i in range(0, state.shape[0], self.max_enc_batch_size): |
|
indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) |
|
quantized_out_list.append(quantized_out) |
|
indices_list.append(indices) |
|
quantized_out = torch.cat(quantized_out_list, dim=0) |
|
indices = torch.cat(indices_list, dim=0) |
|
else: |
|
indices, quantized_out, _ = self.encoder(state.to(self.dtype)) |
|
assert quantized_out.shape[2] == latent_chunk_duration |
|
return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( |
|
indices, "(b n) t h w -> b (n t) h w", b=B |
|
) |
|
|
|
@torch.no_grad() |
|
def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: |
|
B, T, _, _ = indices.shape |
|
if pixel_chunk_duration is None: |
|
pixel_chunk_duration = self.pixel_chunk_duration |
|
latent_chunk_duration = self.latent_chunk_duration |
|
else: |
|
latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] |
|
assert ( |
|
T % latent_chunk_duration == 0 |
|
), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" |
|
indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) |
|
|
|
|
|
if indices.shape[0] > self.max_dec_batch_size: |
|
state = [] |
|
for i in range(0, indices.shape[0], self.max_dec_batch_size): |
|
state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) |
|
state = torch.cat(state, dim=0) |
|
else: |
|
state = self.decoder(indices) |
|
|
|
assert state.shape[2] == pixel_chunk_duration |
|
return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) |
|
|
|
def reset_dtype(self, *args, **kwargs): |
|
""" |
|
Resets the data type of the encoder and decoder to the model's default data type. |
|
|
|
Args: |
|
*args, **kwargs: Unused, present to allow flexibility in method calls. |
|
""" |
|
del args, kwargs |
|
self.decoder.to(self.dtype) |
|
self.encoder.to(self.dtype) |
|
|
|
|
|
class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): |
|
""" |
|
A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder |
|
and decoder components from a remote store, handles data type conversions, and normalization |
|
using provided mean and standard deviation values for latent space representation. |
|
|
|
Attributes: |
|
encoder (Module): The JIT compiled encoder loaded from storage. |
|
decoder (Module): The JIT compiled decoder loaded from storage. |
|
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
|
|
|
Args: |
|
enc_fp (str): File path to the encoder's JIT file on the remote store. |
|
dec_fp (str): File path to the decoder's JIT file on the remote store. |
|
name (str): Name of the model, used for differentiating cache file paths. |
|
latent_ch (int, optional): Number of latent channels (default is 6). |
|
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
|
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
|
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
|
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
|
level (list[int]): The level defined in FSQ quantizer. |
|
compression_ratio (list[int]): The compression factor for (T, H, W). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
enc_fp: str, |
|
dec_fp: str, |
|
name: str, |
|
latent_ch: int = 6, |
|
is_bf16: bool = True, |
|
pixel_chunk_duration: int = 25, |
|
latent_chunk_duration: int = 4, |
|
max_enc_batch_size: int = 8, |
|
max_dec_batch_size: int = 4, |
|
levels: list[int] = [8, 8, 8, 5, 5, 5], |
|
compression_ratio: list[int] = [8, 16, 16], |
|
): |
|
super().__init__( |
|
name, |
|
latent_ch, |
|
is_bf16, |
|
pixel_chunk_duration, |
|
latent_chunk_duration, |
|
max_enc_batch_size, |
|
max_dec_batch_size, |
|
levels, |
|
compression_ratio, |
|
) |
|
|
|
self.load_encoder(enc_fp) |
|
self.load_decoder(dec_fp) |
|
|
|
def load_encoder(self, enc_fp: str) -> None: |
|
""" |
|
Load the encoder from the remote store. |
|
|
|
Args: |
|
- enc_fp (str): File path to the encoder's JIT file on the remote store. |
|
""" |
|
self.encoder = load_jit_model(enc_fp, device="cuda") |
|
self.encoder.eval() |
|
for param in self.encoder.parameters(): |
|
param.requires_grad = False |
|
self.encoder.to(self.dtype) |
|
|
|
def load_decoder(self, dec_fp: str) -> None: |
|
""" |
|
Load the decoder from the remote store. |
|
|
|
Args: |
|
- dec_fp (str): File path to the decoder's JIT file on the remote store. |
|
""" |
|
self.decoder = load_jit_model(dec_fp, device="cuda") |
|
self.decoder.eval() |
|
for param in self.decoder.parameters(): |
|
param.requires_grad = False |
|
self.decoder.to(self.dtype) |
|
|
|
|
|
class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): |
|
""" |
|
A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder |
|
into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, |
|
handles data type conversions, and normalization using provided mean and standard deviation values for latent |
|
space representation. |
|
|
|
Attributes: |
|
tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints |
|
encoder (Callable): tokenizer_module's encode method |
|
decoder (Callable): tokenizer_module's decode method |
|
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
|
|
|
Args: |
|
enc_fp (str): File path to the encoder's JIT file on the remote store. |
|
dec_fp (str): File path to the decoder's JIT file on the remote store. |
|
tokenizer_module (Module): Tokenizer module that will have it's weights loaded |
|
name (str): Name of the model, used for differentiating cache file paths. |
|
latent_ch (int, optional): Number of latent channels (default is 6). |
|
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
|
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
|
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
|
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
|
level (list[int]): The level defined in FSQ quantizer. |
|
compression_ratio (list[int]): The compression factor for (T, H, W). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
enc_fp: str, |
|
dec_fp: str, |
|
tokenizer_module: torch.nn.Module, |
|
name: str, |
|
latent_ch: int = 6, |
|
is_bf16: bool = True, |
|
pixel_chunk_duration: int = 25, |
|
latent_chunk_duration: int = 4, |
|
max_enc_batch_size: int = 8, |
|
max_dec_batch_size: int = 4, |
|
levels: list[int] = [8, 8, 8, 5, 5, 5], |
|
compression_ratio: list[int] = [8, 16, 16], |
|
): |
|
super().__init__( |
|
name, |
|
latent_ch, |
|
is_bf16, |
|
pixel_chunk_duration, |
|
latent_chunk_duration, |
|
max_enc_batch_size, |
|
max_dec_batch_size, |
|
levels, |
|
compression_ratio, |
|
) |
|
|
|
self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) |
|
|
|
def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: |
|
""" |
|
Load the encoder from the remote store. |
|
|
|
Args: |
|
- enc_fp (str): File path to the encoder's JIT file on the remote store. |
|
- def_fp (str): File path to the decoder's JIT file on the remote store. |
|
- tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints |
|
""" |
|
self.decoder = load_jit_model(dec_fp) |
|
|
|
self.decoder.eval() |
|
for param in self.decoder.parameters(): |
|
param.requires_grad = False |
|
self.decoder.to(self.dtype) |
|
|
|
encoder_sd = load_jit_model(enc_fp).state_dict() |
|
|
|
del tokenizer_module.post_quant_conv |
|
del tokenizer_module.decoder |
|
|
|
state_dict = { |
|
k: v |
|
for k, v in (encoder_sd).items() |
|
|
|
if k |
|
not in ( |
|
"encoder.patcher3d.wavelets", |
|
"encoder.patcher3d._arange", |
|
"encoder.patcher3d.patch_size_buffer", |
|
"quantizer._levels", |
|
"quantizer._basis", |
|
"quantizer.implicit_codebook", |
|
) |
|
} |
|
|
|
tokenizer_module.load_state_dict(state_dict) |
|
|
|
tokenizer_module.eval() |
|
for param in tokenizer_module.parameters(): |
|
param.requires_grad = False |
|
tokenizer_module.to(self.dtype) |
|
|
|
self.tokenizer_module = tokenizer_module |
|
self.encoder = self.tokenizer_module.encode |
|
|
|
def reset_dtype(self, *args, **kwargs): |
|
""" |
|
Resets the data type of the encoder and decoder to the model's default data type. |
|
|
|
Args: |
|
*args, **kwargs: Unused, present to allow flexibility in method calls. |
|
""" |
|
del args, kwargs |
|
self.decoder.to(self.dtype) |
|
self.tokenizer_module.to(self.dtype) |
|
|