Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,498 Bytes
b5ce381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import torch
import torch.nn as nn
from einops import rearrange
from diffusers import (
AutoencoderKL,
AutoencoderKLTemporalDecoder,
StableDiffusionPipeline,
)
def default(value, default_value):
return default_value if value is None else value
def load_stable_model(model_path):
vae_model = StableDiffusionPipeline.from_pretrained(model_path)
vae_model.set_use_memory_efficient_attention_xformers(True)
return vae_model.vae
def process_image(image: torch.Tensor, resolution=None) -> torch.Tensor:
"""
Process image tensor by resizing and normalizing.
Args:
image: Input image tensor
resolution: Target resolution for resizing
Returns:
Processed image tensor
"""
if resolution is not None:
image = torch.nn.functional.interpolate(
image.float(), size=resolution, mode="bilinear", align_corners=False
)
return image / 127.5 - 1.0
def encode_video_chunk(
model,
video,
target_resolution,
) -> torch.Tensor:
"""
Encode a chunk of video frames into latent space.
Args:
model: VAE model for encoding
video: Video tensor to encode
target_resolution: Target resolution for processing
Returns:
Encoded latent tensor
"""
video = rearrange(video, "t h w c -> c t h w")
vid_rez = min(video.shape[-1], video.shape[-2])
to_rez = default(target_resolution, vid_rez)
video = process_image(video, to_rez)
encoded = model.encode_video(video.cuda().unsqueeze(0)).squeeze(0)
return rearrange(encoded, "c t h w -> t c h w")
class VaeWrapper(nn.Module):
def __init__(self, latent_type, max_chunk_decode=16, variant="fp16"):
super().__init__()
self.vae_model = self.get_vae(latent_type, variant)
# self.latent_scale = latent_scale
self.latent_type = latent_type
self.max_chunk_decode = max_chunk_decode
def get_vae(self, latent_type, variant="fp16"):
if latent_type == "stable":
vae_model = load_stable_model("stabilityai/stable-diffusion-x4-upscaler")
vae_model.enable_slicing()
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 4
elif latent_type == "video":
vae_model = AutoencoderKLTemporalDecoder.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="vae",
torch_dtype=torch.float16 if variant == "fp16" else torch.float32,
variant="fp16" if variant == "fp16" else None,
)
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 8
elif latent_type == "refiner":
vae_model = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
subfolder="vae",
revision=None,
)
vae_model.enable_slicing()
vae_model.set_use_memory_efficient_attention_xformers(True)
self.down_factor = 8
vae_model.eval()
vae_model.requires_grad_(False)
vae_model.cuda()
vae_model = torch.compile(vae_model)
return vae_model
# def accelerate_model(self, example_shape):
# self.vae_model = torch.jit.trace(self.vae_model, torch.randn(example_shape).cuda())
# self.vae_model = torch.compile(self.vae_model)
# self.is_accelerated = True
def disable_slicing(self):
self.vae_model.disable_slicing()
@torch.no_grad()
def encode_video(self, video):
"""
video: (B, C, T, H, W)
"""
is_video = False
if len(video.shape) == 5:
is_video = True
T = video.shape[2]
video = rearrange(video, "b c t h w -> (b t) c h w")
or_dtype = video.dtype
# if not self.is_accelerated:
# self.accelerate_model(video.shape)
if self.latent_type in ["stable", "refiner", "video"]:
encoded_video = (
self.vae_model.encode(video.to(dtype=self.vae_model.dtype))
.latent_dist.sample()
.to(dtype=or_dtype)
* self.vae_model.config.scaling_factor
)
elif self.latent_type == "ldm":
encoded_video = self.vae_model.encode_first_stage(video) * 0.18215
if not is_video:
return encoded_video
return rearrange(encoded_video, "(b t) c h w -> b c t h w", t=T)
@torch.no_grad()
def decode_video(self, encoded_video):
"""
encoded_video: (B, C, T, H, W)
"""
is_video = False
B, T = encoded_video.shape[0], 1
if len(encoded_video.shape) == 5:
is_video = True
T = encoded_video.shape[2]
encoded_video = rearrange(encoded_video, "b c t h w -> (b t) c h w")
decoded_full = []
or_dtype = encoded_video.dtype
for i in range(0, T * B, self.max_chunk_decode): # Slow but no memory issues
if self.latent_type in ["stable", "refiner"]:
decoded_full.append(
self.vae_model.decode(
(1 / self.vae_model.config.scaling_factor)
* encoded_video[i : i + self.max_chunk_decode]
).sample
)
elif self.latent_type == "video":
chunk = encoded_video[i : i + self.max_chunk_decode].to(
dtype=self.vae_model.dtype
)
num_frames_in = chunk.shape[0]
decode_kwargs = {}
decode_kwargs["num_frames"] = num_frames_in
decoded_full.append(
self.vae_model.decode(
1 / self.vae_model.config.scaling_factor * chunk,
**decode_kwargs,
).sample.to(or_dtype)
)
elif self.latent_type == "ldm":
decoded_full.append(
self.vae_model.decode_first_stage(
1 / 0.18215 * encoded_video[i : i + self.max_chunk_decode]
)
)
decoded_video = torch.cat(decoded_full, dim=0)
if not is_video:
return decoded_video.clamp(-1.0, 1.0)
return rearrange(decoded_video, "(b t) c h w -> b c t h w", t=T).clamp(
-1.0, 1.0
)
|