import torch from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from einops import rearrange, repeat import math def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def get_batch(keys, value_dict, N, T, device): batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor: """ Merges overlapping segments by averaging overlapping frames. Segments have shape (b, t, ...), where 'b' is the number of segments, 't' is frames per segment, and '...' are other dimensions. Args: segments: Tensor of shape (b, t, ...) overlap: Integer, number of frames that overlap between consecutive segments Returns: Tensor of the merged video """ # Get the shape details b, t, *other_dims = segments.shape num_frames = (b - 1) * ( t - overlap ) + t # Calculate the total number of frames in the merged video # Initialize the output tensor and a count tensor to keep track of contributions for averaging output_shape = [num_frames] + other_dims output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device) count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device) current_index = 0 for i in range(b): end_index = current_index + t # Add the segment to the output tensor output[current_index:end_index] += rearrange(segments[i], "... -> ...") # Increment the count tensor for each frame that's added count[current_index:end_index] += 1 # Update the starting index for the next segment current_index += t - overlap # Avoid division by zero count[count == 0] = 1 # Average the frames where there's overlap output /= count return output def get_batch_overlap( keys: List[str], value_dict: Dict[str, Any], N: Tuple[int, ...], T: Optional[int], device: str, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Create a batch dictionary with overlapping frames for model input. Args: keys: List of keys to include in the batch value_dict: Dictionary containing values for each key N: Batch dimensions T: Number of frames (optional) device: Device to place tensors on Returns: Tuple of (batch dictionary, unconditional batch dictionary) """ batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0] ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc @torch.inference_mode() def sample_keyframes( model_keyframes: Any, audio_list: torch.Tensor, gt_list: torch.Tensor, masks_list: torch.Tensor, condition: torch.Tensor, num_frames: int, fps_id: int, cond_aug: float, device: str, embbedings: Optional[torch.Tensor], force_uc_zero_embeddings: List[str], n_batch_keyframes: int, added_frames: int, strength: float, scale: Optional[Union[float, List[float]]], gt_as_cond: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample keyframes using the keyframe generation model. Args: model_keyframes: The keyframe generation model audio_list: List of audio embeddings gt_list: List of ground truth frames masks_list: List of masks condition: Conditioning tensor num_frames: Number of frames to generate fps_id: FPS ID cond_aug: Conditioning augmentation factor device: Device to use for computation embbedings: Optional embeddings force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case n_batch_keyframes: Batch size for keyframe generation added_frames: Number of additional frames strength: Strength parameter for sampling scale: Scale parameter for guidance gt_as_cond: Whether to use ground truth as conditioning Returns: Tuple of (latent samples, decoded samples) """ if scale is not None: model_keyframes.sampler.guider.set_scale(scale) # samples_list = [] samples_z_list = [] # samples_x_list = [] for i in range(audio_list.shape[0]): H, W = condition.shape[-2:] assert condition.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) audio_cond = audio_list[i].unsqueeze(0) value_dict: Dict[str, Any] = {} value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = condition if embbedings is not None: value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like( embbedings ) else: value_dict["cond_frames"] = condition + cond_aug * torch.randn_like( condition ) gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device) if gt_as_cond: value_dict["cond_frames"] = gt[:, :, 0] value_dict["cond_aug"] = cond_aug value_dict["audio_emb"] = audio_cond value_dict["gt"] = gt value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device) with torch.no_grad(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner), value_dict, [1, 1], T=num_frames, device=device, ) c, uc = model_keyframes.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in ["crossattn"]: if c[k].shape[1] != num_frames: uc[k] = repeat( uc[k], "b ... -> b t ...", t=num_frames, ) uc[k] = rearrange( uc[k], "b t ... -> (b t) ...", t=num_frames, ) c[k] = repeat( c[k], "b ... -> b t ...", t=num_frames, ) c[k] = rearrange( c[k], "b t ... -> (b t) ...", t=num_frames, ) video = torch.randn(shape, device=device) additional_model_inputs: Dict[str, torch.Tensor] = {} additional_model_inputs["image_only_indicator"] = torch.zeros( n_batch_keyframes, num_frames ).to(device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] def denoiser( input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] ) -> torch.Tensor: return model_keyframes.denoiser( model_keyframes.model, input, sigma, c, **additional_model_inputs, ) samples_z = model_keyframes.sampler( denoiser, video, cond=c, uc=uc, strength=strength ) samples_z_list.append(samples_z) # samples_x_list.append(samples_x) # samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) # samples_list.append(samples) video = None # samples = ( # torch.concat(samples_list)[:-added_frames] # if added_frames > 0 # else torch.concat(samples_list) # ) samples_z = ( torch.concat(samples_z_list)[:-added_frames] if added_frames > 0 else torch.concat(samples_z_list) ) # samples_x = ( # torch.concat(samples_x_list)[:-added_frames] # if added_frames > 0 # else torch.concat(samples_x_list) # ) return samples_z @torch.inference_mode() def sample_interpolation( model: Any, samples_z: torch.Tensor, # samples_x: torch.Tensor, audio_interpolation_list: List[torch.Tensor], gt_chunks: List[torch.Tensor], masks_chunks: List[torch.Tensor], condition: torch.Tensor, num_frames: int, device: str, overlap: int, fps_id: int, cond_aug: float, force_uc_zero_embeddings: List[str], n_batch: int, chunk_size: Optional[int], strength: float, scale: Optional[float] = None, cut_audio: bool = False, to_remove: List[bool] = [], ) -> np.ndarray: """ Sample interpolation frames between keyframes. Args: model: The interpolation model samples_z: Latent samples from keyframe generation samples_x: Decoded samples from keyframe generation audio_interpolation_list: List of audio embeddings for interpolation gt_chunks: Ground truth video chunks masks_chunks: Mask chunks for conditional generation condition: Visual conditioning num_frames: Number of frames to generate device: Device to run inference on overlap: Number of frames to overlap between segments fps_id: FPS ID for conditioning motion_bucket_id: Motion bucket ID for conditioning cond_aug: Conditioning augmentation strength force_uc_zero_embeddings: Keys to zero out in unconditional embeddings n_batch: Batch size for generation chunk_size: Size of chunks for processing (to manage memory) strength: Strength of the conditioning scale: Optional scale for classifier-free guidance cut_audio: Whether to cut audio embeddings to_remove: List of flags indicating which frames to remove Returns: Generated video frames as numpy array """ if scale is not None: model.sampler.guider.set_scale(scale) # Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last] # The first and last are the first and last frames of the interpolation # interpolation_cond_list = [] interpolation_cond_list_emb = [] # samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i] samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i] for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2): # interpolation_cond_list.append( # torch.stack([samples_x[i], samples_x[i + 1]], dim=1) # ) interpolation_cond_list_emb.append( torch.stack([samples_z[i], samples_z[i + 1]], dim=1) ) # condition = torch.stack(interpolation_cond_list).to(device) audio_cond = torch.stack(audio_interpolation_list).to(device) embbedings = torch.stack(interpolation_cond_list_emb).to(device) gt_chunks = torch.stack(gt_chunks).to(device) masks_chunks = torch.stack(masks_chunks).to(device) H, W = 512, 512 F = 8 C = 4 shape = (num_frames * audio_cond.shape[0], C, H // F, W // F) value_dict: Dict[str, Any] = {} value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug # value_dict["cond_frames_without_noise"] = condition value_dict["cond_frames"] = embbedings value_dict["cond_aug"] = cond_aug if cut_audio: value_dict["audio_emb"] = audio_cond[:, :, :, :768] else: value_dict["audio_emb"] = audio_cond value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device) value_dict["masks"] = masks_chunks.transpose(1, 2).to(device) with torch.no_grad(): with torch.autocast(device): batch, batch_uc = get_batch_overlap( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1, num_frames], T=num_frames, device=device, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in ["crossattn"]: if c[k].shape[1] != num_frames: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) video = torch.randn(shape, device=device) additional_model_inputs: Dict[str, torch.Tensor] = {} additional_model_inputs["image_only_indicator"] = torch.zeros( n_batch, num_frames ).to(device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] # Debug information print( f"Shapes - Embeddings: {embbedings.shape}, " f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}" ) if chunk_size is not None: chunk_size = chunk_size * num_frames def denoiser( input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] ) -> torch.Tensor: return model.denoiser( model.model, input, sigma, c, num_overlap_frames=overlap, num_frames=num_frames, n_skips=n_batch, chunk_size=chunk_size, **additional_model_inputs, ) samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength) samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames) samples_z[:, 0] = embbedings[:, :, 0] samples_z[:, -1] = embbedings[:, :, 1] samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w") samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) # Free up memory video = None samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames) samples = merge_overlapping_segments(samples, overlap) vid = ( (rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8) ) return vid