kimjy0411 commited on
Commit
5801cb3
·
verified ·
1 Parent(s): 1ef2055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -69
app.py CHANGED
@@ -1,36 +1,186 @@
1
- import gradio as gr
2
  import os
3
  import shutil
4
- import ffmpeg
5
- from datetime import datetime
6
- from pathlib import Path
7
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import cv2
9
  import torch
10
- import spaces
11
-
12
  from diffusers import AutoencoderKL, DDIMScheduler
13
  from einops import repeat
14
  from omegaconf import OmegaConf
15
- from PIL import Image
16
  from torchvision import transforms
17
  from transformers import CLIPVisionModelWithProjection
18
-
19
  from src.models.pose_guider import PoseGuider
20
  from src.models.unet_2d_condition import UNet2DConditionModel
21
  from src.models.unet_3d import UNet3DConditionModel
22
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
23
  from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
24
-
25
  from src.audio_models.model import Audio2MeshModel
26
  from src.utils.audio_util import prepare_audio_feature
27
- from src.utils.mp_utils import LMKExtractor
28
  from src.utils.draw_util import FaceMeshVisualizer
29
  from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
30
  from src.utils.crop_face_single import crop_face
31
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
32
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
35
  if config.weight_dtype == "fp16":
36
  weight_dtype = torch.float16
@@ -38,49 +188,28 @@ else:
38
  weight_dtype = torch.float32
39
 
40
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
41
- # prepare model
42
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
43
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
44
  a2m_model.cuda().eval()
45
 
46
- vae = AutoencoderKL.from_pretrained(
47
- config.pretrained_vae_path,
48
- ).to("cuda", dtype=weight_dtype)
49
 
50
- reference_unet = UNet2DConditionModel.from_pretrained(
51
- config.pretrained_base_model_path,
52
- subfolder="unet",
53
- ).to(dtype=weight_dtype, device="cuda")
54
 
55
  inference_config_path = config.inference_config
56
  infer_config = OmegaConf.load(inference_config_path)
57
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
58
- config.pretrained_base_model_path,
59
- config.motion_module_path,
60
- subfolder="unet",
61
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
62
- ).to(dtype=weight_dtype, device="cuda")
63
 
64
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
65
 
66
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
67
- config.image_encoder_path
68
- ).to(dtype=weight_dtype, device="cuda")
69
 
70
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
71
  scheduler = DDIMScheduler(**sched_kwargs)
72
 
73
- # load pretrained weights
74
- denoising_unet.load_state_dict(
75
- torch.load(config.denoising_unet_path, map_location="cpu"),
76
- strict=False,
77
- )
78
- reference_unet.load_state_dict(
79
- torch.load(config.reference_unet_path, map_location="cpu"),
80
- )
81
- pose_guider.load_state_dict(
82
- torch.load(config.pose_guider_path, map_location="cpu"),
83
- )
84
 
85
  pipe = Pose2VideoPipeline(
86
  vae=vae,
@@ -92,13 +221,10 @@ pipe = Pose2VideoPipeline(
92
  )
93
  pipe = pipe.to("cuda", dtype=weight_dtype)
94
 
95
- # lmk_extractor = LMKExtractor()
96
- # vis = FaceMeshVisualizer()
97
-
98
  frame_inter_model = init_frame_interpolation_model()
99
 
100
  @spaces.GPU
101
- def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=60, seed=42):
102
  fps = 30
103
  cfg = 3.5
104
  fi_step = 3
@@ -138,7 +264,6 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
138
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
139
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
140
 
141
- # inference
142
  pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
143
  pred = pred.squeeze().detach().cpu().numpy()
144
  pred = pred.reshape(pred.shape[0], -1, 3)
@@ -151,7 +276,6 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
151
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
152
  cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
153
 
154
- # project 3D mesh to 2D landmark
155
  projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
156
 
157
  pose_images = []
@@ -160,10 +284,10 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
160
  pose_images.append(lmk_img)
161
 
162
  pose_list = []
163
- args_L = len(pose_images) if length==0 or length > len(pose_images) else length
164
  args_L = min(args_L, 90)
