File size: 655 Bytes
96d7ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from einops import rearrange

from torch import nn
import torch


def decode_unet_latents_with_vae(vae: nn.Module, latents: torch.tensor):
    n_dim = latents.ndim
    batch_size = latents.shape[0]
    if n_dim == 5:
        latents = rearrange(latents, "b c f h w -> (b f) c h w")
    latents = 1 / vae.config.scaling_factor * latents
    video = vae.decode(latents, return_dict=False)[0]
    video = (video / 2 + 0.5).clamp(0, 1)
    if n_dim == 5:
        latents = rearrange(latents, "(b f) h w c -> b c f h w", b=batch_size)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    return video