zejunyang commited on
Commit
d947e9b
β€’
1 Parent(s): 727741c
Files changed (2) hide show
  1. app.py +8 -4
  2. src/create_modules.py +372 -69
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
 
3
- from src.audio2vid import audio2video
4
- from src.vid2vid import video2video
 
 
5
 
6
  title = r"""
7
  <h1>AniPortrait</h1>
@@ -11,6 +13,8 @@ description = r"""
11
  <b>Official πŸ€— Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
12
  """
13
 
 
 
14
  with gr.Blocks() as demo:
15
 
16
  gr.Markdown(title)
@@ -73,13 +77,13 @@ with gr.Blocks() as demo:
73
  )
74
 
75
  a2v_botton.click(
76
- fn=audio2video,
77
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
78
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
79
  outputs=[a2v_output_video, a2v_ref_img]
80
  )
81
  v2v_botton.click(
82
- fn=video2video,
83
  inputs=[v2v_ref_img, v2v_source_video,
84
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
85
  outputs=[v2v_output_video, v2v_ref_img]
 
1
  import gradio as gr
2
 
3
+ # from src.audio2vid import audio2video
4
+ # from src.vid2vid import video2video
5
+
6
+ from src.create_modules import Processer
7
 
8
  title = r"""
9
  <h1>AniPortrait</h1>
 
13
  <b>Official πŸ€— Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
