Spaces:
Build error
Build error
import glob | |
import json | |
import os | |
import pickle | |
import random | |
import re | |
import subprocess | |
from functools import partial | |
import librosa.core | |
import numpy as np | |
import torch | |
import torch.distributions | |
import torch.distributed as dist | |
import torch.optim | |
import torch.utils.data | |
from utils.commons.indexed_datasets import IndexedDataset | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn.functional as F | |
import pandas as pd | |
import tqdm | |
import csv | |
from utils.commons.hparams import hparams, set_hparams | |
from utils.commons.meters import Timer | |
from data_util.face3d_helper import Face3DHelper | |
from utils.audio import librosa_wav2mfcc | |
from utils.commons.dataset_utils import collate_xd | |
from utils.commons.tensor_utils import convert_to_tensor | |
face3d_helper = None | |
def erosion_1d(arr): | |
result = arr.copy() | |
start_index = None | |
continuous_length = 0 | |
for i, num in enumerate(arr): | |
if num == 1: | |
if continuous_length == 0: | |
start_index = i | |
continuous_length += 1 | |
else: | |
if continuous_length > 0: | |
# Replace middle 1s with 0s, keep first and last 1 | |
for j in range(start_index, start_index + continuous_length): | |
result[j] = 0 | |
result[start_index + continuous_length // 2] = 1 | |
continuous_length = 0 | |
if continuous_length > 0: | |
# Replace middle 1s with 0s, keep first and last 1 | |
for j in range(start_index, start_index + continuous_length): | |
result[j] = 0 | |
# result[start_index + continuous_length // 2] = 1 | |
return result | |
def get_mouth_amp(ldm): | |
""" | |
ldm: [T, 68/468, 3] | |
""" | |
is_mediapipe = ldm.shape[1] != 68 | |
is_torch = isinstance(ldm, torch.Tensor) | |
if not is_torch: | |
ldm = torch.FloatTensor(ldm) | |
if is_mediapipe: | |
assert ldm.shape[1] in [468, 478] | |
mouth_d = (ldm[:, 0] - ldm[:, 17]).abs().sum(-1) | |
else: | |
mouth_d = (ldm[:, 51] - ldm[:, 57]).abs().sum(-1) | |
mouth_amp = torch.quantile(mouth_d, 0.9, dim=0) | |
return mouth_amp | |
def get_eye_amp(ldm): | |
""" | |
ldm: [T, 68/468, 3] | |
""" | |
is_mediapipe = ldm.shape[1] != 68 | |
is_torch = isinstance(ldm, torch.Tensor) | |
if not is_torch: | |
ldm = torch.FloatTensor(ldm) | |
if is_mediapipe: | |
assert ldm.shape[1] in [468, 478] | |
eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) | |
else: | |
eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) | |
eye_amp = torch.quantile(eye_d, 0.9, dim=0) | |
return eye_amp | |
def get_blink(ldm): | |
""" | |
ldm: [T, 68/468, 3] | |
""" | |
is_mediapipe = ldm.shape[1] != 68 | |
is_torch = isinstance(ldm, torch.Tensor) | |
if not is_torch: | |
ldm = torch.FloatTensor(ldm) | |
if is_mediapipe: | |
assert ldm.shape[1] in [468, 478] | |
eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) | |
else: | |
eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) | |
eye_d_qtl = torch.quantile(eye_d, 0.75, dim=0) | |
blink = eye_d / eye_d_qtl | |
blink = (blink < 0.85).long().numpy() | |
blink = erosion_1d(blink) | |
if is_torch: | |
blink = torch.LongTensor(blink) | |
return blink | |
class Audio2Motion_Dataset(Dataset): | |
def __init__(self, prefix='train', data_dir=None): | |
self.hparams = hparams | |
self.db_key = prefix | |
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir | |
self.ds = None | |
self.sizes = None | |
self.x_maxframes = 200 # 50 video frames | |
self.x_multiply = 8 | |
self.hparams = hparams | |
def __len__(self): | |
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return len(ds) | |
def _get_item(self, index): | |
""" | |
This func is necessary to open files in multi-threads! | |
""" | |
if self.ds is None: | |
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return self.ds[index] | |
def __getitem__(self, idx): | |
raw_item = self._get_item(idx) | |
if raw_item is None: | |
print("loading from binary data failed!") | |
return None | |
item = { | |
'idx': idx, | |
'item_id': raw_item['img_dir'], | |
'id': torch.from_numpy(raw_item['id']).float(), # [T_x, c=80] | |
'exp': torch.from_numpy(raw_item['exp']).float(), # [T_x, c=80] | |
} | |
if item['id'].shape[0] == 1: # global_id | |
item['id'] = item['id'].repeat([item['exp'].shape[0], 1]) | |
item['hubert'] = torch.from_numpy(raw_item['hubert']).float() # [T_x, 1024] | |
item['f0'] = torch.from_numpy(raw_item['f0']).float() # [T_x,] | |
global face3d_helper | |
if face3d_helper is None: | |
face3d_helper = Face3DHelper(use_gpu=False) | |
cano_lm3d = face3d_helper.reconstruct_cano_lm3d(item['id'], item['exp']) | |
item['blink_unit'] = get_blink(cano_lm3d) | |
item['eye_amp'] = get_eye_amp(cano_lm3d) | |
item['mouth_amp'] = get_mouth_amp(cano_lm3d) | |
x_len = len(item['hubert']) | |
x_len = x_len // self.x_multiply * self.x_multiply # make it divisible by our CNN | |
y_len = x_len // 2 # video is 25fps | |
item['hubert'] = item['hubert'][:x_len] # [T_x, c=80] | |
item['f0'] = item['f0'][:x_len] | |
item['id'] = item['id'][:y_len] | |
item['exp'] = item['exp'][:y_len] | |
item['euler'] = convert_to_tensor(raw_item['euler'][:y_len]) | |
item['trans'] = convert_to_tensor(raw_item['trans'][:y_len]) | |
item['blink_unit'] = item['blink_unit'][:y_len].reshape([-1,1]) | |
item['eye_amp'] = item['eye_amp'].reshape([1,]) | |
item['mouth_amp'] = item['mouth_amp'].reshape([1,]) | |
return item | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
sizes_fname = os.path.join(self.ds_path, f"sizes_{self.db_key}.npy") | |
if os.path.exists(sizes_fname): | |
sizes = np.load(sizes_fname, allow_pickle=True) | |
self.sizes = sizes | |
if self.sizes is None: | |
self.sizes = [] | |
print("Counting the size of each item in dataset...") | |
ds = IndexedDataset(f"{self.ds_path}/{self.db_key}") | |
for i_sample in tqdm.trange(len(ds)): | |
sample = ds[i_sample] | |
if sample is None: | |
size = 0 | |
else: | |
x = sample['mel'] | |
size = x.shape[-1] # time step in audio | |
self.sizes.append(size) | |
np.save(sizes_fname, self.sizes) | |
indices = np.arange(len(self)) | |
indices = indices[np.argsort(np.array(self.sizes)[indices], kind='mergesort')] | |
return indices | |
def batch_by_size(self, indices, max_tokens=None, max_sentences=None, | |
required_batch_size_multiple=1): | |
""" | |
Yield mini-batches of indices bucketed by size. Batches may contain | |
sequences of different lengths. | |
Args: | |
indices (List[int]): ordered list of dataset indices | |
num_tokens_fn (callable): function that returns the number of tokens at | |
a given index | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
""" | |
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
if len(batch) == 0: | |
return 0 | |
if len(batch) == max_sentences: | |
return 1 | |
if num_tokens > max_tokens: | |
return 1 | |
return 0 | |
num_tokens_fn = lambda x: self.sizes[x] | |
max_tokens = max_tokens if max_tokens is not None else 60000 | |
max_sentences = max_sentences if max_sentences is not None else 512 | |
bsz_mult = required_batch_size_multiple | |
sample_len = 0 | |
sample_lens = [] | |
batch = [] | |
batches = [] | |
for i in range(len(indices)): | |
idx = indices[i] | |
num_tokens = num_tokens_fn(idx) | |
sample_lens.append(num_tokens) | |
sample_len = max(sample_len, num_tokens) | |
assert sample_len <= max_tokens, ( | |
"sentence at index {} of size {} exceeds max_tokens " | |
"limit of {}!".format(idx, sample_len, max_tokens) | |
) | |
num_tokens = (len(batch) + 1) * sample_len | |
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
mod_len = max( | |
bsz_mult * (len(batch) // bsz_mult), | |
len(batch) % bsz_mult, | |
) | |
batches.append(batch[:mod_len]) | |
batch = batch[mod_len:] | |
sample_lens = sample_lens[mod_len:] | |
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 | |
batch.append(idx) | |
if len(batch) > 0: | |
batches.append(batch) | |
return batches | |
def get_dataloader(self, batch_size=1, num_workers=0): | |
batches_idx = self.batch_by_size(self.ordered_indices(), max_tokens=hparams['max_tokens_per_batch'], max_sentences=hparams['max_sentences_per_batch']) | |
batches_idx = batches_idx * 50 | |
random.shuffle(batches_idx) | |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) | |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=num_workers) | |
return loader | |
def collater(self, samples): | |
hparams = self.hparams | |
if len(samples) == 0: | |
return {} | |
batch = {} | |
item_names = [s['item_id'] for s in samples] | |
x_len = max(s['hubert'].size(0) for s in samples) | |
assert x_len % self.x_multiply == 0 | |
y_len = x_len // 2 | |
batch['hubert'] = collate_xd([s["hubert"] for s in samples], max_len=x_len, pad_idx=0) # [b, t_max_y, 64] | |
batch['x_mask'] = (batch['hubert'].abs().sum(dim=-1) > 0).float() # [b, t_max_x] | |
batch['f0'] = collate_xd([s["f0"].reshape([-1,1]) for s in samples], max_len=x_len, pad_idx=0).squeeze(-1) # [b, t_max_y] | |
batch.update({ | |
'item_id': item_names, | |
}) | |
batch['id'] = collate_xd([s["id"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] | |
batch['exp'] = collate_xd([s["exp"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] | |
batch['euler'] = collate_xd([s["euler"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] | |
batch['trans'] = collate_xd([s["trans"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] | |
batch['blink_unit'] = collate_xd([s["blink_unit"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] | |
batch['eye_amp'] = collate_xd([s["eye_amp"] for s in samples], max_len=1, pad_idx=0) # [b, t_max, 1] | |
batch['mouth_amp'] = collate_xd([s["mouth_amp"] for s in samples], max_len=1, pad_idx=0) # [b, t_max, 1] | |
batch['y_mask'] = (batch['id'].abs().sum(dim=-1) > 0).float() # [b, t_max_y] | |
return batch | |
if __name__ == '__main__': | |
os.environ["OMP_NUM_THREADS"] = "1" | |
set_hparams('egs/os_avatar/audio2secc_vae.yaml') | |
ds = Audio2Motion_Dataset("train", 'data/binary/th1kh') | |
dl = ds.get_dataloader() | |
for b in tqdm.tqdm(dl): | |
pass | |