Live2Diff / live2diff /pipeline_stream_animation_depth.py
leoxing1996
add demo
d16b52d
import time
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from diffusers import LCMScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
retrieve_latents,
)
from einops import rearrange
from live2diff.image_filter import SimilarImageFilter
from .animatediff.pipeline import AnimationDepthPipeline
WARMUP_FRAMES = 8
WINDOW_SIZE = 16
class StreamAnimateDiffusionDepth:
def __init__(
self,
pipe: AnimationDepthPipeline,
num_inference_steps: int,
t_index_list: Optional[List[int]] = None,
strength: Optional[float] = None,
torch_dtype: torch.dtype = torch.float16,
width: int = 512,
height: int = 512,
do_add_noise: bool = True,
use_denoising_batch: bool = True,
frame_buffer_size: int = 1,
clip_skip: int = 1,
cfg_type: Literal["none", "full", "self", "initialize"] = "none",
) -> None:
self.device = pipe.device
self.dtype = torch_dtype
self.generator = None
self.height = height
self.width = width
self.pipe = pipe
self.latent_height = int(height // pipe.vae_scale_factor)
self.latent_width = int(width // pipe.vae_scale_factor)
self.clip_skip = clip_skip
self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
self.scheduler.set_timesteps(num_inference_steps, self.device)
if strength is not None:
t_index_list, timesteps = self.get_timesteps(num_inference_steps, strength, self.device)
print(
f"Generate t_index_list: {t_index_list} via "
f"num_inference_steps: {num_inference_steps}, strength: {strength}"
)
self.timesteps = timesteps
else:
print(
f"t_index_list is passed: {t_index_list}. "
f"Number Inference Steps: {num_inference_steps}, "
f"equivalents to strength {1 - t_index_list[0] / num_inference_steps}."
)
self.timesteps = self.scheduler.timesteps.to(self.device)
self.frame_bff_size = frame_buffer_size
self.denoising_steps_num = len(t_index_list)
self.strength = strength
assert cfg_type == "none", f'cfg_type must be "none" for now, but got {cfg_type}.'
self.cfg_type = cfg_type
if use_denoising_batch:
self.batch_size = self.denoising_steps_num * frame_buffer_size
if self.cfg_type == "initialize":
self.trt_unet_batch_size = (self.denoising_steps_num + 1) * self.frame_bff_size
elif self.cfg_type == "full":
self.trt_unet_batch_size = 2 * self.denoising_steps_num * self.frame_bff_size
else:
self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
else:
self.trt_unet_batch_size = self.frame_bff_size
self.batch_size = frame_buffer_size
self.t_list = t_index_list
self.do_add_noise = do_add_noise
self.use_denoising_batch = use_denoising_batch
self.similar_image_filter = False
self.similar_filter = SimilarImageFilter()
self.prev_image_result = None
self.image_processor = VaeImageProcessor(pipe.vae_scale_factor)
self.text_encoder = pipe.text_encoder
self.unet = pipe.unet
self.vae = pipe.vae
self.depth_detector = pipe.depth_model
self.inference_time_ema = 0
self.depth_time_ema = 0
self.inference_time_list = []
self.depth_time_list = []
self.mask_shift = 1
self.is_tensorrt = False
def prepare_cache(self, height, width, denoising_steps_num):
kv_cache_list = self.pipe.prepare_cache(
height=height,
width=width,
denoising_steps_num=denoising_steps_num,
)
self.pipe.prepare_warmup_unet(height=height, width=width, unet=self.unet_warmup)
self.kv_cache_list = kv_cache_list
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:].to(device)
t_index = list(range(len(timesteps)))
return t_index, timesteps
def load_lora(
self,
pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[Any] = None,
**kwargs,
) -> None:
self.pipe.load_lora_weights(
pretrained_lora_model_name_or_path_or_dict,
adapter_name,
**kwargs,
)
def fuse_lora(
self,
fuse_unet: bool = True,
fuse_text_encoder: bool = True,
lora_scale: float = 1.0,
safe_fusing: bool = False,
) -> None:
self.pipe.fuse_lora(
fuse_unet=fuse_unet,
fuse_text_encoder=fuse_text_encoder,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
)
def enable_similar_image_filter(
self,
threshold: float = 0.98,
max_skip_frame: float = 10,
) -> None:
self.similar_image_filter = True
self.similar_filter.set_threshold(threshold)
self.similar_filter.set_max_skip_frame(max_skip_frame)
def disable_similar_image_filter(self) -> None:
self.similar_image_filter = False
@torch.no_grad()
def prepare(
self,
warmup_frames: torch.Tensor,
prompt: str,
negative_prompt: str = "",
guidance_scale: float = 1.2,
delta: float = 1.0,
generator: Optional[torch.Generator] = None,
seed: int = 2,
) -> None:
"""
Forward warm-up frames and fill the buffer
images: [warmup_size, 3, h, w] in [0, 1]
"""
if generator is None:
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(seed)
else:
self.generator = generator
# initialize x_t_latent (it can be any random tensor)
if self.denoising_steps_num > 1:
self.x_t_latent_buffer = torch.zeros(
(
(self.denoising_steps_num - 1) * self.frame_bff_size,
4,
1, # for video
self.latent_height,
self.latent_width,
),
dtype=self.dtype,
device=self.device,
)
self.depth_latent_buffer = torch.zeros_like(self.x_t_latent_buffer)
else:
self.x_t_latent_buffer = None
self.depth_latent_buffer = None
self.attn_bias, self.pe_idx, self.update_idx = self.initialize_attn_bias_pe_and_update_idx()
if self.cfg_type == "none":
self.guidance_scale = 1.0
else:
self.guidance_scale = guidance_scale
self.delta = delta
do_classifier_free_guidance = False
if self.guidance_scale > 1.0:
do_classifier_free_guidance = True
encoder_output = self.pipe._encode_prompt(
prompt=prompt,
device=self.device,
num_videos_per_prompt=1,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
clip_skip=self.clip_skip,
)
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
if self.use_denoising_batch and self.cfg_type == "full":
uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1)
elif self.cfg_type == "initialize":
uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1)
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize" or self.cfg_type == "full"):
self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0)
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
self.sub_timesteps = []
for t in self.t_list:
self.sub_timesteps.append(self.timesteps[t])
sub_timesteps_tensor = torch.tensor(self.sub_timesteps, dtype=torch.long, device=self.device)
self.sub_timesteps_tensor = torch.repeat_interleave(
sub_timesteps_tensor,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
self.init_noise = torch.randn(
(self.batch_size, 4, WARMUP_FRAMES, self.latent_height, self.latent_width),
generator=generator,
).to(device=self.device, dtype=self.dtype)
self.stock_noise = torch.zeros_like(self.init_noise)
c_skip_list = []
c_out_list = []
for timestep in self.sub_timesteps:
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
c_skip_list.append(c_skip)
c_out_list.append(c_out)
self.c_skip = (
torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1, 1).to(dtype=self.dtype, device=self.device)
)
self.c_out = (
torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1, 1).to(dtype=self.dtype, device=self.device)
)
# print(self.c_skip)
alpha_prod_t_sqrt_list = []
beta_prod_t_sqrt_list = []
for timestep in self.sub_timesteps:
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
alpha_prod_t_sqrt = (
torch.stack(alpha_prod_t_sqrt_list)
.view(len(self.t_list), 1, 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
beta_prod_t_sqrt = (
torch.stack(beta_prod_t_sqrt_list)
.view(len(self.t_list), 1, 1, 1, 1)
.to(dtype=self.dtype, device=self.device)
)
self.alpha_prod_t_sqrt = torch.repeat_interleave(
alpha_prod_t_sqrt,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
self.beta_prod_t_sqrt = torch.repeat_interleave(
beta_prod_t_sqrt,
repeats=self.frame_bff_size if self.use_denoising_batch else 1,
dim=0,
)
# do warmup
# 1. encode images
warmup_x_list = []
for f in warmup_frames:
x = self.image_processor.preprocess(f, self.height, self.width)
warmup_x_list.append(x.to(device=self.device, dtype=self.dtype))
warmup_x = torch.cat(warmup_x_list, dim=0) # [warmup_size, c, h, w]
warmup_x_t = self.encode_image(warmup_x)
x_t_latent = rearrange(warmup_x_t, "f c h w -> c f h w")[None, ...]
depth_latent = self.encode_depth(warmup_x)
depth_latent = rearrange(depth_latent, "f c h w -> c f h w")[None, ...]
# 2. run warmup denoising
self.unet_warmup = self.unet_warmup.to(device="cuda", dtype=self.dtype)
warmup_prompt = self.prompt_embeds[0:1]
for idx, t in enumerate(self.sub_timesteps_tensor):
t = t.view(1).repeat(x_t_latent.shape[0])
output_t = self.unet_warmup(
x_t_latent,
t,
temporal_attention_mask=None,
depth_sample=depth_latent,
encoder_hidden_states=warmup_prompt,
kv_cache=[cache[idx] for cache in self.kv_cache_list],
return_dict=True,
)
model_pred = output_t["sample"]
x_0_pred = self.scheduler_step_batch(model_pred, x_t_latent, idx)
if idx < len(self.sub_timesteps_tensor) - 1:
# x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + self.beta_prod_t_sqrt[
idx + 1
] * torch.randn_like(x_0_pred, device=self.device, dtype=self.dtype)
self.unet_warmup = self.unet_warmup.to(device="cpu")
x_0_pred = rearrange(x_0_pred, "b c f h w -> b f c h w")[0] # [f, c, h, w]
denoisied_frame = self.decode_image(x_0_pred)
self.warmup_engine()
return denoisied_frame
def warmup_engine(self):
"""Warmup tensorrt engine."""
if not self.is_tensorrt:
return
print("Warmup TensorRT engine.")
pseudo_latent = self.init_noise[:, :, 0:1, ...]
for _ in range(self.batch_size):
self.unet(
pseudo_latent,
self.sub_timesteps_tensor,
depth_sample=pseudo_latent,
encoder_hidden_states=self.prompt_embeds,
temporal_attention_mask=self.attn_bias,
kv_cache=self.kv_cache_list,
pe_idx=self.pe_idx,
update_idx=self.update_idx,
return_dict=True,
)
print("Warmup TensorRT engine finished.")
@torch.no_grad()
def update_prompt(self, prompt: str) -> None:
encoder_output = self.pipe._encode_prompt(
prompt=prompt,
device=self.device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
t_index: int,
) -> torch.Tensor:
noisy_samples = self.alpha_prod_t_sqrt[t_index] * original_samples + self.beta_prod_t_sqrt[t_index] * noise
return noisy_samples
def scheduler_step_batch(
self,
model_pred_batch: torch.Tensor,
x_t_latent_batch: torch.Tensor,
idx: Optional[int] = None,
) -> torch.Tensor:
# TODO: use t_list to select beta_prod_t_sqrt
if idx is None:
F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch) / self.alpha_prod_t_sqrt
denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch
else:
F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch) / self.alpha_prod_t_sqrt[idx]
denoised_batch = self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
return denoised_batch
def initialize_attn_bias_pe_and_update_idx(self):
attn_mask = torch.zeros((self.denoising_steps_num, WINDOW_SIZE), dtype=torch.bool, device=self.device)
attn_mask[:, :WARMUP_FRAMES] = True
attn_mask[0, WARMUP_FRAMES] = True
attn_bias = torch.zeros_like(attn_mask, dtype=self.dtype)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
pe_idx = torch.arange(WINDOW_SIZE).unsqueeze(0).repeat(self.denoising_steps_num, 1).cuda()
update_idx = torch.ones(self.denoising_steps_num, dtype=torch.int64, device=self.device) * WARMUP_FRAMES
update_idx[1] = WARMUP_FRAMES + 1
return attn_bias, pe_idx, update_idx
def update_attn_bias(self, attn_bias, pe_idx, update_idx):
"""
attn_bias: (timesteps, prev_len), init value: [[0, 0, 0, inf], [0, 0, inf, inf]]
pe_idx: (timesteps, prev_len), init value: [[0, 1, 2, 3], [0, 1, 2, 3]]
update_idx: (timesteps, ), init value: [2, 1]
"""
for idx in range(self.denoising_steps_num):
# update pe_idx and update_idx based on attn_bias from last iteration
if torch.isinf(attn_bias[idx]).any():
# some position not filled, do not change pe
# some position not filled, fill the last position
update_idx[idx] = (attn_bias[idx] == 0).sum()
else:
# all position are filled, roll pe
pe_idx[idx, WARMUP_FRAMES:] = pe_idx[idx, WARMUP_FRAMES:].roll(shifts=1, dims=0)
# all position are filled, fill the position with largest PE
update_idx[idx] = pe_idx[idx].argmax()
num_unmask = (attn_bias[idx] == 0).sum()
attn_bias[idx, : min(num_unmask + 1, WINDOW_SIZE)] = 0
return attn_bias, pe_idx, update_idx
def unet_step(
self,
x_t_latent: torch.Tensor,
depth_latent: torch.Tensor,
t_list: Union[torch.Tensor, list[int]],
idx: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0)
t_list = torch.concat([t_list[0:1], t_list], dim=0)
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0)
t_list = torch.concat([t_list, t_list], dim=0)
else:
x_t_latent_plus_uc = x_t_latent
output = self.unet(
x_t_latent_plus_uc,
t_list,
depth_sample=depth_latent,
encoder_hidden_states=self.prompt_embeds,
temporal_attention_mask=self.attn_bias,
kv_cache=self.kv_cache_list,
pe_idx=self.pe_idx,
update_idx=self.update_idx,
return_dict=True,
)
model_pred = output["sample"]
kv_cache_list = output["kv_cache"]
self.kv_cache_list = kv_cache_list
if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"):
noise_pred_text = model_pred[1:]
self.stock_noise = torch.concat(
[model_pred[0:1], self.stock_noise[1:]], dim=0
) # ここコメントアウトでself out cfg
elif self.guidance_scale > 1.0 and (self.cfg_type == "full"):
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
else:
noise_pred_text = model_pred
if self.guidance_scale > 1.0 and (self.cfg_type == "self" or self.cfg_type == "initialize"):
noise_pred_uncond = self.stock_noise * self.delta
if self.guidance_scale > 1.0 and self.cfg_type != "none":
model_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
model_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1
if self.use_denoising_batch:
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
if self.cfg_type == "self" or self.cfg_type == "initialize":
scaled_noise = self.beta_prod_t_sqrt * self.stock_noise
delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
alpha_next = torch.concat(
[
self.alpha_prod_t_sqrt[1:],
torch.ones_like(self.alpha_prod_t_sqrt[0:1]),
],
dim=0,
)
delta_x = alpha_next * delta_x
beta_next = torch.concat(
[
self.beta_prod_t_sqrt[1:],
torch.ones_like(self.beta_prod_t_sqrt[0:1]),
],
dim=0,
)
delta_x = delta_x / beta_next
init_noise = torch.concat([self.init_noise[1:], self.init_noise[0:1]], dim=0)
self.stock_noise = init_noise + delta_x
else:
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
return denoised_batch, model_pred
def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor:
"""
image_tensors: [f, c, h, w]
"""
# num_frames = image_tensors.shape[2]
image_tensors = image_tensors.to(
device=self.device,
dtype=self.vae.dtype,
)
img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator)
img_latent = img_latent * self.vae.config.scaling_factor
noise = torch.randn(
img_latent.shape,
device=img_latent.device,
dtype=img_latent.dtype,
generator=self.generator,
)
x_t_latent = self.add_noise(img_latent, noise, 0)
return x_t_latent
def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor:
"""
x_0_pred: [f, c, h, w]
"""
output_latent = self.vae.decode(x_0_pred_out / self.vae.config.scaling_factor, return_dict=False)[0]
return output_latent.clip(-1, 1)
def encode_depth(self, image_tensors: torch.Tensor) -> Tuple[torch.Tensor]:
"""
image_tensor: [f, c, h, w], [-1, 1]
"""
image_tensors = image_tensors.to(
device=self.device,
dtype=self.depth_detector.dtype,
)
# depth_map = self.depth_detector(image_tensors)
# depth_map_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
# depth_map_norm = depth_map_norm[:, None].repeat(1, 3, 1, 1) * 2 - 1
# depth_latent = retrieve_latents(self.vae.encode(depth_map_norm.to(dtype=self.vae.dtype)), self.generator)
# depth_latent = depth_latent * self.vae.config.scaling_factor
# return depth_latent
# preprocess
h, w = image_tensors.shape[2], image_tensors.shape[3]
images_input = F.interpolate(image_tensors, (384, 384), mode="bilinear", align_corners=False)
# forward
depth_map = self.depth_detector(images_input)
# postprocess
depth_map_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map_norm = depth_map_norm[:, None].repeat(1, 3, 1, 1) * 2 - 1
depth_map_norm = F.interpolate(depth_map_norm, (h, w), mode="bilinear", align_corners=False)
# encode
depth_latent = retrieve_latents(self.vae.encode(depth_map_norm.to(dtype=self.vae.dtype)), self.generator)
depth_latent = depth_latent * self.vae.config.scaling_factor
return depth_latent
def predict_x0_batch(self, x_t_latent: torch.Tensor, depth_latent: torch.Tensor) -> torch.Tensor:
prev_latent_batch = self.x_t_latent_buffer
prev_depth_latent_batch = self.depth_latent_buffer
if self.use_denoising_batch:
t_list = self.sub_timesteps_tensor
if self.denoising_steps_num > 1:
x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0)
depth_latent = torch.cat((depth_latent, prev_depth_latent_batch), dim=0)
self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0)
x_0_pred_batch, model_pred = self.unet_step(x_t_latent, depth_latent, t_list)
self.attn_bias, self.pe_idx, self.update_idx = self.update_attn_bias(
self.attn_bias, self.pe_idx, self.update_idx
)
if self.denoising_steps_num > 1:
x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0)
if self.do_add_noise:
# self.x_t_latent_buffer = (
# self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
# + self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
# )
self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + self.beta_prod_t_sqrt[
1:
] * torch.randn_like(x_0_pred_batch[:-1])
else:
self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
self.depth_latent_buffer = depth_latent[:-1]
else:
x_0_pred_out = x_0_pred_batch
self.x_t_latent_buffer = None
else:
self.init_noise = x_t_latent
for idx, t in enumerate(self.sub_timesteps_tensor):
t = t.view(
1,
).repeat(
self.frame_bff_size,
)
x_0_pred, model_pred = self.unet_step(x_t_latent, depth_latent, t, idx)
if idx < len(self.sub_timesteps_tensor) - 1:
if self.do_add_noise:
x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + self.beta_prod_t_sqrt[
idx + 1
] * torch.randn_like(x_0_pred, device=self.device, dtype=self.dtype)
else:
x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred
x_0_pred_out = x_0_pred
return x_0_pred_out
@torch.no_grad()
def __call__(self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
x = self.image_processor.preprocess(x, self.height, self.width).to(device=self.device, dtype=self.dtype)
if self.similar_image_filter:
x = self.similar_filter(x)
if x is None:
time.sleep(self.inference_time_ema)
return self.prev_image_result
x_t_latent = self.encode_image(x)
start_depth = torch.cuda.Event(enable_timing=True)
end_depth = torch.cuda.Event(enable_timing=True)
start_depth.record()
depth_latent = self.encode_depth(x)
end_depth.record()
torch.cuda.synchronize()
depth_time = start_depth.elapsed_time(end_depth) / 1000
x_t_latent = x_t_latent.unsqueeze(2)
depth_latent = depth_latent.unsqueeze(2)
x_0_pred_out = self.predict_x0_batch(x_t_latent, depth_latent) # [1, c, 1, h, w]
x_0_pred_out = rearrange(x_0_pred_out, "b c f h w -> (b f) c h w")
x_output = self.decode_image(x_0_pred_out).detach().clone()
self.prev_image_result = x_output
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end) / 1000
self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time
self.depth_time_ema = 0.9 * self.depth_time_ema + 0.1 * depth_time
self.inference_time_list.append(inference_time)
self.depth_time_list.append(depth_time)
return x_output
def load_warmup_unet(self, config):
unet_warmup = self.pipe.build_warmup_unet(config)
self.unet_warmup = unet_warmup
self.pipe.unet_warmup = unet_warmup
print("Load Warmup UNet.")