14
  """
15
 
16
+ main_processer = Processer()
17
+
18
  with gr.Blocks() as demo:
19
 
20
  gr.Markdown(title)
 
77
  )
78
 
79
  a2v_botton.click(
80
+ fn=main_processer.audio2video,
81
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
82
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
83
  outputs=[a2v_output_video, a2v_ref_img]
84
  )
85
  v2v_botton.click(
86
+ fn=main_processer.video2video,
87
  inputs=[v2v_ref_img, v2v_source_video,
88
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
89
  outputs=[v2v_output_video, v2v_ref_img]
src/create_modules.py CHANGED
@@ -4,93 +4,396 @@ from datetime import datetime
4
  from pathlib import Path
5
  import numpy as np
6
  import cv2
 
 
7
  import torch
 
 
8
  from scipy.spatial.transform import Rotation as R
9
  from scipy.interpolate import interp1d
 
10
 
11
  from diffusers import AutoencoderKL, DDIMScheduler
12
- from einops import repeat
13
  from omegaconf import OmegaConf
14
- from PIL import Image
15
- from torchvision import transforms
16
  from transformers import CLIPVisionModelWithProjection
17
 
18
-
19
  from src.models.pose_guider import PoseGuider
20
  from src.models.unet_2d_condition import UNet2DConditionModel
21
  from src.models.unet_3d import UNet3DConditionModel
22
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
23
- from src.utils.util import save_videos_grid
24
 
25
  from src.audio_models.model import Audio2MeshModel
26
- from src.utils.audio_util import prepare_audio_feature
27
  from src.utils.mp_utils import LMKExtractor
28
  from src.utils.draw_util import FaceMeshVisualizer
29
- from src.utils.pose_util import project_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
31
 
32
- lmk_extractor = LMKExtractor()
33
- vis = FaceMeshVisualizer(forehead_edge=False)
 
 
34
 
35
- config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
 
 
 
36
 
37
- if config.weight_dtype == "fp16":
38
- weight_dtype = torch.float16
39
- else:
40
- weight_dtype = torch.float32
41
 
42
- audio_infer_config = OmegaConf.load(config.audio_inference_config)
43
- # prepare model
44
- a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
45
- a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
46
- a2m_model.cuda().eval()
47
-
48
- vae = AutoencoderKL.from_pretrained(
49
- config.pretrained_vae_path,
50
- ).to("cuda", dtype=weight_dtype)
51
-
52
- reference_unet = UNet2DConditionModel.from_pretrained(
53
- config.pretrained_base_model_path,
54
- subfolder="unet",
55
- ).to(dtype=weight_dtype, device="cuda")
56
-
57
- inference_config_path = config.inference_config
58
- infer_config = OmegaConf.load(inference_config_path)
59
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
- config.pretrained_base_model_path,
61
- config.motion_module_path,
62
- subfolder="unet",
63
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
64
- ).to(dtype=weight_dtype, device="cuda")
65
-
66
-
67
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
68
-
69
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
70
- config.image_encoder_path
71
- ).to(dtype=weight_dtype, device="cuda")
72
-
73
- sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
74
- scheduler = DDIMScheduler(**sched_kwargs)
75
-
76
- # load pretrained weights
77
- denoising_unet.load_state_dict(
78
- torch.load(config.denoising_unet_path, map_location="cpu"),
79
- strict=False,
80
- )
81
- reference_unet.load_state_dict(
82
- torch.load(config.reference_unet_path, map_location="cpu"),
83
- )
84
- pose_guider.load_state_dict(
85
- torch.load(config.pose_guider_path, map_location="cpu"),
86
- )
87
-
88
- pipe = Pose2VideoPipeline(
89
- vae=vae,
90
- image_encoder=image_enc,
91
- reference_unet=reference_unet,
92
- denoising_unet=denoising_unet,
93
- pose_guider=pose_guider,
94
- scheduler=scheduler,
95
- )
96
- pipe = pipe.to("cuda", dtype=weight_dtype)
 
4
  from pathlib import Path
5
  import numpy as np
6
  import cv2
7
+ import spaces
8
+ import shutil
9
  import torch
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
  from scipy.spatial.transform import Rotation as R
13
  from scipy.interpolate import interp1d
14
+ from torchvision import transforms
15
 
16
  from diffusers import AutoencoderKL, DDIMScheduler
 
17
  from omegaconf import OmegaConf
 
 
18
  from transformers import CLIPVisionModelWithProjection
19
 
 
20
  from src.models.pose_guider import PoseGuider
21
  from src.models.unet_2d_condition import UNet2DConditionModel
22
  from src.models.unet_3d import UNet3DConditionModel
23
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
 
24
 
25
  from src.audio_models.model import Audio2MeshModel
 
26
  from src.utils.mp_utils import LMKExtractor
27
  from src.utils.draw_util import FaceMeshVisualizer
28
+ from src.utils.util import get_fps, read_frames, save_videos_grid
29
+
30
+ from src.utils.audio_util import prepare_audio_feature
31
+ from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix, project_points
32
+ from src.utils.crop_face_single import crop_face
33
+
34
+ class Processer():
35
+ def __init__(self):
36
+ self.create_models()
37
+
38
+ def create_models(self):
39
+
40
+ self.lmk_extractor = LMKExtractor()
41
+ self.vis = FaceMeshVisualizer(forehead_edge=False)
42
+
43
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
44
+
45
+ if config.weight_dtype == "fp16":
46
+ weight_dtype = torch.float16
47
+ else:
48
+ weight_dtype = torch.float32
49
+
50
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
51
+ # prepare model
52
+ self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
53
+ self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
54
+ self.a2m_model.cuda().eval()
55
+
56
+ self.vae = AutoencoderKL.from_pretrained(
57
+ config.pretrained_vae_path,
58
+ ).to("cuda", dtype=weight_dtype)
59
+
60
+ self.reference_unet = UNet2DConditionModel.from_pretrained(
61
+ config.pretrained_base_model_path,
62
+ subfolder="unet",
63
+ ).to(dtype=weight_dtype, device="cuda")
64
+
65
+ inference_config_path = config.inference_config
66
+ infer_config = OmegaConf.load(inference_config_path)
67
+ self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
68
+ config.pretrained_base_model_path,
69
+ config.motion_module_path,
70
+ subfolder="unet",
71
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
72
+ ).to(dtype=weight_dtype, device="cuda")
73
+
74
+ self.pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
75
+
76
+ self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
77
+ config.image_encoder_path
78
+ ).to(dtype=weight_dtype, device="cuda")
79
+
80
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
81
+ self.scheduler = DDIMScheduler(**sched_kwargs)
82
+
83
+ # load pretrained weights
84
+ self.denoising_unet.load_state_dict(
85
+ torch.load(config.denoising_unet_path, map_location="cpu"),
86
+ strict=False,
87
+ )
88
+ self.reference_unet.load_state_dict(
89
+ torch.load(config.reference_unet_path, map_location="cpu"),
90
+ )
91
+ self.pose_guider.load_state_dict(
92
+ torch.load(config.pose_guider_path, map_location="cpu"),
93
+ )
94
+
95
+ self.pipe = Pose2VideoPipeline(
96
+ vae=self.vae,
97
+ image_encoder=self.image_enc,
98
+ reference_unet=self.reference_unet,
99
+ denoising_unet=self.denoising_unet,
100
+ pose_guider=self.pose_guider,
101
+ scheduler=self.scheduler,
102
+ )
103
+ self.pipe = self.pipe.to("cuda", dtype=weight_dtype)
104
+
105
+
106
+ @spaces.GPU
107
+ def audio2video(self, input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
108
+ fps = 30
109
+ cfg = 3.5
110
+
111
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
112
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
113
+ generator = torch.manual_seed(seed)
114
+
115
+ width, height = size, size
116
+
117
+ date_str = datetime.now().strftime("%Y%m%d")
118
+ time_str = datetime.now().strftime("%H%M")
119
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
120
+
121
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
122
+ save_dir.mkdir(exist_ok=True, parents=True)
123
+
124
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
125
+ ref_image_np = crop_face(ref_image_np, self.lmk_extractor)
126
+ if ref_image_np is None:
127
+ return None, Image.fromarray(ref_img)
128
+
129
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
130
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
131
+
132
+ face_result = self.lmk_extractor(ref_image_np)
133
+ if face_result is None:
134
+ return None, ref_image_pil
135
+
136
+ lmks = face_result['lmks'].astype(np.float32)
137
+ ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
138
+
139
+ sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
140
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
141
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
142
+
143
+ # inference
144
+ pred = self.a2m_model.infer(sample['audio_feature'], sample['seq_len'])
145
+ pred = pred.squeeze().detach().cpu().numpy()
146
+ pred = pred.reshape(pred.shape[0], -1, 3)
147
+ pred = pred + face_result['lmks3d']
148
+
149
+ if headpose_video is not None:
150
+ pose_seq = get_headpose_temp(headpose_video, self.lmk_extractor)
151
+ else:
152
+ pose_seq = np.load(config['pose_temp'])
153
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
154
+ cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
155
+
156
+ # project 3D mesh to 2D landmark
157
+ projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
158
+
159
+ pose_images = []
160
+ for i, verts in enumerate(projected_vertices):
161
+ lmk_img = self.vis.draw_landmarks((width, height), verts, normed=False)
162
+ pose_images.append(lmk_img)
163
+
164
+ pose_list = []
165
+ pose_tensor_list = []
166
+
167
+ pose_transform = transforms.Compose(
168
+ [transforms.Resize((height, width)), transforms.ToTensor()]
169
+ )
170
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
171
+ args_L = min(args_L, 300)
172
+ for pose_image_np in pose_images[: args_L]:
173
+ pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
174
+ pose_tensor_list.append(pose_transform(pose_image_pil))
175
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
176
+ pose_list.append(pose_image_np)
177
+
178
+ pose_list = np.array(pose_list)
179
+
180
+ video_length = len(pose_tensor_list)
181
+
182
+ video = self.pipe(
183
+ ref_image_pil,
184
+ pose_list,
185
+ ref_pose,
186
+ width,
187
+ height,
188
+ video_length,
189
+ steps,
190
+ cfg,
191
+ generator=generator,
192
+ ).videos
193
+
194
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
195
+ save_videos_grid(
196
+ video,
197
+ save_path,
198
+ n_rows=1,
199
+ fps=fps,
200
+ )
201
+
202
+ stream = ffmpeg.input(save_path)
203
+ audio = ffmpeg.input(input_audio)
204
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
205
+ os.remove(save_path)
206
+
207
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
208
+
209
+ @spaces.GPU
210
+ def video2video(self, ref_img, source_video, size=512, steps=25, length=150, seed=42):
211
+ cfg = 3.5
212
+
213
+ generator = torch.manual_seed(seed)
214
+ width, height = size, size
215
+
216
+ date_str = datetime.now().strftime("%Y%m%d")
217
+ time_str = datetime.now().strftime("%H%M")
218
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
219
+
220
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
221
+ save_dir.mkdir(exist_ok=True, parents=True)
222
+
223
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
224
+ ref_image_np = crop_face(ref_image_np, self.lmk_extractor)
225
+ if ref_image_np is None:
226
+ return None, Image.fromarray(ref_img)
227
+
228
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
229
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
230
+
231
+ face_result = self.lmk_extractor(ref_image_np)
232
+ if face_result is None:
233
+ return None, ref_image_pil
234
+
235
+ lmks = face_result['lmks'].astype(np.float32)
236
+ ref_pose = self.vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
237
+
238
+ source_images = read_frames(source_video)
239
+ src_fps = get_fps(source_video)
240
+ pose_transform = transforms.Compose(
241
+ [transforms.Resize((height, width)), transforms.ToTensor()]
242
+ )
243
+
244
+ step = 1
245
+ if src_fps == 60:
246
+ src_fps = 30
247
+ step = 2
248
+
249
+ pose_trans_list = []
250
+ verts_list = []
251
+ bs_list = []
252
+ src_tensor_list = []
253
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
254
+ args_L = min(args_L, 300*step)
255
+ for src_image_pil in source_images[: args_L: step]:
256
+ src_tensor_list.append(pose_transform(src_image_pil))
257
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
258
+ frame_height, frame_width, _ = src_img_np.shape
259
+ src_img_result = self.lmk_extractor(src_img_np)
260
+ if src_img_result is None:
261
+ break
262
+ pose_trans_list.append(src_img_result['trans_mat'])
263
+ verts_list.append(src_img_result['lmks3d'])
264
+ bs_list.append(src_img_result['bs'])
265
+
266
+ trans_mat_arr = np.array(pose_trans_list)
267
+ verts_arr = np.array(verts_list)
268
+ bs_arr = np.array(bs_list)
269
+ min_bs_idx = np.argmin(bs_arr.sum(1))
270
+
271
+ # compute delta pose
272
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
273
+
274
+ for i in range(pose_arr.shape[0]):
275
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
276
+ pose_arr[i, :3] = euler_angles
277
+ pose_arr[i, 3:6] = translation_vector
278
+
279
+ init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
280
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
281
+
282
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
283
+ pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
284
+ pose_mat_smooth = np.array(pose_mat_smooth)
285
+
286
+ # face retarget
287
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
288
+ # project 3D mesh to 2D landmark
289
+ projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
290
+
291
+ pose_list = []
292
+ for i, verts in enumerate(projected_vertices):
293
+ lmk_img = self.vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
294
+ pose_image_np = cv2.resize(lmk_img, (width, height))
295
+ pose_list.append(pose_image_np)
296
+
297
+ pose_list = np.array(pose_list)
298
+
299
+ video_length = len(pose_list)
300
+
301
+ video = self.pipe(
302
+ ref_image_pil,
303
+ pose_list,
304
+ ref_pose,
305
+ width,
306
+ height,
307
+ video_length,
308
+ steps,
309
+ cfg,
310
+ generator=generator,
311
+ ).videos
312
+
313
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
314
+ save_videos_grid(
315
+ video,
316
+ save_path,
317
+ n_rows=1,
318
+ fps=src_fps,
319
+ )
320
+
321
+ audio_output = f'{save_dir}/audio_from_video.aac'
322
+ # extract audio
323
+ try:
324
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
325
+ # merge audio and video
326
+ stream = ffmpeg.input(save_path)
327
+ audio = ffmpeg.input(audio_output)
328
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
329
+
330
+ os.remove(save_path)
331
+ os.remove(audio_output)
332
+ except:
333
+ shutil.move(
334
+ save_path,
335
+ save_path.replace('_noaudio.mp4', '.mp4')
336
+ )
337
+
338
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
339
+
340
+
341
+ def matrix_to_euler_and_translation(matrix):
342
+ rotation_matrix = matrix[:3, :3]
343
+ translation_vector = matrix[:3, 3]
344
+ rotation = R.from_matrix(rotation_matrix)
345
+ euler_angles = rotation.as_euler('xyz', degrees=True)
346
+ return euler_angles, translation_vector
347
+
348
+
349
+ def smooth_pose_seq(pose_seq, window_size=5):
350
+ smoothed_pose_seq = np.zeros_like(pose_seq)
351
+
352
+ for i in range(len(pose_seq)):
353
+ start = max(0, i - window_size // 2)
354
+ end = min(len(pose_seq), i + window_size // 2 + 1)
355
+ smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0)
356
+
357
+ return smoothed_pose_seq
358
+
359
+ def get_headpose_temp(input_video, lmk_extractor):
360
+ cap = cv2.VideoCapture(input_video)
361
+
362
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
363
+ fps = cap.get(cv2.CAP_PROP_FPS)
364
+
365
+ trans_mat_list = []
366
+ while cap.isOpened():
367
+ ret, frame = cap.read()
368
+ if not ret:
369
+ break
370
+
371
+ result = lmk_extractor(frame)
372
+ trans_mat_list.append(result['trans_mat'].astype(np.float32))
373
+ cap.release()
374
+
375
+ trans_mat_arr = np.array(trans_mat_list)
376
+
377
+ # compute delta pose
378
+ trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
379
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
380
 
381
+ for i in range(pose_arr.shape[0]):
382
+ pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
383
+ euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
384
+ pose_arr[i, :3] = euler_angles
385
+ pose_arr[i, 3:6] = translation_vector
386
 
387
+ # interpolate to 30 fps
388
+ new_fps = 30
389
+ old_time = np.linspace(0, total_frames / fps, total_frames)
390
+ new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps))
391
 
392
+ pose_arr_interp = np.zeros((len(new_time), 6))
393
+ for i in range(6):
394
+ interp_func = interp1d(old_time, pose_arr[:, i])
395
+ pose_arr_interp[:, i] = interp_func(new_time)
396
 
397
+ pose_arr_smooth = smooth_pose_seq(pose_arr_interp)
 
 
 
398
 
399
+ return pose_arr_smooth