Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import tempfile | |
import os | |
from vae_wrapper import VaeWrapper, encode_video_chunk | |
from landmarks_extractor import LandmarksExtractor | |
import decord | |
from utils import ( | |
get_raw_audio, | |
save_audio_video, | |
calculate_splits, | |
instantiate_from_config, | |
create_pipeline_inputs, | |
) | |
from transformers import HubertModel | |
from einops import rearrange | |
import numpy as np | |
from WavLM import WavLM_wrapper | |
from omegaconf import OmegaConf | |
from inference_functions import ( | |
sample_keyframes, | |
sample_interpolation, | |
) | |
from wordle_game import WordleGame | |
import torch.cuda.amp as amp # Import amp for mixed precision | |
# Set default tensor type to float16 for faster computation | |
if torch.cuda.is_available(): | |
# torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
# Enable TF32 precision for better performance on Ampere+ GPUs | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Cache for video and audio processing | |
cache = { | |
"video": { | |
"path": None, | |
"embedding": None, | |
"frames": None, | |
"landmarks": None, | |
}, | |
"audio": { | |
"path": None, | |
"raw_audio": None, | |
"hubert_embedding": None, | |
"wavlm_embedding": None, | |
}, | |
} | |
# Create mixed precision scaler | |
scaler = amp.GradScaler() | |
def load_model( | |
config: str, | |
device: str = "cuda", | |
ckpt: str = None, | |
): | |
""" | |
Load a model from configuration. | |
Args: | |
config: Path to model configuration file | |
device: Device to load the model on | |
num_frames: Number of frames to process | |
input_key: Input key for the model | |
ckpt: Optional checkpoint path | |
Returns: | |
Tuple of (model, filter, batch size) | |
""" | |
config = OmegaConf.load(config) | |
config["model"]["params"]["input_key"] = "latents" | |
if ckpt is not None: | |
config.model.params.ckpt_path = ckpt | |
with torch.device(device): | |
model = instantiate_from_config(config.model).to(device).eval() | |
# Convert model to half precision | |
if torch.cuda.is_available(): | |
model = model.half() | |
model.first_stage_model = model.first_stage_model.float() | |
print("Converted model to FP16 precision") | |
# Compile model for faster inference | |
if torch.cuda.is_available(): | |
try: | |
model = torch.compile(model) | |
print(f"Successfully compiled model with torch.compile()") | |
except Exception as e: | |
print(f"Warning: Failed to compile model: {e}") | |
return model | |
# keyframe_model = KeyframeModel(device=device) | |
# interpolation_model = InterpolationModel(device=device) | |
vae_model = VaeWrapper("video") | |
if torch.cuda.is_available(): | |
vae_model = vae_model.half() # Convert to half precision | |
try: | |
vae_model = torch.compile(vae_model) | |
print("Successfully compiled vae_model in FP16") | |
except Exception as e: | |
print(f"Warning: Failed to compile vae_model: {e}") | |
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda() | |
if torch.cuda.is_available(): | |
hubert_model = hubert_model.half() # Convert to half precision | |
try: | |
hubert_model = torch.compile(hubert_model) | |
print("Successfully compiled hubert_model in FP16") | |
except Exception as e: | |
print(f"Warning: Failed to compile hubert_model: {e}") | |
wavlm_model = WavLM_wrapper( | |
model_size="Base+", | |
feed_as_frames=False, | |
merge_type="None", | |
model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt", | |
).cuda() | |
if torch.cuda.is_available(): | |
wavlm_model = wavlm_model.half() # Convert to half precision | |
try: | |
wavlm_model = torch.compile(wavlm_model) | |
print("Successfully compiled wavlm_model in FP16") | |
except Exception as e: | |
print(f"Warning: Failed to compile wavlm_model: {e}") | |
landmarks_extractor = LandmarksExtractor() | |
# keyframe_model = load_model( | |
# config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml", | |
# ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt", | |
# ) | |
# interpolation_model = load_model( | |
# config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml", | |
# ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt", | |
# ) | |
# keyframe_model.en_and_decode_n_samples_a_time = 2 | |
# interpolation_model.en_and_decode_n_samples_a_time = 2 | |
# Default media paths | |
DEFAULT_VIDEO_PATH = os.path.join( | |
os.path.dirname(__file__), "assets", "sample_video.mp4" | |
) | |
DEFAULT_AUDIO_PATH = os.path.join( | |
os.path.dirname(__file__), "assets", "sample_audio.wav" | |
) | |
def compute_video_embedding(video_reader, min_len): | |
"""Compute embeddings from video""" | |
total_frames = min_len | |
encoded = [] | |
video_frames = [] | |
chunk_size = 16 | |
resolution = 512 | |
# # Create a progress bar for Gradio | |
progress = gr.Progress() | |
# Calculate total chunks for progress tracking | |
total_chunks = (total_frames + chunk_size - 1) // chunk_size | |
for i, start_idx in enumerate(range(0, total_frames, chunk_size)): | |
# Update progress bar | |
progress(i / total_chunks, desc="Processing video chunks") | |
end_idx = min(start_idx + chunk_size, total_frames) | |
video_chunk = video_reader.get_batch(range(start_idx, end_idx)) | |
# Interpolate video chunk to the target resolution | |
video_chunk = rearrange(video_chunk, "f h w c -> f c h w") | |
video_chunk = torch.nn.functional.interpolate( | |
video_chunk, | |
size=(resolution, resolution), | |
mode="bilinear", | |
align_corners=False, | |
) | |
video_chunk = rearrange(video_chunk, "f c h w -> f h w c") | |
video_frames.append(video_chunk) | |
# Convert chunk to FP16 if using CUDA | |
if torch.cuda.is_available(): | |
video_chunk = video_chunk.half() | |
# Always use autocast for FP16 computation | |
with amp.autocast(enabled=True): | |
encoded.append(encode_video_chunk(vae_model, video_chunk, resolution)) | |
encoded = torch.cat(encoded, dim=0) | |
video_frames = torch.cat(video_frames, dim=0) | |
video_frames = rearrange(video_frames, "f h w c -> f c h w") | |
torch.cuda.empty_cache() | |
return encoded, video_frames | |
def compute_hubert_embedding(raw_audio): | |
"""Compute embeddings from audio""" | |
print(f"Computing audio embedding from {raw_audio.shape}") | |
audio = ( | |
(raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7) | |
).unsqueeze(0) | |
chunks = 16000 * 20 | |
# Create a progress bar for Gradio | |
progress = gr.Progress() | |
# Get audio embeddings | |
audio_embeddings = [] | |
splits = list(calculate_splits(audio, chunks)) | |
total_splits = len(splits) | |
for i, chunk in enumerate(splits): | |
# Update progress bar | |
progress(i / total_splits, desc="Processing audio chunks") | |
# Convert audio chunk to half precision | |
if torch.cuda.is_available(): | |
chunk_cuda = chunk.cuda().half() | |
else: | |
chunk_cuda = chunk.cuda() | |
# Always use autocast for FP16 computation | |
with amp.autocast(enabled=True): | |
hidden_states = hubert_model(chunk_cuda)[0] | |
audio_embeddings.append(hidden_states) | |
audio_embeddings = torch.cat(audio_embeddings, dim=1) | |
# audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0] | |
if audio_embeddings.shape[1] % 2 != 0: | |
audio_embeddings = torch.cat( | |
[audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1 | |
) | |
audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2) | |
torch.cuda.empty_cache() | |
return audio_embeddings | |
def compute_wavlm_embedding(raw_audio): | |
"""Compute embeddings from audio""" | |
audio = rearrange(raw_audio, "(f s) -> f s", s=640) | |
if audio.shape[0] % 2 != 0: | |
audio = torch.cat([audio, torch.zeros(1, 640)], dim=0) | |
chunks = 500 | |
# Create a progress bar for Gradio | |
progress = gr.Progress() | |
# Get audio embeddings | |
audio_embeddings = [] | |
splits = list(calculate_splits(audio, chunks)) | |
total_splits = len(splits) | |
for i, chunk in enumerate(splits): | |
# Update progress bar | |
progress(i / total_splits, desc="Processing audio chunks") | |
# Convert chunk to half precision | |
if torch.cuda.is_available(): | |
chunk_cuda = chunk.unsqueeze(0).cuda().half() | |
else: | |
chunk_cuda = chunk.unsqueeze(0).cuda() | |
# Always use autocast for FP16 computation | |
with amp.autocast(enabled=True): | |
wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0) | |
audio_embeddings.append(wavlm_hidden_states) | |
audio_embeddings = torch.cat(audio_embeddings, dim=0) | |
torch.cuda.empty_cache() | |
return audio_embeddings | |
def extract_video_landmarks(video_frames): | |
"""Extract landmarks from video frames""" | |
# Create a progress bar for Gradio | |
progress = gr.Progress() | |
landmarks = [] | |
batch_size = 10 | |
for i in range(0, len(video_frames), batch_size): | |
# Update progress bar | |
progress(i / len(video_frames), desc="Extracting facial landmarks") | |
batch = video_frames[i : i + batch_size].cpu().float() | |
batch_landmarks = landmarks_extractor.extract_landmarks(batch) | |
landmarks.extend(batch_landmarks) | |
torch.cuda.empty_cache() | |
# Convert landmarks to a list of numpy arrays with consistent shape | |
processed_landmarks = [] | |
expected_shape = (68, 2) # Common shape for facial landmarks | |
# Process each landmark to ensure consistent shape | |
last_valid_landmark = None | |
for i, lm in enumerate(landmarks): | |
if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape: | |
processed_landmarks.append(lm) | |
last_valid_landmark = lm | |
else: | |
# Print information about inconsistent landmarks | |
if lm is None: | |
print(f"Warning: Landmark at index {i} is None") | |
elif not isinstance(lm, np.ndarray): | |
print( | |
f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}" | |
) | |
elif lm.shape != expected_shape: | |
print( | |
f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}" | |
) | |
# Replace invalid landmarks with the closest valid landmark if available | |
if last_valid_landmark is not None: | |
processed_landmarks.append(last_valid_landmark.copy()) | |
else: | |
# If no valid landmark has been seen yet, look ahead for a valid one | |
found_future_valid = False | |
for future_lm in landmarks[i + 1 :]: | |
if ( | |
future_lm is not None | |
and isinstance(future_lm, np.ndarray) | |
and future_lm.shape == expected_shape | |
): | |
processed_landmarks.append(future_lm.copy()) | |
found_future_valid = True | |
break | |
# If no valid landmark found in the future, use zeros | |
if not found_future_valid: | |
processed_landmarks.append(np.zeros(expected_shape)) | |
return np.array(processed_landmarks) | |
def sample( | |
audio_list, | |
gt_keyframes, | |
masks_keyframes, | |
to_remove, | |
test_keyframes_list, | |
num_frames, | |
device, | |
emb, | |
force_uc_zero_embeddings, | |
n_batch_keyframes, | |
n_batch, | |
test_interpolation_list, | |
audio_interpolation_list, | |
masks_interpolation, | |
gt_interpolation, | |
model_keyframes, | |
model, | |
): | |
# Create a progress bar for Gradio | |
progress = gr.Progress() | |
condition = torch.zeros(1, 3, 512, 512).to(device) | |
if torch.cuda.is_available(): | |
condition = condition.half() | |
audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames) | |
gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames) | |
# Rearrange masks_keyframes and save locally | |
masks_keyframes = rearrange( | |
masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames | |
) | |
# Convert to_remove into chunks of num_frames | |
to_remove_chunks = [ | |
to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames) | |
] | |
test_keyframes_list = [ | |
test_keyframes_list[i : i + num_frames] | |
for i in range(0, len(test_keyframes_list), num_frames) | |
] | |
audio_cond = audio_list | |
if emb is not None: | |
embbedings = emb.unsqueeze(0).to(device) | |
if torch.cuda.is_available(): | |
embbedings = embbedings.half() | |
else: | |
embbedings = None | |
# One batch of keframes is approximately 7 seconds | |
chunk_size = 2 | |
complete_video = [] | |
start_idx = 0 | |
last_frame_z = None | |
last_frame_x = None | |
last_keyframe_idx = None | |
last_to_remove = None | |
total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size | |
for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)): | |
# Update progress bar | |
progress(chunk_idx / total_chunks, desc="Generating video") | |
# Clear GPU cache between chunks | |
torch.cuda.empty_cache() | |
chunk_end = min(chunk_start + chunk_size, len(audio_cond)) | |
chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda() | |
if torch.cuda.is_available(): | |
chunk_audio_cond = chunk_audio_cond.half() | |
chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda() | |
chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda() | |
if torch.cuda.is_available(): | |
chunk_gt_keyframes = chunk_gt_keyframes.half() | |
chunk_masks = chunk_masks.half() | |
test_keyframes_list_unwrapped = [ | |
elem | |
for sublist in test_keyframes_list[chunk_start:chunk_end] | |
for elem in sublist | |
] | |
to_remove_chunks_unwrapped = [ | |
elem | |
for sublist in to_remove_chunks[chunk_start:chunk_end] | |
for elem in sublist | |
] | |
if last_keyframe_idx is not None: | |
test_keyframes_list_unwrapped = [ | |
last_keyframe_idx | |
] + test_keyframes_list_unwrapped | |
to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped | |
last_keyframe_idx = test_keyframes_list_unwrapped[-1] | |
last_to_remove = to_remove_chunks_unwrapped[-1] | |
# Find the first non-None keyframe in the chunk | |
first_keyframe = next( | |
(kf for kf in test_keyframes_list_unwrapped if kf is not None), None | |
) | |
# Find the last non-None keyframe in the chunk | |
last_keyframe = next( | |
(kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None), | |
None, | |
) | |
start_idx = next( | |
( | |
idx | |
for idx, comb in enumerate(test_interpolation_list) | |
if comb[0] == first_keyframe | |
), | |
None, | |
) | |
end_idx = next( | |
( | |
idx | |
for idx, comb in enumerate(reversed(test_interpolation_list)) | |
if comb[1] == last_keyframe | |
), | |
None, | |
) | |
if start_idx is not None and end_idx is not None: | |
end_idx = ( | |
len(test_interpolation_list) - 1 - end_idx | |
) # Adjust for reversed enumeration | |
end_idx += 1 | |
if start_idx is None: | |
break | |
if end_idx < start_idx: | |
end_idx = len(audio_interpolation_list) | |
audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx] | |
chunk_masks_interpolation = masks_interpolation[start_idx:end_idx] | |
gt_interpolation_chunks = gt_interpolation[start_idx:end_idx] | |
if torch.cuda.is_available(): | |
audio_interpolation_list_chunk = [ | |
chunk.half() for chunk in audio_interpolation_list_chunk | |
] | |
chunk_masks_interpolation = [ | |
chunk.half() for chunk in chunk_masks_interpolation | |
] | |
gt_interpolation_chunks = [ | |
chunk.half() for chunk in gt_interpolation_chunks | |
] | |
progress(chunk_idx / total_chunks, desc="Generating keyframes") | |
# Always use autocast for FP16 computation | |
with amp.autocast(enabled=True): | |
samples_z = sample_keyframes( | |
model_keyframes, | |
chunk_audio_cond, | |
chunk_gt_keyframes, | |
chunk_masks, | |
condition.cuda(), | |
num_frames, | |
24, | |
0.0, | |
device, | |
embbedings.cuda() if embbedings is not None else None, | |
force_uc_zero_embeddings, | |
n_batch_keyframes, | |
0, | |
1.0, | |
None, | |
gt_as_cond=False, | |
) | |
if last_frame_x is not None: | |
# samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0) | |
samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0) | |
# last_frame_x = samples_x[-1] | |
last_frame_z = samples_z[-1] | |
progress(chunk_idx / total_chunks, desc="Interpolating frames") | |
# Always use autocast for FP16 computation | |
with amp.autocast(enabled=True): | |
vid = sample_interpolation( | |
model, | |
samples_z, | |
# samples_x, | |
audio_interpolation_list_chunk, | |
gt_interpolation_chunks, | |
chunk_masks_interpolation, | |
condition.cuda(), | |
num_frames, | |
device, | |
1, | |
24, | |
0.0, | |
force_uc_zero_embeddings, | |
n_batch, | |
chunk_size, | |
1.0, | |
None, | |
cut_audio=False, | |
to_remove=to_remove_chunks_unwrapped, | |
) | |
if chunk_start == 0: | |
complete_video = vid | |
else: | |
complete_video = np.concatenate([complete_video[:-1], vid], axis=0) | |
return complete_video | |
def process_video(video_input, audio_input, max_num_seconds): | |
"""Main processing function to generate synchronized video""" | |
# Display a message to the user about the processing time | |
gr.Info("Processing video. This may take a while...", duration=10) | |
gr.Info( | |
"If you're tired of waiting, try playing the Wordle game in the other tab!", | |
duration=10, | |
) | |
# Use default media if none provided | |
if video_input is None: | |
video_input = DEFAULT_VIDEO_PATH | |
print(f"Using default video: {DEFAULT_VIDEO_PATH}") | |
if audio_input is None: | |
audio_input = DEFAULT_AUDIO_PATH | |
print(f"Using default audio: {DEFAULT_AUDIO_PATH}") | |
try: | |
# Calculate hashes for cache keys | |
video_path_hash = video_input | |
audio_path_hash = audio_input | |
# Check if we need to recompute video embeddings | |
video_cache_hit = cache["video"]["path"] == video_path_hash | |
audio_cache_hit = cache["audio"]["path"] == audio_path_hash | |
if video_cache_hit and audio_cache_hit: | |
print("Using cached video and audio computations") | |
# Make copies of cached data to avoid modifying cache | |
video_embedding = cache["video"]["embedding"].clone() | |
video_frames = cache["video"]["frames"].clone() | |
video_landmarks = cache["video"]["landmarks"].copy() | |
raw_audio = cache["audio"]["raw_audio"].clone() | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
hubert_embedding = cache["audio"]["hubert_embedding"].clone() | |
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() | |
# Ensure all data is truncated to the same length if needed | |
min_len = min( | |
len(video_frames), | |
len(raw_audio), | |
len(hubert_embedding), | |
len(wavlm_embedding), | |
) | |
video_frames = video_frames[:min_len] | |
video_embedding = video_embedding[:min_len] | |
video_landmarks = video_landmarks[:min_len] | |
raw_audio = raw_audio[:min_len] | |
hubert_embedding = hubert_embedding[:min_len] | |
wavlm_embedding = wavlm_embedding[:min_len] | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
else: | |
# Process video if needed | |
if not video_cache_hit: | |
print("Computing video embeddings and landmarks") | |
video_reader = decord.VideoReader(video_input) | |
decord.bridge.set_bridge("torch") | |
if not audio_cache_hit: | |
# Need to process audio to determine min_len | |
raw_audio = get_raw_audio(audio_input, 16000) | |
if len(raw_audio) == 0 or len(video_reader) == 0: | |
raise ValueError("Empty audio or video input") | |
min_len = min(len(raw_audio), len(video_reader)) | |
# Store full audio in cache | |
cache["audio"]["path"] = audio_path_hash | |
cache["audio"]["raw_audio"] = raw_audio.clone() | |
# Create truncated copy for processing | |
raw_audio = raw_audio[:min_len] | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
else: | |
# Use cached audio - make a copy | |
if cache["audio"]["raw_audio"] is None: | |
raise ValueError("Cached audio is None") | |
raw_audio = cache["audio"]["raw_audio"].clone() | |
if len(raw_audio) == 0 or len(video_reader) == 0: | |
raise ValueError("Empty cached audio or video input") | |
min_len = min(len(raw_audio), len(video_reader)) | |
# Create truncated copy for processing | |
raw_audio = raw_audio[:min_len] | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
# Compute video embeddings and landmarks - store full version in cache | |
video_embedding, video_frames = compute_video_embedding( | |
video_reader, len(video_reader) | |
) | |
video_landmarks = extract_video_landmarks(video_frames) | |
# Update video cache with full versions | |
cache["video"]["path"] = video_path_hash | |
cache["video"]["embedding"] = video_embedding | |
cache["video"]["frames"] = video_frames | |
cache["video"]["landmarks"] = video_landmarks | |
# Create truncated copies for processing | |
video_embedding = video_embedding[:min_len] | |
video_frames = video_frames[:min_len] | |
video_landmarks = video_landmarks[:min_len] | |
else: | |
# Use cached video data - make copies | |
print("Using cached video computations") | |
if ( | |
cache["video"]["embedding"] is None | |
or cache["video"]["frames"] is None | |
or cache["video"]["landmarks"] is None | |
): | |
raise ValueError("One or more video cache entries are None") | |
if not audio_cache_hit: | |
# New audio with cached video | |
raw_audio = get_raw_audio(audio_input, 16000) | |
if len(raw_audio) == 0: | |
raise ValueError("Empty audio input") | |
# Store full audio in cache | |
cache["audio"]["path"] = audio_path_hash | |
cache["audio"]["raw_audio"] = raw_audio.clone() | |
# Make copies of video data | |
video_embedding = cache["video"]["embedding"].clone() | |
video_frames = cache["video"]["frames"].clone() | |
video_landmarks = cache["video"]["landmarks"].copy() | |
# Determine truncation length and create truncated copies | |
min_len = min(len(raw_audio), len(video_frames)) | |
raw_audio = raw_audio[:min_len] | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
video_frames = video_frames[:min_len] | |
video_embedding = video_embedding[:min_len] | |
video_landmarks = video_landmarks[:min_len] | |
else: | |
# Both video and audio are cached - should not reach here | |
# as it's handled in the first if statement | |
pass | |
# Process audio if needed | |
if not audio_cache_hit: | |
print("Computing audio embeddings") | |
# Compute audio embeddings with the truncated audio | |
hubert_embedding = compute_hubert_embedding(raw_audio_reshape) | |
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape) | |
# Update audio cache with full embeddings | |
# Note: raw_audio was already cached above | |
cache["audio"]["hubert_embedding"] = hubert_embedding.clone() | |
cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone() | |
else: | |
# Use cached audio data - make copies | |
if ( | |
cache["audio"]["hubert_embedding"] is None | |
or cache["audio"]["wavlm_embedding"] is None | |
): | |
raise ValueError( | |
"One or more audio embedding cache entries are None" | |
) | |
hubert_embedding = cache["audio"]["hubert_embedding"].clone() | |
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() | |
# Make sure embeddings match the truncated video length if needed | |
if "min_len" in locals() and ( | |
min_len < len(hubert_embedding) or min_len < len(wavlm_embedding) | |
): | |
hubert_embedding = hubert_embedding[:min_len] | |
wavlm_embedding = wavlm_embedding[:min_len] | |
# Apply max_num_seconds limit if specified | |
if max_num_seconds > 0: | |
# Convert seconds to frames (assuming 25 fps) | |
max_frames = int(max_num_seconds * 25) | |
# Truncate all data to max_frames | |
video_embedding = video_embedding[:max_frames] | |
video_frames = video_frames[:max_frames] | |
video_landmarks = video_landmarks[:max_frames] | |
hubert_embedding = hubert_embedding[:max_frames] | |
wavlm_embedding = wavlm_embedding[:max_frames] | |
raw_audio = raw_audio[:max_frames] | |
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") | |
# Validate shapes before proceeding | |
assert video_embedding.shape[0] == hubert_embedding.shape[0], ( | |
f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})" | |
) | |
assert video_embedding.shape[0] == wavlm_embedding.shape[0], ( | |
f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})" | |
) | |
assert video_embedding.shape[0] == video_landmarks.shape[0], ( | |
f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})" | |
) | |
print(f"Hubert embedding shape: {hubert_embedding.shape}") | |
print(f"WavLM embedding shape: {wavlm_embedding.shape}") | |
print(f"Video embedding shape: {video_embedding.shape}") | |
print(f"Video landmarks shape: {video_landmarks.shape}") | |
# Create pipeline inputs for models | |
( | |
interpolation_chunks, | |
keyframe_chunks, | |
audio_interpolation_chunks, | |
audio_keyframe_chunks, | |
emb_cond, | |
masks_keyframe_chunks, | |
masks_interpolation_chunks, | |
to_remove, | |
audio_interpolation_idx, | |
audio_keyframe_idx, | |
) = create_pipeline_inputs( | |
hubert_embedding, | |
wavlm_embedding, | |
14, | |
video_embedding, | |
video_landmarks, | |
overlap=1, | |
add_zero_flag=True, | |
mask_arms=None, | |
nose_index=28, | |
) | |
complete_video = sample( | |
audio_keyframe_chunks, | |
keyframe_chunks, | |
masks_keyframe_chunks, | |
to_remove, | |
audio_keyframe_idx, | |
14, | |
"cuda", | |
emb_cond, | |
[], | |
3, | |
3, | |
audio_interpolation_idx, | |
audio_interpolation_chunks, | |
masks_interpolation_chunks, | |
interpolation_chunks, | |
keyframe_model, | |
interpolation_model, | |
) | |
complete_audio = rearrange( | |
raw_audio[: complete_video.shape[0]], "f s -> () (f s)" | |
) | |
# 4. Convert frames to video and combine with audio | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video: | |
output_path = temp_video.name | |
print("Saving video to", output_path) | |
save_audio_video(complete_video, audio=complete_audio, save_path=output_path) | |
torch.cuda.empty_cache() | |
return output_path | |
except Exception as e: | |
raise e | |
print(f"Error processing video: {str(e)}") | |
return None | |
def get_max_duration(video_input, audio_input): | |
"""Get the maximum duration in seconds for the slider""" | |
try: | |
# Default to 60 seconds if files don't exist | |
if video_input is None or not os.path.exists(video_input): | |
video_input = DEFAULT_VIDEO_PATH | |
if audio_input is None or not os.path.exists(audio_input): | |
audio_input = DEFAULT_AUDIO_PATH | |
# Get video duration | |
video_reader = decord.VideoReader(video_input) | |
video_duration = len(video_reader) / video_reader.get_avg_fps() | |
# Get audio duration | |
raw_audio = get_raw_audio(audio_input, 16000) | |
audio_duration = len(raw_audio) / 25 # Assuming 25 fps | |
# Return the minimum of the two durations | |
return min(video_duration, audio_duration) | |
except Exception as e: | |
print(f"Error getting max duration: {str(e)}") | |
return 60 # Default to 60 seconds | |
def new_game_click(state): | |
"""Handle the 'New Game' button click.""" | |
message = state.new_game() | |
feedback_history = state.get_feedback_history() | |
return state, feedback_history, message | |
def submit_guess_click(guess, state): | |
"""Handle the 'Submit Guess' button click.""" | |
message = state.submit_guess(guess) | |
feedback_history = state.get_feedback_history() | |
return state, feedback_history, message | |
# Create Gradio interface | |
with gr.Blocks(title="Video Synchronization with Diffusion Models") as demo: | |
gr.Markdown("# Video Synchronization with Diffusion Models") | |
gr.Markdown( | |
"Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio." | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Video Synchronization"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Input Video", | |
value=DEFAULT_VIDEO_PATH | |
if os.path.exists(DEFAULT_VIDEO_PATH) | |
else None, | |
width=512, | |
height=512, | |
) | |
audio_input = gr.Audio( | |
label="Input Audio", | |
type="filepath", | |
value=DEFAULT_AUDIO_PATH | |
if os.path.exists(DEFAULT_AUDIO_PATH) | |
else None, | |
) | |
max_duration = gr.State(value=60) # Default max duration | |
max_seconds_slider = gr.Slider( | |
minimum=0, | |
maximum=60, # Will be updated dynamically | |
value=0, | |
step=1, | |
label="Max Duration (seconds, 0 = full length)", | |
info="Limit the processing duration (0 means use full length)", | |
) | |
process_button = gr.Button("Generate Synchronized Video") | |
with gr.Column("Output Video"): | |
video_output = gr.Video(label="Output Video", width=512, height=512) | |
# Update slider max value when inputs change | |
def update_slider_max(video, audio): | |
max_dur = get_max_duration(video, audio) | |
return {"maximum": max_dur, "__type__": "update"} | |
video_input.change( | |
update_slider_max, [video_input, audio_input], [max_seconds_slider] | |
) | |
audio_input.change( | |
update_slider_max, [video_input, audio_input], [max_seconds_slider] | |
) | |
# Show Wordle message when processing starts and hide when complete | |
process_button.click( | |
fn=process_video, | |
inputs=[video_input, audio_input, max_seconds_slider], | |
outputs=video_output, | |
) | |
with gr.TabItem("Wordle Game"): | |
state = gr.State(WordleGame()) # Persist the WordleGame instance | |
guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5) | |
submit_btn = gr.Button("Submit Guess") | |
new_game_btn = gr.Button("New Game") | |
feedback_display = gr.HTML(label="Guesses") | |
message_display = gr.Textbox( | |
label="Message", interactive=False, value="Click 'New Game' to start." | |
) | |
# Connect the 'New Game' button | |
new_game_btn.click( | |
fn=new_game_click, | |
inputs=[state], | |
outputs=[state, feedback_display, message_display], | |
) | |
# Connect the 'Submit Guess' button | |
submit_btn.click( | |
fn=submit_guess_click, | |
inputs=[guess_input, state], | |
outputs=[state, feedback_display, message_display], | |
) | |
gr.Markdown("## How it works") | |
gr.Markdown(""" | |
1. The system extracts embeddings and landmarks from the input video | |
2. Audio embeddings are computed from the input audio | |
3. A keyframe model generates key visual frames | |
4. An interpolation model creates a smooth video between keyframes | |
5. The final video is rendered with the new audio | |
""") | |
if __name__ == "__main__": | |
demo.launch() | |