|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path as osp |
|
import numpy as np |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
|
|
|
|
class NormalDataset(): |
|
def __init__(self, cfg, split='train'): |
|
|
|
self.split = split |
|
self.root = cfg.root |
|
self.overfit = cfg.overfit |
|
|
|
self.opt = cfg.dataset |
|
self.datasets = self.opt.types |
|
self.input_size = self.opt.input_size |
|
self.set_splits = self.opt.set_splits |
|
self.scales = self.opt.scales |
|
self.pifu = self.opt.pifu |
|
|
|
|
|
self.in_nml = [item[0] for item in cfg.net.in_nml] |
|
self.in_nml_dim = [item[1] for item in cfg.net.in_nml] |
|
self.in_total = self.in_nml + ['normal_F', 'normal_B'] |
|
self.in_total_dim = self.in_nml_dim + [3, 3] |
|
|
|
if self.split != 'train': |
|
self.rotations = range(0, 360, 120) |
|
else: |
|
self.rotations = np.arange(0, 360, 360 / |
|
self.opt.rotation_num).astype(np.int) |
|
|
|
self.datasets_dict = {} |
|
for dataset_id, dataset in enumerate(self.datasets): |
|
dataset_dir = osp.join(self.root, dataset, "smplx") |
|
self.datasets_dict[dataset] = { |
|
"subjects": |
|
np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str), |
|
"path": |
|
dataset_dir, |
|
"scale": |
|
self.scales[dataset_id] |
|
} |
|
|
|
self.subject_list = self.get_subject_list(split) |
|
|
|
|
|
self.image_to_tensor = transforms.Compose([ |
|
transforms.Resize(self.input_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
|
|
self.mask_to_tensor = transforms.Compose([ |
|
transforms.Resize(self.input_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.0, ), (1.0, )) |
|
]) |
|
|
|
def get_subject_list(self, split): |
|
|
|
subject_list = [] |
|
|
|
for dataset in self.datasets: |
|
|
|
if self.pifu: |
|
txt = osp.join(self.root, dataset, f'{split}_pifu.txt') |
|
else: |
|
txt = osp.join(self.root, dataset, f'{split}.txt') |
|
|
|
if osp.exists(txt): |
|
print(f"load from {txt}") |
|
subject_list += sorted(np.loadtxt(txt, dtype=str).tolist()) |
|
|
|
if self.pifu: |
|
miss_pifu = sorted( |
|
np.loadtxt(osp.join(self.root, dataset, |
|
"miss_pifu.txt"), |
|
dtype=str).tolist()) |
|
subject_list = [ |
|
subject for subject in subject_list |
|
if subject not in miss_pifu |
|
] |
|
subject_list = [ |
|
"renderpeople/" + subject for subject in subject_list |
|
] |
|
|
|
else: |
|
train_txt = osp.join(self.root, dataset, 'train.txt') |
|
val_txt = osp.join(self.root, dataset, 'val.txt') |
|
test_txt = osp.join(self.root, dataset, 'test.txt') |
|
|
|
print( |
|
f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n" |
|
) |
|
|
|
split_txt = osp.join(self.root, dataset, f'{split}.txt') |
|
|
|
subjects = self.datasets_dict[dataset]['subjects'] |
|
train_split = int(len(subjects) * self.set_splits[0]) |
|
val_split = int( |
|
len(subjects) * self.set_splits[1]) + train_split |
|
|
|
with open(train_txt, "w") as f: |
|
f.write("\n".join(dataset + "/" + item |
|
for item in subjects[:train_split])) |
|
with open(val_txt, "w") as f: |
|
f.write("\n".join( |
|
dataset + "/" + item |
|
for item in subjects[train_split:val_split])) |
|
with open(test_txt, "w") as f: |
|
f.write("\n".join(dataset + "/" + item |
|
for item in subjects[val_split:])) |
|
|
|
subject_list += sorted( |
|
np.loadtxt(split_txt, dtype=str).tolist()) |
|
|
|
bug_list = sorted( |
|
np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist()) |
|
|
|
subject_list = [ |
|
subject for subject in subject_list if (subject not in bug_list) |
|
] |
|
|
|
return subject_list |
|
|
|
def __len__(self): |
|
return len(self.subject_list) * len(self.rotations) |
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
if self.overfit: |
|
index = 0 |
|
|
|
rid = index % len(self.rotations) |
|
mid = index // len(self.rotations) |
|
|
|
rotation = self.rotations[rid] |
|
|
|
|
|
subject = self.subject_list[mid] |
|
|
|
subject_render = "/".join( |
|
[subject.split("/")[0] + "_12views", |
|
subject.split("/")[1]]) |
|
|
|
|
|
data_dict = { |
|
'dataset': |
|
subject.split("/")[0], |
|
'subject': |
|
subject, |
|
'rotation': |
|
rotation, |
|
'image_path': |
|
osp.join(self.root, subject_render, 'render', |
|
f'{rotation:03d}.png') |
|
} |
|
|
|
|
|
for name, channel in zip(self.in_total, self.in_total_dim): |
|
|
|
if name != 'image': |
|
data_dict.update({ |
|
f'{name}_path': |
|
osp.join(self.root, subject_render, name, |
|
f'{rotation:03d}.png') |
|
}) |
|
data_dict.update({ |
|
name: |
|
self.imagepath2tensor(data_dict[f'{name}_path'], |
|
channel, |
|
inv='depth_B' in name) |
|
}) |
|
|
|
path_keys = [ |
|
key for key in data_dict.keys() if '_path' in key or '_dir' in key |
|
] |
|
for key in path_keys: |
|
del data_dict[key] |
|
|
|
return data_dict |
|
|
|
def imagepath2tensor(self, path, channel=3, inv=False): |
|
|
|
rgba = Image.open(path).convert('RGBA') |
|
mask = rgba.split()[-1] |
|
image = rgba.convert('RGB') |
|
image = self.image_to_tensor(image) |
|
mask = self.mask_to_tensor(mask) |
|
image = (image * mask)[:channel] |
|
|
|
return (image * (0.5 - inv) * 2.0).float() |
|
|