Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import numpy as np | |
from functools import partial | |
from torch.utils.data import Dataset, WeightedRandomSampler | |
import torch.nn.functional as F | |
import torch | |
import math | |
import decord | |
from einops import rearrange | |
from more_itertools import sliding_window | |
from omegaconf import ListConfig | |
import torchaudio | |
import soundfile as sf | |
from torchvision.transforms import RandomHorizontalFlip | |
from audiomentations import Compose, AddGaussianNoise, PitchShift | |
from safetensors.torch import load_file | |
from tqdm import tqdm | |
import cv2 | |
from sgm.data.data_utils import ( | |
create_masks_from_landmarks_full_size, | |
create_face_mask_from_landmarks, | |
create_masks_from_landmarks_box, | |
create_masks_from_landmarks_mouth, | |
) | |
from sgm.data.mask import face_mask_cheeks_batch | |
torchaudio.set_audio_backend("sox_io") | |
decord.bridge.set_bridge("torch") | |
def exists(x): | |
return x is not None | |
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): | |
len_file = audio.shape[-1] | |
if max_len_sec or max_len_raw: | |
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) | |
if len_file < int(max_len): | |
extened_wav = torch.nn.functional.pad( | |
audio, (0, int(max_len) - len_file), "constant" | |
) | |
else: | |
extened_wav = audio[:, : int(max_len)] | |
else: | |
extened_wav = audio | |
return extened_wav | |
# Similar to regular video dataset but trades flexibility for speed | |
class VideoDataset(Dataset): | |
def __init__( | |
self, | |
filelist, | |
resize_size=None, | |
audio_folder="Audio", | |
video_folder="CroppedVideos", | |
emotions_folder="emotions", | |
landmarks_folder=None, | |
audio_emb_folder=None, | |
video_extension=".avi", | |
audio_extension=".wav", | |
audio_rate=16000, | |
latent_folder=None, | |
audio_in_video=False, | |
fps=25, | |
num_frames=5, | |
need_cond=True, | |
step=1, | |
mode="prediction", | |
scale_audio=False, | |
augment=False, | |
augment_audio=False, | |
use_latent=False, | |
latent_type="stable", | |
latent_scale=1, # For backwards compatibility | |
from_audio_embedding=False, | |
load_all_possible_indexes=False, | |
audio_emb_type="wavlm", | |
cond_noise=[-3.0, 0.5], | |
motion_id=255.0, | |
data_mean=None, | |
data_std=None, | |
use_latent_condition=False, | |
skip_frames=0, | |
get_separate_id=False, | |
virtual_increase=1, | |
filter_by_length=False, | |
select_randomly=False, | |
balance_datasets=True, | |
use_emotions=False, | |
get_original_frames=False, | |
add_extra_audio_emb=False, | |
expand_box=0.0, | |
nose_index=28, | |
what_mask="full", | |
get_masks=False, | |
): | |
self.audio_folder = audio_folder | |
self.from_audio_embedding = from_audio_embedding | |
self.audio_emb_type = audio_emb_type | |
self.cond_noise = cond_noise | |
self.latent_condition = use_latent_condition | |
precomputed_latent = latent_type | |
self.audio_emb_folder = ( | |
audio_emb_folder if audio_emb_folder is not None else audio_folder | |
) | |
self.skip_frames = skip_frames | |
self.get_separate_id = get_separate_id | |
self.fps = fps | |
self.virtual_increase = virtual_increase | |
self.select_randomly = select_randomly | |
self.use_emotions = use_emotions | |
self.emotions_folder = emotions_folder | |
self.get_original_frames = get_original_frames | |
self.add_extra_audio_emb = add_extra_audio_emb | |
self.expand_box = expand_box | |
self.nose_index = nose_index | |
self.landmarks_folder = landmarks_folder | |
self.what_mask = what_mask | |
self.get_masks = get_masks | |
assert not (exists(data_mean) ^ exists(data_std)), ( | |
"Both data_mean and data_std should be provided" | |
) | |
if data_mean is not None: | |
data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()") | |
data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()") | |
self.data_mean = data_mean | |
self.data_std = data_std | |
self.motion_id = motion_id | |
self.latent_folder = ( | |
latent_folder if latent_folder is not None else video_folder | |
) | |
self.audio_in_video = audio_in_video | |
self.filelist = [] | |
self.audio_filelist = [] | |
self.landmark_filelist = [] if get_masks else None | |
with open(filelist, "r") as files: | |
for f in files.readlines(): | |
f = f.rstrip() | |
audio_path = f.replace(video_folder, audio_folder).replace( | |
video_extension, audio_extension | |
) | |
self.filelist += [f] | |
self.audio_filelist += [audio_path] | |
if self.get_masks: | |
landmark_path = f.replace(video_folder, landmarks_folder).replace( | |
video_extension, ".npy" | |
) | |
self.landmark_filelist += [landmark_path] | |
self.resize_size = resize_size | |
if use_latent and not precomputed_latent: | |
self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8 | |
self.scale_audio = scale_audio | |
self.step = step | |
self.use_latent = use_latent | |
self.precomputed_latent = precomputed_latent | |
self.latent_type = latent_type | |
self.latent_scale = latent_scale | |
self.video_ext = video_extension | |
self.video_folder = video_folder | |
self.augment = augment | |
self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x | |
self.maybe_augment_audio = ( | |
Compose( | |
[ | |
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25), | |
# TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3), | |
PitchShift(min_semitones=-1, max_semitones=1, p=0.25), | |
# Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333), | |
] | |
) | |
if augment_audio | |
else lambda x, sample_rate: x | |
) | |
self.maybe_augment_audio = partial( | |
self.maybe_augment_audio, sample_rate=audio_rate | |
) | |
self.mode = mode | |
if mode == "interpolation": | |
need_cond = False # Interpolation does not need condition as first and last frame becomes the condition | |
self.need_cond = need_cond # If need cond will extract one more frame than the number of frames | |
if get_separate_id: | |
self.need_cond = True | |
# It is used for the conditional model when the condition is not on the temporal dimension | |
num_frames = num_frames if not self.need_cond else num_frames + 1 | |
vr = decord.VideoReader(self.filelist[0]) | |
self.video_rate = math.ceil(vr.get_avg_fps()) | |
print(f"Video rate: {self.video_rate}") | |
self.audio_rate = audio_rate | |
a2v_ratio = fps / float(self.audio_rate) | |
self.samples_per_frame = math.ceil(1 / a2v_ratio) | |
if get_separate_id: | |
assert mode == "prediction", ( | |
"Separate identity frame is only supported for prediction mode" | |
) | |
# No need for extra frame if we are getting a separate identity frame | |
self.need_cond = True | |
num_frames -= 1 | |
self.num_frames = num_frames | |
self.load_all_possible_indexes = load_all_possible_indexes | |
if load_all_possible_indexes: | |
self._indexes = self._get_indexes( | |
self.filelist, self.audio_filelist, self.landmark_filelist | |
) | |
else: | |
if filter_by_length: | |
self._indexes = self.filter_by_length( | |
self.filelist, self.audio_filelist, self.landmark_filelist | |
) | |
else: | |
if self.get_masks: | |
self._indexes = list( | |
zip(self.filelist, self.audio_filelist, self.landmark_filelist) | |
) | |
else: | |
self._indexes = list( | |
zip( | |
self.filelist, | |
self.audio_filelist, | |
[None] * len(self.filelist), | |
) | |
) | |
self.balance_datasets = balance_datasets | |
if self.balance_datasets: | |
self.weights = self._calculate_weights() | |
self.sampler = WeightedRandomSampler( | |
self.weights, num_samples=len(self._indexes), replacement=True | |
) | |
def __len__(self): | |
return len(self._indexes) * self.virtual_increase | |
def _load_landmarks(self, filename, original_size, target_size, indexes): | |
landmarks = np.load(filename, allow_pickle=True)[indexes, :] | |
if self.what_mask == "full": | |
mask = create_masks_from_landmarks_full_size( | |
landmarks, | |
original_size[0], | |
original_size[1], | |
offset=self.expand_box, | |
nose_index=self.nose_index, | |
) | |
elif self.what_mask == "box": | |
mask = create_masks_from_landmarks_box( | |
landmarks, | |
(original_size[0], original_size[1]), | |
box_expand=self.expand_box, | |
nose_index=self.nose_index, | |
) | |
elif self.what_mask == "heart": | |
mask = face_mask_cheeks_batch( | |
original_size, landmarks, box_expand=0.0, show_nose=True | |
) | |
elif self.what_mask == "mouth": | |
mask = create_masks_from_landmarks_mouth( | |
landmarks, | |
(original_size[0], original_size[1]), | |
box_expand=0.01, | |
nose_index=self.nose_index, | |
) | |
else: | |
mask = create_face_mask_from_landmarks( | |
landmarks, original_size[0], original_size[1], mask_expand=0.05 | |
) | |
# Interpolate the mask to the target size | |
mask = F.interpolate( | |
mask.unsqueeze(1).float(), size=target_size, mode="nearest" | |
) | |
return mask, landmarks | |
def get_emotions(self, video_file, video_indexes): | |
emotions_path = video_file.replace( | |
self.video_folder, self.emotions_folder | |
).replace(self.video_ext, ".pt") | |
emotions = torch.load(emotions_path) | |
return ( | |
emotions["valence"][video_indexes], | |
emotions["arousal"][video_indexes], | |
emotions["labels"][video_indexes], | |
) | |
def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0): | |
if select_randomly: | |
# Randomly select self.num_frames indices from the available range | |
available_indices = list(range(start_idx, total_video_frames)) | |
if len(available_indices) < self.num_frames: | |
raise ValueError( | |
"Not enough frames in the video to sample with given parameters." | |
) | |
indexes = random.sample(available_indices, self.num_frames) | |
return sorted(indexes) # Sort to maintain temporal order | |
else: | |
# Calculate the maximum possible start index | |
max_start_idx = total_video_frames - ( | |
(self.num_frames - 1) * (self.skip_frames + 1) + 1 | |
) | |
# Generate a random start index | |
if max_start_idx > 0: | |
start_idx = np.random.randint(start_idx, max_start_idx) | |
else: | |
raise ValueError( | |
"Not enough frames in the video to sample with given parameters." | |
) | |
# Generate the indices | |
indexes = [ | |
start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames) | |
] | |
return indexes | |
def _load_audio(self, filename, max_len_sec, start=None, indexes=None): | |
audio, sr = sf.read( | |
filename, | |
start=math.ceil(start * self.audio_rate), | |
frames=math.ceil(self.audio_rate * max_len_sec), | |
always_2d=True, | |
) # e.g (16000, 1) | |
audio = audio.T # (1, 16000) | |
assert sr == self.audio_rate, ( | |
f"Audio rate is {sr} but should be {self.audio_rate}" | |
) | |
audio = audio.mean(0, keepdims=True) | |
audio = self.maybe_augment_audio(audio) | |
audio = torch.from_numpy(audio).float() | |
# audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate) | |
audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec) | |
return audio[0] | |
def ensure_shape(self, tensors): | |
target_length = self.samples_per_frame | |
processed_tensors = [] | |
for tensor in tensors: | |
current_length = tensor.shape[1] | |
diff = current_length - target_length | |
assert abs(diff) <= 5, ( | |
f"Expected shape {target_length}, but got {current_length}" | |
) | |
if diff < 0: | |
# Calculate how much padding is needed | |
padding_needed = target_length - current_length | |
# Pad the tensor | |
padded_tensor = F.pad(tensor, (0, padding_needed)) | |
processed_tensors.append(padded_tensor) | |
elif diff > 0: | |
# Trim the tensor | |
trimmed_tensor = tensor[:, :target_length] | |
processed_tensors.append(trimmed_tensor) | |
else: | |
# If it's already the correct size | |
processed_tensors.append(tensor) | |
return torch.cat(processed_tensors) | |
def normalize_latents(self, latents): | |
if self.data_mean is not None: | |
# Normalize latents to 0 mean and 0.5 std | |
latents = ((latents - self.data_mean) / self.data_std) * 0.5 | |
return latents | |
def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60): | |
ratio = fps_to / fps_from | |
indexes_60fps = [int(index * ratio) for index in indexes_25fps] | |
return indexes_60fps | |
def _get_frames_and_audio(self, idx): | |
if self.load_all_possible_indexes: | |
indexes, video_file, audio_file, land_file = self._indexes[idx] | |
if self.audio_in_video: | |
vr = decord.AVReader(video_file, sample_rate=self.audio_rate) | |
else: | |
vr = decord.VideoReader(video_file) | |
len_video = len(vr) | |
if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
len_video *= 25 / 60 | |
len_video = int(len_video) | |
else: | |
video_file, audio_file, land_file = self._indexes[idx] | |
if self.audio_in_video: | |
vr = decord.AVReader(video_file, sample_rate=self.audio_rate) | |
else: | |
vr = decord.VideoReader(video_file) | |
len_video = len(vr) | |
if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
len_video *= 25 / 60 | |
len_video = int(len_video) | |
indexes = self.get_frame_indices( | |
len_video, | |
select_randomly=self.select_randomly, | |
start_idx=120 if "1000actors_nsv" in video_file else 0, | |
) | |
if self.get_separate_id: | |
id_idx = np.random.randint(0, len_video) | |
indexes.insert(0, id_idx) | |
if "AA_processed" in video_file or "1000actors_nsv" in video_file: | |
video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60) | |
audio_file = audio_file.replace("_output_output", "") | |
if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file: | |
audio_path_extra = ".safetensors" | |
else: | |
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" | |
video_path_extra = f"_{self.latent_type}_512_latent.safetensors" | |
audio_path_extra_extra = ( | |
".pt" if "AA_processed" in video_file else "_beats_emb.pt" | |
) | |
else: | |
video_indexes = indexes | |
audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" | |
video_path_extra = f"_{self.latent_type}_512_latent.safetensors" | |
audio_path_extra_extra = "_beats_emb.pt" | |
emotions = None | |
if self.use_emotions: | |
emotions = self.get_emotions(video_file, video_indexes) | |
if self.get_separate_id: | |
emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:]) | |
raw_audio = None | |
if self.audio_in_video: | |
raw_audio, frames_video = vr.get_batch(video_indexes) | |
raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)") | |
if self.use_latent and self.precomputed_latent: | |
latent_file = video_file.replace(self.video_ext, video_path_extra).replace( | |
self.video_folder, self.latent_folder | |
) | |
frames = load_file(latent_file)["latents"][video_indexes, :, :, :] | |
if frames.shape[-1] != 64: | |
print(f"Frames shape: {frames.shape}, video file: {video_file}") | |
frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale | |
frames = self.normalize_latents(frames) | |
else: | |
if self.audio_in_video: | |
frames = frames_video.permute(3, 0, 1, 2).float() | |
else: | |
frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() | |
if raw_audio is None: | |
# Audio is not in video | |
raw_audio = self._load_audio( | |
audio_file, | |
max_len_sec=frames.shape[1] / self.fps, | |
start=indexes[0] / self.fps, | |
# indexes=indexes, | |
) | |
if not self.from_audio_embedding: | |
audio = raw_audio | |
audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame) | |
else: | |
audio = load_file( | |
audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[ | |
0 | |
] | |
+ audio_path_extra | |
)["audio"] | |
audio_frames = audio[indexes, :] | |
if self.add_extra_audio_emb: | |
audio_extra = torch.load( | |
audio_file.replace(self.audio_folder, self.audio_emb_folder).split( | |
"." | |
)[0] | |
+ audio_path_extra_extra | |
) | |
audio_extra = audio_extra[indexes, :] | |
audio_frames = torch.cat([audio_frames, audio_extra], dim=-1) | |
audio_frames = ( | |
audio_frames[1:] if self.need_cond else audio_frames | |
) # Remove audio of first frame | |
if self.get_original_frames: | |
original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() | |
original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1) | |
original_frames = ( | |
original_frames[:, 1:] if self.need_cond else original_frames | |
) | |
else: | |
original_frames = None | |
if not self.use_latent or (self.use_latent and not self.precomputed_latent): | |
frames = self.scale_and_crop((frames / 255.0) * 2 - 1) | |
target = frames[:, 1:] if self.need_cond else frames | |
if self.mode == "prediction": | |
if self.use_latent: | |
if self.audio_in_video: | |
clean_cond = ( | |
frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float() | |
) | |
else: | |
clean_cond = ( | |
vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float() | |
) | |
original_size = clean_cond.shape[-2:] | |
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze( | |
0 | |
) | |
if self.latent_condition: | |
noisy_cond = frames[:, 0] | |
else: | |
noisy_cond = clean_cond | |
else: | |
clean_cond = frames[:, 0] | |
noisy_cond = clean_cond | |
elif self.mode == "interpolation": | |
if self.use_latent: | |
if self.audio_in_video: | |
clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float() | |
else: | |
clean_cond = ( | |
vr.get_batch([video_indexes[0], video_indexes[-1]]) | |
.permute(3, 0, 1, 2) | |
.float() | |
) | |
original_size = clean_cond.shape[-2:] | |
clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1) | |
if self.latent_condition: | |
noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) | |
else: | |
noisy_cond = clean_cond | |
else: | |
clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) | |
noisy_cond = clean_cond | |
# Add noise to conditional frame | |
if self.cond_noise and isinstance(self.cond_noise, ListConfig): | |
cond_noise = ( | |
self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,)) | |
).exp() | |
noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond) | |
else: | |
noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond) | |
cond_noise = self.cond_noise | |
if self.get_masks: | |
target_size = ( | |
(self.resize_size, self.resize_size) | |
if not self.use_latent | |
else (self.resize_size // 8, self.resize_size // 8) | |
) | |
masks, landmarks = self._load_landmarks( | |
land_file, original_size, target_size, video_indexes | |
) | |
landmarks = None | |
masks = ( | |
masks.permute(1, 0, 2, 3)[:, 1:] | |
if self.need_cond | |
else masks.permute(1, 0, 2, 3) | |
) | |
else: | |
masks = None | |
landmarks = None | |
return ( | |
original_frames, | |
clean_cond, | |
noisy_cond, | |
target, | |
audio_frames, | |
raw_audio, | |
cond_noise, | |
emotions, | |
masks, | |
landmarks, | |
) | |
def filter_by_length(self, video_filelist, audio_filelist): | |
def with_opencv(filename): | |
video = cv2.VideoCapture(filename) | |
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) | |
return int(frame_count) | |
filtered_video = [] | |
filtered_audio = [] | |
min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1 | |
for vid_file, audio_file in tqdm( | |
zip(video_filelist, audio_filelist), | |
total=len(video_filelist), | |
desc="Filtering", | |
): | |
# vr = decord.VideoReader(vid_file) | |
len_video = with_opencv(vid_file) | |
# Short videos | |
if len_video < min_length: | |
continue | |
filtered_video.append(vid_file) | |
filtered_audio.append(audio_file) | |
print(f"New number of files: {len(filtered_video)}") | |
return filtered_video, filtered_audio | |
def _get_indexes(self, video_filelist, audio_filelist): | |
indexes = [] | |
self.og_shape = None | |
for vid_file, audio_file in zip(video_filelist, audio_filelist): | |
vr = decord.VideoReader(vid_file) | |
if self.og_shape is None: | |
self.og_shape = vr[0].shape[-2] | |
len_video = len(vr) | |
# Short videos | |
if len_video < self.num_frames: | |
continue | |
else: | |
possible_indexes = list( | |
sliding_window(range(len_video), self.num_frames) | |
)[:: self.step] | |
possible_indexes = list( | |
map(lambda x: (x, vid_file, audio_file), possible_indexes) | |
) | |
indexes.extend(possible_indexes) | |
print("Indexes", len(indexes), "\n") | |
return indexes | |
def scale_and_crop(self, video): | |
h, w = video.shape[-2], video.shape[-1] | |
# scale shorter side to resolution | |
if self.resize_size is not None: | |
scale = self.resize_size / min(h, w) | |
if h < w: | |
target_size = (self.resize_size, math.ceil(w * scale)) | |
else: | |
target_size = (math.ceil(h * scale), self.resize_size) | |
video = F.interpolate( | |
video, | |
size=target_size, | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
# center crop | |
h, w = video.shape[-2], video.shape[-1] | |
w_start = (w - self.resize_size) // 2 | |
h_start = (h - self.resize_size) // 2 | |
video = video[ | |
:, | |
:, | |
h_start : h_start + self.resize_size, | |
w_start : w_start + self.resize_size, | |
] | |
return self.maybe_augment(video) | |
def _calculate_weights(self): | |
aa_processed_count = sum( | |
1 | |
for item in self._indexes | |
if "AA_processed" in (item[1] if len(item) == 3 else item[0]) | |
) | |
nsv_processed_count = sum( | |
1 | |
for item in self._indexes | |
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) | |
) | |
other_count = len(self._indexes) - aa_processed_count - nsv_processed_count | |
aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0 | |
nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0 | |
other_weight = 1 / other_count if other_count > 0 else 0 | |
print( | |
f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}" | |
) | |
print(f"AA processed weight: {aa_processed_weight}") | |
print(f"NSV processed weight: {nsv_processed_weight}") | |
print(f"Other weight: {other_weight}") | |
weights = [ | |
aa_processed_weight | |
if "AA_processed" in (item[1] if len(item) == 3 else item[0]) | |
else nsv_processed_weight | |
if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) | |
else other_weight | |
for item in self._indexes | |
] | |
return weights | |
def __getitem__(self, idx): | |
if self.balance_datasets: | |
idx = self.sampler.__iter__().__next__() | |
try: | |
( | |
original_frames, | |
clean_cond, | |
noisy_cond, | |
target, | |
audio, | |
raw_audio, | |
cond_noise, | |
emotions, | |
masks, | |
landmarks, | |
) = self._get_frames_and_audio(idx % len(self._indexes)) | |
except Exception as e: | |
print(f"Error with index {idx}: {e}") | |
return self.__getitem__(np.random.randint(0, len(self))) | |
out_data = {} | |
if original_frames is not None: | |
out_data["original_frames"] = original_frames | |
if audio is not None: | |
out_data["audio_emb"] = audio | |
out_data["raw_audio"] = raw_audio | |
if self.use_emotions: | |
out_data["valence"] = emotions[0] | |
out_data["arousal"] = emotions[1] | |
out_data["emo_labels"] = emotions[2] | |
if self.use_latent: | |
input_key = "latents" | |
else: | |
input_key = "frames" | |
out_data[input_key] = target | |
if noisy_cond is not None: | |
out_data["cond_frames"] = noisy_cond | |
out_data["cond_frames_without_noise"] = clean_cond | |
if cond_noise is not None: | |
out_data["cond_aug"] = cond_noise | |
if masks is not None: | |
out_data["masks"] = masks | |
out_data["gt"] = target | |
if landmarks is not None: | |
out_data["landmarks"] = landmarks | |
out_data["motion_bucket_id"] = torch.tensor([self.motion_id]) | |
out_data["fps_id"] = torch.tensor([self.fps - 1]) | |
out_data["num_video_frames"] = self.num_frames | |
out_data["image_only_indicator"] = torch.zeros(self.num_frames) | |
return out_data | |
if __name__ == "__main__": | |
import torchvision.transforms as transforms | |
import cv2 | |
transform = transforms.Compose(transforms=[transforms.Resize((256, 256))]) | |
dataset = VideoDataset( | |
"/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt", | |
transform=transform, | |
num_frames=25, | |
) | |
print(len(dataset)) | |
idx = np.random.randint(0, len(dataset)) | |
for i in range(10): | |
print(dataset[i][0].shape, dataset[i][1].shape) | |
image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1]) | |
for i in range(25): | |
image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255 | |
cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1]) | |