kimjy0411 commited on
Commit
bc48784
·
verified ·
1 Parent(s): 5801cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -213
app.py CHANGED
@@ -1,186 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import shutil
3
- import numpy as np
4
- import gradio as gr
5
- import torchaudio
6
- import soundfile as sf
7
- from pathlib import Path
8
- from datetime import datetime
9
- from scipy.io.wavfile import write as write_wav
10
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
11
- from encodec.utils import convert_audio
12
-
13
- from src.bark.history_to_hash import history_to_hash
14
- from src.bark.npz_tools import save_npz
15
- from src.bark.FullGeneration import FullGeneration
16
- from src.utils.date import get_date_string
17
- from src.bark.get_audio_from_npz import get_audio_from_full_generation
18
- from bark_hubert_quantizer.hubert_manager import HuBERTManager
19
- from bark_hubert_quantizer.pre_kmeans_hubert import CustomHubert
20
- from bark_hubert_quantizer.customtokenizer import CustomTokenizer
21
- from bark import SAMPLE_RATE
22
- from bark.generation import SUPPORTED_LANGS, generate_text_semantic, generate_coarse, generate_fine, codec_decode
23
-
24
  import ffmpeg
 
 
 
25
  import cv2
26
  import torch
27
- from PIL import Image
 
28
  from diffusers import AutoencoderKL, DDIMScheduler
29
  from einops import repeat
30
  from omegaconf import OmegaConf
 
31
  from torchvision import transforms
32
  from transformers import CLIPVisionModelWithProjection
 
33
  from src.models.pose_guider import PoseGuider
34
  from src.models.unet_2d_condition import UNet2DConditionModel
35
  from src.models.unet_3d import UNet3DConditionModel
36
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
37
  from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
 
38
  from src.audio_models.model import Audio2MeshModel
39
  from src.utils.audio_util import prepare_audio_feature
40
- from src.utils.mp_utils import LMKExtractor
41
  from src.utils.draw_util import FaceMeshVisualizer
42
  from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
43
  from src.utils.crop_face_single import crop_face
44
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
45
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
46
 