165
- for pose_image_np in pose_images[: args_L : fi_step]:
166
- pose_image_np = cv2.resize(pose_image_np, (width, height))
167
  pose_list.append(pose_image_np)
168
 
169
  pose_list = np.array(pose_list)
@@ -182,7 +306,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
182
  generator=generator,
183
  ).videos
184
 
185
- video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
186
 
187
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
188
  save_videos_grid(
@@ -199,6 +323,131 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
199
 
200
  return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  ################# GUI ################
204
 
@@ -225,8 +474,8 @@ with gr.Blocks() as demo:
225
  with gr.Column():
226
  with gr.Row():
227
  a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
228
- a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
229
- a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
230
 
231
  with gr.Row():
232
  a2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
@@ -236,48 +485,73 @@ with gr.Blocks() as demo:
236
  a2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
237
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
238
 
239
- a2v_botton = gr.Button("Generate", variant="primary")
240
- a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
241
 
242
  gr.Examples(
243
  examples=[
244
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
245
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
246
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
247
- ],
248
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
249
  )
250
 
251
 
252
- with gr.Tab("TTS"):
253
  with gr.Row():
254
  with gr.Column():
255
  with gr.Row():
256
- tts_text_input = gr.Textbox(lines=5, label="Input text", placeholder="Enter text to synthesize...")
257
- tts_ref_img = gr.Image(label="Upload reference image", sources="upload")
258
 
259
  with gr.Row():
260
- tts_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
261
- tts_step_slider = gr.Slider(minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)")
262
 
263
  with gr.Row():
264
- tts_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
265
- tts_seed = gr.Number(value=42, label="Seed (--seed)")
266
 
267
- tts_button = gr.Button("Generate", variant="primary")
268
- tts_output_video = gr.PlayableVideo(label="Result", interactive=False)
269
 
270
- a2v_botton.click(
 
 
 
 
 
 
 
 
 
271
  fn=audio2video,
272
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
273
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
274
  outputs=[a2v_output_video, a2v_ref_img]
275
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  tts_button.click(
277
- fn=audio2video, # 추후 TTS 관련 함수로 대체 필요
278
- inputs=[tts_text_input, tts_ref_img, None,
279
- tts_size_slider, tts_step_slider, tts_length, tts_seed],
280
- outputs=[tts_output_video, tts_ref_img]
281
  )
282
 
283
- demo.launch()
 
 
1
  import os
2
  import shutil
 
 
 
3
  import numpy as np
4
+ import gradio as gr
5
+ import torchaudio
6
+ import soundfile as sf
7
+ from pathlib import Path
8
+ from datetime import datetime
9
+ from scipy.io.wavfile import write as write_wav
10
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
11
+ from encodec.utils import convert_audio
12
+
13
+ from src.bark.history_to_hash import history_to_hash
14
+ from src.bark.npz_tools import save_npz
15
+ from src.bark.FullGeneration import FullGeneration
16
+ from src.utils.date import get_date_string
17
+ from src.bark.get_audio_from_npz import get_audio_from_full_generation
18
+ from bark_hubert_quantizer.hubert_manager import HuBERTManager
19
+ from bark_hubert_quantizer.pre_kmeans_hubert import CustomHubert
20
+ from bark_hubert_quantizer.customtokenizer import CustomTokenizer
21
+ from bark import SAMPLE_RATE
22
+ from bark.generation import SUPPORTED_LANGS, generate_text_semantic, generate_coarse, generate_fine, codec_decode
23
+
24
+ import ffmpeg
25
  import cv2
26
  import torch
27
+ from PIL import Image
 
28
  from diffusers import AutoencoderKL, DDIMScheduler
29
  from einops import repeat
30
  from omegaconf import OmegaConf
 
31
  from torchvision import transforms
32
  from transformers import CLIPVisionModelWithProjection
 
33
  from src.models.pose_guider import PoseGuider
34
  from src.models.unet_2d_condition import UNet2DConditionModel
35
  from src.models.unet_3d import UNet3DConditionModel
36
  from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
37
  from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
 
38
  from src.audio_models.model import Audio2MeshModel
39
  from src.utils.audio_util import prepare_audio_feature
40
+ from src.utils.mp_utils import LMKExtractor
41
  from src.utils.draw_util import FaceMeshVisualizer
42
  from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
43
  from src.utils.crop_face_single import crop_face
44
  from src.audio2vid import get_headpose_temp, smooth_pose_seq
45
  from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
46
 
47
+ hubert_model = None
48
+
49
+ def _load_hubert_model(device):
50
+ hubert_path = HuBERTManager.make_sure_hubert_installed()
51
+ global hubert_model
52
+ if hubert_model is None:
53
+ hubert_model = CustomHubert(
54
+ checkpoint_path=hubert_path,
55
+ device=device,
56
+ )
57
+ return hubert_model
58
+
59
+ def _get_semantic_vectors(hubert_model: CustomHubert, path_to_wav: str, device):
60
+ wav, sr = torchaudio.load(path_to_wav)
61
+ if wav.shape[0] == 2:
62
+ wav = wav.mean(0, keepdim=True)
63
+ wav = wav.to(device)
64
+ return hubert_model.forward(wav, input_sample_hz=sr)
65
+
66
+ def get_semantic_vectors(path_to_wav: str, device):
67
+ hubert_model = _load_hubert_model(device)
68
+ return _get_semantic_vectors(hubert_model, path_to_wav, device)
69
+
70
+ tokenizer = None
71
+
72
+ def _load_tokenizer(
73
+ model: str = "quantifier_hubert_base_ls960_14.pth",
74
+ repo: str = "GitMylo/bark-voice-cloning",
75
+ force_reload: bool = False,
76
+ device="cpu",
77
+ ) -> CustomTokenizer:
78
+ tokenizer_path = HuBERTManager.make_sure_tokenizer_installed(
79
+ model=model,
80
+ repo=repo,
81
+ local_file=model,
82
+ )
83
+ global tokenizer
84
+ if tokenizer is None or force_reload:
85
+ tokenizer = CustomTokenizer.load_from_checkpoint(
86
+ tokenizer_path,
87
+ map_location=device,
88
+ )
89
+ tokenizer.load_state_dict(torch.load(tokenizer_path, map_location=device))
90
+ return tokenizer
91
+
92
+ def get_semantic_tokens(semantic_vectors: torch.Tensor, device):
93
+ tokenizer = _load_tokenizer(device=device)
94
+ return tokenizer.get_token(semantic_vectors)
95
+
96
+ def get_semantic_prompt(path_to_wav: str, device):
97
+ semantic_vectors = get_semantic_vectors(path_to_wav, device)
98
+ return get_semantic_tokens(semantic_vectors, device).cpu().numpy()
99
+
100
+ def get_prompts(path_to_wav: str, use_gpu: bool):
101
+ device = "cuda" if use_gpu else "cpu"
102
+ semantic_prompt = get_semantic_prompt(path_to_wav, device)
103
+ fine_prompt, coarse_prompt = get_encodec_prompts(path_to_wav, use_gpu)
104
+ return FullGeneration(
105
+ semantic_prompt=semantic_prompt,
106
+ coarse_prompt=coarse_prompt,
107
+ fine_prompt=fine_prompt,
108
+ )
109
+
110
+ def get_encodec_prompts(path_to_wav: str, use_gpu=True):
111
+ device = "cuda" if use_gpu else "cpu"
112
+ model = load_codec_model(use_gpu=use_gpu)
113
+ wav, sr = torchaudio.load(path_to_wav)
114
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
115
+ wav = wav.unsqueeze(0).to(device)
116
+ model.to(device)
117
+
118
+ with torch.no_grad():
119
+ encoded_frames = model.encode(wav)
120
+
121
+ fine_prompt = (
122
+ torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
123
+ .squeeze()
124
+ .cpu()
125
+ .numpy()
126
+ )
127
+ coarse_prompt = fine_prompt[:2, :]
128
+ return fine_prompt, coarse_prompt
129
+
130
+ def save_cloned_voice(full_generation: FullGeneration):
131
+ voice_name = f"voice_from_audio_{history_to_hash(full_generation)}"
132
+ filename = f"voices/{voice_name}.npz"
133
+ date = get_date_string()
134
+ metadata = generate_cloned_voice_metadata(full_generation, date)
135
+ save_npz(filename, full_generation, metadata)
136
+ return filename
137
+
138
+ def generate_cloned_voice_metadata(full_generation, date):
139
+ return {
140
+ "_version": "0.0.1",
141
+ "_hash_version": "0.0.2",
142
+ "_type": "bark",
143
+ "hash": history_to_hash(full_generation),
144
+ "date": date,
145
+ }
146
+
147
+ def generate_voice(wav_file: str, use_gpu: bool):
148
+ full_generation = get_prompts(wav_file, use_gpu)
149
+ filename = save_cloned_voice(full_generation)
150
+ return filename, get_audio_from_full_generation(full_generation)
151
+
152
+ # 음성 합성을 위한 함수
153
+ def synthesize_speech(text, input_audio):
154
+ semantic_tokens = generate_text_semantic(text)
155
+ coarse_tokens = generate_coarse(semantic_tokens)
156
+ fine_tokens = generate_fine(coarse_tokens)
157
+ synthesized_audio = codec_decode(fine_tokens)
158
+ if isinstance(synthesized_audio, torch.Tensor):
159
+ synthesized_audio = synthesized_audio.squeeze().cpu().numpy()
160
+ else:
161
+ synthesized_audio = synthesized_audio.squeeze()
162
+
163
+ # 입력 음성의 길이 가져오기
164
+ input_wav, input_sr = torchaudio.load(input_audio)
165
+ input_length = input_wav.shape[1] / input_sr
166
+
167
+ # 출력 음성을 입력 음성의 길이에 맞추기
168
+ output_length = synthesized_audio.shape[0] / SAMPLE_RATE
169
+ if output_length > input_length:
170
+ synthesized_audio = synthesized_audio[:int(input_length * SAMPLE_RATE)]
171
+ else:
172
+ padding = int((input_length - output_length) * SAMPLE_RATE)
173
+ synthesized_audio = np.pad(synthesized_audio, (0, padding), 'constant')
174
+
175
+ sf.write("synthesized_audio.wav", synthesized_audio, SAMPLE_RATE)
176
+ return "synthesized_audio.wav"
177
+
178
+ # TTS 기능 함수
179
+ def tts_function(input_audio, input_text):
180
+ synthesized_audio_path = synthesize_speech(input_text, input_audio)
181
+ return synthesized_audio_path
182
+
183
+ # aniportrait 함수 정의
184
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
185
  if config.weight_dtype == "fp16":
186
  weight_dtype = torch.float16
 
188
  weight_dtype = torch.float32
189
 
190
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
 
191
  a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
192
  a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
193
  a2m_model.cuda().eval()
194
 
195
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path).to("cuda", dtype=weight_dtype)
 
 
196
 
197
+ reference_unet = UNet2DConditionModel.from_pretrained(config.pretrained_base_model_path, subfolder="unet").to(dtype=weight_dtype, device="cuda")
 
 
 
198
 
199
  inference_config_path = config.inference_config
200
  infer_config = OmegaConf.load(inference_config_path)
201
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs).to(dtype=weight_dtype, device="cuda")
 
 
 
 
 
