keysync-demo / app.py
Antoni Bigata
requirements
2fb3e22
import gradio as gr
import torch
import tempfile
import os
from vae_wrapper import VaeWrapper, encode_video_chunk
from landmarks_extractor import LandmarksExtractor
import decord
from utils import (
get_raw_audio,
save_audio_video,
calculate_splits,
instantiate_from_config,
create_pipeline_inputs,
)
from transformers import HubertModel
from einops import rearrange
import numpy as np
from WavLM import WavLM_wrapper
from omegaconf import OmegaConf
from inference_functions import (
sample_keyframes,
sample_interpolation,
)
from wordle_game import WordleGame
import torch.cuda.amp as amp # Import amp for mixed precision
# Set default tensor type to float16 for faster computation
if torch.cuda.is_available():
# torch.set_default_tensor_type(torch.cuda.FloatTensor)
# Enable TF32 precision for better performance on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Cache for video and audio processing
cache = {
"video": {
"path": None,
"embedding": None,
"frames": None,
"landmarks": None,
},
"audio": {
"path": None,
"raw_audio": None,
"hubert_embedding": None,
"wavlm_embedding": None,
},
}
# Create mixed precision scaler
scaler = amp.GradScaler()
def load_model(
config: str,
device: str = "cuda",
ckpt: str = None,
):
"""
Load a model from configuration.
Args:
config: Path to model configuration file
device: Device to load the model on
num_frames: Number of frames to process
input_key: Input key for the model
ckpt: Optional checkpoint path
Returns:
Tuple of (model, filter, batch size)
"""
config = OmegaConf.load(config)
config["model"]["params"]["input_key"] = "latents"
if ckpt is not None:
config.model.params.ckpt_path = ckpt
with torch.device(device):
model = instantiate_from_config(config.model).to(device).eval()
# Convert model to half precision
if torch.cuda.is_available():
model = model.half()
model.first_stage_model = model.first_stage_model.float()
print("Converted model to FP16 precision")
# Compile model for faster inference
if torch.cuda.is_available():
try:
model = torch.compile(model)
print(f"Successfully compiled model with torch.compile()")
except Exception as e:
print(f"Warning: Failed to compile model: {e}")
return model
# keyframe_model = KeyframeModel(device=device)
# interpolation_model = InterpolationModel(device=device)
vae_model = VaeWrapper("video")
if torch.cuda.is_available():
vae_model = vae_model.half() # Convert to half precision
try:
vae_model = torch.compile(vae_model)
print("Successfully compiled vae_model in FP16")
except Exception as e:
print(f"Warning: Failed to compile vae_model: {e}")
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
if torch.cuda.is_available():
hubert_model = hubert_model.half() # Convert to half precision
try:
hubert_model = torch.compile(hubert_model)
print("Successfully compiled hubert_model in FP16")
except Exception as e:
print(f"Warning: Failed to compile hubert_model: {e}")
wavlm_model = WavLM_wrapper(
model_size="Base+",
feed_as_frames=False,
merge_type="None",
model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
).cuda()
if torch.cuda.is_available():
wavlm_model = wavlm_model.half() # Convert to half precision
try:
wavlm_model = torch.compile(wavlm_model)
print("Successfully compiled wavlm_model in FP16")
except Exception as e:
print(f"Warning: Failed to compile wavlm_model: {e}")
landmarks_extractor = LandmarksExtractor()
# keyframe_model = load_model(
# config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
# ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
# )
# interpolation_model = load_model(
# config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
# ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
# )
# keyframe_model.en_and_decode_n_samples_a_time = 2
# interpolation_model.en_and_decode_n_samples_a_time = 2
# Default media paths
DEFAULT_VIDEO_PATH = os.path.join(
os.path.dirname(__file__), "assets", "sample_video.mp4"
)
DEFAULT_AUDIO_PATH = os.path.join(
os.path.dirname(__file__), "assets", "sample_audio.wav"
)
@torch.no_grad()
def compute_video_embedding(video_reader, min_len):
"""Compute embeddings from video"""
total_frames = min_len
encoded = []
video_frames = []
chunk_size = 16
resolution = 512
# # Create a progress bar for Gradio
progress = gr.Progress()
# Calculate total chunks for progress tracking
total_chunks = (total_frames + chunk_size - 1) // chunk_size
for i, start_idx in enumerate(range(0, total_frames, chunk_size)):
# Update progress bar
progress(i / total_chunks, desc="Processing video chunks")
end_idx = min(start_idx + chunk_size, total_frames)
video_chunk = video_reader.get_batch(range(start_idx, end_idx))
# Interpolate video chunk to the target resolution
video_chunk = rearrange(video_chunk, "f h w c -> f c h w")
video_chunk = torch.nn.functional.interpolate(
video_chunk,
size=(resolution, resolution),
mode="bilinear",
align_corners=False,
)
video_chunk = rearrange(video_chunk, "f c h w -> f h w c")
video_frames.append(video_chunk)
# Convert chunk to FP16 if using CUDA
if torch.cuda.is_available():
video_chunk = video_chunk.half()
# Always use autocast for FP16 computation
with amp.autocast(enabled=True):
encoded.append(encode_video_chunk(vae_model, video_chunk, resolution))
encoded = torch.cat(encoded, dim=0)
video_frames = torch.cat(video_frames, dim=0)
video_frames = rearrange(video_frames, "f h w c -> f c h w")
torch.cuda.empty_cache()
return encoded, video_frames
@torch.no_grad()
def compute_hubert_embedding(raw_audio):
"""Compute embeddings from audio"""
print(f"Computing audio embedding from {raw_audio.shape}")
audio = (
(raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7)
).unsqueeze(0)
chunks = 16000 * 20
# Create a progress bar for Gradio
progress = gr.Progress()
# Get audio embeddings
audio_embeddings = []
splits = list(calculate_splits(audio, chunks))
total_splits = len(splits)
for i, chunk in enumerate(splits):
# Update progress bar
progress(i / total_splits, desc="Processing audio chunks")
# Convert audio chunk to half precision
if torch.cuda.is_available():
chunk_cuda = chunk.cuda().half()
else:
chunk_cuda = chunk.cuda()
# Always use autocast for FP16 computation
with amp.autocast(enabled=True):
hidden_states = hubert_model(chunk_cuda)[0]
audio_embeddings.append(hidden_states)
audio_embeddings = torch.cat(audio_embeddings, dim=1)
# audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0]
if audio_embeddings.shape[1] % 2 != 0:
audio_embeddings = torch.cat(
[audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1
)
audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2)
torch.cuda.empty_cache()
return audio_embeddings
@torch.no_grad()
def compute_wavlm_embedding(raw_audio):
"""Compute embeddings from audio"""
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
if audio.shape[0] % 2 != 0:
audio = torch.cat([audio, torch.zeros(1, 640)], dim=0)
chunks = 500
# Create a progress bar for Gradio
progress = gr.Progress()
# Get audio embeddings
audio_embeddings = []
splits = list(calculate_splits(audio, chunks))
total_splits = len(splits)
for i, chunk in enumerate(splits):
# Update progress bar
progress(i / total_splits, desc="Processing audio chunks")
# Convert chunk to half precision
if torch.cuda.is_available():
chunk_cuda = chunk.unsqueeze(0).cuda().half()
else:
chunk_cuda = chunk.unsqueeze(0).cuda()
# Always use autocast for FP16 computation
with amp.autocast(enabled=True):
wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0)
audio_embeddings.append(wavlm_hidden_states)
audio_embeddings = torch.cat(audio_embeddings, dim=0)
torch.cuda.empty_cache()
return audio_embeddings
@torch.no_grad()
def extract_video_landmarks(video_frames):
"""Extract landmarks from video frames"""
# Create a progress bar for Gradio
progress = gr.Progress()
landmarks = []
batch_size = 10
for i in range(0, len(video_frames), batch_size):
# Update progress bar
progress(i / len(video_frames), desc="Extracting facial landmarks")
batch = video_frames[i : i + batch_size].cpu().float()
batch_landmarks = landmarks_extractor.extract_landmarks(batch)
landmarks.extend(batch_landmarks)
torch.cuda.empty_cache()
# Convert landmarks to a list of numpy arrays with consistent shape
processed_landmarks = []
expected_shape = (68, 2) # Common shape for facial landmarks
# Process each landmark to ensure consistent shape
last_valid_landmark = None
for i, lm in enumerate(landmarks):
if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape:
processed_landmarks.append(lm)
last_valid_landmark = lm
else:
# Print information about inconsistent landmarks
if lm is None:
print(f"Warning: Landmark at index {i} is None")
elif not isinstance(lm, np.ndarray):
print(
f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}"
)
elif lm.shape != expected_shape:
print(
f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}"
)
# Replace invalid landmarks with the closest valid landmark if available
if last_valid_landmark is not None:
processed_landmarks.append(last_valid_landmark.copy())
else:
# If no valid landmark has been seen yet, look ahead for a valid one
found_future_valid = False
for future_lm in landmarks[i + 1 :]:
if (
future_lm is not None
and isinstance(future_lm, np.ndarray)
and future_lm.shape == expected_shape
):
processed_landmarks.append(future_lm.copy())
found_future_valid = True
break
# If no valid landmark found in the future, use zeros
if not found_future_valid:
processed_landmarks.append(np.zeros(expected_shape))
return np.array(processed_landmarks)
@torch.no_grad()
def sample(
audio_list,
gt_keyframes,
masks_keyframes,
to_remove,
test_keyframes_list,
num_frames,
device,
emb,
force_uc_zero_embeddings,
n_batch_keyframes,
n_batch,
test_interpolation_list,
audio_interpolation_list,
masks_interpolation,
gt_interpolation,
model_keyframes,
model,
):
# Create a progress bar for Gradio
progress = gr.Progress()
condition = torch.zeros(1, 3, 512, 512).to(device)
if torch.cuda.is_available():
condition = condition.half()
audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames)
gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames)
# Rearrange masks_keyframes and save locally
masks_keyframes = rearrange(
masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames
)
# Convert to_remove into chunks of num_frames
to_remove_chunks = [
to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames)
]
test_keyframes_list = [
test_keyframes_list[i : i + num_frames]
for i in range(0, len(test_keyframes_list), num_frames)
]
audio_cond = audio_list
if emb is not None:
embbedings = emb.unsqueeze(0).to(device)
if torch.cuda.is_available():
embbedings = embbedings.half()
else:
embbedings = None
# One batch of keframes is approximately 7 seconds
chunk_size = 2
complete_video = []
start_idx = 0
last_frame_z = None
last_frame_x = None
last_keyframe_idx = None
last_to_remove = None
total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size
for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)):
# Update progress bar
progress(chunk_idx / total_chunks, desc="Generating video")
# Clear GPU cache between chunks
torch.cuda.empty_cache()
chunk_end = min(chunk_start + chunk_size, len(audio_cond))
chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda()
if torch.cuda.is_available():
chunk_audio_cond = chunk_audio_cond.half()
chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda()
chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda()
if torch.cuda.is_available():
chunk_gt_keyframes = chunk_gt_keyframes.half()
chunk_masks = chunk_masks.half()
test_keyframes_list_unwrapped = [
elem
for sublist in test_keyframes_list[chunk_start:chunk_end]
for elem in sublist
]
to_remove_chunks_unwrapped = [
elem
for sublist in to_remove_chunks[chunk_start:chunk_end]
for elem in sublist
]
if last_keyframe_idx is not None:
test_keyframes_list_unwrapped = [
last_keyframe_idx
] + test_keyframes_list_unwrapped
to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped
last_keyframe_idx = test_keyframes_list_unwrapped[-1]
last_to_remove = to_remove_chunks_unwrapped[-1]
# Find the first non-None keyframe in the chunk
first_keyframe = next(
(kf for kf in test_keyframes_list_unwrapped if kf is not None), None
)
# Find the last non-None keyframe in the chunk
last_keyframe = next(
(kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None),
None,
)
start_idx = next(
(
idx
for idx, comb in enumerate(test_interpolation_list)
if comb[0] == first_keyframe
),
None,
)
end_idx = next(
(
idx
for idx, comb in enumerate(reversed(test_interpolation_list))
if comb[1] == last_keyframe
),
None,
)
if start_idx is not None and end_idx is not None:
end_idx = (
len(test_interpolation_list) - 1 - end_idx
) # Adjust for reversed enumeration
end_idx += 1
if start_idx is None:
break
if end_idx < start_idx:
end_idx = len(audio_interpolation_list)
audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx]
chunk_masks_interpolation = masks_interpolation[start_idx:end_idx]
gt_interpolation_chunks = gt_interpolation[start_idx:end_idx]
if torch.cuda.is_available():
audio_interpolation_list_chunk = [
chunk.half() for chunk in audio_interpolation_list_chunk
]
chunk_masks_interpolation = [
chunk.half() for chunk in chunk_masks_interpolation
]
gt_interpolation_chunks = [
chunk.half() for chunk in gt_interpolation_chunks
]
progress(chunk_idx / total_chunks, desc="Generating keyframes")
# Always use autocast for FP16 computation
with amp.autocast(enabled=True):
samples_z = sample_keyframes(
model_keyframes,
chunk_audio_cond,
chunk_gt_keyframes,
chunk_masks,
condition.cuda(),
num_frames,
24,
0.0,
device,
embbedings.cuda() if embbedings is not None else None,
force_uc_zero_embeddings,
n_batch_keyframes,
0,
1.0,
None,
gt_as_cond=False,
)
if last_frame_x is not None:
# samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0)
samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0)
# last_frame_x = samples_x[-1]
last_frame_z = samples_z[-1]
progress(chunk_idx / total_chunks, desc="Interpolating frames")
# Always use autocast for FP16 computation
with amp.autocast(enabled=True):
vid = sample_interpolation(
model,
samples_z,
# samples_x,
audio_interpolation_list_chunk,
gt_interpolation_chunks,
chunk_masks_interpolation,
condition.cuda(),
num_frames,
device,
1,
24,
0.0,
force_uc_zero_embeddings,
n_batch,
chunk_size,
1.0,
None,
cut_audio=False,
to_remove=to_remove_chunks_unwrapped,
)
if chunk_start == 0:
complete_video = vid
else:
complete_video = np.concatenate([complete_video[:-1], vid], axis=0)
return complete_video
def process_video(video_input, audio_input, max_num_seconds):
"""Main processing function to generate synchronized video"""
# Display a message to the user about the processing time
gr.Info("Processing video. This may take a while...", duration=10)
gr.Info(
"If you're tired of waiting, try playing the Wordle game in the other tab!",
duration=10,
)
# Use default media if none provided
if video_input is None:
video_input = DEFAULT_VIDEO_PATH
print(f"Using default video: {DEFAULT_VIDEO_PATH}")
if audio_input is None:
audio_input = DEFAULT_AUDIO_PATH
print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
try:
# Calculate hashes for cache keys
video_path_hash = video_input
audio_path_hash = audio_input
# Check if we need to recompute video embeddings
video_cache_hit = cache["video"]["path"] == video_path_hash
audio_cache_hit = cache["audio"]["path"] == audio_path_hash
if video_cache_hit and audio_cache_hit:
print("Using cached video and audio computations")
# Make copies of cached data to avoid modifying cache
video_embedding = cache["video"]["embedding"].clone()
video_frames = cache["video"]["frames"].clone()
video_landmarks = cache["video"]["landmarks"].copy()
raw_audio = cache["audio"]["raw_audio"].clone()
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
hubert_embedding = cache["audio"]["hubert_embedding"].clone()
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
# Ensure all data is truncated to the same length if needed
min_len = min(
len(video_frames),
len(raw_audio),
len(hubert_embedding),
len(wavlm_embedding),
)
video_frames = video_frames[:min_len]
video_embedding = video_embedding[:min_len]
video_landmarks = video_landmarks[:min_len]
raw_audio = raw_audio[:min_len]
hubert_embedding = hubert_embedding[:min_len]
wavlm_embedding = wavlm_embedding[:min_len]
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
else:
# Process video if needed
if not video_cache_hit:
print("Computing video embeddings and landmarks")
video_reader = decord.VideoReader(video_input)
decord.bridge.set_bridge("torch")
if not audio_cache_hit:
# Need to process audio to determine min_len
raw_audio = get_raw_audio(audio_input, 16000)
if len(raw_audio) == 0 or len(video_reader) == 0:
raise ValueError("Empty audio or video input")
min_len = min(len(raw_audio), len(video_reader))
# Store full audio in cache
cache["audio"]["path"] = audio_path_hash
cache["audio"]["raw_audio"] = raw_audio.clone()
# Create truncated copy for processing
raw_audio = raw_audio[:min_len]
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
else:
# Use cached audio - make a copy
if cache["audio"]["raw_audio"] is None:
raise ValueError("Cached audio is None")
raw_audio = cache["audio"]["raw_audio"].clone()
if len(raw_audio) == 0 or len(video_reader) == 0:
raise ValueError("Empty cached audio or video input")
min_len = min(len(raw_audio), len(video_reader))
# Create truncated copy for processing
raw_audio = raw_audio[:min_len]
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
# Compute video embeddings and landmarks - store full version in cache
video_embedding, video_frames = compute_video_embedding(
video_reader, len(video_reader)
)
video_landmarks = extract_video_landmarks(video_frames)
# Update video cache with full versions
cache["video"]["path"] = video_path_hash
cache["video"]["embedding"] = video_embedding
cache["video"]["frames"] = video_frames
cache["video"]["landmarks"] = video_landmarks
# Create truncated copies for processing
video_embedding = video_embedding[:min_len]
video_frames = video_frames[:min_len]
video_landmarks = video_landmarks[:min_len]
else:
# Use cached video data - make copies
print("Using cached video computations")
if (
cache["video"]["embedding"] is None
or cache["video"]["frames"] is None
or cache["video"]["landmarks"] is None
):
raise ValueError("One or more video cache entries are None")
if not audio_cache_hit:
# New audio with cached video
raw_audio = get_raw_audio(audio_input, 16000)
if len(raw_audio) == 0:
raise ValueError("Empty audio input")
# Store full audio in cache
cache["audio"]["path"] = audio_path_hash
cache["audio"]["raw_audio"] = raw_audio.clone()
# Make copies of video data
video_embedding = cache["video"]["embedding"].clone()
video_frames = cache["video"]["frames"].clone()
video_landmarks = cache["video"]["landmarks"].copy()
# Determine truncation length and create truncated copies
min_len = min(len(raw_audio), len(video_frames))
raw_audio = raw_audio[:min_len]
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
video_frames = video_frames[:min_len]
video_embedding = video_embedding[:min_len]
video_landmarks = video_landmarks[:min_len]
else:
# Both video and audio are cached - should not reach here
# as it's handled in the first if statement
pass
# Process audio if needed
if not audio_cache_hit:
print("Computing audio embeddings")
# Compute audio embeddings with the truncated audio
hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
# Update audio cache with full embeddings
# Note: raw_audio was already cached above
cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
else:
# Use cached audio data - make copies
if (
cache["audio"]["hubert_embedding"] is None
or cache["audio"]["wavlm_embedding"] is None
):
raise ValueError(
"One or more audio embedding cache entries are None"
)
hubert_embedding = cache["audio"]["hubert_embedding"].clone()
wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
# Make sure embeddings match the truncated video length if needed
if "min_len" in locals() and (
min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
):
hubert_embedding = hubert_embedding[:min_len]
wavlm_embedding = wavlm_embedding[:min_len]
# Apply max_num_seconds limit if specified
if max_num_seconds > 0:
# Convert seconds to frames (assuming 25 fps)
max_frames = int(max_num_seconds * 25)
# Truncate all data to max_frames
video_embedding = video_embedding[:max_frames]
video_frames = video_frames[:max_frames]
video_landmarks = video_landmarks[:max_frames]
hubert_embedding = hubert_embedding[:max_frames]
wavlm_embedding = wavlm_embedding[:max_frames]
raw_audio = raw_audio[:max_frames]
raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
# Validate shapes before proceeding
assert video_embedding.shape[0] == hubert_embedding.shape[0], (
f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
)
assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
)
assert video_embedding.shape[0] == video_landmarks.shape[0], (
f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
)
print(f"Hubert embedding shape: {hubert_embedding.shape}")
print(f"WavLM embedding shape: {wavlm_embedding.shape}")
print(f"Video embedding shape: {video_embedding.shape}")
print(f"Video landmarks shape: {video_landmarks.shape}")
# Create pipeline inputs for models
(
interpolation_chunks,
keyframe_chunks,
audio_interpolation_chunks,
audio_keyframe_chunks,
emb_cond,
masks_keyframe_chunks,
masks_interpolation_chunks,
to_remove,
audio_interpolation_idx,
audio_keyframe_idx,
) = create_pipeline_inputs(
hubert_embedding,
wavlm_embedding,
14,
video_embedding,
video_landmarks,
overlap=1,
add_zero_flag=True,
mask_arms=None,
nose_index=28,
)
complete_video = sample(
audio_keyframe_chunks,
keyframe_chunks,
masks_keyframe_chunks,
to_remove,
audio_keyframe_idx,
14,
"cuda",
emb_cond,
[],
3,
3,
audio_interpolation_idx,
audio_interpolation_chunks,
masks_interpolation_chunks,
interpolation_chunks,
keyframe_model,
interpolation_model,
)
complete_audio = rearrange(
raw_audio[: complete_video.shape[0]], "f s -> () (f s)"
)
# 4. Convert frames to video and combine with audio
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
output_path = temp_video.name
print("Saving video to", output_path)
save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
torch.cuda.empty_cache()
return output_path
except Exception as e:
raise e
print(f"Error processing video: {str(e)}")
return None
def get_max_duration(video_input, audio_input):
"""Get the maximum duration in seconds for the slider"""
try:
# Default to 60 seconds if files don't exist
if video_input is None or not os.path.exists(video_input):
video_input = DEFAULT_VIDEO_PATH
if audio_input is None or not os.path.exists(audio_input):
audio_input = DEFAULT_AUDIO_PATH
# Get video duration
video_reader = decord.VideoReader(video_input)
video_duration = len(video_reader) / video_reader.get_avg_fps()
# Get audio duration
raw_audio = get_raw_audio(audio_input, 16000)
audio_duration = len(raw_audio) / 25 # Assuming 25 fps
# Return the minimum of the two durations
return min(video_duration, audio_duration)
except Exception as e:
print(f"Error getting max duration: {str(e)}")
return 60 # Default to 60 seconds
def new_game_click(state):
"""Handle the 'New Game' button click."""
message = state.new_game()
feedback_history = state.get_feedback_history()
return state, feedback_history, message
def submit_guess_click(guess, state):
"""Handle the 'Submit Guess' button click."""
message = state.submit_guess(guess)
feedback_history = state.get_feedback_history()
return state, feedback_history, message
# Create Gradio interface
with gr.Blocks(title="Video Synchronization with Diffusion Models") as demo:
gr.Markdown("# Video Synchronization with Diffusion Models")
gr.Markdown(
"Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio."
)
with gr.Tabs():
with gr.TabItem("Video Synchronization"):
with gr.Row():
with gr.Column():
video_input = gr.Video(
label="Input Video",
value=DEFAULT_VIDEO_PATH
if os.path.exists(DEFAULT_VIDEO_PATH)
else None,
width=512,
height=512,
)
audio_input = gr.Audio(
label="Input Audio",
type="filepath",
value=DEFAULT_AUDIO_PATH
if os.path.exists(DEFAULT_AUDIO_PATH)
else None,
)
max_duration = gr.State(value=60) # Default max duration
max_seconds_slider = gr.Slider(
minimum=0,
maximum=60, # Will be updated dynamically
value=0,
step=1,
label="Max Duration (seconds, 0 = full length)",
info="Limit the processing duration (0 means use full length)",
)
process_button = gr.Button("Generate Synchronized Video")
with gr.Column("Output Video"):
video_output = gr.Video(label="Output Video", width=512, height=512)
# Update slider max value when inputs change
def update_slider_max(video, audio):
max_dur = get_max_duration(video, audio)
return {"maximum": max_dur, "__type__": "update"}
video_input.change(
update_slider_max, [video_input, audio_input], [max_seconds_slider]
)
audio_input.change(
update_slider_max, [video_input, audio_input], [max_seconds_slider]
)
# Show Wordle message when processing starts and hide when complete
process_button.click(
fn=process_video,
inputs=[video_input, audio_input, max_seconds_slider],
outputs=video_output,
)
with gr.TabItem("Wordle Game"):
state = gr.State(WordleGame()) # Persist the WordleGame instance
guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5)
submit_btn = gr.Button("Submit Guess")
new_game_btn = gr.Button("New Game")
feedback_display = gr.HTML(label="Guesses")
message_display = gr.Textbox(
label="Message", interactive=False, value="Click 'New Game' to start."
)
# Connect the 'New Game' button
new_game_btn.click(
fn=new_game_click,
inputs=[state],
outputs=[state, feedback_display, message_display],
)
# Connect the 'Submit Guess' button
submit_btn.click(
fn=submit_guess_click,
inputs=[guess_input, state],
outputs=[state, feedback_display, message_display],
)
gr.Markdown("## How it works")
gr.Markdown("""
1. The system extracts embeddings and landmarks from the input video
2. Audio embeddings are computed from the input audio
3. A keyframe model generates key visual frames
4. An interpolation model creates a smooth video between keyframes
5. The final video is rendered with the new audio
""")
if __name__ == "__main__":
demo.launch()