|
import os |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torchaudio.functional |
|
import torchvision.io |
|
from PIL import Image |
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from insightface.app import FaceAnalysis |
|
from omegaconf import OmegaConf |
|
from transformers import CLIPVisionModelWithProjection, Wav2Vec2Model, Wav2Vec2Processor |
|
|
|
from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection |
|
from pipelines import VExpressPipeline |
|
from pipelines.utils import draw_kps_image, save_video |
|
from pipelines.utils import retarget_kps |
|
|
|
|
|
def load_reference_net(unet_config_path, reference_net_path, dtype, device): |
|
reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device) |
|
reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False) |
|
print(f'Loaded weights of Reference Net from {reference_net_path}.') |
|
return reference_net |
|
|
|
|
|
def load_denoising_unet(unet_config_path, denoising_unet_path, motion_module_path, dtype, device): |
|
inference_config_path = './inference_v2.yaml' |
|
inference_config = OmegaConf.load(inference_config_path) |
|
denoising_unet = UNet3DConditionModel.from_config_2d( |
|
unet_config_path, |
|
unet_additional_kwargs=inference_config.unet_additional_kwargs, |
|
).to(dtype=dtype, device=device) |
|
denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False) |
|
print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.') |
|
|
|
denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False) |
|
print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.') |
|
|
|
return denoising_unet |
|
|
|
|
|
def load_v_kps_guider(v_kps_guider_path, dtype, device): |
|
v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device) |
|
v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu")) |
|
print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.') |
|
return v_kps_guider |
|
|
|
|
|
def load_audio_projection( |
|
audio_projection_path, |
|
dtype, |
|
device, |
|
inp_dim: int, |
|
mid_dim: int, |
|
out_dim: int, |
|
inp_seq_len: int, |
|
out_seq_len: int, |
|
): |
|
audio_projection = AudioProjection( |
|
dim=mid_dim, |
|
depth=4, |
|
dim_head=64, |
|
heads=12, |
|
num_queries=out_seq_len, |
|
embedding_dim=inp_dim, |
|
output_dim=out_dim, |
|
ff_mult=4, |
|
max_seq_len=inp_seq_len, |
|
).to(dtype=dtype, device=device) |
|
audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu')) |
|
print(f'Loaded weights of Audio Projection from {audio_projection_path}.') |
|
return audio_projection |
|
|
|
|
|
def get_scheduler(): |
|
inference_config_path = './inference_v2.yaml' |
|
inference_config = OmegaConf.load(inference_config_path) |
|
scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs) |
|
scheduler = DDIMScheduler(**scheduler_kwargs) |
|
return scheduler |
|
|
|
class InferenceEngine(object): |
|
|
|
|
|
def __init__(self, args): |
|
self.init_params(args) |
|
self.load_models() |
|
self.set_generator() |
|
self.set_vexpress_pipeline() |
|
self.set_face_analysis_app() |
|
|
|
|
|
def init_params(self, args): |
|
for key, value in args.items(): |
|
setattr(self, key, value) |
|
|
|
print("Image width: ", self.image_width) |
|
print("Image height: ", self.image_height) |
|
|
|
|
|
|
|
def load_models(self): |
|
self.device = torch.device(f'cuda:{self.gpu_id}') |
|
self.dtype = torch.float16 if self.dtype == 'fp16' else torch.float32 |
|
|
|
self.vae = AutoencoderKL.from_pretrained(self.vae_path).to(dtype=self.dtype, device=self.device) |
|
print("VAE exists: ", self.vae) |
|
self.audio_encoder = Wav2Vec2Model.from_pretrained(self.audio_encoder_path).to(dtype=self.dtype, device=self.device) |
|
self.audio_processor = Wav2Vec2Processor.from_pretrained(self.audio_encoder_path) |
|
|
|
self.scheduler = get_scheduler() |
|
self.reference_net = load_reference_net(self.unet_config_path, self.reference_net_path, self.dtype, self.device) |
|
self.denoising_unet = load_denoising_unet(self.unet_config_path, self.denoising_unet_path, self.motion_module_path, self.dtype, self.device) |
|
self.v_kps_guider = load_v_kps_guider(self.v_kps_guider_path, self.dtype, self.device) |
|
self.audio_projection = load_audio_projection( |
|
self.audio_projection_path, |
|
self.dtype, |
|
self.device, |
|
inp_dim=self.denoising_unet.config.cross_attention_dim, |
|
mid_dim=self.denoising_unet.config.cross_attention_dim, |
|
out_dim=self.denoising_unet.config.cross_attention_dim, |
|
inp_seq_len=2 * (2 * self.num_pad_audio_frames + 1), |
|
out_seq_len=2 * self.num_pad_audio_frames + 1, |
|
) |
|
|
|
if is_xformers_available(): |
|
self.reference_net.enable_xformers_memory_efficient_attention() |
|
self.denoising_unet.enable_xformers_memory_efficient_attention() |
|
else: |
|
raise ValueError("xformers is not available. Make sure it is installed correctly") |
|
|
|
|
|
def set_generator(self): |
|
self.generator = torch.manual_seed(self.seed) |
|
|
|
|
|
def set_vexpress_pipeline(self): |
|
print("VAE exists (2): ", self.vae) |
|
self.pipeline = VExpressPipeline( |
|
vae=self.vae, |
|
reference_net=self.reference_net, |
|
denoising_unet=self.denoising_unet, |
|
v_kps_guider=self.v_kps_guider, |
|
audio_processor=self.audio_processor, |
|
audio_encoder=self.audio_encoder, |
|
audio_projection=self.audio_projection, |
|
scheduler=self.scheduler, |
|
).to(dtype=self.dtype, device=self.device) |
|
|
|
|
|
def set_face_analysis_app(self): |
|
self.app = FaceAnalysis( |
|
providers=['CUDAExecutionProvider'], |
|
provider_options=[{'device_id': self.gpu_id}], |
|
root=self.insightface_model_path, |
|
) |
|
self.app.prepare(ctx_id=0, det_size=(self.image_height, self.image_width)) |
|
|
|
|
|
def get_reference_image_for_kps(self, reference_image_path): |
|
reference_image = Image.open(reference_image_path).convert('RGB') |
|
print("Image width ???", self.image_width) |
|
reference_image = reference_image.resize((self.image_height, self.image_width)) |
|
|
|
reference_image_for_kps = cv2.imread(reference_image_path) |
|
reference_image_for_kps = cv2.resize(reference_image_for_kps, (self.image_height, self.image_width)) |
|
reference_kps = self.app.get(reference_image_for_kps)[0].kps[:3] |
|
return reference_image, reference_image_for_kps, reference_kps |
|
|
|
|
|
def get_waveform_video_length(self, audio_path): |
|
_, audio_waveform, meta_info = torchvision.io.read_video(audio_path, pts_unit='sec') |
|
audio_sampling_rate = meta_info['audio_fps'] |
|
print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.') |
|
if audio_sampling_rate != self.standard_audio_sampling_rate: |
|
audio_waveform = torchaudio.functional.resample( |
|
audio_waveform, |
|
orig_freq=audio_sampling_rate, |
|
new_freq=self.standard_audio_sampling_rate, |
|
) |
|
audio_waveform = audio_waveform.mean(dim=0) |
|
|
|
duration = audio_waveform.shape[0] / self.standard_audio_sampling_rate |
|
video_length = int(duration * self.fps) |
|
print(f'The corresponding video length is {video_length}.') |
|
return audio_waveform, video_length |
|
|
|
|
|
def get_kps_sequence(self, kps_path, reference_kps, video_length, retarget_strategy): |
|
if kps_path != "": |
|
assert os.path.exists(kps_path), f'{kps_path} does not exist' |
|
kps_sequence = torch.tensor(torch.load(kps_path)) |
|
print(f'The original length of kps sequence is {kps_sequence.shape[0]}.') |
|
kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear') |
|
kps_sequence = kps_sequence.permute(2, 0, 1) |
|
print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.') |
|
|
|
if retarget_strategy == 'fix_face': |
|
kps_sequence = torch.tensor([reference_kps] * video_length) |
|
elif retarget_strategy == 'no_retarget': |
|
kps_sequence = kps_sequence |
|
elif retarget_strategy == 'offset_retarget': |
|
kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True) |
|
elif retarget_strategy == 'naive_retarget': |
|
kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False) |
|
else: |
|
raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.') |
|
|
|
return kps_sequence |
|
|
|
|
|
def get_kps_images(self, kps_sequence, reference_image_for_kps, video_length): |
|
kps_images = [] |
|
for i in range(video_length): |
|
kps_image = np.zeros_like(reference_image_for_kps) |
|
kps_image = draw_kps_image(kps_image, kps_sequence[i]) |
|
kps_images.append(Image.fromarray(kps_image)) |
|
return kps_images |
|
|
|
def get_video_latents(self, reference_image, kps_images, audio_waveform, video_length, reference_attention_weight, audio_attention_weight): |
|
vae_scale_factor = 8 |
|
latent_height = self.image_height // vae_scale_factor |
|
latent_width = self.image_width // vae_scale_factor |
|
|
|
latent_shape = (1, 4, video_length, latent_height, latent_width) |
|
vae_latents = randn_tensor(latent_shape, generator=self.generator, device=self.device, dtype=self.dtype) |
|
|
|
video_latents = self.pipeline( |
|
vae_latents=vae_latents, |
|
reference_image=reference_image, |
|
kps_images=kps_images, |
|
audio_waveform=audio_waveform, |
|
width=self.image_width, |
|
height=self.image_height, |
|
video_length=video_length, |
|
num_inference_steps=self.num_inference_steps, |
|
guidance_scale=self.guidance_scale, |
|
context_frames=self.context_frames, |
|
context_stride=self.context_stride, |
|
context_overlap=self.context_overlap, |
|
reference_attention_weight=reference_attention_weight, |
|
audio_attention_weight=audio_attention_weight, |
|
num_pad_audio_frames=self.num_pad_audio_frames, |
|
generator=self.generator, |
|
).video_latents |
|
|
|
return video_latents |
|
|
|
|
|
def get_video_tensor(self, video_latents): |
|
video_tensor = self.pipeline.decode_latents(video_latents) |
|
if isinstance(video_tensor, np.ndarray): |
|
video_tensor = torch.from_numpy(video_tensor) |
|
return video_tensor |
|
|
|
|
|
def save_video_tensor(self, video_tensor, audio_path, output_path): |
|
save_video(video_tensor, audio_path, output_path, self.fps) |
|
print(f'The generated video has been saved at {output_path}.') |
|
|
|
def infer( |
|
self, |
|
reference_image_path, audio_path, kps_path, |
|
output_path, |
|
retarget_strategy, |
|
reference_attention_weight, audio_attention_weight): |
|
reference_image, reference_image_for_kps, reference_kps = self.get_reference_image_for_kps(reference_image_path) |
|
audio_waveform, video_length = self.get_waveform_video_length(audio_path) |
|
kps_sequence = self.get_kps_sequence(kps_path, reference_kps, video_length, retarget_strategy) |
|
kps_images = self.get_kps_images(kps_sequence, reference_image_for_kps, video_length) |
|
|
|
video_latents = self.get_video_latents( |
|
reference_image, kps_images, audio_waveform, |
|
video_length, |
|
reference_attention_weight, audio_attention_weight) |
|
video_tensor = self.get_video_tensor(video_latents) |
|
|
|
self.save_video_tensor(video_tensor, audio_path, output_path) |
|
|
|
|