202
 
203
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype)
204
 
205
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda")
 
 
206
 
207
  sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
208
  scheduler = DDIMScheduler(**sched_kwargs)
209
 
210
+ denoising_unet.load_state_dict(torch.load(config.denoising_unet_path, map_location="cpu"), strict=False)
211
+ reference_unet.load_state_dict(torch.load(config.reference_unet_path, map_location="cpu"))
212
+ pose_guider.load_state_dict(torch.load(config.pose_guider_path, map_location="cpu"))
 
 
 
 
 
 
 
 
213
 
214
  pipe = Pose2VideoPipeline(
215
  vae=vae,
 
221
  )
222
  pipe = pipe.to("cuda", dtype=weight_dtype)
223
 
 
 
 
224
  frame_inter_model = init_frame_interpolation_model()
225
 
226
  @spaces.GPU
227
+ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=60, seed=42):
228
  fps = 30
229
  cfg = 3.5
230
  fi_step = 3
 
264
  sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
265
  sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
266
 
 
267
  pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
268
  pred = pred.squeeze().detach().cpu().numpy()
269
  pred = pred.reshape(pred.shape[0], -1, 3)
 
276
  mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
277
  cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
278
 
 
279
  projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
280
 
281
  pose_images = []
 
284
  pose_images.append(lmk_img)