47
- hubert_model = None
48
-
49
- def _load_hubert_model(device):
50
- hubert_path = HuBERTManager.make_sure_hubert_installed()
51
- global hubert_model
52
- if hubert_model is None:
53
- hubert_model = CustomHubert(
54
- checkpoint_path=hubert_path,
55
- device=device,
56
- )
57
- return hubert_model
58
-
59
- def _get_semantic_vectors(hubert_model: CustomHubert, path_to_wav: str, device):
60
- wav, sr = torchaudio.load(path_to_wav)
61
- if wav.shape[0] == 2:
62
- wav = wav.mean(0, keepdim=True)
63
- wav = wav.to(device)
64
- return hubert_model.forward(wav, input_sample_hz=sr)
65
-
66
- def get_semantic_vectors(path_to_wav: str, device):
67
- hubert_model = _load_hubert_model(device)
68
- return _get_semantic_vectors(hubert_model, path_to_wav, device)
69
-
70
- tokenizer = None
71
-
72
- def _load_tokenizer(
73
- model: str = "quantifier_hubert_base_ls960_14.pth",
74
- repo: str = "GitMylo/bark-voice-cloning",
75
- force_reload: bool = False,
76
- device="cpu",
77
- ) -> CustomTokenizer:
78
- tokenizer_path = HuBERTManager.make_sure_tokenizer_installed(
79
- model=model,
80
- repo=repo,
81
- local_file=model,
82
- )
83
- global tokenizer
84
- if tokenizer is None or force_reload:
85
- tokenizer = CustomTokenizer.load_from_checkpoint(
86
- tokenizer_path,
87
- map_location=device,
88
- )
89
- tokenizer.load_state_dict(torch.load(tokenizer_path, map_location=device))
90
- return tokenizer
91
-
92
- def get_semantic_tokens(semantic_vectors: torch.Tensor, device):
93
- tokenizer = _load_tokenizer(device=device)
94
- return tokenizer.get_token(semantic_vectors)
95
-
96
- def get_semantic_prompt(path_to_wav: str, device):
97
- semantic_vectors = get_semantic_vectors(path_to_wav, device)
98
- return get_semantic_tokens(semantic_vectors, device).cpu().numpy()
99
-
100
- def get_prompts(path_to_wav: str, use_gpu: bool):
101
- device = "cuda" if use_gpu else "cpu"
102
- semantic_prompt = get_semantic_prompt(path_to_wav, device)
103
- fine_prompt, coarse_prompt = get_encodec_prompts(path_to_wav, use_gpu)
104
- return FullGeneration(
105
- semantic_prompt=semantic_prompt,
106
- coarse_prompt=coarse_prompt,
107
- fine_prompt=fine_prompt,
108
- )
109
-
110
- def get_encodec_prompts(path_to_wav: str, use_gpu=True):
111
- device = "cuda" if use_gpu else "cpu"
112
- model = load_codec_model(use_gpu=use_gpu)
113
- wav, sr = torchaudio.load(path_to_wav)
114
- wav = convert_audio(wav, sr, model.sample_rate, model.channels)
115
- wav = wav.unsqueeze(0).to(device)
116
- model.to(device)
117
-
118
- with torch.no_grad():
119
- encoded_frames = model.encode(wav)
120
-
121
- fine_prompt = (
122
- torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
123
- .squeeze()
124
- .cpu()
125
- .numpy()
126
- )
127
- coarse_prompt = fine_prompt[:2, :]
128
- return fine_prompt, coarse_prompt
129
-
130
- def save_cloned_voice(full_generation: FullGeneration):
131
- voice_name = f"voice_from_audio_{history_to_hash(full_generation)}"
132
- filename = f"voices/{voice_name}.npz"
133
- date = get_date_string()
134
- metadata = generate_cloned_voice_metadata(full_generation, date)
135
- save_npz(filename, full_generation, metadata)
136
- return filename
137
-
138
- def generate_cloned_voice_metadata(full_generation, date):
139
- return {
140
- "_version": "0.0.1",
141
- "_hash_version": "0.0.2",
142
- "_type": "bark",
143
- "hash": history_to_hash(full_generation),
144
- "date": date,
145
- }
146
-
147
- def generate_voice(wav_file: str, use_gpu: bool):
148
- full_generation = get_prompts(wav_file, use_gpu)
149
- filename = save_cloned_voice(full_generation)
150
- return filename, get_audio_from_full_generation(full_generation)
151
-
152
- # 음성 합성을 위한 함수
153
- def synthesize_speech(text, input_audio):
154
- semantic_tokens = generate_text_semantic(text)
155
- coarse_tokens = generate_coarse(semantic_tokens)
156
- fine_tokens = generate_fine(coarse_tokens)
157
- synthesized_audio = codec_decode(fine_tokens)
158
- if isinstance(synthesized_audio, torch.Tensor):
159
- synthesized_audio = synthesized_audio.squeeze().cpu().numpy()
160
- else:
161
- synthesized_audio = synthesized_audio.squeeze()
162
-
163
- # 입력 음성의 길이 가져오기
164
- input_wav, input_sr = torchaudio.load(input_audio)
165
- input_length = input_wav.shape[1] / input_sr
166
-
167
- # 출력 음성을 입력 음성의 길이에 맞추기
168
- output_length = synthesized_audio.shape[0] / SAMPLE_RATE
169
- if output_length > input_length:
170
- synthesized_audio = synthesized_audio[:int(input_length * SAMPLE_RATE)]
171
- else:
172
- padding = int((input_length - output_length) * SAMPLE_RATE)
173
- synthesized_audio = np.pad(synthesized_audio, (0, padding), 'constant')
174
 
175
- sf.write("synthesized_audio.wav", synthesized_audio, SAMPLE_RATE)
176
- return "synthesized_audio.wav"
177
-
178
- # TTS 기능 함수
179
- def tts_function(input_audio, input_text):
180
- synthesized_audio_path = synthesize_speech(input_text, input_audio)
181
- return synthesized_audio_path
182
-
183
- # aniportrait 함수 정의
184
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
185
  if config.weight_dtype == "fp16":
186
  weight_dtype = torch.float16
@@ -188,28 +81,49 @@ else:
188
  weight_dtype = torch.float32
189
 
190
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
 
191
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
192
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
193
  a2m_model.cuda().eval()
194
 
