Spaces:
Runtime error
Runtime error
import os | |
import ffmpeg | |
from datetime import datetime | |
from pathlib import Path | |
import numpy as np | |
import cv2 | |
import torch | |
from scipy.spatial.transform import Rotation as R | |
from scipy.interpolate import interp1d | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from einops import repeat | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPVisionModelWithProjection | |
from src.models.pose_guider import PoseGuider | |
from src.models.unet_2d_condition import UNet2DConditionModel | |
from src.models.unet_3d import UNet3DConditionModel | |
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline | |
from src.utils.util import save_videos_grid | |
from src.audio_models.model import Audio2MeshModel | |
from src.utils.audio_util import prepare_audio_feature | |
from src.utils.mp_utils import LMKExtractor | |
from src.utils.draw_util import FaceMeshVisualizer | |
from src.utils.pose_util import project_points | |
lmk_extractor = LMKExtractor() | |
vis = FaceMeshVisualizer(forehead_edge=False) | |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml') | |
if config.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
else: | |
weight_dtype = torch.float32 | |
audio_infer_config = OmegaConf.load(config.audio_inference_config) | |
# prepare model | |
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) | |
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) | |
a2m_model.cuda().eval() | |
vae = AutoencoderKL.from_pretrained( | |
config.pretrained_vae_path, | |
).to("cuda", dtype=weight_dtype) | |
reference_unet = UNet2DConditionModel.from_pretrained( | |
config.pretrained_base_model_path, | |
subfolder="unet", | |
).to(dtype=weight_dtype, device="cuda") | |
inference_config_path = config.inference_config | |
infer_config = OmegaConf.load(inference_config_path) | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
config.pretrained_base_model_path, | |
config.motion_module_path, | |
subfolder="unet", | |
unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
).to(dtype=weight_dtype, device="cuda") | |
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention | |
image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
config.image_encoder_path | |
).to(dtype=weight_dtype, device="cuda") | |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
scheduler = DDIMScheduler(**sched_kwargs) | |
# load pretrained weights | |
denoising_unet.load_state_dict( | |
torch.load(config.denoising_unet_path, map_location="cpu"), | |
strict=False, | |
) | |
reference_unet.load_state_dict( | |
torch.load(config.reference_unet_path, map_location="cpu"), | |
) | |
pose_guider.load_state_dict( | |
torch.load(config.pose_guider_path, map_location="cpu"), | |
) | |
pipe = Pose2VideoPipeline( | |
vae=vae, | |
image_encoder=image_enc, | |
reference_unet=reference_unet, | |
denoising_unet=denoising_unet, | |
pose_guider=pose_guider, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to("cuda", dtype=weight_dtype) |