285
 
286
  pose_list = []
287
+ args_L = len(pose_images) if length == 0 or length > len(pose_images) else length
288
  args_L = min(args_L, 90)
289
+ for pose_image_np in pose_images[:args_L:fi_step]:
290
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
291
  pose_list.append(pose_image_np)
292
 
293
  pose_list = np.array(pose_list)
 
306
  generator=generator,
307
  ).videos
308
 
309
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step - 1)
310
 
311
  save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
312
  save_videos_grid(
 
323
 
324
  return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
325
 
326
+ @spaces.GPU
327
+ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
328
+ cfg = 3.5
329
+ fi_step = 3
330
+
331
+ generator = torch.manual_seed(seed)
332
+
333
+ lmk_extractor = LMKExtractor()
334
+ vis = FaceMeshVisualizer()
335
+
336
+ width, height = size, size
337
+
338
+ date_str = datetime.now().strftime("%Y%m%d")
339
+ time_str = datetime.now().strftime("%H%M")
340
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
341
+
342
+ save_dir = Path(f"v2v_output/{date_str}/{save_dir_name}")
343
+ while os.path.exists(save_dir):
344
+ save_dir = Path(f"v2v_output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}")
345
+ save_dir.mkdir(exist_ok=True, parents=True)
346
+
347
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
348
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
349
+ if ref_image_np is None:
350
+ return None, Image.fromarray(ref_img)
351
+
352
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
353
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
354
+
355
+ face_result = lmk_extractor(ref_image_np)
356
+ if face_result is None:
357
+ return None, ref_image_pil
358
+
359
+ lmks = face_result['lmks'].astype(np.float32)
360
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
361
+
362
+ source_images = read_frames(source_video)
363
+ src_fps = get_fps(source_video)
364
+ pose_transform = transforms.Compose([transforms.Resize((height, width)), transforms.ToTensor()])
365
+
366
+ step = 1
367
+ if src_fps == 60:
368
+ src_fps = 30
369
+ step = 2
370
+
371
+ pose_trans_list = []
372
+ verts_list = []
373
+ bs_list = []
374
+ args_L = len(source_images) if length == 0 or length * step > len(source_images) else length * step
375
+ args_L = min(args_L, 90 * step)
376
+ for src_image_pil in source_images[:args_L:step*fi_step]:
377
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
378
+ frame_height, frame_width, _ = src_img_np.shape
379
+ src_img_result = lmk_extractor(src_img_np)
380
+ if src_img_result is None:
381
+ break
382
+ pose_trans_list.append(src_img_result['trans_mat'])
383
+ verts_list.append(src_img_result['lmks3d'])
384
+ bs_list.append(src_img_result['bs'])
385
+
386
+ trans_mat_arr = np.array(pose_trans_list)
387
+ verts_arr = np.array(verts_list)
388
+ bs_arr = np.array(bs_list)
389
+ min_bs_idx = np.argmin(bs_arr.sum(1))
390
+
391
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
392
+ for i in range(pose_arr.shape[0]):
393
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i])
394
+ pose_arr[i, :3] = euler_angles
395
+ pose_arr[i, 3:6] = translation_vector
396
+
397
+ init_tran_vec = face_result['trans_mat'][:3, 3]
398
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec
399
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
400
+ pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
401
+ pose_mat_smooth = np.array(pose_mat_smooth)
402
+
403
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
404
+ projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
405
+
406
+ pose_list = []
407
+ for i, verts in enumerate(projected_vertices):
408
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
409
+ pose_image_np = cv2.resize(lmk_img, (width, height))
410
+ pose_list.append(pose_image_np)
411
+
412
+ pose_list = np.array(pose_list)
413
+
414
+ video_length = len(pose_list)
415
+
416
+ video = pipe(
417
+ ref_image_pil,
418
+ pose_list,
419
+ ref_pose,
420
+ width,
421
+ height,
422
+ video_length,
423
+ steps,
424
+ cfg,
425
+ generator=generator,
426
+ ).videos
427
+
428
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step - 1)
429
+
430
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
431
+ save_videos_grid(
432
+ video,
433
+ save_path,
434
+ n_rows=1,
435
+ fps=src_fps,
436
+ )
437
+
438
+ audio_output = f'{save_dir}/audio_from_video.aac'
439
+ try:
440
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
441
+ stream = ffmpeg.input(save_path)
442
+ audio = ffmpeg.input(audio_output)
443
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
444
+
445
+ os.remove(save_path)
446
+ os.remove(audio_output)
447
+ except:
448
+ shutil.move(save_path, save_path.replace('_noaudio.mp4', '.mp4'))
449
+
450
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
451
 
