Spaces:
Running
on
L40S
Running
on
L40S
import argparse | |
import logging | |
import os | |
import torch | |
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from omegaconf import OmegaConf | |
from packaging import version | |
from tqdm import tqdm | |
from memo.models.audio_proj import AudioProjModel | |
from memo.models.image_proj import ImageProjModel | |
from memo.models.unet_2d_condition import UNet2DConditionModel | |
from memo.models.unet_3d import UNet3DConditionModel | |
from memo.pipelines.video_pipeline import VideoPipeline | |
from memo.utils.audio_utils import extract_audio_emotion_labels, preprocess_audio, resample_audio | |
from memo.utils.vision_utils import preprocess_image, tensor_to_video | |
logger = logging.getLogger("memo") | |
logger.setLevel(logging.INFO) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Inference script for MEMO") | |
parser.add_argument("--config", type=str, default="configs/inference.yaml") | |
parser.add_argument("--input_image", type=str) | |
parser.add_argument("--input_audio", type=str) | |
parser.add_argument("--output_dir", type=str) | |
parser.add_argument("--seed", type=int, default=42) | |
return parser.parse_args() | |
def main(): | |
# Parse arguments | |
args = parse_args() | |
input_image_path = args.input_image | |
input_audio_path = args.input_audio | |
if "wav" not in input_audio_path: | |
logger.warning("MEMO might not generate full-length video for non-wav audio file.") | |
output_dir = args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
output_video_path = os.path.join( | |
output_dir, | |
f"{os.path.basename(input_image_path).split('.')[0]}_{os.path.basename(input_audio_path).split('.')[0]}.mp4", | |
) | |
if os.path.exists(output_video_path): | |
logger.info(f"Output file {output_video_path} already exists. Skipping inference.") | |
return | |
generator = torch.manual_seed(args.seed) | |
logger.info(f"Loading config from {args.config}") | |
config = OmegaConf.load(args.config) | |
# Determine model paths | |
if config.model_name_or_path == "memoavatar/memo": | |
logger.info( | |
f"The MEMO model will be downloaded from Hugging Face to the default cache directory. The models for face analysis and vocal separation will be downloaded to {config.misc_model_dir}." | |
) | |
face_analysis = os.path.join(config.misc_model_dir, "misc/face_analysis") | |
os.makedirs(face_analysis, exist_ok=True) | |
for model in [ | |
"1k3d68.onnx", | |
"2d106det.onnx", | |
"face_landmarker_v2_with_blendskapes.task", | |
"genderage.onnx", | |
"glintr100.onnx", | |
"scrfd_10g_bnkps.onnx", | |
]: | |
if not os.path.exists(os.path.join(face_analysis, model)): | |
logger.info(f"Downloading {model} to {face_analysis}") | |
os.system( | |
f"wget -P {face_analysis} https://huggingface.co/memoavatar/memo/raw/main/misc/face_analysis/models/{model}" | |
) | |
logger.info(f"Use face analysis models from {face_analysis}") | |
vocal_separator = os.path.join(config.misc_model_dir, "misc/vocal_separator/Kim_Vocal_2.onnx") | |
if os.path.exists(vocal_separator): | |
logger.info(f"Vocal separator {vocal_separator} already exists. Skipping download.") | |
else: | |
logger.info(f"Downloading vocal separator to {vocal_separator}") | |
os.makedirs(os.path.dirname(vocal_separator), exist_ok=True) | |
os.system( | |
f"wget -P {os.path.dirname(vocal_separator)} https://huggingface.co/memoavatar/memo/raw/main/misc/vocal_separator/Kim_Vocal_2.onnx" | |
) | |
else: | |
logger.info(f"Loading manually specified model path: {config.model_name_or_path}") | |
face_analysis = os.path.join(config.model_name_or_path, "misc/face_analysis") | |
vocal_separator = os.path.join(config.model_name_or_path, "misc/vocal_separator/Kim_Vocal_2.onnx") | |
# Set up device and weight dtype | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
if config.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
elif config.weight_dtype == "bf16": | |
weight_dtype = torch.bfloat16 | |
elif config.weight_dtype == "fp32": | |
weight_dtype = torch.float32 | |
else: | |
weight_dtype = torch.float32 | |
logger.info(f"Inference dtype: {weight_dtype}") | |
logger.info(f"Processing image {input_image_path}") | |
img_size = (config.resolution, config.resolution) | |
pixel_values, face_emb = preprocess_image( | |
face_analysis_model=face_analysis, | |
image_path=input_image_path, | |
image_size=config.resolution, | |
) | |
logger.info(f"Processing audio {input_audio_path}") | |
cache_dir = os.path.join(output_dir, "audio_preprocess") | |
os.makedirs(cache_dir, exist_ok=True) | |
input_audio_path = resample_audio( | |
input_audio_path, | |
os.path.join(cache_dir, f"{os.path.basename(input_audio_path).split('.')[0]}-16k.wav"), | |
) | |
audio_emb, audio_length = preprocess_audio( | |
wav_path=input_audio_path, | |
num_generated_frames_per_clip=config.num_generated_frames_per_clip, | |
fps=config.fps, | |
wav2vec_model=config.wav2vec, | |
vocal_separator_model=vocal_separator, | |
cache_dir=cache_dir, | |
device=device, | |
) | |
logger.info("Processing audio emotion") | |
audio_emotion, num_emotion_classes = extract_audio_emotion_labels( | |
model=config.model_name_or_path, | |
wav_path=input_audio_path, | |
emotion2vec_model=config.emotion2vec, | |
audio_length=audio_length, | |
device=device, | |
) | |
logger.info("Loading models") | |
vae = AutoencoderKL.from_pretrained(config.vae).to(device=device, dtype=weight_dtype) | |
reference_net = UNet2DConditionModel.from_pretrained( | |
config.model_name_or_path, subfolder="reference_net", use_safetensors=True | |
) | |
diffusion_net = UNet3DConditionModel.from_pretrained( | |
config.model_name_or_path, subfolder="diffusion_net", use_safetensors=True | |
) | |
image_proj = ImageProjModel.from_pretrained( | |
config.model_name_or_path, subfolder="image_proj", use_safetensors=True | |
) | |
audio_proj = AudioProjModel.from_pretrained( | |
config.model_name_or_path, subfolder="audio_proj", use_safetensors=True | |
) | |
vae.requires_grad_(False).eval() | |
reference_net.requires_grad_(False).eval() | |
diffusion_net.requires_grad_(False).eval() | |
image_proj.requires_grad_(False).eval() | |
audio_proj.requires_grad_(False).eval() | |
# Enable memory-efficient attention for xFormers | |
if config.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
import xformers | |
xformers_version = version.parse(xformers.__version__) | |
if xformers_version == version.parse("0.0.16"): | |
logger.info( | |
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
) | |
reference_net.enable_xformers_memory_efficient_attention() | |
diffusion_net.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# Create inference pipeline | |
noise_scheduler = FlowMatchEulerDiscreteScheduler() | |
pipeline = VideoPipeline( | |
vae=vae, | |
reference_net=reference_net, | |
diffusion_net=diffusion_net, | |
scheduler=noise_scheduler, | |
image_proj=image_proj, | |
) | |
pipeline.to(device=device, dtype=weight_dtype) | |
video_frames = [] | |
num_clips = audio_emb.shape[0] // config.num_generated_frames_per_clip | |
for t in tqdm(range(num_clips), desc="Generating video clips"): | |
if len(video_frames) == 0: | |
# Initialize the first past frames with reference image | |
past_frames = pixel_values.repeat(config.num_init_past_frames, 1, 1, 1) | |
past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) | |
pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) | |
else: | |
past_frames = video_frames[-1][0] | |
past_frames = past_frames.permute(1, 0, 2, 3) | |
past_frames = past_frames[0 - config.num_past_frames :] | |
past_frames = past_frames * 2.0 - 1.0 | |
past_frames = past_frames.to(dtype=pixel_values.dtype, device=pixel_values.device) | |
pixel_values_ref_img = torch.cat([pixel_values, past_frames], dim=0) | |
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) | |
audio_tensor = ( | |
audio_emb[ | |
t | |
* config.num_generated_frames_per_clip : min( | |
(t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0] | |
) | |
] | |
.unsqueeze(0) | |
.to(device=audio_proj.device, dtype=audio_proj.dtype) | |
) | |
audio_tensor = audio_proj(audio_tensor) | |
audio_emotion_tensor = audio_emotion[ | |
t | |
* config.num_generated_frames_per_clip : min( | |
(t + 1) * config.num_generated_frames_per_clip, audio_emb.shape[0] | |
) | |
] | |
pipeline_output = pipeline( | |
ref_image=pixel_values_ref_img, | |
audio_tensor=audio_tensor, | |
audio_emotion=audio_emotion_tensor, | |
emotion_class_num=num_emotion_classes, | |
face_emb=face_emb, | |
width=img_size[0], | |
height=img_size[1], | |
video_length=config.num_generated_frames_per_clip, | |
num_inference_steps=config.inference_steps, | |
guidance_scale=config.cfg_scale, | |
generator=generator, | |
) | |
video_frames.append(pipeline_output.videos) | |
video_frames = torch.cat(video_frames, dim=2) | |
video_frames = video_frames.squeeze(0) | |
video_frames = video_frames[:, :audio_length] | |
tensor_to_video(video_frames, output_video_path, input_audio_path, fps=config.fps) | |
if __name__ == "__main__": | |
main() | |