|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
|
from ..constants import VAE_PATH, PRECISION_TO_TYPE
|
|
|
|
def load_vae(vae_type: str="884-16c-hy",
|
|
vae_precision: str=None,
|
|
sample_size: tuple=None,
|
|
vae_path: str=None,
|
|
logger=None,
|
|
device=None
|
|
):
|
|
"""the fucntion to load the 3D VAE model
|
|
|
|
Args:
|
|
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
|
|
vae_precision (str, optional): the precision to load vae. Defaults to None.
|
|
sample_size (tuple, optional): the tiling size. Defaults to None.
|
|
vae_path (str, optional): the path to vae. Defaults to None.
|
|
logger (_type_, optional): logger. Defaults to None.
|
|
device (_type_, optional): device to load vae. Defaults to None.
|
|
"""
|
|
if vae_path is None:
|
|
vae_path = VAE_PATH[vae_type]
|
|
|
|
if logger is not None:
|
|
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
|
|
config = AutoencoderKLCausal3D.load_config(vae_path)
|
|
if sample_size:
|
|
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
|
|
else:
|
|
vae = AutoencoderKLCausal3D.from_config(config)
|
|
|
|
vae_ckpt = Path(vae_path) / "pytorch_model.pt"
|
|
assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
|
|
|
|
ckpt = torch.load(vae_ckpt, map_location=vae.device)
|
|
if "state_dict" in ckpt:
|
|
ckpt = ckpt["state_dict"]
|
|
if any(k.startswith("vae.") for k in ckpt.keys()):
|
|
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
|
|
vae.load_state_dict(ckpt)
|
|
|
|
spatial_compression_ratio = vae.config.spatial_compression_ratio
|
|
time_compression_ratio = vae.config.time_compression_ratio
|
|
|
|
if vae_precision is not None:
|
|
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
|
|
|
|
vae.requires_grad_(False)
|
|
|
|
if logger is not None:
|
|
logger.info(f"VAE to dtype: {vae.dtype}")
|
|
|
|
if device is not None:
|
|
vae = vae.to(device)
|
|
|
|
vae.eval()
|
|
|
|
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
|
|
|