anchorxia's picture
add musev
96d7ad8
raw
history blame
655 Bytes
from einops import rearrange
from torch import nn
import torch
def decode_unet_latents_with_vae(vae: nn.Module, latents: torch.tensor):
n_dim = latents.ndim
batch_size = latents.shape[0]
if n_dim == 5:
latents = rearrange(latents, "b c f h w -> (b f) c h w")
latents = 1 / vae.config.scaling_factor * latents
video = vae.decode(latents, return_dict=False)[0]
video = (video / 2 + 0.5).clamp(0, 1)
if n_dim == 5:
latents = rearrange(latents, "(b f) h w c -> b c f h w", b=batch_size)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
return video