Spaces:
Runtime error
Runtime error
import torch | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
from einops import rearrange, repeat | |
import math | |
def get_unique_embedder_keys_from_conditioner(conditioner): | |
return list(set([x.input_key for x in conditioner.embedders])) | |
def get_batch(keys, value_dict, N, T, device): | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "fps_id": | |
batch[key] = ( | |
torch.tensor([value_dict["fps_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "motion_bucket_id": | |
batch[key] = ( | |
torch.tensor([value_dict["motion_bucket_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "cond_aug": | |
batch[key] = repeat( | |
torch.tensor([value_dict["cond_aug"]]).to(device), | |
"1 -> b", | |
b=math.prod(N), | |
) | |
elif key == "cond_frames": | |
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) | |
elif key == "cond_frames_without_noise": | |
batch[key] = repeat( | |
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] | |
) | |
else: | |
batch[key] = value_dict[key] | |
if T is not None: | |
batch["num_video_frames"] = T | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
return batch, batch_uc | |
def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor: | |
""" | |
Merges overlapping segments by averaging overlapping frames. | |
Segments have shape (b, t, ...), where 'b' is the number of segments, | |
't' is frames per segment, and '...' are other dimensions. | |
Args: | |
segments: Tensor of shape (b, t, ...) | |
overlap: Integer, number of frames that overlap between consecutive segments | |
Returns: | |
Tensor of the merged video | |
""" | |
# Get the shape details | |
b, t, *other_dims = segments.shape | |
num_frames = (b - 1) * ( | |
t - overlap | |
) + t # Calculate the total number of frames in the merged video | |
# Initialize the output tensor and a count tensor to keep track of contributions for averaging | |
output_shape = [num_frames] + other_dims | |
output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device) | |
count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device) | |
current_index = 0 | |
for i in range(b): | |
end_index = current_index + t | |
# Add the segment to the output tensor | |
output[current_index:end_index] += rearrange(segments[i], "... -> ...") | |
# Increment the count tensor for each frame that's added | |
count[current_index:end_index] += 1 | |
# Update the starting index for the next segment | |
current_index += t - overlap | |
# Avoid division by zero | |
count[count == 0] = 1 | |
# Average the frames where there's overlap | |
output /= count | |
return output | |
def get_batch_overlap( | |
keys: List[str], | |
value_dict: Dict[str, Any], | |
N: Tuple[int, ...], | |
T: Optional[int], | |
device: str, | |
) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
""" | |
Create a batch dictionary with overlapping frames for model input. | |
Args: | |
keys: List of keys to include in the batch | |
value_dict: Dictionary containing values for each key | |
N: Batch dimensions | |
T: Number of frames (optional) | |
device: Device to place tensors on | |
Returns: | |
Tuple of (batch dictionary, unconditional batch dictionary) | |
""" | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "fps_id": | |
batch[key] = ( | |
torch.tensor([value_dict["fps_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "motion_bucket_id": | |
batch[key] = ( | |
torch.tensor([value_dict["motion_bucket_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "cond_aug": | |
batch[key] = repeat( | |
torch.tensor([value_dict["cond_aug"]]).to(device), | |
"1 -> b", | |
b=math.prod(N), | |
) | |
elif key == "cond_frames": | |
batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0]) | |
elif key == "cond_frames_without_noise": | |
batch[key] = repeat( | |
value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0] | |
) | |
else: | |
batch[key] = value_dict[key] | |
if T is not None: | |
batch["num_video_frames"] = T | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
return batch, batch_uc | |
def sample_keyframes( | |
model_keyframes: Any, | |
audio_list: torch.Tensor, | |
gt_list: torch.Tensor, | |
masks_list: torch.Tensor, | |
condition: torch.Tensor, | |
num_frames: int, | |
fps_id: int, | |
cond_aug: float, | |
device: str, | |
embbedings: Optional[torch.Tensor], | |
force_uc_zero_embeddings: List[str], | |
n_batch_keyframes: int, | |
added_frames: int, | |
strength: float, | |
scale: Optional[Union[float, List[float]]], | |
gt_as_cond: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Sample keyframes using the keyframe generation model. | |
Args: | |
model_keyframes: The keyframe generation model | |
audio_list: List of audio embeddings | |
gt_list: List of ground truth frames | |
masks_list: List of masks | |
condition: Conditioning tensor | |
num_frames: Number of frames to generate | |
fps_id: FPS ID | |
cond_aug: Conditioning augmentation factor | |
device: Device to use for computation | |
embbedings: Optional embeddings | |
force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case | |
n_batch_keyframes: Batch size for keyframe generation | |
added_frames: Number of additional frames | |
strength: Strength parameter for sampling | |
scale: Scale parameter for guidance | |
gt_as_cond: Whether to use ground truth as conditioning | |
Returns: | |
Tuple of (latent samples, decoded samples) | |
""" | |
if scale is not None: | |
model_keyframes.sampler.guider.set_scale(scale) | |
# samples_list = [] | |
samples_z_list = [] | |
# samples_x_list = [] | |
for i in range(audio_list.shape[0]): | |
H, W = condition.shape[-2:] | |
assert condition.shape[1] == 3 | |
F = 8 | |
C = 4 | |
shape = (num_frames, C, H // F, W // F) | |
audio_cond = audio_list[i].unsqueeze(0) | |
value_dict: Dict[str, Any] = {} | |
value_dict["fps_id"] = fps_id | |
value_dict["cond_aug"] = cond_aug | |
value_dict["cond_frames_without_noise"] = condition | |
if embbedings is not None: | |
value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like( | |
embbedings | |
) | |
else: | |
value_dict["cond_frames"] = condition + cond_aug * torch.randn_like( | |
condition | |
) | |
gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device) | |
if gt_as_cond: | |
value_dict["cond_frames"] = gt[:, :, 0] | |
value_dict["cond_aug"] = cond_aug | |
value_dict["audio_emb"] = audio_cond | |
value_dict["gt"] = gt | |
value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device) | |
with torch.no_grad(): | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner), | |
value_dict, | |
[1, 1], | |
T=num_frames, | |
device=device, | |
) | |
c, uc = model_keyframes.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
) | |
for k in ["crossattn"]: | |
if c[k].shape[1] != num_frames: | |
uc[k] = repeat( | |
uc[k], | |
"b ... -> b t ...", | |
t=num_frames, | |
) | |
uc[k] = rearrange( | |
uc[k], | |
"b t ... -> (b t) ...", | |
t=num_frames, | |
) | |
c[k] = repeat( | |
c[k], | |
"b ... -> b t ...", | |
t=num_frames, | |
) | |
c[k] = rearrange( | |
c[k], | |
"b t ... -> (b t) ...", | |
t=num_frames, | |
) | |
video = torch.randn(shape, device=device) | |
additional_model_inputs: Dict[str, torch.Tensor] = {} | |
additional_model_inputs["image_only_indicator"] = torch.zeros( | |
n_batch_keyframes, num_frames | |
).to(device) | |
additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
def denoiser( | |
input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] | |
) -> torch.Tensor: | |
return model_keyframes.denoiser( | |
model_keyframes.model, | |
input, | |
sigma, | |
c, | |
**additional_model_inputs, | |
) | |
samples_z = model_keyframes.sampler( | |
denoiser, video, cond=c, uc=uc, strength=strength | |
) | |
samples_z_list.append(samples_z) | |
# samples_x_list.append(samples_x) | |
# samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
# samples_list.append(samples) | |
video = None | |
# samples = ( | |
# torch.concat(samples_list)[:-added_frames] | |
# if added_frames > 0 | |
# else torch.concat(samples_list) | |
# ) | |
samples_z = ( | |
torch.concat(samples_z_list)[:-added_frames] | |
if added_frames > 0 | |
else torch.concat(samples_z_list) | |
) | |
# samples_x = ( | |
# torch.concat(samples_x_list)[:-added_frames] | |
# if added_frames > 0 | |
# else torch.concat(samples_x_list) | |
# ) | |
return samples_z | |
def sample_interpolation( | |
model: Any, | |
samples_z: torch.Tensor, | |
# samples_x: torch.Tensor, | |
audio_interpolation_list: List[torch.Tensor], | |
gt_chunks: List[torch.Tensor], | |
masks_chunks: List[torch.Tensor], | |
condition: torch.Tensor, | |
num_frames: int, | |
device: str, | |
overlap: int, | |
fps_id: int, | |
cond_aug: float, | |
force_uc_zero_embeddings: List[str], | |
n_batch: int, | |
chunk_size: Optional[int], | |
strength: float, | |
scale: Optional[float] = None, | |
cut_audio: bool = False, | |
to_remove: List[bool] = [], | |
) -> np.ndarray: | |
""" | |
Sample interpolation frames between keyframes. | |
Args: | |
model: The interpolation model | |
samples_z: Latent samples from keyframe generation | |
samples_x: Decoded samples from keyframe generation | |
audio_interpolation_list: List of audio embeddings for interpolation | |
gt_chunks: Ground truth video chunks | |
masks_chunks: Mask chunks for conditional generation | |
condition: Visual conditioning | |
num_frames: Number of frames to generate | |
device: Device to run inference on | |
overlap: Number of frames to overlap between segments | |
fps_id: FPS ID for conditioning | |
motion_bucket_id: Motion bucket ID for conditioning | |
cond_aug: Conditioning augmentation strength | |
force_uc_zero_embeddings: Keys to zero out in unconditional embeddings | |
n_batch: Batch size for generation | |
chunk_size: Size of chunks for processing (to manage memory) | |
strength: Strength of the conditioning | |
scale: Optional scale for classifier-free guidance | |
cut_audio: Whether to cut audio embeddings | |
to_remove: List of flags indicating which frames to remove | |
Returns: | |
Generated video frames as numpy array | |
""" | |
if scale is not None: | |
model.sampler.guider.set_scale(scale) | |
# Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last] | |
# The first and last are the first and last frames of the interpolation | |
# interpolation_cond_list = [] | |
interpolation_cond_list_emb = [] | |
# samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i] | |
samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i] | |
for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2): | |
# interpolation_cond_list.append( | |
# torch.stack([samples_x[i], samples_x[i + 1]], dim=1) | |
# ) | |
interpolation_cond_list_emb.append( | |
torch.stack([samples_z[i], samples_z[i + 1]], dim=1) | |
) | |
# condition = torch.stack(interpolation_cond_list).to(device) | |
audio_cond = torch.stack(audio_interpolation_list).to(device) | |
embbedings = torch.stack(interpolation_cond_list_emb).to(device) | |
gt_chunks = torch.stack(gt_chunks).to(device) | |
masks_chunks = torch.stack(masks_chunks).to(device) | |
H, W = 512, 512 | |
F = 8 | |
C = 4 | |
shape = (num_frames * audio_cond.shape[0], C, H // F, W // F) | |
value_dict: Dict[str, Any] = {} | |
value_dict["fps_id"] = fps_id | |
value_dict["cond_aug"] = cond_aug | |
# value_dict["cond_frames_without_noise"] = condition | |
value_dict["cond_frames"] = embbedings | |
value_dict["cond_aug"] = cond_aug | |
if cut_audio: | |
value_dict["audio_emb"] = audio_cond[:, :, :, :768] | |
else: | |
value_dict["audio_emb"] = audio_cond | |
value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device) | |
value_dict["masks"] = masks_chunks.transpose(1, 2).to(device) | |
with torch.no_grad(): | |
with torch.autocast(device): | |
batch, batch_uc = get_batch_overlap( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
[1, num_frames], | |
T=num_frames, | |
device=device, | |
) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
) | |
for k in ["crossattn"]: | |
if c[k].shape[1] != num_frames: | |
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) | |
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) | |
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) | |
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) | |
video = torch.randn(shape, device=device) | |
additional_model_inputs: Dict[str, torch.Tensor] = {} | |
additional_model_inputs["image_only_indicator"] = torch.zeros( | |
n_batch, num_frames | |
).to(device) | |
additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
# Debug information | |
print( | |
f"Shapes - Embeddings: {embbedings.shape}, " | |
f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}" | |
) | |
if chunk_size is not None: | |
chunk_size = chunk_size * num_frames | |
def denoiser( | |
input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor] | |
) -> torch.Tensor: | |
return model.denoiser( | |
model.model, | |
input, | |
sigma, | |
c, | |
num_overlap_frames=overlap, | |
num_frames=num_frames, | |
n_skips=n_batch, | |
chunk_size=chunk_size, | |
**additional_model_inputs, | |
) | |
samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength) | |
samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames) | |
samples_z[:, 0] = embbedings[:, :, 0] | |
samples_z[:, -1] = embbedings[:, :, 1] | |
samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w") | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
# Free up memory | |
video = None | |
samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames) | |
samples = merge_overlapping_segments(samples, overlap) | |
vid = ( | |
(rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8) | |
) | |
return vid | |