Spaces:
Running
on
Zero
Running
on
Zero
import torchvision | |
from einops import rearrange | |
import numpy as np | |
import math | |
import torchaudio | |
import torch | |
import importlib | |
from data_utils import create_masks_from_landmarks_box | |
import torch.nn.functional as F | |
def save_audio_video( | |
video, | |
audio=None, | |
frame_rate=25, | |
sample_rate=16000, | |
save_path="temp.mp4", | |
): | |
"""Save audio and video to a single file. | |
video: (t, c, h, w) | |
audio: (channels t) | |
""" | |
save_path = str(save_path) | |
if isinstance(video, torch.Tensor): | |
video = video.cpu().numpy() | |
video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8) | |
print("video_tensor shape", video_tensor.shape) | |
print("audio shape", audio.shape) | |
if audio is not None: | |
# Assuming audio is a tensor of shape (channels, samples) | |
audio_tensor = audio | |
torchvision.io.write_video( | |
save_path, | |
video_tensor, | |
fps=frame_rate, | |
audio_array=audio_tensor, | |
audio_fps=sample_rate, | |
video_codec="h264", # Specify a codec to address the error | |
audio_codec="aac", | |
) | |
else: | |
torchvision.io.write_video( | |
save_path, | |
video_tensor, | |
fps=frame_rate, | |
video_codec="h264", # Specify a codec to address the error | |
audio_codec="aac", | |
) | |
return save_path | |
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): | |
len_file = audio.shape[-1] | |
if max_len_sec or max_len_raw: | |
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) | |
if len_file < int(max_len): | |
# dummy = np.zeros((1, int(max_len_sec * sr) - len_file)) | |
# extened_wav = np.concatenate((audio_data, dummy[0])) | |
extened_wav = torch.nn.functional.pad( | |
audio, (0, int(max_len) - len_file), "constant" | |
) | |
else: | |
extened_wav = audio[:, : int(max_len)] | |
else: | |
extened_wav = audio | |
return extened_wav | |
def get_raw_audio(audio_path, audio_rate, fps=25): | |
audio, sr = torchaudio.load(audio_path, channels_first=True) | |
if audio.shape[0] > 1: | |
audio = audio.mean(0, keepdim=True) | |
audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0] | |
samples_per_frame = math.ceil(audio_rate / fps) | |
n_frames = audio.shape[-1] / samples_per_frame | |
if not n_frames.is_integer(): | |
audio = trim_pad_audio( | |
audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame | |
) | |
audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame) | |
return audio | |
def calculate_splits(tensor, min_last_size): | |
# Check the total number of elements in the tensor | |
total_size = tensor.size(1) # size along the second dimension | |
# If total size is less than the minimum size for the last split, return the tensor as a single split | |
if total_size <= min_last_size: | |
return [tensor] | |
# Calculate number of splits and size of each split | |
num_splits = (total_size - min_last_size) // min_last_size + 1 | |
base_size = (total_size - min_last_size) // num_splits | |
# Create split sizes list | |
split_sizes = [base_size] * (num_splits - 1) | |
split_sizes.append( | |
total_size - sum(split_sizes) | |
) # Ensure the last split has at least min_last_size | |
# Adjust sizes to ensure they sum exactly to total_size | |
sum_sizes = sum(split_sizes) | |
while sum_sizes != total_size: | |
for i in range(num_splits): | |
if sum_sizes < total_size: | |
split_sizes[i] += 1 | |
sum_sizes += 1 | |
if sum_sizes >= total_size: | |
break | |
# Split the tensor | |
splits = torch.split(tensor, split_sizes, dim=1) | |
return splits | |
def make_into_multiple_of(x, multiple, dim=0): | |
"""Make the torch tensor into a multiple of the given number.""" | |
if x.shape[dim] % multiple != 0: | |
x = torch.cat( | |
[ | |
x, | |
torch.zeros( | |
*x.shape[:dim], | |
multiple - (x.shape[dim] % multiple), | |
*x.shape[dim + 1 :], | |
).to(x.device), | |
], | |
dim=dim, | |
) | |
return x | |
def default(value, default_value): | |
return default_value if value is None else value | |
def instantiate_from_config(config): | |
if not "target" in config: | |
if config == "__is_first_stage__": | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
module, cls = string.rsplit(".", 1) | |
if invalidate_cache: | |
importlib.invalidate_caches() | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def load_landmarks( | |
landmarks: np.ndarray, | |
original_size, | |
target_size=(64, 64), | |
nose_index=28, | |
): | |
""" | |
Load and process facial landmarks to create masks. | |
Args: | |
landmarks: Facial landmarks array | |
original_size: Original size of the video frames | |
index: Index for non-dub mode | |
target_size: Target size for the output mask | |
is_dub: Whether this is for dubbing mode | |
what_mask: Type of mask to create ("full", "box", "heart", "mouth") | |
nose_index: Index of the nose landmark | |
Returns: | |
Processed landmarks mask | |
""" | |
expand_box = 0.0 | |
if len(landmarks.shape) == 2: | |
landmarks = landmarks[None, ...] | |
mask = create_masks_from_landmarks_box( | |
landmarks, | |
(original_size[0], original_size[1]), | |
box_expand=expand_box, | |
nose_index=nose_index, | |
) | |
mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest") | |
return mask | |
def create_pipeline_inputs( | |
audio: torch.Tensor, | |
audio_interpolation: torch.Tensor, | |
num_frames: int, | |
video_emb: torch.Tensor, | |
landmarks: np.ndarray, | |
overlap: int = 1, | |
add_zero_flag: bool = False, | |
mask_arms: bool = None, | |
nose_index: int = 28, | |
): | |
""" | |
Create inputs for the keyframe generation and interpolation pipeline. | |
Args: | |
video: Input video tensor | |
audio: Audio embeddings for keyframe generation | |
audio_interpolation: Audio embeddings for interpolation | |
num_frames: Number of frames per segment | |
video_emb: Optional video embeddings | |
landmarks: Facial landmarks for mask generation | |
overlap: Number of frames to overlap between segments | |
add_zero_flag: Whether to add zero flag every num_frames | |
what_mask: Type of mask to generate ("box" or other options) | |
mask_arms: Optional mask for arms region | |
nose_index: Index of the nose landmark point | |
Returns: | |
Tuple containing all necessary inputs for the pipeline | |
""" | |
audio_interpolation_chunks = [] | |
audio_image_preds = [] | |
gt_chunks = [] | |
gt_keyframes_chunks = [] | |
# Adjustment for overlap to ensure segments are created properly | |
step = num_frames - overlap | |
# Ensure there's at least one step forward on each iteration | |
if step < 1: | |
step = 1 | |
audio_image_preds_idx = [] | |
audio_interp_preds_idx = [] | |
masks_chunks = [] | |
masks_interpolation_chunks = [] | |
for i in range(0, audio.shape[0] - num_frames + 1, step): | |
try: | |
audio[i + num_frames - 1] | |
except IndexError: | |
break # Last chunk is smaller than num_frames | |
segment_end = i + num_frames | |
gt_chunks.append(video_emb[i:segment_end]) | |
masks = load_landmarks( | |
landmarks[i:segment_end], | |
(512, 512), | |
target_size=(64, 64), | |
nose_index=nose_index, | |
) | |
if mask_arms is not None: | |
masks = np.logical_and( | |
masks, np.logical_not(mask_arms[i:segment_end, None, ...]) | |
) | |
masks_interpolation_chunks.append(masks) | |
if i not in audio_image_preds_idx: | |
audio_image_preds.append(audio[i]) | |
masks_chunks.append(masks[0]) | |
gt_keyframes_chunks.append(video_emb[i]) | |
audio_image_preds_idx.append(i) | |
if segment_end - 1 not in audio_image_preds_idx: | |
audio_image_preds_idx.append(segment_end - 1) | |
audio_image_preds.append(audio[segment_end - 1]) | |
masks_chunks.append(masks[-1]) | |
gt_keyframes_chunks.append(video_emb[segment_end - 1]) | |
audio_interpolation_chunks.append(audio_interpolation[i:segment_end]) | |
audio_interp_preds_idx.append([i, segment_end - 1]) | |
# If the flag is on, add element 0 every 14 audio elements | |
if add_zero_flag: | |
first_element = audio_image_preds[0] | |
len_audio_image_preds = ( | |
len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames | |
) | |
for i in range(0, len_audio_image_preds, num_frames): | |
audio_image_preds.insert(i, first_element) | |
audio_image_preds_idx.insert(i, None) | |
masks_chunks.insert(i, masks_chunks[0]) | |
gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0]) | |
to_remove = [idx is None for idx in audio_image_preds_idx] | |
audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx] | |
if add_zero_flag: | |
# Remove the added elements from the list | |
audio_image_preds_idx = [ | |
sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i | |
] | |
interpolation_cond_list = [] | |
for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2): | |
interpolation_cond_list.append( | |
[audio_image_preds_idx[i], audio_image_preds_idx[i + 1]] | |
) | |
# Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames | |
# Calculate the number of frames needed to make audio_image_preds a multiple of num_frames | |
frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames | |
# Extend from the start of audio_image_preds | |
audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed | |
masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed | |
gt_keyframes_chunks = ( | |
gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed | |
) | |
to_remove = to_remove + [True] * frames_needed | |
audio_image_preds_idx_clone = ( | |
audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed | |
) | |
print( | |
f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}" | |
) | |
# random_cond_idx = np.random.randint(0, len(video_emb)) | |
random_cond_idx = 0 | |
assert len(to_remove) == len(audio_image_preds), ( | |
"to_remove and audio_image_preds must have the same length" | |
) | |
return ( | |
gt_chunks, | |
gt_keyframes_chunks, | |
audio_interpolation_chunks, | |
audio_image_preds, | |
video_emb[random_cond_idx], | |
masks_chunks, | |
masks_interpolation_chunks, | |
to_remove, | |
audio_interp_preds_idx, | |
audio_image_preds_idx_clone, | |
) | |