zejunyang commited on
Commit
f1d69c2
1 Parent(s): 8749423

add frame interpolation function

Browse files
Files changed (1) hide show
  1. app.py +344 -7
app.py CHANGED
@@ -1,9 +1,348 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- from src.audio2vid import audio2video
4
- from src.vid2vid import video2video
 
 
 
 
5
 
6
- # from src.create_modules import Processer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  title = r"""
9
  <h1>AniPortrait</h1>
@@ -13,8 +352,6 @@ description = r"""
13
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
14
  """
15
 
16
- # main_processer = Processer()
17
-
18
  with gr.Blocks() as demo:
19
 
20
  gr.Markdown(title)
@@ -33,7 +370,7 @@ with gr.Blocks() as demo:
33
  a2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
34
 
35
  with gr.Row():
36
- a2v_length = gr.Slider(minimum=0, maximum=300, step=1, value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
37
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
38
 
39
  a2v_botton = gr.Button("Generate", variant="primary")
@@ -61,7 +398,7 @@ with gr.Blocks() as demo:
61
  v2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
62
 
63
  with gr.Row():
64
- v2v_length = gr.Slider(minimum=0, maximum=300, step=1, value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
65
  v2v_seed = gr.Number(value=42, label="Seed (--seed)")
66
 
67
  v2v_botton = gr.Button("Generate", variant="primary")
 
1
  import gradio as gr
2
+ import os
3
+ import shutil
4
+ import ffmpeg
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ import spaces
11
 
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
 
19
+ from src.models.pose_guider import PoseGuider
20
+ from src.models.unet_2d_condition import UNet2DConditionModel
21
+ from src.models.unet_3d import UNet3DConditionModel
22
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
23
+ from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
24
+
25
+ from src.audio_models.model import Audio2MeshModel
26
+ from src.utils.audio_util import prepare_audio_feature
27
+ from src.utils.mp_utils import LMKExtractor
28
+ from src.utils.draw_util import FaceMeshVisualizer
29
+ from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
30
+ from src.utils.crop_face_single import crop_face
31
+ from src.audio2vid import get_headpose_temp, smooth_pose_seq
32
+ from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
33
+
34
+
35
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
36
+ if config.weight_dtype == "fp16":
37
+ weight_dtype = torch.float16
38
+ else:
39
+ weight_dtype = torch.float32
40
+
41
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
42
+ # prepare model
43
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
44
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
45
+ a2m_model.cuda().eval()
46
+
47
+ vae = AutoencoderKL.from_pretrained(
48
+ config.pretrained_vae_path,
49
+ ).to("cuda", dtype=weight_dtype)
50
+
51
+ reference_unet = UNet2DConditionModel.from_pretrained(
52
+ config.pretrained_base_model_path,
53
+ subfolder="unet",
54
+ ).to(dtype=weight_dtype, device="cuda")
55
+
56
+ inference_config_path = config.inference_config
57
+ infer_config = OmegaConf.load(inference_config_path)
58
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
59
+ config.pretrained_base_model_path,
60
+ config.motion_module_path,
61
+ subfolder="unet",
62
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
63
+ ).to(dtype=weight_dtype, device="cuda")
64
+
65
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
66
+
67
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
68
+ config.image_encoder_path
69
+ ).to(dtype=weight_dtype, device="cuda")
70
+
71
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
72
+ scheduler = DDIMScheduler(**sched_kwargs)
73
+
74
+ # load pretrained weights
75
+ denoising_unet.load_state_dict(
76
+ torch.load(config.denoising_unet_path, map_location="cpu"),
77
+ strict=False,
78
+ )
79
+ reference_unet.load_state_dict(
80
+ torch.load(config.reference_unet_path, map_location="cpu"),
81
+ )
82
+ pose_guider.load_state_dict(
83
+ torch.load(config.pose_guider_path, map_location="cpu"),
84
+ )
85
+
86
+ pipe = Pose2VideoPipeline(
87
+ vae=vae,
88
+ image_encoder=image_enc,
89
+ reference_unet=reference_unet,
90
+ denoising_unet=denoising_unet,
91
+ pose_guider=pose_guider,
92
+ scheduler=scheduler,
93
+ )
94
+ pipe = pipe.to("cuda", dtype=weight_dtype)
95
+
96
+ lmk_extractor = LMKExtractor()
97
+ vis = FaceMeshVisualizer()
98
+
99
+ frame_inter_model = init_frame_interpolation_model()
100
+
101
+ @spaces.GPU(duration=200)
102
+ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42):
103
+ fps = 30
104
+ cfg = 3.5
105
+
106
+ generator = torch.manual_seed(seed)
107
+
108
+ width, height = size, size
109
+
110
+ date_str = datetime.now().strftime("%Y%m%d")
111
+ time_str = datetime.now().strftime("%H%M")
112
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
113
+
114
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
115
+ save_dir.mkdir(exist_ok=True, parents=True)
116
+
117
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
118
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
119
+ if ref_image_np is None:
120
+ return None, Image.fromarray(ref_img)
121
+
122
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
123
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
124
+
125
+ face_result = lmk_extractor(ref_image_np)
126
+ if face_result is None:
127
+ return None, ref_image_pil
128
+
129
+ lmks = face_result['lmks'].astype(np.float32)
130
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
131
+
132
+ sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
133
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
134
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
135
+
136
+ # inference
137
+ pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
138
+ pred = pred.squeeze().detach().cpu().numpy()
139
+ pred = pred.reshape(pred.shape[0], -1, 3)
140
+ pred = pred + face_result['lmks3d']
141
+
142
+ if headpose_video is not None:
143
+ pose_seq = get_headpose_temp(headpose_video)
144
+ else:
145
+ pose_seq = np.load(config['pose_temp'])
146
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
147
+ cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
148
+
149
+ # project 3D mesh to 2D landmark
150
+ projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
151
+
152
+ pose_images = []
153
+ for i, verts in enumerate(projected_vertices):
154
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
155
+ pose_images.append(lmk_img)
156
+
157
+ pose_list = []
158
+ # pose_tensor_list = []
159
+
160
+ # pose_transform = transforms.Compose(
161
+ # [transforms.Resize((height, width)), transforms.ToTensor()]
162
+ # )
163
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
164
+ args_L = min(args_L, 180)
165
+ for pose_image_np in pose_images[: args_L : 2]:
166
+ # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
167
+ # pose_tensor_list.append(pose_transform(pose_image_pil))
168
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
169
+ pose_list.append(pose_image_np)
170
+
171
+ pose_list = np.array(pose_list)
172
+
173
+ video_length = len(pose_list)
174
+
175
+ video = pipe(
176
+ ref_image_pil,
177
+ pose_list,
178
+ ref_pose,
179
+ width,
180
+ height,
181
+ video_length,
182
+ steps,
183
+ cfg,
184
+ generator=generator,
185
+ ).videos
186
+
187
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
188
+ # save_videos_grid(
189
+ # video,
190
+ # save_path,
191
+ # n_rows=1,
192
+ # fps=fps,
193
+ # )
194
+
195
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
196
+ save_pil_imgs(video, save_path)
197
+
198
+ save_path = batch_images_interpolation_tool(save_path, frame_inter_model, fps)
199
+
200
+ stream = ffmpeg.input(save_path)
201
+ audio = ffmpeg.input(input_audio)
202
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
203
+ os.remove(save_path)
204
+
205
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
206
+
207
+ @spaces.GPU(duration=200)
208
+ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
209
+ cfg = 3.5
210
+
211
+ generator = torch.manual_seed(seed)
212
+
213
+ width, height = size, size
214
+
215
+ date_str = datetime.now().strftime("%Y%m%d")
216
+ time_str = datetime.now().strftime("%H%M")
217
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
218
+
219
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
220
+ save_dir.mkdir(exist_ok=True, parents=True)
221
+
222
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
223
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
224
+ if ref_image_np is None:
225
+ return None, Image.fromarray(ref_img)
226
+
227
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
228
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
229
+
230
+ face_result = lmk_extractor(ref_image_np)
231
+ if face_result is None:
232
+ return None, ref_image_pil
233
+
234
+ lmks = face_result['lmks'].astype(np.float32)
235
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
236
+
237
+ source_images = read_frames(source_video)
238
+ src_fps = get_fps(source_video)
239
+ pose_transform = transforms.Compose(
240
+ [transforms.Resize((height, width)), transforms.ToTensor()]
241
+ )
242
+
243
+ step = 1
244
+ if src_fps == 60:
245
+ src_fps = 30
246
+ step = 2
247
+
248
+ pose_trans_list = []
249
+ verts_list = []
250
+ bs_list = []
251
+ src_tensor_list = []
252
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
253
+ args_L = min(args_L, 180*step)
254
+ for src_image_pil in source_images[: args_L : step*2]:
255
+ src_tensor_list.append(pose_transform(src_image_pil))
256
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
257
+ frame_height, frame_width, _ = src_img_np.shape
258
+ src_img_result = lmk_extractor(src_img_np)
259
+ if src_img_result is None:
260
+ break
261
+ pose_trans_list.append(src_img_result['trans_mat'])
262
+ verts_list.append(src_img_result['lmks3d'])
263
+ bs_list.append(src_img_result['bs'])
264
+
265
+ trans_mat_arr = np.array(pose_trans_list)
266
+ verts_arr = np.array(verts_list)
267
+ bs_arr = np.array(bs_list)
268
+ min_bs_idx = np.argmin(bs_arr.sum(1))
269
+
270
+ # compute delta pose
271
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
272
+
273
+ for i in range(pose_arr.shape[0]):
274
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
275
+ pose_arr[i, :3] = euler_angles
276
+ pose_arr[i, 3:6] = translation_vector
277
+
278
+ init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
279
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
280
+
281
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
282
+ pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
283
+ pose_mat_smooth = np.array(pose_mat_smooth)
284
+
285
+ # face retarget
286
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
287
+ # project 3D mesh to 2D landmark
288
+ projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
289
+
290
+ pose_list = []
291
+ for i, verts in enumerate(projected_vertices):
292
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
293
+ pose_image_np = cv2.resize(lmk_img, (width, height))
294
+ pose_list.append(pose_image_np)
295
+
296
+ pose_list = np.array(pose_list)
297
+
298
+ video_length = len(pose_list)
299
+
300
+ video = pipe(
301
+ ref_image_pil,
302
+ pose_list,
303
+ ref_pose,
304
+ width,
305
+ height,
306
+ video_length,
307
+ steps,
308
+ cfg,
309
+ generator=generator,
310
+ ).videos
311
+
312
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
313
+ # save_videos_grid(
314
+ # video,
315
+ # save_path,
316
+ # n_rows=1,
317
+ # fps=src_fps,
318
+ # )
319
+
320
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
321
+ save_pil_imgs(video, save_path)
322
+
323
+ save_path = batch_images_interpolation_tool(save_path, frame_inter_model, src_fps)
324
+
325
+ audio_output = f'{save_dir}/audio_from_video.aac'
326
+ # extract audio
327
+ try:
328
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
329
+ # merge audio and video
330
+ stream = ffmpeg.input(save_path)
331
+ audio = ffmpeg.input(audio_output)
332
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
333
+
334
+ os.remove(save_path)
335
+ os.remove(audio_output)
336
+ except:
337
+ shutil.move(
338
+ save_path,
339
+ save_path.replace('_noaudio.mp4', '.mp4')
340
+ )
341
+
342
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
343
+
344
+
345
+ ################# GUI ################
346
 
347
  title = r"""
348
  <h1>AniPortrait</h1>
 
352
  <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
353
  """
354
 
 
 
355
  with gr.Blocks() as demo:
356
 
357
  gr.Markdown(title)
 
370
  a2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
371
 
372
  with gr.Row():
373
+ a2v_length = gr.Slider(minimum=0, maximum=180, step=1, value=60, label="Length (-L) (Set 0 to automatically calculate video length.)")
374
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
375
 
376
  a2v_botton = gr.Button("Generate", variant="primary")
 
398
  v2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
399
 
400
  with gr.Row():
401
+ v2v_length = gr.Slider(minimum=0, maximum=180, step=1, value=60, label="Length (-L) (Set 0 to automatically calculate video length.)")
402
  v2v_seed = gr.Number(value=42, label="Seed (--seed)")
403
 
404
  v2v_botton = gr.Button("Generate", variant="primary")