import os import ffmpeg from datetime import datetime from pathlib import Path import numpy as np import cv2 import torch import spaces from scipy.spatial.transform import Rotation as R from scipy.interpolate import interp1d from diffusers import AutoencoderKL, DDIMScheduler from einops import repeat from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import save_videos_grid from src.audio_models.model import Audio2MeshModel from src.utils.audio_util import prepare_audio_feature from src.utils.mp_utils import LMKExtractor from src.utils.draw_util import FaceMeshVisualizer from src.utils.pose_util import project_points def matrix_to_euler_and_translation(matrix): rotation_matrix = matrix[:3, :3] translation_vector = matrix[:3, 3] rotation = R.from_matrix(rotation_matrix) euler_angles = rotation.as_euler('xyz', degrees=True) return euler_angles, translation_vector def smooth_pose_seq(pose_seq, window_size=5): smoothed_pose_seq = np.zeros_like(pose_seq) for i in range(len(pose_seq)): start = max(0, i - window_size // 2) end = min(len(pose_seq), i + window_size // 2 + 1) smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) return smoothed_pose_seq def get_headpose_temp(input_video): lmk_extractor = LMKExtractor() cap = cv2.VideoCapture(input_video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) trans_mat_list = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break result = lmk_extractor(frame) trans_mat_list.append(result['trans_mat'].astype(np.float32)) cap.release() trans_mat_arr = np.array(trans_mat_list) # compute delta pose trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) for i in range(pose_arr.shape[0]): pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) pose_arr[i, :3] = euler_angles pose_arr[i, 3:6] = translation_vector # interpolate to 30 fps new_fps = 30 old_time = np.linspace(0, total_frames / fps, total_frames) new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) pose_arr_interp = np.zeros((len(new_time), 6)) for i in range(6): interp_func = interp1d(old_time, pose_arr[:, i]) pose_arr_interp[:, i] = interp_func(new_time) pose_arr_smooth = smooth_pose_seq(pose_arr_interp) return pose_arr_smooth @spaces.GPU def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42): fps = 30 cfg = 3.5 config = OmegaConf.load('./configs/prompts/animation_audio.yaml') if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 audio_infer_config = OmegaConf.load(config.audio_inference_config) # prepare model a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) a2m_model.cuda().eval() vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to("cuda", dtype=weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained( config.pretrained_base_model_path, subfolder="unet", ).to(dtype=weight_dtype, device="cuda") inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device="cuda") pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) generator = torch.manual_seed(seed) width, height = size, size # load pretrained weights denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) pipe = pipe.to("cuda", dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" save_dir = Path(f"output/{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) lmk_extractor = LMKExtractor() vis = FaceMeshVisualizer(forehead_edge=False) ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) # TODO: 人脸检测+裁剪 ref_image_np = cv2.resize(ref_image_np, (size, size)) ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) face_result = lmk_extractor(ref_image_np) if face_result is None: return None lmks = face_result['lmks'].astype(np.float32) ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) # inference pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) pred = pred.squeeze().detach().cpu().numpy() pred = pred.reshape(pred.shape[0], -1, 3) pred = pred + face_result['lmks3d'] if headpose_video is not None: pose_seq = get_headpose_temp(headpose_video) else: pose_seq = np.load(config['pose_temp']) mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] # project 3D mesh to 2D landmark projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width]) pose_images = [] for i, verts in enumerate(projected_vertices): lmk_img = vis.draw_landmarks((width, height), verts, normed=False) pose_images.append(lmk_img) pose_list = [] pose_tensor_list = [] pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) args_L = len(pose_images) if length==0 or length > len(pose_images) else length for pose_image_np in pose_images[: args_L]: pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) pose_tensor_list.append(pose_transform(pose_image_pil)) pose_image_np = cv2.resize(pose_image_np, (width, height)) pose_list.append(pose_image_np) pose_list = np.array(pose_list) video_length = len(pose_tensor_list) video = pipe( ref_image_pil, pose_list, ref_pose, width, height, video_length, steps, cfg, generator=generator, ).videos save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" save_videos_grid( video, save_path, n_rows=1, fps=fps, ) stream = ffmpeg.input(save_path) audio = ffmpeg.input(input_audio) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() os.remove(save_path) return save_path.replace('_noaudio.mp4', '.mp4')