195
- vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path).to("cuda", dtype=weight_dtype)
 
 
196
 
197
- reference_unet = UNet2DConditionModel.from_pretrained(config.pretrained_base_model_path, subfolder="unet").to(dtype=weight_dtype, device="cuda")
 
 
 
198
 
199
  inference_config_path = config.inference_config
200
  infer_config = OmegaConf.load(inference_config_path)
201
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs).to(dtype=weight_dtype, device="cuda")
 
 
 
 
 
202
 
203
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype)
204
 
205
- image_enc = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda")
 
 
206
 
207
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
208
  scheduler = DDIMScheduler(**sched_kwargs)
209
 
210
- denoising_unet.load_state_dict(torch.load(config.denoising_unet_path, map_location="cpu"), strict=False)
211
- reference_unet.load_state_dict(torch.load(config.reference_unet_path, map_location="cpu"))
212
- pose_guider.load_state_dict(torch.load(config.pose_guider_path, map_location="cpu"))
 
 
 
 
 
 
 
 
213
 
214
  pipe = Pose2VideoPipeline(
215
  vae=vae,
@@ -221,6 +135,9 @@ pipe = Pose2VideoPipeline(
221
  )
222
  pipe = pipe.to("cuda", dtype=weight_dtype)
223
 
 
 
 
224
  frame_inter_model = init_frame_interpolation_model()
225
 
226
  @spaces.GPU
@@ -264,6 +181,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
264
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
265
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
266
 
 
267
  pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
268
  pred = pred.squeeze().detach().cpu().numpy()
269
  pred = pred.reshape(pred.shape[0], -1, 3)
@@ -276,6 +194,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
276
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
277
  cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
278
 
 
279
  projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
280
 
281
  pose_images = []
@@ -284,10 +203,17 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
284
  pose_images.append(lmk_img)
285
 
286
  pose_list = []
287
- args_L = len(pose_images) if length == 0 or length > len(pose_images) else length
 
 
 
 
 
288
  args_L = min(args_L, 90)
289
- for pose_image_np in pose_images[:args_L:fi_step]:
290
- pose_image_np = cv2.resize(pose_image_np, (width, height))
 
 
291
  pose_list.append(pose_image_np)
292
 
293
  pose_list = np.array(pose_list)
@@ -306,7 +232,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
306
  generator=generator,
307
  ).videos
308
 
309
- video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step - 1)
310
 
311
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
312
  save_videos_grid(
@@ -316,6 +242,11 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
316
  fps=fps,
317
  )
318
 
 
 
 
 
 
319
  stream = ffmpeg.input(save_path)
320
  audio = ffmpeg.input(input_audio)
321
  ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
@@ -361,7 +292,9 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
361
 
362
  source_images = read_frames(source_video)
363
  src_fps = get_fps(source_video)
364
- pose_transform = transforms.Compose([transforms.Resize((height, width)), transforms.ToTensor()])
 
 
365
 
366
  step = 1
367
  if src_fps == 60:
@@ -371,9 +304,9 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
371
  pose_trans_list = []
372
  verts_list = []
373
  bs_list = []
374
- args_L = len(source_images) if length == 0 or length * step > len(source_images) else length * step
375
- args_L = min(args_L, 90 * step)
376
- for src_image_pil in source_images[:args_L:step*fi_step]:
377
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
378
  frame_height, frame_width, _ = src_img_np.shape
379
  src_img_result = lmk_extractor(src_img_np)
@@ -388,25 +321,30 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
388
  bs_arr = np.array(bs_list)
389
  min_bs_idx = np.argmin(bs_arr.sum(1))
390
 
 
391
  pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
 
392
  for i in range(pose_arr.shape[0]):
393
- euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i])
394
- pose_arr[i, :3] = euler_angles
395
- pose_arr[i, 3:6] = translation_vector
396
 
397
- init_tran_vec = face_result['trans_mat'][:3, 3]
398
- pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec
 
399
  pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
400
  pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
401
- pose_mat_smooth = np.array(pose_mat_smooth)
402
 
 
403
  verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
 
404
  projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
405
 
406
  pose_list = []
407
  for i, verts in enumerate(projected_vertices):
408
  lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