452
  ################# GUI ################
453
 
 
474
  with gr.Column():
475
  with gr.Row():
476
  a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
477
+ a2v_ref_img = gr.Image(label="Upload reference image", source="upload")
478
+ a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", source="upload")
479
 
480
  with gr.Row():
481
  a2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
 
485
  a2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
486
  a2v_seed = gr.Number(value=42, label="Seed (--seed)")
487
 
488
+ a2v_button = gr.Button("Generate", variant="primary")
489
+ a2v_output_video = gr.Video(label="Result", interactive=False)
490
 
491
  gr.Examples(
492
  examples=[
493
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
494
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
495
  ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
496
+ ],
497
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
498
  )
499
 
500
 
501
+ with gr.Tab("Video2video"):
502
  with gr.Row():
503
  with gr.Column():
504
  with gr.Row():
505
+ v2v_ref_img = gr.Image(label="Upload reference image", source="upload")
506
+ v2v_source_video = gr.Video(label="Upload source video", source="upload")
507
 
508
  with gr.Row():
509
+ v2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
510
+ v2v_step_slider = gr.Slider(minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)")
511
 
512
  with gr.Row():
513
+ v2v_length = gr.Slider(minimum=0, maximum=90, step=1, value=30, label="Length (-L)")
514
+ v2v_seed = gr.Number(value=42, label="Seed (--seed)")
515
 
