import os import ffmpeg from datetime import datetime from pathlib import Path import numpy as np import cv2 import spaces import shutil import torch from omegaconf import OmegaConf from PIL import Image from scipy.spatial.transform import Rotation as R from scipy.interpolate import interp1d from torchvision import transforms from diffusers import AutoencoderKL, DDIMScheduler from omegaconf import OmegaConf from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.audio_models.model import Audio2MeshModel from src.utils.mp_utils import LMKExtractor from src.utils.draw_util import FaceMeshVisualizer from src.utils.util import get_fps, read_frames, save_videos_grid from src.utils.audio_util import prepare_audio_feature from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix, project_points from src.utils.crop_face_single import crop_face class Processer(): def __init__(self): self.create_models() def create_models(self): self.lmk_extractor = LMKExtractor() self.vis = FaceMeshVisualizer(forehead_edge=False) config = OmegaConf.load('./configs/prompts/animation_audio.yaml') if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 audio_infer_config = OmegaConf.load(config.audio_inference_config) # prepare model self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) self.a2m_model.cuda().eval() self.vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to("cuda", dtype=weight_dtype) self.reference_unet = UNet2DConditionModel.from_pretrained( config.pretrained_base_model_path, subfolder="unet", ).to(dtype=weight_dtype, device="cuda") inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) self.denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device="cuda") self.pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention self.image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) self.scheduler = DDIMScheduler(**sched_kwargs) # load pretrained weights self.denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) self.reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) self.pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) self.pipe = Pose2VideoPipeline( vae=self.vae, image_encoder=self.image_enc, reference_unet=self.reference_unet, denoising_unet=self.denoising_unet, pose_guider=self.pose_guider, scheduler=self.scheduler, ) self.pipe = self.pipe.to("cuda", dtype=weight_dtype) @spaces.GPU def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42): fps = 30 cfg = 3.5 config = OmegaConf.load('./configs/prompts/animation_audio.yaml') audio_infer_config = OmegaConf.load(config.audio_inference_config) generator = torch.manual_seed(seed) width, height = size, size date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" save_dir = Path(f"output/{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) ref_image_np = crop_face(ref_image_np, self.lmk_extractor) if ref_image_np is None: return None, Image.fromarray(ref_img) ref_image_np = cv2.resize(ref_image_np, (size, size)) ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) face_result = self.lmk_extractor(ref_image_np) if face_result is None: return None, ref_image_pil lmks = face_result['lmks'].astype(np.float32) ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) # inference pred = self.a2m_model.infer(sample['audio_feature'], sample['seq_len']) pred = pred.squeeze().detach().cpu().numpy() pred = pred.reshape(pred.shape[0], -1, 3) pred = pred + face_result['lmks3d'] if headpose_video is not None: pose_seq = get_headpose_temp(headpose_video, self.lmk_extractor) else: pose_seq = np.load(config['pose_temp']) mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] # project 3D mesh to 2D landmark projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width]) pose_images = [] for i, verts in enumerate(projected_vertices): lmk_img = self.vis.draw_landmarks((width, height), verts, normed=False) pose_images.append(lmk_img) pose_list = [] pose_tensor_list = [] pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) args_L = len(pose_images) if length==0 or length > len(pose_images) else length args_L = min(args_L, 300) for pose_image_np in pose_images[: args_L]: pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) pose_tensor_list.append(pose_transform(pose_image_pil)) pose_image_np = cv2.resize(pose_image_np, (width, height)) pose_list.append(pose_image_np) pose_list = np.array(pose_list) video_length = len(pose_tensor_list) video = self.pipe( ref_image_pil, pose_list, ref_pose, width, height, video_length, steps, cfg, generator=generator, ).videos save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" save_videos_grid( video, save_path, n_rows=1, fps=fps, ) stream = ffmpeg.input(save_path) audio = ffmpeg.input(input_audio) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() os.remove(save_path) return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil @spaces.GPU def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42): cfg = 3.5 generator = torch.manual_seed(seed) width, height = size, size date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" save_dir = Path(f"output/{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) ref_image_np = crop_face(ref_image_np, self.lmk_extractor) if ref_image_np is None: return None, Image.fromarray(ref_img) ref_image_np = cv2.resize(ref_image_np, (size, size)) ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) face_result = self.lmk_extractor(ref_image_np) if face_result is None: return None, ref_image_pil lmks = face_result['lmks'].astype(np.float32) ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) source_images = read_frames(source_video) src_fps = get_fps(source_video) pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) step = 1 if src_fps == 60: src_fps = 30 step = 2 pose_trans_list = [] verts_list = [] bs_list = [] src_tensor_list = [] args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step args_L = min(args_L, 300*step) for src_image_pil in source_images[: args_L: step]: src_tensor_list.append(pose_transform(src_image_pil)) src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR) frame_height, frame_width, _ = src_img_np.shape src_img_result = self.lmk_extractor(src_img_np) if src_img_result is None: break pose_trans_list.append(src_img_result['trans_mat']) verts_list.append(src_img_result['lmks3d']) bs_list.append(src_img_result['bs']) trans_mat_arr = np.array(pose_trans_list) verts_arr = np.array(verts_list) bs_arr = np.array(bs_list) min_bs_idx = np.argmin(bs_arr.sum(1)) # compute delta pose pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) for i in range(pose_arr.shape[0]): euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source pose_arr[i, :3] = euler_angles pose_arr[i, 3:6] = translation_vector init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt) pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3) pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])] pose_mat_smooth = np.array(pose_mat_smooth) # face retarget verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d'] # project 3D mesh to 2D landmark projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width]) pose_list = [] for i, verts in enumerate(projected_vertices): lmk_img = self.vis.draw_landmarks((frame_width, frame_height), verts, normed=False) pose_image_np = cv2.resize(lmk_img, (width, height)) pose_list.append(pose_image_np) pose_list = np.array(pose_list) video_length = len(pose_list) video = self.pipe( ref_image_pil, pose_list, ref_pose, width, height, video_length, steps, cfg, generator=generator, ).videos save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" save_videos_grid( video, save_path, n_rows=1, fps=src_fps, ) audio_output = f'{save_dir}/audio_from_video.aac' # extract audio try: ffmpeg.input(source_video).output(audio_output, acodec='copy').run() # merge audio and video stream = ffmpeg.input(save_path) audio = ffmpeg.input(audio_output) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() os.remove(save_path) os.remove(audio_output) except: shutil.move( save_path, save_path.replace('_noaudio.mp4', '.mp4') ) return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil def matrix_to_euler_and_translation(matrix): rotation_matrix = matrix[:3, :3] translation_vector = matrix[:3, 3] rotation = R.from_matrix(rotation_matrix) euler_angles = rotation.as_euler('xyz', degrees=True) return euler_angles, translation_vector def smooth_pose_seq(pose_seq, window_size=5): smoothed_pose_seq = np.zeros_like(pose_seq) for i in range(len(pose_seq)): start = max(0, i - window_size // 2) end = min(len(pose_seq), i + window_size // 2 + 1) smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) return smoothed_pose_seq def get_headpose_temp(input_video, lmk_extractor): cap = cv2.VideoCapture(input_video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) trans_mat_list = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break result = lmk_extractor(frame) trans_mat_list.append(result['trans_mat'].astype(np.float32)) cap.release() trans_mat_arr = np.array(trans_mat_list) # compute delta pose trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) for i in range(pose_arr.shape[0]): pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) pose_arr[i, :3] = euler_angles pose_arr[i, 3:6] = translation_vector # interpolate to 30 fps new_fps = 30 old_time = np.linspace(0, total_frames / fps, total_frames) new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) pose_arr_interp = np.zeros((len(new_time), 6)) for i in range(6): interp_func = interp1d(old_time, pose_arr[:, i]) pose_arr_interp[:, i] = interp_func(new_time) pose_arr_smooth = smooth_pose_seq(pose_arr_interp) return pose_arr_smooth