|
|
|
|
|
|
|
"""
|
|
This script contains the main inference pipeline for processing audio and image inputs to generate a video output.
|
|
|
|
The script imports necessary packages and classes, defines a neural network model,
|
|
and contains functions for processing audio embeddings and performing inference.
|
|
|
|
The main inference process is outlined in the following steps:
|
|
1. Initialize the configuration.
|
|
2. Set up runtime variables.
|
|
3. Prepare the input data for inference (source image, face mask, and face embeddings).
|
|
4. Process the audio embeddings.
|
|
5. Build and freeze the model and scheduler.
|
|
6. Run the inference loop and save the result.
|
|
|
|
Usage:
|
|
This script can be run from the command line with the following arguments:
|
|
- audio_path: Path to the audio file.
|
|
- image_path: Path to the source image.
|
|
- face_mask_path: Path to the face mask image.
|
|
- face_emb_path: Path to the face embeddings file.
|
|
- output_path: Path to save the output video.
|
|
|
|
Example:
|
|
python scripts/inference.py --audio_path audio.wav --image_path image.jpg
|
|
--face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
|
|
import torch
|
|
from diffusers import AutoencoderKL, DDIMScheduler
|
|
from omegaconf import OmegaConf
|
|
from torch import nn
|
|
|
|
from hallo.animate.face_animate import FaceAnimatePipeline
|
|
from hallo.datasets.audio_processor import AudioProcessor
|
|
from hallo.datasets.image_processor import ImageProcessor
|
|
from hallo.models.audio_proj import AudioProjModel
|
|
from hallo.models.face_locator import FaceLocator
|
|
from hallo.models.image_proj import ImageProjModel
|
|
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
|
from hallo.models.unet_3d import UNet3DConditionModel
|
|
from hallo.utils.config import filter_non_none
|
|
from hallo.utils.util import tensor_to_video
|
|
|
|
|
|
class Net(nn.Module):
|
|
"""
|
|
The Net class combines all the necessary modules for the inference process.
|
|
|
|
Args:
|
|
reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
|
|
denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
|
|
face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
|
|
imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
|
|
audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
reference_unet: UNet2DConditionModel,
|
|
denoising_unet: UNet3DConditionModel,
|
|
face_locator: FaceLocator,
|
|
imageproj,
|
|
audioproj,
|
|
):
|
|
super().__init__()
|
|
self.reference_unet = reference_unet
|
|
self.denoising_unet = denoising_unet
|
|
self.face_locator = face_locator
|
|
self.imageproj = imageproj
|
|
self.audioproj = audioproj
|
|
|
|
def forward(self,):
|
|
"""
|
|
empty function to override abstract function of nn Module
|
|
"""
|
|
|
|
def get_modules(self):
|
|
"""
|
|
Simple method to avoid too-few-public-methods pylint error
|
|
"""
|
|
return {
|
|
"reference_unet": self.reference_unet,
|
|
"denoising_unet": self.denoising_unet,
|
|
"face_locator": self.face_locator,
|
|
"imageproj": self.imageproj,
|
|
"audioproj": self.audioproj,
|
|
}
|
|
|
|
|
|
def process_audio_emb(audio_emb):
|
|
"""
|
|
Process the audio embedding to concatenate with other tensors.
|
|
|
|
Parameters:
|
|
audio_emb (torch.Tensor): The audio embedding tensor to process.
|
|
|
|
Returns:
|
|
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
|
|
"""
|
|
concatenated_tensors = []
|
|
|
|
for i in range(audio_emb.shape[0]):
|
|
vectors_to_concat = [
|
|
audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
|
|
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
|
|
|
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
|
|
|
return audio_emb
|
|
|
|
|
|
|
|
def inference_process(args: argparse.Namespace):
|
|
"""
|
|
Perform inference processing.
|
|
|
|
Args:
|
|
args (argparse.Namespace): Command-line arguments.
|
|
|
|
This function initializes the configuration for the inference process. It sets up the necessary
|
|
modules and variables to prepare for the upcoming inference steps.
|
|
"""
|
|
|
|
cli_args = filter_non_none(vars(args))
|
|
config = OmegaConf.load(args.config)
|
|
config = OmegaConf.merge(config, cli_args)
|
|
source_image_path = config.source_image
|
|
driving_audio_path = config.driving_audio
|
|
save_path = config.save_path
|
|
if not os.path.exists(save_path):
|
|
os.makedirs(save_path)
|
|
motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
|
|
|
|
|
|
device = torch.device(
|
|
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
if config.weight_dtype == "fp16":
|
|
weight_dtype = torch.float16
|
|
elif config.weight_dtype == "bf16":
|
|
weight_dtype = torch.bfloat16
|
|
elif config.weight_dtype == "fp32":
|
|
weight_dtype = torch.float32
|
|
else:
|
|
weight_dtype = torch.float32
|
|
|
|
|
|
|
|
img_size = (config.data.source_image.width,
|
|
config.data.source_image.height)
|
|
clip_length = config.data.n_sample_frames
|
|
face_analysis_model_path = config.face_analysis.model_path
|
|
with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
|
|
source_image_pixels, \
|
|
source_image_face_region, \
|
|
source_image_face_emb, \
|
|
source_image_full_mask, \
|
|
source_image_face_mask, \
|
|
source_image_lip_mask = image_processor.preprocess(
|
|
source_image_path, save_path, config.face_expand_ratio)
|
|
|
|
|
|
sample_rate = config.data.driving_audio.sample_rate
|
|
assert sample_rate == 16000, "audio sample rate must be 16000"
|
|
fps = config.data.export_video.fps
|
|
wav2vec_model_path = config.wav2vec.model_path
|
|
wav2vec_only_last_features = config.wav2vec.features == "last"
|
|
audio_separator_model_file = config.audio_separator.model_path
|
|
with AudioProcessor(
|
|
sample_rate,
|
|
fps,
|
|
wav2vec_model_path,
|
|
wav2vec_only_last_features,
|
|
os.path.dirname(audio_separator_model_file),
|
|
os.path.basename(audio_separator_model_file),
|
|
os.path.join(save_path, "audio_preprocess")
|
|
) as audio_processor:
|
|
audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
|
|
|
|
|
|
sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
|
|
if config.enable_zero_snr:
|
|
sched_kwargs.update(
|
|
rescale_betas_zero_snr=True,
|
|
timestep_spacing="trailing",
|
|
prediction_type="v_prediction",
|
|
)
|
|
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
|
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
|
|
|
vae = AutoencoderKL.from_pretrained(config.vae.model_path)
|
|
reference_unet = UNet2DConditionModel.from_pretrained(
|
|
config.base_model_path, subfolder="unet")
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
|
config.base_model_path,
|
|
config.motion_module_path,
|
|
subfolder="unet",
|
|
unet_additional_kwargs=OmegaConf.to_container(
|
|
config.unet_additional_kwargs),
|
|
use_landmark=False,
|
|
)
|
|
face_locator = FaceLocator(conditioning_embedding_channels=320)
|
|
image_proj = ImageProjModel(
|
|
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
|
clip_embeddings_dim=512,
|
|
clip_extra_context_tokens=4,
|
|
)
|
|
|
|
audio_proj = AudioProjModel(
|
|
seq_len=5,
|
|
blocks=12,
|
|
channels=768,
|
|
intermediate_dim=512,
|
|
output_dim=768,
|
|
context_tokens=32,
|
|
).to(device=device, dtype=weight_dtype)
|
|
|
|
audio_ckpt_dir = config.audio_ckpt_dir
|
|
|
|
|
|
|
|
vae.requires_grad_(False)
|
|
image_proj.requires_grad_(False)
|
|
reference_unet.requires_grad_(False)
|
|
denoising_unet.requires_grad_(False)
|
|
face_locator.requires_grad_(False)
|
|
audio_proj.requires_grad_(False)
|
|
|
|
reference_unet.enable_gradient_checkpointing()
|
|
denoising_unet.enable_gradient_checkpointing()
|
|
|
|
net = Net(
|
|
reference_unet,
|
|
denoising_unet,
|
|
face_locator,
|
|
image_proj,
|
|
audio_proj,
|
|
)
|
|
|
|
m,u = net.load_state_dict(
|
|
torch.load(
|
|
os.path.join(audio_ckpt_dir, "net.pth"),
|
|
map_location="cpu",
|
|
),
|
|
)
|
|
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
|
|
print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
|
|
|
|
|
|
pipeline = FaceAnimatePipeline(
|
|
vae=vae,
|
|
reference_unet=net.reference_unet,
|
|
denoising_unet=net.denoising_unet,
|
|
face_locator=net.face_locator,
|
|
scheduler=val_noise_scheduler,
|
|
image_proj=net.imageproj,
|
|
)
|
|
pipeline.to(device=device, dtype=weight_dtype)
|
|
|
|
audio_emb = process_audio_emb(audio_emb)
|
|
|
|
source_image_pixels = source_image_pixels.unsqueeze(0)
|
|
source_image_face_region = source_image_face_region.unsqueeze(0)
|
|
source_image_face_emb = source_image_face_emb.reshape(1, -1)
|
|
source_image_face_emb = torch.tensor(source_image_face_emb)
|
|
|
|
source_image_full_mask = [
|
|
(mask.repeat(clip_length, 1))
|
|
for mask in source_image_full_mask
|
|
]
|
|
source_image_face_mask = [
|
|
(mask.repeat(clip_length, 1))
|
|
for mask in source_image_face_mask
|
|
]
|
|
source_image_lip_mask = [
|
|
(mask.repeat(clip_length, 1))
|
|
for mask in source_image_lip_mask
|
|
]
|
|
|
|
|
|
times = audio_emb.shape[0] // clip_length
|
|
|
|
tensor_result = []
|
|
|
|
generator = torch.manual_seed(42)
|
|
|
|
for t in range(times):
|
|
print(f"[{t+1}/{times}]")
|
|
|
|
if len(tensor_result) == 0:
|
|
|
|
motion_zeros = source_image_pixels.repeat(
|
|
config.data.n_motion_frames, 1, 1, 1)
|
|
motion_zeros = motion_zeros.to(
|
|
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
|
pixel_values_ref_img = torch.cat(
|
|
[source_image_pixels, motion_zeros], dim=0)
|
|
else:
|
|
motion_frames = tensor_result[-1][0]
|
|
motion_frames = motion_frames.permute(1, 0, 2, 3)
|
|
motion_frames = motion_frames[0-config.data.n_motion_frames:]
|
|
motion_frames = motion_frames * 2.0 - 1.0
|
|
motion_frames = motion_frames.to(
|
|
dtype=source_image_pixels.dtype, device=source_image_pixels.device)
|
|
pixel_values_ref_img = torch.cat(
|
|
[source_image_pixels, motion_frames], dim=0)
|
|
|
|
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
|
|
|
audio_tensor = audio_emb[
|
|
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
|
|
]
|
|
audio_tensor = audio_tensor.unsqueeze(0)
|
|
audio_tensor = audio_tensor.to(
|
|
device=net.audioproj.device, dtype=net.audioproj.dtype)
|
|
audio_tensor = net.audioproj(audio_tensor)
|
|
|
|
pipeline_output = pipeline(
|
|
ref_image=pixel_values_ref_img,
|
|
audio_tensor=audio_tensor,
|
|
face_emb=source_image_face_emb,
|
|
face_mask=source_image_face_region,
|
|
pixel_values_full_mask=source_image_full_mask,
|
|
pixel_values_face_mask=source_image_face_mask,
|
|
pixel_values_lip_mask=source_image_lip_mask,
|
|
width=img_size[0],
|
|
height=img_size[1],
|
|
video_length=clip_length,
|
|
num_inference_steps=config.inference_steps,
|
|
guidance_scale=config.cfg_scale,
|
|
generator=generator,
|
|
motion_scale=motion_scale,
|
|
)
|
|
|
|
tensor_result.append(pipeline_output.videos)
|
|
|
|
tensor_result = torch.cat(tensor_result, dim=2)
|
|
tensor_result = tensor_result.squeeze(0)
|
|
tensor_result = tensor_result[:, :audio_length]
|
|
|
|
output_file = config.output
|
|
|
|
tensor_to_video(tensor_result, output_file, driving_audio_path)
|
|
return output_file
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"-c", "--config", default="configs/inference/default.yaml")
|
|
parser.add_argument("--source_image", type=str, required=False,
|
|
help="source image")
|
|
parser.add_argument("--driving_audio", type=str, required=False,
|
|
help="driving audio")
|
|
parser.add_argument(
|
|
"--output", type=str, help="output video file name", default=".cache/output.mp4")
|
|
parser.add_argument(
|
|
"--pose_weight", type=float, help="weight of pose", required=False)
|
|
parser.add_argument(
|
|
"--face_weight", type=float, help="weight of face", required=False)
|
|
parser.add_argument(
|
|
"--lip_weight", type=float, help="weight of lip", required=False)
|
|
parser.add_argument(
|
|
"--face_expand_ratio", type=float, help="face region", required=False)
|
|
parser.add_argument(
|
|
"--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False)
|
|
|
|
|
|
command_line_args = parser.parse_args()
|
|
|
|
inference_process(command_line_args)
|
|
|