from dataclasses import dataclass import json from typing import Optional, Tuple, Union from pathlib import Path import numpy as np import torch import torch.nn as nn from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.torch_utils import randn_tensor from diffusers.models.attention_processor import SpatialNorm from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) SCALING_FACTOR = 0.476986 VAE_VER = "884-16c-hy" def load_vae( vae_type: str = "884-16c-hy", vae_dtype: Optional[Union[str, torch.dtype]] = None, sample_size: tuple = None, vae_path: str = None, device=None, ): """the fucntion to load the 3D VAE model Args: vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". vae_precision (str, optional): the precision to load vae. Defaults to None. sample_size (tuple, optional): the tiling size. Defaults to None. vae_path (str, optional): the path to vae. Defaults to None. logger (_type_, optional): logger. Defaults to None. device (_type_, optional): device to load vae. Defaults to None. """ if vae_path is None: vae_path = VAE_PATH[vae_type] logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") # use fixed config for Hunyuan's VAE CONFIG_JSON = """{ "_class_name": "AutoencoderKLCausal3D", "_diffusers_version": "0.4.2", "act_fn": "silu", "block_out_channels": [ 128, 256, 512, 512 ], "down_block_types": [ "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D" ], "in_channels": 3, "latent_channels": 16, "layers_per_block": 2, "norm_num_groups": 32, "out_channels": 3, "sample_size": 256, "sample_tsize": 64, "up_block_types": [ "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D" ], "scaling_factor": 0.476986, "time_compression_ratio": 4, "mid_block_add_attention": true }""" # config = AutoencoderKLCausal3D.load_config(vae_path) config = json.loads(CONFIG_JSON) # import here to avoid circular import from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D if sample_size: vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) else: vae = AutoencoderKLCausal3D.from_config(config) # vae_ckpt = Path(vae_path) / "pytorch_model.pt" # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True) if "state_dict" in ckpt: ckpt = ckpt["state_dict"] if any(k.startswith("vae.") for k in ckpt.keys()): ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} vae.load_state_dict(ckpt) spatial_compression_ratio = vae.config.spatial_compression_ratio time_compression_ratio = vae.config.time_compression_ratio if vae_dtype is not None: vae = vae.to(vae_dtype) vae.requires_grad_(False) logger.info(f"VAE to dtype: {vae.dtype}") if device is not None: vae = vae.to(device) vae.eval() return vae, vae_path, spatial_compression_ratio, time_compression_ratio @dataclass class DecoderOutput(BaseOutput): r""" Output of decoding method. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor class EncoderCausal3D(nn.Module): r""" The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) num_time_downsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block) else: raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) downsample_stride_T = (2,) if add_time_downsample else (1,) downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) down_block = get_down_block3d( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=bool(add_spatial_downsample or add_time_downsample), downsample_stride=downsample_stride, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `EncoderCausal3D` class.""" assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" sample = self.conv_in(sample) # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class DecoderCausal3D(nn.Module): r""" The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) num_time_upsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: add_spatial_upsample = bool(i < num_spatial_upsample_layers) add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block) else: raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) upsample_scale_factor_T = (2,) if add_time_upsample else (1,) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) up_block = get_up_block3d( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=bool(add_spatial_upsample or add_time_upsample), upsample_scale_factor=upsample_scale_factor, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) self.gradient_checkpointing = False def forward( self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `DecoderCausal3D` class.""" assert len(sample.shape) == 5, "The input tensor should have 5 dimensions." sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False, ) else: # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): if parameters.ndim == 3: dim = 2 # (B, L, C) elif parameters.ndim == 5 or parameters.ndim == 4: dim = 1 # (B, C, T, H ,W) / (B, C, H, W) else: raise NotImplementedError self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: reduce_dim = list(range(1, self.mean.ndim)) if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=reduce_dim, ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=reduce_dim, ) def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self) -> torch.Tensor: return self.mean