409
- pose_image_np = cv2.resize(lmk_img, (width, height))
410
  pose_list.append(pose_image_np)
411
 
412
  pose_list = np.array(pose_list)
@@ -425,7 +363,7 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
425
  generator=generator,
426
  ).videos
427
 
428
- video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step - 1)
429
 
430
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
431
  save_videos_grid(
@@ -435,9 +373,16 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
435
  fps=src_fps,
436
  )
437
 
 
 
 
 
 
438
  audio_output = f'{save_dir}/audio_from_video.aac'
 
439
  try:
440
  ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
 
441
  stream = ffmpeg.input(save_path)
442
  audio = ffmpeg.input(audio_output)
443
  ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
@@ -445,10 +390,14 @@ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
445
  os.remove(save_path)
446
  os.remove(audio_output)
447
  except:
448
- shutil.move(save_path, save_path.replace('_noaudio.mp4', '.mp4'))
 
 
 
449
 
450
  return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
451
 
 
452
  ################# GUI ################
453
 
454
  title = r"""
@@ -474,8 +423,8 @@ with gr.Blocks() as demo:
474
  with gr.Column():
475
  with gr.Row():
476
  a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
477
- a2v_ref_img = gr.Image(label="Upload reference image", source="upload")
478
- a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", source="upload")
479
 
480
  with gr.Row():
481
  a2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
@@ -485,15 +434,15 @@ with gr.Blocks() as demo:
485
  a2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
486
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
487
 
488
- a2v_button = gr.Button("Generate", variant="primary")
489
- a2v_output_video = gr.Video(label="Result", interactive=False)
490
 
491
  gr.Examples(
492
  examples=[
493
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
494
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
495
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
496
- ],
497
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
498
  )
499
 
@@ -502,8 +451,8 @@ with gr.Blocks() as demo:
502
  with gr.Row():
503
  with gr.Column():
504
  with gr.Row():
505
- v2v_ref_img = gr.Image(label="Upload reference image", source="upload")
506
- v2v_source_video = gr.Video(label="Upload source video", source="upload")
507
 
508
  with gr.Row():
509
  v2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
@@ -513,45 +462,30 @@ with gr.Blocks() as demo:
513
  v2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
514
  v2v_seed = gr.Number(value=42, label="Seed (--seed)")
515
 
516
- v2v_button = gr.Button("Generate", variant="primary")
517
- v2v_output_video = gr.Video(label="Result", interactive=False)
518
 
519
  gr.Examples(
520
  examples=[
521
  ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
522
  ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
523
  ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
524
- ],
525
  inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
526
  )
527
 
528
- a2v_button.click(
529
  fn=audio2video,
530
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
531
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
532
  outputs=[a2v_output_video, a2v_ref_img]
533
  )
534
- v2v_button.click(
535
  fn=video2video,
536
  inputs=[v2v_ref_img, v2v_source_video,
537
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
538
  outputs=[v2v_output_video, v2v_ref_img]
539
  )
540
 
541
- with gr.Tab("TTS"):
542
- with gr.Row():
543
- with gr.Column():
544
- with gr.Row():
545
- tts_input_audio = gr.Audio(type="filepath", label="Input audio for feature extraction")
546
- tts_text_input = gr.Textbox(lines=5, label="Input text", placeholder="Enter text to synthesize...")
547
-
548
- tts_button = gr.Button("Synthesize", variant="primary")
549
- tts_output_audio = gr.Audio(label="Synthesized Audio", interactive=False)
550
-
551
- tts_button.click(
552
- fn=tts_function,
553
- inputs=[tts_input_audio, tts_text_input],
554
- outputs=[tts_output_audio]
555
- )
556
-
557
- demo.launch(debug=True)
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Solutions
10
+ Pricing
11
+
12
+
13
+
14
+ Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
15
+ Spaces:
16
+
17
+ ZJYang
18
+ /
19
+ AniPortrait_official
20
+
21
+
22
+ like
23
+ 149
24
+ App
25
+ Files
26
+ Community
27
+ 2
28
+ AniPortrait_official
29
+ /
30
+ app.py
31
+
32
+ zejunyang
33
+ last push
34
+ 9600e7d
35
+ about 2 months ago
36
+ raw
37
+ history
38
+ blame
39
+ contribute
40
+ delete
41
+ No virus
42
+ 17 kB
43
+ import gradio as gr
44
  import os
45
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  import ffmpeg
47
+ from datetime import datetime
48
+ from pathlib import Path
49
+ import numpy as np
50
  import cv2
51
  import torch
52
+ import spaces
53
+
54
  from diffusers import AutoencoderKL, DDIMScheduler
55
  from einops import repeat
56
  from omegaconf import OmegaConf
57
+ from PIL import Image
58
  from torchvision import transforms
59
  from transformers import CLIPVisionModelWithProjection
60
+
61
  from src.models.pose_guider import PoseGuider
62
  from src.models.unet_2d_condition import UNet2DConditionModel
63
  from src.models.unet_3d import UNet3DConditionModel
64
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
65
  from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
66
+
67
  from src.audio_models.model import Audio2MeshModel
68
  from src.utils.audio_util import prepare_audio_feature
69
+ from src.utils.mp_utils import LMKExtractor
70
  from src.utils.draw_util import FaceMeshVisualizer
71
  from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
72
  from src.utils.crop_face_single import crop_face
73
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
74
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
77
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
78
  if config.weight_dtype == "fp16":
79
  weight_dtype = torch.float16
 
81
  weight_dtype = torch.float32
82
 
83
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
84
+ # prepare model
85
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
86
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
87
  a2m_model.cuda().eval()
88
 
89
+ vae = AutoencoderKL.from_pretrained(
90
+ config.pretrained_vae_path,
91
+ ).to("cuda", dtype=weight_dtype)
92
 
93
+ reference_unet = UNet2DConditionModel.from_pretrained(
94
+ config.pretrained_base_model_path,
95
+ subfolder="unet",
96
+ ).to(dtype=weight_dtype, device="cuda")
97
 
98
  inference_config_path = config.inference_config
99
  infer_config = OmegaConf.load(inference_config_path)
100
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
101
+ config.pretrained_base_model_path,
102
+ config.motion_module_path,
103
+ subfolder="unet",
104
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
105
+ ).to(dtype=weight_dtype, device="cuda")
106
 
107
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
108
 
109
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
110
+ config.image_encoder_path
111
+ ).to(dtype=weight_dtype, device="cuda")
112
 
113
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
114
  scheduler = DDIMScheduler(**sched_kwargs)
115
 
116
+ # load pretrained weights
117
+ denoising_unet.load_state_dict(
118
+ torch.load(config.denoising_unet_path, map_location="cpu"),
119
+ strict=False,
120
+ )
121
+ reference_unet.load_state_dict(
122
+ torch.load(config.reference_unet_path, map_location="cpu"),
123
+ )
124
+ pose_guider.load_state_dict(
125
+ torch.load(config.pose_guider_path, map_location="cpu"),
126
+ )
127
 
128
  pipe = Pose2VideoPipeline(
129
  vae=vae,
 
135
  )
136
  pipe = pipe.to("cuda", dtype=weight_dtype)
137
 
138
+ # lmk_extractor = LMKExtractor()
139
+ # vis = FaceMeshVisualizer()
140
+
141
  frame_inter_model = init_frame_interpolation_model()
142
 
143
  @spaces.GPU
 
181
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
182
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
183
 
184
+ # inference
185
  pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
186
  pred = pred.squeeze().detach().cpu().numpy()
187
  pred = pred.reshape(pred.shape[0], -1, 3)
 
194
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
195
  cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
196
 
197
+ # project 3D mesh to 2D landmark
198
  projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
199
 
200
  pose_images = []
 
203
  pose_images.append(lmk_img)
204
 
205
  pose_list = []
206
+ # pose_tensor_list = []
207
+
208
+ # pose_transform = transforms.Compose(
209
+ # [transforms.Resize((height, width)), transforms.ToTensor()]
210
+ # )
211
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
212
  args_L = min(args_L, 90)
213
+ for pose_image_np in pose_images[: args_L : fi_step]:
214
+ # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
215
+ # pose_tensor_list.append(pose_transform(pose_image_pil))
216
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
217
  pose_list.append(pose_image_np)
218
 
219
  pose_list = np.array(pose_list)
 
232
  generator=generator,
233
  ).videos
234
 
235
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
236
 
237
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
238
  save_videos_grid(
 
242
  fps=fps,
243
  )
244
 
245
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
246
+ # save_pil_imgs(video, save_path)
247
+
248
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(fps))
249
+
250
  stream = ffmpeg.input(save_path)
251
  audio = ffmpeg.input(input_audio)
252
  ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
 
292
 
293
  source_images = read_frames(source_video)
294
  src_fps = get_fps(source_video)
295
+ pose_transform = transforms.Compose(
296
+ [transforms.Resize((height, width)), transforms.ToTensor()]
297
+ )
298
 
299
  step = 1
300
  if src_fps == 60:
 
304
  pose_trans_list = []
305
  verts_list = []
306
  bs_list = []
307
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
308
+ args_L = min(args_L, 90*step)
309
+ for src_image_pil in source_images[: args_L : step*fi_step]:
310
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
311
  frame_height, frame_width, _ = src_img_np.shape
312
  src_img_result = lmk_extractor(src_img_np)
 
321
  bs_arr = np.array(bs_list)
322
  min_bs_idx = np.argmin(bs_arr.sum(1))
323
 
324
+ # compute delta pose
325
  pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
326
+
327
  for i in range(pose_arr.shape[0]):
328
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
329
+ pose_arr[i, :3] = euler_angles
330
+ pose_arr[i, 3:6] = translation_vector
331
 
332
+ init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
333
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
334
+
335
  pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
336
  pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
337
+ pose_mat_smooth = np.array(pose_mat_smooth)
338
 
339
+ # face retarget
340
  verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
341
+ # project 3D mesh to 2D landmark
342
  projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
343
 
344
  pose_list = []
345
  for i, verts in enumerate(projected_vertices):
346
  lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
347
+ pose_image_np = cv2.resize(lmk_img, (width, height))
348
  pose_list.append(pose_image_np)
349
 
350
  pose_list = np.array(pose_list)
 
363
  generator=generator,
364
  ).videos
365
 
366
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
367
 
368
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
369
  save_videos_grid(
 
373
  fps=src_fps,
374
  )
375
 
376
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
377
+ # save_pil_imgs(video, save_path)
378
+
379
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(src_fps))
380
+
381
  audio_output = f'{save_dir}/audio_from_video.aac'
382
+ # extract audio
383
  try:
384
  ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
385
+ # merge audio and video
386
  stream = ffmpeg.input(save_path)
387
  audio = ffmpeg.input(audio_output)
388
  ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
 
390
  os.remove(save_path)
391
  os.remove(audio_output)
392
  except:
393
+ shutil.move(
394
+ save_path,
395
+ save_path.replace('_noaudio.mp4', '.mp4')
396
+ )
397
 
398
  return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
399
 
400
+
401
  ################# GUI ################
402
 
403
  title = r"""
 
