kimjy0411 commited on
Commit
4e83425
·
verified ·
1 Parent(s): 1e1148c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +494 -0
app.py CHANGED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import ffmpeg
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import random
10
+ import torch
11
+
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+ from scipy.interpolate import interp1d
19
+
20
+ from src.models.pose_guider import PoseGuider
21
+ from src.models.unet_2d_condition import UNet2DConditionModel
22
+ from src.models.unet_3d import UNet3DConditionModel
23
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
+ from src.utils.util import get_fps, read_frames, save_videos_grid
25
+
26
+ from src.audio_models.model import Audio2MeshModel
27
+ from src.audio_models.pose_model import Audio2PoseModel
28
+ from src.utils.audio_util import prepare_audio_feature
29
+ from src.utils.mp_utils import LMKExtractor
30
+ from src.utils.draw_util import FaceMeshVisualizer
31
+ from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix, smooth_pose_seq
32
+ from src.utils.util import crop_face
33
+ from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
34
+
35
+
36
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
37
+ if config.weight_dtype == "fp16":
38
+ weight_dtype = torch.float16
39
+ else:
40
+ weight_dtype = torch.float32
41
+
42
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
43
+ # prepare model
44
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
45
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
46
+ a2m_model.cuda().eval()
47
+
48
+ a2p_model = Audio2PoseModel(audio_infer_config['a2p_model'])
49
+ a2p_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2p_ckpt']), strict=False)
50
+ a2p_model.cuda().eval()
51
+
52
+ vae = AutoencoderKL.from_pretrained(
53
+ config.pretrained_vae_path,
54
+ ).to("cuda", dtype=weight_dtype)
55
+
56
+ reference_unet = UNet2DConditionModel.from_pretrained(
57
+ config.pretrained_base_model_path,
58
+ subfolder="unet",
59
+ ).to(dtype=weight_dtype, device="cuda")
60
+
61
+ inference_config_path = config.inference_config
62
+ infer_config = OmegaConf.load(inference_config_path)
63
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
64
+ config.pretrained_base_model_path,
65
+ config.motion_module_path,
66
+ subfolder="unet",
67
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
68
+ ).to(dtype=weight_dtype, device="cuda")
69
+
70
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
71
+
72
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
73
+ config.image_encoder_path
74
+ ).to(dtype=weight_dtype, device="cuda")
75
+
76
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
77
+ scheduler = DDIMScheduler(**sched_kwargs)
78
+
79
+ # load pretrained weights
80
+ denoising_unet.load_state_dict(
81
+ torch.load(config.denoising_unet_path, map_location="cpu"),
82
+ strict=False,
83
+ )
84
+ reference_unet.load_state_dict(
85
+ torch.load(config.reference_unet_path, map_location="cpu"),
86
+ )
87
+ pose_guider.load_state_dict(
88
+ torch.load(config.pose_guider_path, map_location="cpu"),
89
+ )
90
+
91
+ pipe = Pose2VideoPipeline(
92
+ vae=vae,
93
+ image_encoder=image_enc,
94
+ reference_unet=reference_unet,
95
+ denoising_unet=denoising_unet,
96
+ pose_guider=pose_guider,
97
+ scheduler=scheduler,
98
+ )
99
+ pipe = pipe.to("cuda", dtype=weight_dtype)
100
+
101
+ frame_inter_model = init_frame_interpolation_model()
102
+
103
+ def get_headpose_temp(input_video):
104
+ lmk_extractor = LMKExtractor()
105
+ cap = cv2.VideoCapture(input_video)
106
+
107
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
108
+ fps = cap.get(cv2.CAP_PROP_FPS)
109
+
110
+ trans_mat_list = []
111
+ while cap.isOpened():
112
+ ret, frame = cap.read()
113
+ if not ret:
114
+ break
115
+
116
+ result = lmk_extractor(frame)
117
+ trans_mat_list.append(result['trans_mat'].astype(np.float32))
118
+ cap.release()
119
+
120
+ trans_mat_arr = np.array(trans_mat_list)
121
+
122
+ # compute delta pose
123
+ trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0])
124
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
125
+
126
+ for i in range(pose_arr.shape[0]):
127
+ pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i]
128
+ euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat)
129
+ pose_arr[i, :3] = euler_angles
130
+ pose_arr[i, 3:6] = translation_vector
131
+
132
+ # interpolate to 30 fps
133
+ new_fps = 30
134
+ old_time = np.linspace(0, total_frames / fps, total_frames)
135
+ new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps))
136
+
137
+ pose_arr_interp = np.zeros((len(new_time), 6))
138
+ for i in range(6):
139
+ interp_func = interp1d(old_time, pose_arr[:, i])
140
+ pose_arr_interp[:, i] = interp_func(new_time)
141
+
142
+ pose_arr_smooth = smooth_pose_seq(pose_arr_interp)
143
+
144
+ return pose_arr_smooth
145
+
146
+ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=60, seed=42, acc_flag=True):
147
+ fps = 30
148
+ cfg = 3.5
149
+ fi_step = 3 if acc_flag else 1
150
+
151
+ generator = torch.manual_seed(seed)
152
+
153
+ lmk_extractor = LMKExtractor()
154
+ vis = FaceMeshVisualizer()
155
+
156
+ width, height = size, size
157
+
158
+ date_str = datetime.now().strftime("%Y%m%d")
159
+ time_str = datetime.now().strftime("%H%M")
160
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
161
+
162
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
163
+ while os.path.exists(save_dir):
164
+ save_dir = Path(f"output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}")
165
+ save_dir.mkdir(exist_ok=True, parents=True)
166
+
167
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
168
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
169
+ if ref_image_np is None:
170
+ return None, Image.fromarray(ref_img)
171
+
172
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
173
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
174
+
175
+ face_result = lmk_extractor(ref_image_np)
176
+ if face_result is None:
177
+ return None, ref_image_pil
178
+
179
+ lmks = face_result['lmks'].astype(np.float32)
180
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
181
+
182
+ sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
183
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
184
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
185
+
186
+ # inference
187
+ pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
188
+ pred = pred.squeeze().detach().cpu().numpy()
189
+ pred = pred.reshape(pred.shape[0], -1, 3)
190
+ pred = pred + face_result['lmks3d']
191
+
192
+ if headpose_video is not None:
193
+ pose_seq = get_headpose_temp(headpose_video)
194
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
195
+ pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
196
+ else:
197
+ id_seed = random.randint(0, 99)
198
+ id_seed = torch.LongTensor([id_seed]).cuda()
199
+
200
+ # Currently, only inference up to a maximum length of 10 seconds is supported.
201
+ chunk_duration = 5 # 5 seconds
202
+ sr = 16000
203
+ fps = 30
204
+ chunk_size = sr * chunk_duration
205
+
206
+ audio_chunks = list(sample['audio_feature'].split(chunk_size, dim=1))
207
+ seq_len_list = [chunk_duration*fps] * (len(audio_chunks) - 1) + [sample['seq_len'] % (chunk_duration*fps)] # 30 fps
208
+
209
+ audio_chunks[-2] = torch.cat((audio_chunks[-2], audio_chunks[-1]), dim=1)
210
+ seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1]
211
+ del audio_chunks[-1]
212
+ del seq_len_list[-1]
213
+
214
+ pose_seq = []
215
+ for audio, seq_len in zip(audio_chunks, seq_len_list):
216
+ pose_seq_chunk = a2p_model.infer(audio, seq_len, id_seed)
217
+ pose_seq_chunk = pose_seq_chunk.squeeze().detach().cpu().numpy()
218
+ pose_seq_chunk[:, :3] *= 0.5
219
+ pose_seq.append(pose_seq_chunk)
220
+
221
+ pose_seq = np.concatenate(pose_seq, 0)
222
+ pose_seq = smooth_pose_seq(pose_seq, 7)
223
+
224
+ # project 3D mesh to 2D landmark
225
+ projected_vertices = project_points(pred, face_result['trans_mat'], pose_seq, [height, width])
226
+
227
+ pose_images = []
228
+ for i, verts in enumerate(projected_vertices):
229
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
230
+ pose_images.append(lmk_img)
231
+
232
+ pose_list = []
233
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
234
+ for pose_image_np in pose_images[: args_L : fi_step]:
235
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
236
+ pose_list.append(pose_image_np)
237
+
238
+ pose_list = np.array(pose_list)
239
+
240
+ video_length = len(pose_list)
241
+
242
+ video = pipe(
243
+ ref_image_pil,
244
+ pose_list,
245
+ ref_pose,
246
+ width,
247
+ height,
248
+ video_length,
249
+ steps,
250
+ cfg,
251
+ generator=generator,
252
+ ).videos
253
+
254
+ if acc_flag:
255
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
256
+
257
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
258
+ save_videos_grid(
259
+ video,
260
+ save_path,
261
+ n_rows=1,
262
+ fps=fps,
263
+ )
264
+
265
+ stream = ffmpeg.input(save_path)
266
+ audio = ffmpeg.input(input_audio)
267
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
268
+ os.remove(save_path)
269
+
270
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
271
+
272
+ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42, acc_flag=True):
273
+ cfg = 3.5
274
+ fi_step = 3 if acc_flag else 1
275
+
276
+ generator = torch.manual_seed(seed)
277
+
278
+ lmk_extractor = LMKExtractor()
279
+ vis = FaceMeshVisualizer()
280
+
281
+ width, height = size, size
282
+
283
+ date_str = datetime.now().strftime("%Y%m%d")
284
+ time_str = datetime.now().strftime("%H%M")
285
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
286
+
287
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
288
+ while os.path.exists(save_dir):
289
+ save_dir = Path(f"output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}")
290
+ save_dir.mkdir(exist_ok=True, parents=True)
291
+
292
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
293
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
294
+ if ref_image_np is None:
295
+ return None, Image.fromarray(ref_img)
296
+
297
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
298
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
299
+
300
+ face_result = lmk_extractor(ref_image_np)
301
+ if face_result is None:
302
+ return None, ref_image_pil
303
+
304
+ lmks = face_result['lmks'].astype(np.float32)
305
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
306
+
307
+ source_images = read_frames(source_video)
308
+ src_fps = get_fps(source_video)
309
+
310
+ step = 1
311
+ if src_fps == 60:
312
+ src_fps = 30
313
+ step = 2
314
+
315
+ pose_trans_list = []
316
+ verts_list = []
317
+ bs_list = []
318
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
319
+ for src_image_pil in source_images[: args_L : step*fi_step]:
320
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
321
+ frame_height, frame_width, _ = src_img_np.shape
322
+ src_img_result = lmk_extractor(src_img_np)
323
+ if src_img_result is None:
324
+ break
325
+ pose_trans_list.append(src_img_result['trans_mat'])
326
+ verts_list.append(src_img_result['lmks3d'])
327
+ bs_list.append(src_img_result['bs'])
328
+
329
+ trans_mat_arr = np.array(pose_trans_list)
330
+ verts_arr = np.array(verts_list)
331
+ bs_arr = np.array(bs_list)
332
+ min_bs_idx = np.argmin(bs_arr.sum(1))
333
+
334
+ # compute delta pose
335
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
336
+
337
+ for i in range(pose_arr.shape[0]):
338
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
339
+ pose_arr[i, :3] = euler_angles
340
+ pose_arr[i, 3:6] = translation_vector
341
+
342
+ init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
343
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
344
+
345
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
346
+ pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
347
+ pose_mat_smooth = np.array(pose_mat_smooth)
348
+
349
+ # face retarget
350
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
351
+ # project 3D mesh to 2D landmark
352
+ projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
353
+
354
+ pose_list = []
355
+ for i, verts in enumerate(projected_vertices):
356
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
357
+ pose_image_np = cv2.resize(lmk_img, (width, height))
358
+ pose_list.append(pose_image_np)
359
+
360
+ pose_list = np.array(pose_list)
361
+
362
+ video_length = len(pose_list)
363
+
364
+ video = pipe(
365
+ ref_image_pil,
366
+ pose_list,
367
+ ref_pose,
368
+ width,
369
+ height,
370
+ video_length,
371
+ steps,
372
+ cfg,
373
+ generator=generator,
374
+ ).videos
375
+
376
+ if acc_flag:
377
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
378
+
379
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
380
+ save_videos_grid(
381
+ video,
382
+ save_path,
383
+ n_rows=1,
384
+ fps=src_fps,
385
+ )
386
+
387
+ audio_output = f'{save_dir}/audio_from_video.aac'
388
+ # extract audio
389
+ try:
390
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
391
+ # merge audio and video
392
+ stream = ffmpeg.input(save_path)
393
+ audio = ffmpeg.input(audio_output)
394
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
395
+
396
+ os.remove(save_path)
397
+ os.remove(audio_output)
398
+ except:
399
+ shutil.move(
400
+ save_path,
401
+ save_path.replace('_noaudio.mp4', '.mp4')
402
+ )
403
+
404
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
405
+
406
+
407
+ ################# GUI ################
408
+
409
+ title = r"""
410
+ <h1>AniPortrait</h1>
411
+ """
412
+
413
+ description = r"""
414
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
415
+ """
416
+
417
+ with gr.Blocks() as demo:
418
+
419
+ gr.Markdown(title)
420
+ gr.Markdown(description)
421
+
422
+ with gr.Tab("Audio2video"):
423
+ with gr.Row():
424
+ with gr.Column():
425
+ with gr.Row():
426
+ a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
427
+ a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
428
+ a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
429
+
430
+ with gr.Row():
431
+ a2v_size_slider = gr.Slider(minimum=256, maximum=768, step=8, value=512, label="Video size (-W & -H)")
432
+ a2v_step_slider = gr.Slider(minimum=5, maximum=30, step=1, value=25, label="Steps (--steps)")
433
+
434
+ with gr.Row():
435
+ a2v_length = gr.Slider(minimum=0, maximum=9999, step=1, value=60, label="Length (-L) (Set to 0 to automatically calculate length)")
436
+ a2v_seed = gr.Number(value=42, label="Seed (--seed)")
437
+
438
+ with gr.Row():
439
+ a2v_acc_flag = gr.Checkbox(value=True, label="Accelerate")
440
+ a2v_botton = gr.Button("Generate", variant="primary")
441
+ a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
442
+
443
+ gr.Examples(
444
+ examples=[
445
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
446
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
447
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
448
+ ],
449
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
450
+ )
451
+
452
+ with gr.Tab("Video2video"):
453
+ with gr.Row():
454
+ with gr.Column():
455
+ with gr.Row():
456
+ v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
457
+ v2v_source_video = gr.Video(label="Upload source video", sources="upload")
458
+
459
+ with gr.Row():
460
+ v2v_size_slider = gr.Slider(minimum=256, maximum=768, step=8, value=512, label="Video size (-W & -H)")
461
+ v2v_step_slider = gr.Slider(minimum=5, maximum=30, step=1, value=25, label="Steps (--steps)")
462
+
463
+ with gr.Row():
464
+ v2v_length = gr.Slider(minimum=0, maximum=9999, step=1, value=60, label="Length (-L) (Set to 0 to automatically calculate length)")
465
+ v2v_seed = gr.Number(value=42, label="Seed (--seed)")
466
+
467
+ with gr.Row():
468
+ v2v_acc_flag = gr.Checkbox(value=True, label="Accelerate")
469
+ v2v_botton = gr.Button("Generate", variant="primary")
470
+ v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
471
+
472
+ gr.Examples(
473
+ examples=[
474
+ ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
475
+ ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
476
+ ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
477
+ ],
478
+ inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
479
+ )
480
+
481
+ a2v_botton.click(
482
+ fn=audio2video,
483
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
484
+ a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed, a2v_acc_flag],
485
+ outputs=[a2v_output_video, a2v_ref_img]
486
+ )
487
+ v2v_botton.click(
488
+ fn=video2video,
489
+ inputs=[v2v_ref_img, v2v_source_video,
490
+ v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed, v2v_acc_flag],
491
+ outputs=[v2v_output_video, v2v_ref_img]
492
+ )
493
+
494
+ demo.launch()