Suprath's picture
Upload 54 files
9f4b9c7 verified
import argparse
import platform
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterator, List, Literal, Optional, Union
import cv2
import numpy as np
from config import hparams as hp
from nota_wav2lip.inference import Wav2LipInferenceImpl
from nota_wav2lip.util import FFMPEG_LOGGING_MODE
from nota_wav2lip.video import AudioSlicer, VideoSlicer
class Wav2LipModelComparisonDemo:
def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None):
if model_list is None:
model_list: List[str] = ['wav2lip', 'nota_wav2lip']
if isinstance(model_list, str) and len(model_list) != 0:
model_list: List[str] = [model_list]
super().__init__()
self.video_dict: Dict[str, VideoSlicer] = {}
self.audio_dict: Dict[str, AudioSlicer] = {}
self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {}
for model_name in model_list:
assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}"
self.model_zoo[model_name] = Wav2LipInferenceImpl(
model_name, hp_inference_model=hp.inference.model[model_name], device=device
)
self._params_zoo: Dict[str, str] = {
model_name: self.model_zoo[model_name].params for model_name in self.model_zoo
}
self.result_dir: Path = Path(result_dir)
self.result_dir.mkdir(exist_ok=True)
@property
def params(self):
return self._params_zoo
def _infer(
self,
audio_name: str,
video_name: str,
model_type: Literal['wav2lip', 'nota_wav2lip']
) -> Iterator[np.ndarray]:
audio_iterable: AudioSlicer = self.audio_dict[audio_name]
video_iterable: VideoSlicer = self.video_dict[video_name]
target_model = self.model_zoo[model_type]
return target_model.inference_with_iterator(audio_iterable, video_iterable)
def update_audio(self, audio_path, name=None):
_name = name if name is not None else Path(audio_path).stem
self.audio_dict.update(
{_name: AudioSlicer(audio_path)}
)
def update_video(self, frame_dir_path, bbox_path, name=None):
_name = name if name is not None else Path(frame_dir_path).stem
self.video_dict.update(
{_name: VideoSlicer(frame_dir_path, bbox_path)}
)
def save_as_video(self, audio_name, video_name, model_type):
output_video_path = self.result_dir / 'generated_with_audio.mp4'
frame_only_video_path = self.result_dir / 'generated.mp4'
audio_path = self.audio_dict[audio_name].audio_path
out = cv2.VideoWriter(str(frame_only_video_path),
cv2.VideoWriter_fourcc(*'mp4v'),
hp.face.video_fps,
(hp.inference.frame.w, hp.inference.frame.h))
start = time.time()
for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type):
out.write(frame)
inference_time = time.time() - start
out.release()
command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
subprocess.call(command, shell=platform.system() != 'Windows')
# The number of frames of generated video
video_frames_num = len(self.audio_dict[audio_name])
inference_fps = video_frames_num / inference_time
return output_video_path, inference_time, inference_fps