516
+ v2v_button = gr.Button("Generate", variant="primary")
517
+ v2v_output_video = gr.Video(label="Result", interactive=False)
518
 
519
+ gr.Examples(
520
+ examples=[
521
+ ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
522
+ ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
523
+ ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
524
+ ],
525
+ inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
526
+ )
527
+
528
+ a2v_button.click(
529
  fn=audio2video,
530
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
531
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
532
  outputs=[a2v_output_video, a2v_ref_img]
533
  )
534
+ v2v_button.click(
535
+ fn=video2video,
536
+ inputs=[v2v_ref_img, v2v_source_video,
537
+ v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
538
+ outputs=[v2v_output_video, v2v_ref_img]
539
+ )
540
+
541
+ with gr.Tab("TTS"):
542
+ with gr.Row():
543
+ with gr.Column():
544
+ with gr.Row():
545
+ tts_input_audio = gr.Audio(type="filepath", label="Input audio for feature extraction")
546
+ tts_text_input = gr.Textbox(lines=5, label="Input text", placeholder="Enter text to synthesize...")
547
+
548
+ tts_button = gr.Button("Synthesize", variant="primary")
549
+ tts_output_audio = gr.Audio(label="Synthesized Audio", interactive=False)
550
+
551
  tts_button.click(
552
+ fn=tts_function,
553
+ inputs=[tts_input_audio, tts_text_input],
554
+ outputs=[tts_output_audio]
 
555
  )
556
 
557
+ demo.launch(debug=True)