423
  with gr.Column():
424
  with gr.Row():
425
  a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
426
+ a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
427
+ a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
428
 
429
  with gr.Row():
430
  a2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
 
434
  a2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
435
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
436
 
437
+ a2v_botton = gr.Button("Generate", variant="primary")
438
+ a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
439
 
440
  gr.Examples(
441
  examples=[
442
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
443
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
444
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
445
+ ],
446
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
447
  )
448
 
 
451
  with gr.Row():
452
  with gr.Column():
453
  with gr.Row():
454
+ v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
455
+ v2v_source_video = gr.Video(label="Upload source video", sources="upload")
456
 
457
  with gr.Row():
458
  v2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
 
462
  v2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
463
  v2v_seed = gr.Number(value=42, label="Seed (--seed)")
464
 
465
+ v2v_botton = gr.Button("Generate", variant="primary")
466
+ v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
467
 
468
  gr.Examples(
469
  examples=[
470
  ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
471
  ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
472
  ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
473
+ ],
474
  inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
475
  )
476
 
477
+ a2v_botton.click(
478
  fn=audio2video,
479
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
480
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
481
  outputs=[a2v_output_video, a2v_ref_img]
482
  )
483
+ v2v_botton.click(
484
  fn=video2video,
485
  inputs=[v2v_ref_img, v2v_source_video,
486
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
487
  outputs=[v2v_output_video, v2v_ref_img]
488
  )
489
 
490
+ demo.launch()
491
+