keysync-demo / utils.py
Antoni Bigata
first commit
b5ce381
import torchvision
from einops import rearrange
import numpy as np
import math
import torchaudio
import torch
import importlib
from data_utils import create_masks_from_landmarks_box
import torch.nn.functional as F
def save_audio_video(
video,
audio=None,
frame_rate=25,
sample_rate=16000,
save_path="temp.mp4",
):
"""Save audio and video to a single file.
video: (t, c, h, w)
audio: (channels t)
"""
save_path = str(save_path)
if isinstance(video, torch.Tensor):
video = video.cpu().numpy()
video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8)
print("video_tensor shape", video_tensor.shape)
print("audio shape", audio.shape)
if audio is not None:
# Assuming audio is a tensor of shape (channels, samples)
audio_tensor = audio
torchvision.io.write_video(
save_path,
video_tensor,
fps=frame_rate,
audio_array=audio_tensor,
audio_fps=sample_rate,
video_codec="h264", # Specify a codec to address the error
audio_codec="aac",
)
else:
torchvision.io.write_video(
save_path,
video_tensor,
fps=frame_rate,
video_codec="h264", # Specify a codec to address the error
audio_codec="aac",
)
return save_path
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
len_file = audio.shape[-1]
if max_len_sec or max_len_raw:
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
if len_file < int(max_len):
# dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
# extened_wav = np.concatenate((audio_data, dummy[0]))
extened_wav = torch.nn.functional.pad(
audio, (0, int(max_len) - len_file), "constant"
)
else:
extened_wav = audio[:, : int(max_len)]
else:
extened_wav = audio
return extened_wav
def get_raw_audio(audio_path, audio_rate, fps=25):
audio, sr = torchaudio.load(audio_path, channels_first=True)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0]
samples_per_frame = math.ceil(audio_rate / fps)
n_frames = audio.shape[-1] / samples_per_frame
if not n_frames.is_integer():
audio = trim_pad_audio(
audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame
)
audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame)
return audio
def calculate_splits(tensor, min_last_size):
# Check the total number of elements in the tensor
total_size = tensor.size(1) # size along the second dimension
# If total size is less than the minimum size for the last split, return the tensor as a single split
if total_size <= min_last_size:
return [tensor]
# Calculate number of splits and size of each split
num_splits = (total_size - min_last_size) // min_last_size + 1
base_size = (total_size - min_last_size) // num_splits
# Create split sizes list
split_sizes = [base_size] * (num_splits - 1)
split_sizes.append(
total_size - sum(split_sizes)
) # Ensure the last split has at least min_last_size
# Adjust sizes to ensure they sum exactly to total_size
sum_sizes = sum(split_sizes)
while sum_sizes != total_size:
for i in range(num_splits):
if sum_sizes < total_size:
split_sizes[i] += 1
sum_sizes += 1
if sum_sizes >= total_size:
break
# Split the tensor
splits = torch.split(tensor, split_sizes, dim=1)
return splits
def make_into_multiple_of(x, multiple, dim=0):
"""Make the torch tensor into a multiple of the given number."""
if x.shape[dim] % multiple != 0:
x = torch.cat(
[
x,
torch.zeros(
*x.shape[:dim],
multiple - (x.shape[dim] % multiple),
*x.shape[dim + 1 :],
).to(x.device),
],
dim=dim,
)
return x
def default(value, default_value):
return default_value if value is None else value
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_landmarks(
landmarks: np.ndarray,
original_size,
target_size=(64, 64),
nose_index=28,
):
"""
Load and process facial landmarks to create masks.
Args:
landmarks: Facial landmarks array
original_size: Original size of the video frames
index: Index for non-dub mode
target_size: Target size for the output mask
is_dub: Whether this is for dubbing mode
what_mask: Type of mask to create ("full", "box", "heart", "mouth")
nose_index: Index of the nose landmark
Returns:
Processed landmarks mask
"""
expand_box = 0.0
if len(landmarks.shape) == 2:
landmarks = landmarks[None, ...]
mask = create_masks_from_landmarks_box(
landmarks,
(original_size[0], original_size[1]),
box_expand=expand_box,
nose_index=nose_index,
)
mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest")
return mask
def create_pipeline_inputs(
audio: torch.Tensor,
audio_interpolation: torch.Tensor,
num_frames: int,
video_emb: torch.Tensor,
landmarks: np.ndarray,
overlap: int = 1,
add_zero_flag: bool = False,
mask_arms: bool = None,
nose_index: int = 28,
):
"""
Create inputs for the keyframe generation and interpolation pipeline.
Args:
video: Input video tensor
audio: Audio embeddings for keyframe generation
audio_interpolation: Audio embeddings for interpolation
num_frames: Number of frames per segment
video_emb: Optional video embeddings
landmarks: Facial landmarks for mask generation
overlap: Number of frames to overlap between segments
add_zero_flag: Whether to add zero flag every num_frames
what_mask: Type of mask to generate ("box" or other options)
mask_arms: Optional mask for arms region
nose_index: Index of the nose landmark point
Returns:
Tuple containing all necessary inputs for the pipeline
"""
audio_interpolation_chunks = []
audio_image_preds = []
gt_chunks = []
gt_keyframes_chunks = []
# Adjustment for overlap to ensure segments are created properly
step = num_frames - overlap
# Ensure there's at least one step forward on each iteration
if step < 1:
step = 1
audio_image_preds_idx = []
audio_interp_preds_idx = []
masks_chunks = []
masks_interpolation_chunks = []
for i in range(0, audio.shape[0] - num_frames + 1, step):
try:
audio[i + num_frames - 1]
except IndexError:
break # Last chunk is smaller than num_frames
segment_end = i + num_frames
gt_chunks.append(video_emb[i:segment_end])
masks = load_landmarks(
landmarks[i:segment_end],
(512, 512),
target_size=(64, 64),
nose_index=nose_index,
)
if mask_arms is not None:
masks = np.logical_and(
masks, np.logical_not(mask_arms[i:segment_end, None, ...])
)
masks_interpolation_chunks.append(masks)
if i not in audio_image_preds_idx:
audio_image_preds.append(audio[i])
masks_chunks.append(masks[0])
gt_keyframes_chunks.append(video_emb[i])
audio_image_preds_idx.append(i)
if segment_end - 1 not in audio_image_preds_idx:
audio_image_preds_idx.append(segment_end - 1)
audio_image_preds.append(audio[segment_end - 1])
masks_chunks.append(masks[-1])
gt_keyframes_chunks.append(video_emb[segment_end - 1])
audio_interpolation_chunks.append(audio_interpolation[i:segment_end])
audio_interp_preds_idx.append([i, segment_end - 1])
# If the flag is on, add element 0 every 14 audio elements
if add_zero_flag:
first_element = audio_image_preds[0]
len_audio_image_preds = (
len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames
)
for i in range(0, len_audio_image_preds, num_frames):
audio_image_preds.insert(i, first_element)
audio_image_preds_idx.insert(i, None)
masks_chunks.insert(i, masks_chunks[0])
gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0])
to_remove = [idx is None for idx in audio_image_preds_idx]
audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx]
if add_zero_flag:
# Remove the added elements from the list
audio_image_preds_idx = [
sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i
]
interpolation_cond_list = []
for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2):
interpolation_cond_list.append(
[audio_image_preds_idx[i], audio_image_preds_idx[i + 1]]
)
# Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames
# Calculate the number of frames needed to make audio_image_preds a multiple of num_frames
frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames
# Extend from the start of audio_image_preds
audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed
masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed
gt_keyframes_chunks = (
gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed
)
to_remove = to_remove + [True] * frames_needed
audio_image_preds_idx_clone = (
audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed
)
print(
f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}"
)
# random_cond_idx = np.random.randint(0, len(video_emb))
random_cond_idx = 0
assert len(to_remove) == len(audio_image_preds), (
"to_remove and audio_image_preds must have the same length"
)
return (
gt_chunks,
gt_keyframes_chunks,
audio_interpolation_chunks,
audio_image_preds,
video_emb[random_cond_idx],
masks_chunks,
masks_interpolation_chunks,
to_remove,
audio_interp_preds_idx,
audio_image_preds_idx_clone,
)