|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" |
|
|
|
import logging |
|
import torch |
|
from torch import nn |
|
from enum import Enum |
|
import math |
|
|
|
from .cosmos_tokenizer.layers3d import ( |
|
EncoderFactorized, |
|
DecoderFactorized, |
|
CausalConv3d, |
|
) |
|
|
|
|
|
class IdentityDistribution(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, parameters): |
|
return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) |
|
|
|
|
|
class GaussianDistribution(torch.nn.Module): |
|
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): |
|
super().__init__() |
|
self.min_logvar = min_logvar |
|
self.max_logvar = max_logvar |
|
|
|
def sample(self, mean, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
return mean + std * torch.randn_like(mean) |
|
|
|
def forward(self, parameters): |
|
mean, logvar = torch.chunk(parameters, 2, dim=1) |
|
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) |
|
return self.sample(mean, logvar), (mean, logvar) |
|
|
|
|
|
class ContinuousFormulation(Enum): |
|
VAE = GaussianDistribution |
|
AE = IdentityDistribution |
|
|
|
|
|
class CausalContinuousVideoTokenizer(nn.Module): |
|
def __init__( |
|
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs |
|
) -> None: |
|
super().__init__() |
|
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") |
|
self.latent_channels = latent_channels |
|
self.sigma_data = 0.5 |
|
|
|
|
|
self.encoder = EncoderFactorized( |
|
z_channels=z_factor * z_channels, **kwargs |
|
) |
|
if kwargs.get("temporal_compression", 4) == 4: |
|
kwargs["channels_mult"] = [2, 4] |
|
|
|
self.decoder = DecoderFactorized( |
|
z_channels=z_channels, **kwargs |
|
) |
|
|
|
self.quant_conv = CausalConv3d( |
|
z_factor * z_channels, |
|
z_factor * latent_channels, |
|
kernel_size=1, |
|
padding=0, |
|
) |
|
self.post_quant_conv = CausalConv3d( |
|
latent_channels, z_channels, kernel_size=1, padding=0 |
|
) |
|
|
|
|
|
self.distribution = IdentityDistribution() |
|
|
|
num_parameters = sum(param.numel() for param in self.parameters()) |
|
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}") |
|
logging.debug( |
|
f"z_channels={z_channels}, latent_channels={self.latent_channels}." |
|
) |
|
|
|
latent_temporal_chunk = 16 |
|
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32)) |
|
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32)) |
|
|
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
z, posteriors = self.distribution(moments) |
|
latent_ch = z.shape[1] |
|
latent_t = z.shape[2] |
|
in_dtype = z.dtype |
|
mean = self.latent_mean.view(latent_ch, -1) |
|
std = self.latent_std.view(latent_ch, -1) |
|
|
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
|
return ((z - mean) / std) * self.sigma_data |
|
|
|
def decode(self, z): |
|
in_dtype = z.dtype |
|
latent_ch = z.shape[1] |
|
latent_t = z.shape[2] |
|
mean = self.latent_mean.view(latent_ch, -1) |
|
std = self.latent_std.view(latent_ch, -1) |
|
|
|
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
|
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
|
|
|
z = z / self.sigma_data |
|
z = z * std + mean |
|
z = self.post_quant_conv(z) |
|
return self.decoder(z) |
|
|
|
|