Spaces:
Build error
Build error
import glob | |
import json | |
import os | |
import cv2 | |
import pickle | |
import random | |
import re | |
import subprocess | |
from functools import partial | |
import librosa.core | |
import numpy as np | |
import torch | |
import torch.distributions | |
import torch.distributed as dist | |
import torch.optim | |
import torch.utils.data | |
from utils.commons.indexed_datasets import IndexedDataset | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn.functional as F | |
import pandas as pd | |
from tqdm import tqdm | |
import csv | |
from utils.commons.hparams import hparams, set_hparams | |
from utils.commons.meters import Timer | |
from data_util.face3d_helper import Face3DHelper | |
from utils.audio import librosa_wav2mfcc | |
from utils.commons.dataset_utils import collate_xd | |
from utils.commons.tensor_utils import convert_to_tensor | |
from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image | |
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic | |
from utils.commons.image_utils import load_image_as_uint8_tensor | |
from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler | |
def sample_idx(img_dir, num_frames): | |
cnt = 0 | |
while True: | |
cnt += 1 | |
if cnt > 1000: | |
print(f"recycle for more than 1000 times, check this {img_dir}") | |
idx = random.randint(0, num_frames-1) | |
ret1 = find_img_name(img_dir, idx) | |
if ret1 == 'None': | |
continue | |
ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx) | |
if ret2 == 'None': | |
continue | |
ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx) | |
if ret3 == 'None': | |
continue | |
ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx) | |
if ret4 == 'None': | |
continue | |
return idx | |
def find_img_name(img_dir, idx): | |
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg") | |
if not os.path.exists(gt_img_fname): | |
gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg") | |
if not os.path.exists(gt_img_fname): | |
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg") | |
if not os.path.exists(gt_img_fname): | |
gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png") | |
if not os.path.exists(gt_img_fname): | |
gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png") | |
if not os.path.exists(gt_img_fname): | |
gt_img_fname = os.path.join(img_dir, str(idx) + ".png") | |
if os.path.exists(gt_img_fname): | |
return gt_img_fname | |
else: | |
return 'None' | |
def get_win_from_arr(arr, index, win_size): | |
left = index - win_size//2 | |
right = index + (win_size - win_size//2) | |
pad_left = 0 | |
pad_right = 0 | |
if left < 0: | |
pad_left = -left | |
left = 0 | |
if right > arr.shape[0]: | |
pad_right = right - arr.shape[0] | |
right = arr.shape[0] | |
win = arr[left:right] | |
if pad_left > 0: | |
if isinstance(arr, np.ndarray): | |
win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0) | |
else: | |
win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0) | |
if pad_right > 0: | |
if isinstance(arr, np.ndarray): | |
win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) # [8, 16] | |
else: | |
win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) # [8, 16] | |
return win | |
class Img2Plane_Dataset(Dataset): | |
def __init__(self, prefix='train', data_dir=None): | |
self.db_key = prefix | |
self.ds = None | |
self.sizes = None | |
self.x_maxframes = 200 # 50 video frames | |
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') | |
self.x_multiply = 8 | |
self.hparams = hparams | |
self.pose_sampler = UnifiedCameraPoseSampler() | |
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir | |
def __len__(self): | |
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return len(ds) | |
def _get_item(self, index): | |
""" | |
This func is necessary to open files in multi-threads! | |
""" | |
if self.ds is None: | |
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return self.ds[index] | |
def __getitem__(self, idx): | |
raw_item = self._get_item(idx) | |
if raw_item is None: | |
print("loading from binary data failed!") | |
return None | |
item = { | |
'idx': idx, | |
'item_name': raw_item['img_dir'], | |
} | |
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') | |
num_frames = len(raw_item['exp']) | |
hparams = self.hparams | |
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) | |
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] | |
raw_item['c2w'] = c2w | |
raw_item['intrinsics'] = intrinsics | |
max_pitch = 10 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
min_pitch = -max_pitch | |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
max_yaw = 16 / 180 * 3.1415926 | |
min_yaw = - max_yaw | |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
distance = random.random() * (3.2-2.7) + 2.7 # [2.7, 4.0] | |
ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : | |
max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
min_pitch = -max_pitch | |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
max_yaw = 38 / 180 * 3.1415926 | |
min_yaw = - max_yaw | |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0] | |
real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
else: | |
real_idx = sample_idx(img_dir, num_frames) | |
real_c2w = raw_item['c2w'][real_idx] | |
real_intrinsics = raw_item['intrinsics'][real_idx] | |
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) | |
real_camera = convert_to_tensor(real_camera) | |
if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 : | |
max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref | |
min_pitch = -max_pitch | |
pitch = random.random() * (max_pitch - min_pitch) + min_pitch | |
max_yaw = 38 / 180 * 3.1415926 | |
min_yaw = - max_yaw | |
yaw = random.random() * (max_yaw - min_yaw) + min_yaw | |
distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0] | |
fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0] | |
else: | |
fake_idx = sample_idx(img_dir, num_frames) | |
fake_c2w = raw_item['c2w'][fake_idx] | |
fake_intrinsics = raw_item['intrinsics'][fake_idx] | |
fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0) | |
fake_camera = convert_to_tensor(fake_camera) | |
item.update({ | |
'ws_camera': ws_camera, | |
'real_camera': real_camera, | |
'fake_camera': fake_camera, | |
# id,exp,euler,trans, used to generate the secc map | |
}) | |
return item | |
def get_dataloader(self, batch_size=1, num_workers=0): | |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) | |
return loader | |
def collater(self, samples): | |
hparams = self.hparams | |
if len(samples) == 0: | |
return {} | |
batch = {} | |
batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) # [B, 204] | |
batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204] | |
batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204] | |
return batch | |
class Motion2Video_Dataset(Dataset): | |
def __init__(self, prefix='train', data_dir=None): | |
self.db_key = prefix | |
self.ds = None | |
self.sizes = None | |
self.x_maxframes = 200 # 50 video frames | |
self.face3d_helper = Face3DHelper('deep_3drecon/BFM') | |
self.x_multiply = 8 | |
self.hparams = hparams | |
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir | |
def __len__(self): | |
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return len(ds) | |
def _get_item(self, index): | |
""" | |
This func is necessary to open files in multi-threads! | |
""" | |
if self.ds is None: | |
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') | |
return self.ds[index] | |
def __getitem__(self, idx): | |
raw_item = self._get_item(idx) | |
if raw_item is None: | |
print("loading from binary data failed!") | |
return None | |
item = { | |
'idx': idx, | |
'item_name': raw_item['img_dir'], | |
} | |
camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()}) | |
c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics'] | |
raw_item['c2w'] = c2w | |
raw_item['intrinsics'] = intrinsics | |
img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/') | |
num_frames = len(raw_item['exp']) | |
# src | |
real_idx = sample_idx(img_dir, num_frames) | |
real_c2w = raw_item['c2w'][real_idx] | |
real_intrinsics = raw_item['intrinsics'][real_idx] | |
real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0) | |
real_camera = convert_to_tensor(real_camera) | |
item['real_camera'] = real_camera | |
gt_img_fname = find_img_name(img_dir, real_idx) | |
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png | |
item['real_gt_img'] = gt_img.float() / 127.5 - 1 | |
# for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
for key in ['head', 'com', 'inpaint_torso']: | |
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") | |
key_img_fname = find_img_name(key_img_dir, real_idx) | |
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png | |
item[f'real_{key}_img'] = key_img.float() / 127.5 - 1 | |
bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg' | |
bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] # ignore alpha channel when png | |
item[f'bg_img'] = bg_img.float() / 127.5 - 1 | |
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") | |
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] | |
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W] | |
item[f'real_segmap'] = segmap | |
item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0) | |
item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0) | |
item.update({ | |
# id,exp,euler,trans, used to generate the secc map | |
'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]), | |
# 'real_identity': convert_to_tensor(raw_item['id'][real_idx]).reshape([80,]), | |
'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]), | |
'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]), | |
'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]), | |
}) | |
pertube_idx_candidates = [idx for idx in [real_idx-1, real_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
# pertube_idx_candidates = [idx for idx in [real_idx-2, real_idx-1, real_idx+1, real_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
pertube_idx = random.choice(pertube_idx_candidates) | |
item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) | |
item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1'] | |
# tgt | |
fake_idx = sample_idx(img_dir, num_frames) | |
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) | |
while abs(fake_idx - real_idx) < min_offset: | |
fake_idx = sample_idx(img_dir, num_frames) | |
min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2)) | |
fake_c2w = raw_item['c2w'][fake_idx] | |
fake_intrinsics = raw_item['intrinsics'][fake_idx] | |
fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0) | |
fake_camera = convert_to_tensor(fake_camera) | |
item['fake_camera'] = fake_camera | |
gt_img_fname = find_img_name(img_dir, fake_idx) | |
gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png | |
item['fake_gt_img'] = gt_img.float() / 127.5 - 1 | |
seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png") | |
seg_img = cv2.imread(seg_img_name)[:,:, ::-1] | |
segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W] | |
item[f'fake_segmap'] = segmap | |
item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0) | |
item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0) | |
# for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
for key in ['head', 'com', 'inpaint_torso']: | |
key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/") | |
key_img_fname = find_img_name(key_img_dir, fake_idx) | |
key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png | |
item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1 | |
item.update({ | |
# id,exp,euler,trans, used to generate the secc map | |
f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]), | |
# f'fake_identity': convert_to_tensor(raw_item['id'][fake_idx]).reshape([80,]), | |
f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]), | |
f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]), | |
f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]), | |
}) | |
# pertube_idx_candidates = [idx for idx in [fake_idx-2, fake_idx-1, fake_idx+1, fake_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
pertube_idx_candidates = [idx for idx in [fake_idx-1, fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame | |
pertube_idx = random.choice(pertube_idx_candidates) | |
item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,]) | |
item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1'] | |
return item | |
def get_dataloader(self, batch_size=1, num_workers=0): | |
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) | |
return loader | |
def collater(self, samples): | |
hparams = self.hparams | |
if len(samples) == 0: | |
return {} | |
batch = {} | |
batch['th1kh_item_names'] = [s['item_name'] for s in samples] | |
batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) # [B,6,H,W] | |
batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) # [B,6,H,W] | |
batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) # [B,6,H,W] | |
# for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
for key in ['head', 'com', 'inpaint_torso']: | |
batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
# for key in ['head', 'torso', 'torso_with_bg', 'person']: | |
for key in ['head', 'com', 'inpaint_torso']: | |
batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W] | |
batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) # [B,6,H,W] | |
batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) # [B,6,H,W] | |
batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) # [B, 204] | |
batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) # [B, 204] | |
return batch | |
if __name__ == '__main__': | |
os.environ["OMP_NUM_THREADS"] = "1" | |
ds = Img2Plane_Dataset("train", 'data/binary/th1kh') | |
# ds = Motion2Video_Dataset("train", 'data/binary/th1kh') | |
dl = ds.get_dataloader() | |
for b in tqdm(dl): | |
pass | |