import os import ffmpeg from datetime import datetime from pathlib import Path import numpy as np import cv2 # import torch # import spaces from scipy.spatial.transform import Rotation as R from scipy.interpolate import interp1d # from diffusers import AutoencoderKL, DDIMScheduler # from einops import repeat # from omegaconf import OmegaConf # from PIL import Image # from torchvision import transforms # from transformers import CLIPVisionModelWithProjection # from src.models.pose_guider import PoseGuider # from src.models.unet_2d_condition import UNet2DConditionModel # from src.models.unet_3d import UNet3DConditionModel # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline # from src.utils.util import save_videos_grid # from src.audio_models.model import Audio2MeshModel # from src.utils.audio_util import prepare_audio_feature from src.utils.mp_utils import LMKExtractor # from src.utils.draw_util import FaceMeshVisualizer # from src.utils.pose_util import project_points # from src.utils.crop_face_single import crop_face def matrix_to_euler_and_translation(matrix): rotation_matrix = matrix[:3, :3] translation_vector = matrix[:3, 3] rotation = R.from_matrix(rotation_matrix) euler_angles = rotation.as_euler('xyz', degrees=True) return euler_angles, translation_vector def smooth_pose_seq(pose_seq, window_size=5): smoothed_pose_seq = np.zeros_like(pose_seq) for i in range(len(pose_seq)): start = max(0, i - window_size // 2) end = min(len(pose_seq), i + window_size // 2 + 1) smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) return smoothed_pose_seq def get_headpose_temp(input_video): lmk_extractor = LMKExtractor() cap = cv2.VideoCapture(input_video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) trans_mat_list = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break result = lmk_extractor(frame) trans_mat_list.append(result['trans_mat'].astype(np.float32)) cap.release() trans_mat_arr = np.array(trans_mat_list) # compute delta pose trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) for i in range(pose_arr.shape[0]): pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) pose_arr[i, :3] = euler_angles pose_arr[i, 3:6] = translation_vector # interpolate to 30 fps new_fps = 30 old_time = np.linspace(0, total_frames / fps, total_frames) new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) pose_arr_interp = np.zeros((len(new_time), 6)) for i in range(6): interp_func = interp1d(old_time, pose_arr[:, i]) pose_arr_interp[:, i] = interp_func(new_time) pose_arr_smooth = smooth_pose_seq(pose_arr_interp) return pose_arr_smooth # @spaces.GPU(duration=150) # def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42): # fps = 30 # cfg = 3.5 # config = OmegaConf.load('./configs/prompts/animation_audio.yaml') # if config.weight_dtype == "fp16": # weight_dtype = torch.float16 # else: # weight_dtype = torch.float32 # audio_infer_config = OmegaConf.load(config.audio_inference_config) # # prepare model # a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) # a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False) # a2m_model.cuda().eval() # vae = AutoencoderKL.from_pretrained( # config.pretrained_vae_path, # ).to("cuda", dtype=weight_dtype) # reference_unet = UNet2DConditionModel.from_pretrained( # config.pretrained_base_model_path, # subfolder="unet", # ).to(dtype=weight_dtype, device="cuda") # inference_config_path = config.inference_config # infer_config = OmegaConf.load(inference_config_path) # denoising_unet = UNet3DConditionModel.from_pretrained_2d( # config.pretrained_base_model_path, # config.motion_module_path, # subfolder="unet", # unet_additional_kwargs=infer_config.unet_additional_kwargs, # ).to(dtype=weight_dtype, device="cuda") # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention # image_enc = CLIPVisionModelWithProjection.from_pretrained( # config.image_encoder_path # ).to(dtype=weight_dtype, device="cuda") # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) # scheduler = DDIMScheduler(**sched_kwargs) # generator = torch.manual_seed(seed) # width, height = size, size # # load pretrained weights # denoising_unet.load_state_dict( # torch.load(config.denoising_unet_path, map_location="cpu"), # strict=False, # ) # reference_unet.load_state_dict( # torch.load(config.reference_unet_path, map_location="cpu"), # ) # pose_guider.load_state_dict( # torch.load(config.pose_guider_path, map_location="cpu"), # ) # pipe = Pose2VideoPipeline( # vae=vae, # image_encoder=image_enc, # reference_unet=reference_unet, # denoising_unet=denoising_unet, # pose_guider=pose_guider, # scheduler=scheduler, # ) # pipe = pipe.to("cuda", dtype=weight_dtype) # date_str = datetime.now().strftime("%Y%m%d") # time_str = datetime.now().strftime("%H%M") # save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" # save_dir = Path(f"output/{date_str}/{save_dir_name}") # save_dir.mkdir(exist_ok=True, parents=True) # lmk_extractor = LMKExtractor() # vis = FaceMeshVisualizer(forehead_edge=False) # ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) # ref_image_np = crop_face(ref_image_np, lmk_extractor) # if ref_image_np is None: # return None, Image.fromarray(ref_img) # ref_image_np = cv2.resize(ref_image_np, (size, size)) # ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) # face_result = lmk_extractor(ref_image_np) # if face_result is None: # return None, ref_image_pil # lmks = face_result['lmks'].astype(np.float32) # ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) # sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) # sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() # sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) # # inference # pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) # pred = pred.squeeze().detach().cpu().numpy() # pred = pred.reshape(pred.shape[0], -1, 3) # pred = pred + face_result['lmks3d'] # if headpose_video is not None: # pose_seq = get_headpose_temp(headpose_video) # else: # pose_seq = np.load(config['pose_temp']) # mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) # cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] # # project 3D mesh to 2D landmark # projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width]) # pose_images = [] # for i, verts in enumerate(projected_vertices): # lmk_img = vis.draw_landmarks((width, height), verts, normed=False) # pose_images.append(lmk_img) # pose_list = [] # pose_tensor_list = [] # pose_transform = transforms.Compose( # [transforms.Resize((height, width)), transforms.ToTensor()] # ) # args_L = len(pose_images) if length==0 or length > len(pose_images) else length # args_L = min(args_L, 300) # for pose_image_np in pose_images[: args_L]: # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) # pose_tensor_list.append(pose_transform(pose_image_pil)) # pose_image_np = cv2.resize(pose_image_np, (width, height)) # pose_list.append(pose_image_np) # pose_list = np.array(pose_list) # video_length = len(pose_tensor_list) # video = pipe( # ref_image_pil, # pose_list, # ref_pose, # width, # height, # video_length, # steps, # cfg, # generator=generator, # ).videos # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" # save_videos_grid( # video, # save_path, # n_rows=1, # fps=fps, # ) # stream = ffmpeg.input(save_path) # audio = ffmpeg.input(input_audio) # ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run() # os.remove(save_path) # return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil