|
"""run bash scripts/download_models.sh first to prepare the weights file""" |
|
import os |
|
import shutil |
|
from argparse import Namespace |
|
from src.utils.preprocess import CropAndExtract |
|
from src.test_audio2coeff import Audio2Coeff |
|
from src.facerender.animate import AnimateFromCoeff |
|
from src.generate_batch import get_data |
|
from src.generate_facerender_batch import get_facerender_data |
|
from src.utils.init_path import init_path |
|
from cog import BasePredictor, Input, Path |
|
|
|
checkpoints = "checkpoints" |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
device = "cuda" |
|
|
|
|
|
sadtalker_paths = init_path(checkpoints,os.path.join("src","config")) |
|
|
|
|
|
self.preprocess_model = CropAndExtract(sadtalker_paths, device |
|
) |
|
|
|
self.audio_to_coeff = Audio2Coeff( |
|
sadtalker_paths, |
|
device, |
|
) |
|
|
|
self.animate_from_coeff = { |
|
"full": AnimateFromCoeff( |
|
sadtalker_paths, |
|
device, |
|
), |
|
"others": AnimateFromCoeff( |
|
sadtalker_paths, |
|
device, |
|
), |
|
} |
|
|
|
def predict( |
|
self, |
|
source_image: Path = Input( |
|
description="Upload the source image, it can be video.mp4 or picture.png", |
|
), |
|
driven_audio: Path = Input( |
|
description="Upload the driven audio, accepts .wav and .mp4 file", |
|
), |
|
enhancer: str = Input( |
|
description="Choose a face enhancer", |
|
choices=["gfpgan", "RestoreFormer"], |
|
default="gfpgan", |
|
), |
|
preprocess: str = Input( |
|
description="how to preprocess the images", |
|
choices=["crop", "resize", "full"], |
|
default="full", |
|
), |
|
ref_eyeblink: Path = Input( |
|
description="path to reference video providing eye blinking", |
|
default=None, |
|
), |
|
ref_pose: Path = Input( |
|
description="path to reference video providing pose", |
|
default=None, |
|
), |
|
still: bool = Input( |
|
description="can crop back to the original videos for the full body aniamtion when preprocess is full", |
|
default=True, |
|
), |
|
) -> Path: |
|
"""Run a single prediction on the model""" |
|
|
|
animate_from_coeff = ( |
|
self.animate_from_coeff["full"] |
|
if preprocess == "full" |
|
else self.animate_from_coeff["others"] |
|
) |
|
|
|
args = load_default() |
|
args.pic_path = str(source_image) |
|
args.audio_path = str(driven_audio) |
|
device = "cuda" |
|
args.still = still |
|
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink) |
|
args.ref_pose = None if ref_pose is None else str(ref_pose) |
|
|
|
|
|
results_dir = "results" |
|
if os.path.exists(results_dir): |
|
shutil.rmtree(results_dir) |
|
os.makedirs(results_dir) |
|
first_frame_dir = os.path.join(results_dir, "first_frame_dir") |
|
os.makedirs(first_frame_dir) |
|
|
|
print("3DMM Extraction for source image") |
|
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
|
args.pic_path, first_frame_dir, preprocess, source_image_flag=True |
|
) |
|
if first_coeff_path is None: |
|
print("Can't get the coeffs of the input") |
|
return |
|
|
|
if ref_eyeblink is not None: |
|
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[ |
|
0 |
|
] |
|
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname) |
|
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) |
|
print("3DMM Extraction for the reference video providing eye blinking") |
|
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate( |
|
ref_eyeblink, ref_eyeblink_frame_dir |
|
) |
|
else: |
|
ref_eyeblink_coeff_path = None |
|
|
|
if ref_pose is not None: |
|
if ref_pose == ref_eyeblink: |
|
ref_pose_coeff_path = ref_eyeblink_coeff_path |
|
else: |
|
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] |
|
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname) |
|
os.makedirs(ref_pose_frame_dir, exist_ok=True) |
|
print("3DMM Extraction for the reference video providing pose") |
|
ref_pose_coeff_path, _, _ = self.preprocess_model.generate( |
|
ref_pose, ref_pose_frame_dir |
|
) |
|
else: |
|
ref_pose_coeff_path = None |
|
|
|
|
|
batch = get_data( |
|
first_coeff_path, |
|
args.audio_path, |
|
device, |
|
ref_eyeblink_coeff_path, |
|
still=still, |
|
) |
|
coeff_path = self.audio_to_coeff.generate( |
|
batch, results_dir, args.pose_style, ref_pose_coeff_path |
|
) |
|
|
|
print("coeff2video") |
|
data = get_facerender_data( |
|
coeff_path, |
|
crop_pic_path, |
|
first_coeff_path, |
|
args.audio_path, |
|
args.batch_size, |
|
args.input_yaw, |
|
args.input_pitch, |
|
args.input_roll, |
|
expression_scale=args.expression_scale, |
|
still_mode=still, |
|
preprocess=preprocess, |
|
) |
|
animate_from_coeff.generate( |
|
data, results_dir, args.pic_path, crop_info, |
|
enhancer=enhancer, background_enhancer=args.background_enhancer, |
|
preprocess=preprocess) |
|
|
|
output = "/tmp/out.mp4" |
|
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0]) |
|
shutil.copy(mp4_path, output) |
|
|
|
return Path(output) |
|
|
|
|
|
def load_default(): |
|
return Namespace( |
|
pose_style=0, |
|
batch_size=2, |
|
expression_scale=1.0, |
|
input_yaw=None, |
|
input_pitch=None, |
|
input_roll=None, |
|
background_enhancer=None, |
|
face3dvis=False, |
|
net_recon="resnet50", |
|
init_path=None, |
|
use_last_fc=False, |
|
bfm_folder="./src/config/", |
|
bfm_model="BFM_model_front.mat", |
|
focal=1015.0, |
|
center=112.0, |
|
camera_d=10.0, |
|
z_near=5.0, |
|
z_far=15.0, |
|
) |
|
|