keysync-demo / vae_wrapper.py
Antoni Bigata
first commit
b5ce381
import os
import torch
import torch.nn as nn
from einops import rearrange
from diffusers import (
AutoencoderKL,
AutoencoderKLTemporalDecoder,
StableDiffusionPipeline,
)
def default(value, default_value):
return default_value if value is None else value
def load_stable_model(model_path):
vae_model = StableDiffusionPipeline.from_pretrained(model_path)
vae_model.set_use_memory_efficient_attention_xformers(True)
return vae_model.vae
def process_image(image: torch.Tensor, resolution=None) -> torch.Tensor:
"""
Process image tensor by resizing and normalizing.
Args:
image: Input image tensor
resolution: Target resolution for resizing
Returns:
Processed image tensor
"""
if resolution is not None:
image = torch.nn.functional.interpolate(
image.float(), size=resolution, mode="bilinear", align_corners=False
)
return image / 127.5 - 1.0
def encode_video_chunk(
model,
video,
target_resolution,
) -> torch.Tensor:
"""
Encode a chunk of video frames into latent space.
Args:
model: VAE model for encoding
video: Video tensor to encode
target_resolution: Target resolution for processing
Returns:
Encoded latent tensor
"""
video = rearrange(video, "t h w c -> c t h w")
vid_rez = min(video.shape[-1], video.shape[-2])
to_rez = default(target_resolution, vid_rez)
video = process_image(video, to_rez)
encoded = model.encode_video(video.cuda().unsqueeze(0)).squeeze(0)
return rearrange(encoded, "c t h w -> t c h w")
class VaeWrapper(nn.Module):
def __init__(self, latent_type, max_chunk_decode=16, variant="fp16"):
super().__init__()
self.vae_model = self.get_vae(latent_type, variant)
# self.latent_scale = latent_scale
self.latent_type = latent_type
self.max_chunk_decode = max_chunk_decode
def get_vae(self, latent_type, variant="fp16"):
if latent_type == "stable":
vae_model = load_stable_model("stabilityai/stable-diffusion-x4-upscaler")
vae_model.enable_slicing()
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 4
elif latent_type == "video":
vae_model = AutoencoderKLTemporalDecoder.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="vae",
torch_dtype=torch.float16 if variant == "fp16" else torch.float32,
variant="fp16" if variant == "fp16" else None,
)
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 8
elif latent_type == "refiner":
vae_model = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
subfolder="vae",
revision=None,
)
vae_model.enable_slicing()
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 8
vae_model.eval()
vae_model.requires_grad_(False)
vae_model.cuda()
vae_model = torch.compile(vae_model)
return vae_model
# def accelerate_model(self, example_shape):
# self.vae_model = torch.jit.trace(self.vae_model, torch.randn(example_shape).cuda())
# self.vae_model = torch.compile(self.vae_model)
# self.is_accelerated = True
def disable_slicing(self):
self.vae_model.disable_slicing()
@torch.no_grad()
def encode_video(self, video):
"""
video: (B, C, T, H, W)
"""
is_video = False
if len(video.shape) == 5:
is_video = True
T = video.shape[2]
video = rearrange(video, "b c t h w -> (b t) c h w")
or_dtype = video.dtype
# if not self.is_accelerated:
# self.accelerate_model(video.shape)
if self.latent_type in ["stable", "refiner", "video"]:
encoded_video = (
self.vae_model.encode(video.to(dtype=self.vae_model.dtype))
.latent_dist.sample()
.to(dtype=or_dtype)
* self.vae_model.config.scaling_factor
)
elif self.latent_type == "ldm":
encoded_video = self.vae_model.encode_first_stage(video) * 0.18215
if not is_video:
return encoded_video
return rearrange(encoded_video, "(b t) c h w -> b c t h w", t=T)
@torch.no_grad()
def decode_video(self, encoded_video):
"""
encoded_video: (B, C, T, H, W)
"""
is_video = False
B, T = encoded_video.shape[0], 1
if len(encoded_video.shape) == 5:
is_video = True
T = encoded_video.shape[2]
encoded_video = rearrange(encoded_video, "b c t h w -> (b t) c h w")
decoded_full = []
or_dtype = encoded_video.dtype
for i in range(0, T * B, self.max_chunk_decode): # Slow but no memory issues
if self.latent_type in ["stable", "refiner"]:
decoded_full.append(
self.vae_model.decode(
(1 / self.vae_model.config.scaling_factor)
* encoded_video[i : i + self.max_chunk_decode]
).sample
)
elif self.latent_type == "video":
chunk = encoded_video[i : i + self.max_chunk_decode].to(
dtype=self.vae_model.dtype
)
num_frames_in = chunk.shape[0]
decode_kwargs = {}
decode_kwargs["num_frames"] = num_frames_in
decoded_full.append(
self.vae_model.decode(
1 / self.vae_model.config.scaling_factor * chunk,
**decode_kwargs,
).sample.to(or_dtype)
)
elif self.latent_type == "ldm":
decoded_full.append(
self.vae_model.decode_first_stage(
1 / 0.18215 * encoded_video[i : i + self.max_chunk_decode]
)
)
decoded_video = torch.cat(decoded_full, dim=0)
if not is_video:
return decoded_video.clamp(-1.0, 1.0)
return rearrange(decoded_video, "(b t) c h w -> b c t h w", t=T).clamp(
-1.0, 1.0
)