Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from diffusers import ( | |
AutoencoderKL, | |
AutoencoderKLTemporalDecoder, | |
StableDiffusionPipeline, | |
) | |
def default(value, default_value): | |
return default_value if value is None else value | |
def load_stable_model(model_path): | |
vae_model = StableDiffusionPipeline.from_pretrained(model_path) | |
vae_model.set_use_memory_efficient_attention_xformers(True) | |
return vae_model.vae | |
def process_image(image: torch.Tensor, resolution=None) -> torch.Tensor: | |
""" | |
Process image tensor by resizing and normalizing. | |
Args: | |
image: Input image tensor | |
resolution: Target resolution for resizing | |
Returns: | |
Processed image tensor | |
""" | |
if resolution is not None: | |
image = torch.nn.functional.interpolate( | |
image.float(), size=resolution, mode="bilinear", align_corners=False | |
) | |
return image / 127.5 - 1.0 | |
def encode_video_chunk( | |
model, | |
video, | |
target_resolution, | |
) -> torch.Tensor: | |
""" | |
Encode a chunk of video frames into latent space. | |
Args: | |
model: VAE model for encoding | |
video: Video tensor to encode | |
target_resolution: Target resolution for processing | |
Returns: | |
Encoded latent tensor | |
""" | |
video = rearrange(video, "t h w c -> c t h w") | |
vid_rez = min(video.shape[-1], video.shape[-2]) | |
to_rez = default(target_resolution, vid_rez) | |
video = process_image(video, to_rez) | |
encoded = model.encode_video(video.cuda().unsqueeze(0)).squeeze(0) | |
return rearrange(encoded, "c t h w -> t c h w") | |
class VaeWrapper(nn.Module): | |
def __init__(self, latent_type, max_chunk_decode=16, variant="fp16"): | |
super().__init__() | |
self.vae_model = self.get_vae(latent_type, variant) | |
# self.latent_scale = latent_scale | |
self.latent_type = latent_type | |
self.max_chunk_decode = max_chunk_decode | |
def get_vae(self, latent_type, variant="fp16"): | |
if latent_type == "stable": | |
vae_model = load_stable_model("stabilityai/stable-diffusion-x4-upscaler") | |
vae_model.enable_slicing() | |
vae_model.set_use_memory_efficient_attention_xformers(True) | |
self.down_factor = 4 | |
elif latent_type == "video": | |
vae_model = AutoencoderKLTemporalDecoder.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid", | |
subfolder="vae", | |
torch_dtype=torch.float16 if variant == "fp16" else torch.float32, | |
variant="fp16" if variant == "fp16" else None, | |
) | |
vae_model.set_use_memory_efficient_attention_xformers(True) | |
self.down_factor = 8 | |
elif latent_type == "refiner": | |
vae_model = AutoencoderKL.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
subfolder="vae", | |
revision=None, | |
) | |
vae_model.enable_slicing() | |
vae_model.set_use_memory_efficient_attention_xformers(True) | |
self.down_factor = 8 | |
vae_model.eval() | |
vae_model.requires_grad_(False) | |
vae_model.cuda() | |
vae_model = torch.compile(vae_model) | |
return vae_model | |
# def accelerate_model(self, example_shape): | |
# self.vae_model = torch.jit.trace(self.vae_model, torch.randn(example_shape).cuda()) | |
# self.vae_model = torch.compile(self.vae_model) | |
# self.is_accelerated = True | |
def disable_slicing(self): | |
self.vae_model.disable_slicing() | |
def encode_video(self, video): | |
""" | |
video: (B, C, T, H, W) | |
""" | |
is_video = False | |
if len(video.shape) == 5: | |
is_video = True | |
T = video.shape[2] | |
video = rearrange(video, "b c t h w -> (b t) c h w") | |
or_dtype = video.dtype | |
# if not self.is_accelerated: | |
# self.accelerate_model(video.shape) | |
if self.latent_type in ["stable", "refiner", "video"]: | |
encoded_video = ( | |
self.vae_model.encode(video.to(dtype=self.vae_model.dtype)) | |
.latent_dist.sample() | |
.to(dtype=or_dtype) | |
* self.vae_model.config.scaling_factor | |
) | |
elif self.latent_type == "ldm": | |
encoded_video = self.vae_model.encode_first_stage(video) * 0.18215 | |
if not is_video: | |
return encoded_video | |
return rearrange(encoded_video, "(b t) c h w -> b c t h w", t=T) | |
def decode_video(self, encoded_video): | |
""" | |
encoded_video: (B, C, T, H, W) | |
""" | |
is_video = False | |
B, T = encoded_video.shape[0], 1 | |
if len(encoded_video.shape) == 5: | |
is_video = True | |
T = encoded_video.shape[2] | |
encoded_video = rearrange(encoded_video, "b c t h w -> (b t) c h w") | |
decoded_full = [] | |
or_dtype = encoded_video.dtype | |
for i in range(0, T * B, self.max_chunk_decode): # Slow but no memory issues | |
if self.latent_type in ["stable", "refiner"]: | |
decoded_full.append( | |
self.vae_model.decode( | |
(1 / self.vae_model.config.scaling_factor) | |
* encoded_video[i : i + self.max_chunk_decode] | |
).sample | |
) | |
elif self.latent_type == "video": | |
chunk = encoded_video[i : i + self.max_chunk_decode].to( | |
dtype=self.vae_model.dtype | |
) | |
num_frames_in = chunk.shape[0] | |
decode_kwargs = {} | |
decode_kwargs["num_frames"] = num_frames_in | |
decoded_full.append( | |
self.vae_model.decode( | |
1 / self.vae_model.config.scaling_factor * chunk, | |
**decode_kwargs, | |
).sample.to(or_dtype) | |
) | |
elif self.latent_type == "ldm": | |
decoded_full.append( | |
self.vae_model.decode_first_stage( | |
1 / 0.18215 * encoded_video[i : i + self.max_chunk_decode] | |
) | |
) | |
decoded_video = torch.cat(decoded_full, dim=0) | |
if not is_video: | |
return decoded_video.clamp(-1.0, 1.0) | |
return rearrange(decoded_video, "(b t) c h w -> b c t h w", t=T).clamp( | |
-1.0, 1.0 | |
) | |