Spaces:
Configuration error
Configuration error
Upload src/dataset/dataset_face.py with huggingface_hub
Browse files- src/dataset/dataset_face.py +354 -0
src/dataset/dataset_face.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, csv, math, random, pdb
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from PIL import Image
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
from transformers import CLIPImageProcessor
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
|
15 |
+
from src.utils.draw_util import FaceMeshVisualizer
|
16 |
+
|
17 |
+
def zero_rank_print(s):
|
18 |
+
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class FaceDatasetValid(Dataset):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
json_path,
|
26 |
+
extra_json_path=None,
|
27 |
+
sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
|
28 |
+
is_image=False,
|
29 |
+
sample_stride_aug=False
|
30 |
+
):
|
31 |
+
zero_rank_print(f"loading annotations from {json_path} ...")
|
32 |
+
self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
|
33 |
+
|
34 |
+
self.length = len(self.data_dic_name_list)
|
35 |
+
zero_rank_print(f"data scale: {self.length}")
|
36 |
+
|
37 |
+
self.sample_stride = sample_stride
|
38 |
+
self.sample_n_frames = sample_n_frames
|
39 |
+
|
40 |
+
self.sample_stride_aug = sample_stride_aug
|
41 |
+
|
42 |
+
self.sample_size = sample_size
|
43 |
+
self.resize = transforms.Resize((sample_size[0], sample_size[1]))
|
44 |
+
|
45 |
+
|
46 |
+
self.pixel_transforms = transforms.Compose([
|
47 |
+
transforms.Resize([sample_size[1], sample_size[0]]),
|
48 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
49 |
+
])
|
50 |
+
|
51 |
+
self.visualizer = FaceMeshVisualizer(forehead_edge=False)
|
52 |
+
self.clip_image_processor = CLIPImageProcessor()
|
53 |
+
self.is_image = is_image
|
54 |
+
|
55 |
+
def get_data(self, json_name, extra_json_name, augment_num=1):
|
56 |
+
zero_rank_print(f"start loading data: {json_name}")
|
57 |
+
with open(json_name,'r') as f:
|
58 |
+
data_dic = json.load(f)
|
59 |
+
|
60 |
+
data_dic_name_list = []
|
61 |
+
for augment_index in range(augment_num):
|
62 |
+
for video_name in data_dic.keys():
|
63 |
+
data_dic_name_list.append(video_name)
|
64 |
+
|
65 |
+
invalid_video_name_list = []
|
66 |
+
for video_name in data_dic_name_list:
|
67 |
+
video_clip_num = len(data_dic[video_name]['clip_data_list'])
|
68 |
+
if video_clip_num < 1:
|
69 |
+
invalid_video_name_list.append(video_name)
|
70 |
+
for name in invalid_video_name_list:
|
71 |
+
data_dic_name_list.remove(name)
|
72 |
+
|
73 |
+
|
74 |
+
if extra_json_name is not None:
|
75 |
+
zero_rank_print(f"start loading data: {extra_json_name}")
|
76 |
+
with open(extra_json_name,'r') as f:
|
77 |
+
extra_data_dic = json.load(f)
|
78 |
+
data_dic.update(extra_data_dic)
|
79 |
+
for augment_index in range(3*augment_num):
|
80 |
+
for video_name in extra_data_dic.keys():
|
81 |
+
data_dic_name_list.append(video_name)
|
82 |
+
random.shuffle(data_dic_name_list)
|
83 |
+
zero_rank_print("finish loading")
|
84 |
+
return data_dic_name_list, data_dic
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.data_dic_name_list)
|
88 |
+
|
89 |
+
def get_batch_wo_pose(self, index):
|
90 |
+
video_name = self.data_dic_name_list[index]
|
91 |
+
video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
|
92 |
+
|
93 |
+
source_anchor = random.sample(range(video_clip_num), 1)[0]
|
94 |
+
source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
|
95 |
+
source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
|
96 |
+
|
97 |
+
video_length = len(source_image_path_list)
|
98 |
+
|
99 |
+
if self.sample_stride_aug:
|
100 |
+
tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
|
101 |
+
else:
|
102 |
+
tmp_sample_stride = self.sample_stride
|
103 |
+
|
104 |
+
if not self.is_image:
|
105 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
|
106 |
+
start_idx = random.randint(0, video_length - clip_length)
|
107 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
108 |
+
else:
|
109 |
+
batch_index = [random.randint(0, video_length - 1)]
|
110 |
+
|
111 |
+
ref_img_idx = random.randint(0, video_length - 1)
|
112 |
+
|
113 |
+
ref_img = cv2.imread(source_image_path_list[ref_img_idx])
|
114 |
+
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
|
115 |
+
ref_img = self.contrast_normalization(ref_img)
|
116 |
+
|
117 |
+
ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
|
118 |
+
ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
|
119 |
+
|
120 |
+
images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
|
121 |
+
images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
|
122 |
+
image_np = np.array([self.contrast_normalization(img) for img in images])
|
123 |
+
|
124 |
+
pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
|
125 |
+
pixel_values = pixel_values / 255.
|
126 |
+
|
127 |
+
mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
|
128 |
+
|
129 |
+
pixel_values_pose = []
|
130 |
+
for frame_id in range(mesh2d_clip.shape[0]):
|
131 |
+
normed_mesh2d = mesh2d_clip[frame_id]
|
132 |
+
|
133 |
+
pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
|
134 |
+
pixel_values_pose.append(pose_image)
|
135 |
+
pixel_values_pose = np.array(pixel_values_pose)
|
136 |
+
|
137 |
+
if self.is_image:
|
138 |
+
pixel_values = pixel_values[0]
|
139 |
+
pixel_values_pose = pixel_values_pose[0]
|
140 |
+
image_np = image_np[0]
|
141 |
+
|
142 |
+
return ref_img, pixel_values_pose, image_np, ref_pose_image
|
143 |
+
|
144 |
+
def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
|
145 |
+
# convert input image to float32
|
146 |
+
image = image.astype(np.float32)
|
147 |
+
|
148 |
+
# normalize the image
|
149 |
+
normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
|
150 |
+
|
151 |
+
# convert to uint8
|
152 |
+
normalized_image = normalized_image.astype(np.uint8)
|
153 |
+
|
154 |
+
return normalized_image
|
155 |
+
|
156 |
+
def __getitem__(self, idx):
|
157 |
+
ref_img, pixel_values_pose, tar_gt, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
|
158 |
+
|
159 |
+
sample = dict(
|
160 |
+
pixel_values_pose=pixel_values_pose,
|
161 |
+
ref_img=ref_img,
|
162 |
+
tar_gt=tar_gt,
|
163 |
+
pixel_values_ref_pose=pixel_values_ref_pose,
|
164 |
+
)
|
165 |
+
|
166 |
+
return sample
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
class FaceDataset(Dataset):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
json_path,
|
174 |
+
extra_json_path=None,
|
175 |
+
sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
|
176 |
+
is_image=False,
|
177 |
+
sample_stride_aug=False
|
178 |
+
):
|
179 |
+
zero_rank_print(f"loading annotations from {json_path} ...")
|
180 |
+
self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
|
181 |
+
|
182 |
+
self.length = len(self.data_dic_name_list)
|
183 |
+
zero_rank_print(f"data scale: {self.length}")
|
184 |
+
|
185 |
+
self.sample_stride = sample_stride
|
186 |
+
self.sample_n_frames = sample_n_frames
|
187 |
+
|
188 |
+
self.sample_stride_aug = sample_stride_aug
|
189 |
+
|
190 |
+
self.sample_size = sample_size
|
191 |
+
self.resize = transforms.Resize((sample_size[0], sample_size[1]))
|
192 |
+
|
193 |
+
|
194 |
+
self.pixel_transforms = transforms.Compose([
|
195 |
+
transforms.Resize([sample_size[1], sample_size[0]]),
|
196 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
197 |
+
])
|
198 |
+
|
199 |
+
self.visualizer = FaceMeshVisualizer(forehead_edge=False)
|
200 |
+
self.clip_image_processor = CLIPImageProcessor()
|
201 |
+
self.is_image = is_image
|
202 |
+
|
203 |
+
def get_data(self, json_name, extra_json_name, augment_num=1):
|
204 |
+
zero_rank_print(f"start loading data: {json_name}")
|
205 |
+
with open(json_name,'r') as f:
|
206 |
+
data_dic = json.load(f)
|
207 |
+
|
208 |
+
data_dic_name_list = []
|
209 |
+
for augment_index in range(augment_num):
|
210 |
+
for video_name in data_dic.keys():
|
211 |
+
data_dic_name_list.append(video_name)
|
212 |
+
|
213 |
+
invalid_video_name_list = []
|
214 |
+
for video_name in data_dic_name_list:
|
215 |
+
video_clip_num = len(data_dic[video_name]['clip_data_list'])
|
216 |
+
if video_clip_num < 1:
|
217 |
+
invalid_video_name_list.append(video_name)
|
218 |
+
for name in invalid_video_name_list:
|
219 |
+
data_dic_name_list.remove(name)
|
220 |
+
|
221 |
+
|
222 |
+
if extra_json_name is not None:
|
223 |
+
zero_rank_print(f"start loading data: {extra_json_name}")
|
224 |
+
with open(extra_json_name,'r') as f:
|
225 |
+
extra_data_dic = json.load(f)
|
226 |
+
data_dic.update(extra_data_dic)
|
227 |
+
for augment_index in range(3*augment_num):
|
228 |
+
for video_name in extra_data_dic.keys():
|
229 |
+
data_dic_name_list.append(video_name)
|
230 |
+
random.shuffle(data_dic_name_list)
|
231 |
+
zero_rank_print("finish loading")
|
232 |
+
return data_dic_name_list, data_dic
|
233 |
+
|
234 |
+
def __len__(self):
|
235 |
+
return len(self.data_dic_name_list)
|
236 |
+
|
237 |
+
|
238 |
+
def get_batch_wo_pose(self, index):
|
239 |
+
video_name = self.data_dic_name_list[index]
|
240 |
+
video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
|
241 |
+
|
242 |
+
source_anchor = random.sample(range(video_clip_num), 1)[0]
|
243 |
+
source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
|
244 |
+
source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
|
245 |
+
|
246 |
+
video_length = len(source_image_path_list)
|
247 |
+
|
248 |
+
if self.sample_stride_aug:
|
249 |
+
tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
|
250 |
+
else:
|
251 |
+
tmp_sample_stride = self.sample_stride
|
252 |
+
|
253 |
+
if not self.is_image:
|
254 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
|
255 |
+
start_idx = random.randint(0, video_length - clip_length)
|
256 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
257 |
+
else:
|
258 |
+
batch_index = [random.randint(0, video_length - 1)]
|
259 |
+
|
260 |
+
ref_img_idx = random.randint(0, video_length - 1)
|
261 |
+
ref_img = cv2.imread(source_image_path_list[ref_img_idx])
|
262 |
+
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
|
263 |
+
ref_img = self.contrast_normalization(ref_img)
|
264 |
+
ref_img_pil = Image.fromarray(ref_img)
|
265 |
+
|
266 |
+
clip_ref_image = self.clip_image_processor(images=ref_img_pil, return_tensors="pt").pixel_values
|
267 |
+
|
268 |
+
pixel_values_ref_img = torch.from_numpy(ref_img).permute(2, 0, 1).contiguous()
|
269 |
+
pixel_values_ref_img = pixel_values_ref_img / 255.
|
270 |
+
|
271 |
+
ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
|
272 |
+
ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
|
273 |
+
pixel_values_ref_pose = torch.from_numpy(ref_pose_image).permute(2, 0, 1).contiguous()
|
274 |
+
pixel_values_ref_pose = pixel_values_ref_pose / 255.
|
275 |
+
|
276 |
+
images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
|
277 |
+
images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
|
278 |
+
image_np = np.array([self.contrast_normalization(img) for img in images])
|
279 |
+
|
280 |
+
pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
|
281 |
+
pixel_values = pixel_values / 255.
|
282 |
+
|
283 |
+
mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
|
284 |
+
|
285 |
+
pixel_values_pose = []
|
286 |
+
for frame_id in range(mesh2d_clip.shape[0]):
|
287 |
+
normed_mesh2d = mesh2d_clip[frame_id]
|
288 |
+
|
289 |
+
pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
|
290 |
+
|
291 |
+
pixel_values_pose.append(pose_image)
|
292 |
+
|
293 |
+
pixel_values_pose = np.array(pixel_values_pose)
|
294 |
+
pixel_values_pose = torch.from_numpy(pixel_values_pose).permute(0, 3, 1, 2).contiguous()
|
295 |
+
pixel_values_pose = pixel_values_pose / 255.
|
296 |
+
|
297 |
+
if self.is_image:
|
298 |
+
pixel_values = pixel_values[0]
|
299 |
+
pixel_values_pose = pixel_values_pose[0]
|
300 |
+
|
301 |
+
return pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose
|
302 |
+
|
303 |
+
def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
|
304 |
+
image = image.astype(np.float32)
|
305 |
+
normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
|
306 |
+
normalized_image = normalized_image.astype(np.uint8)
|
307 |
+
|
308 |
+
return normalized_image
|
309 |
+
|
310 |
+
def __getitem__(self, idx):
|
311 |
+
pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
|
312 |
+
|
313 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
314 |
+
pixel_values_pose = self.pixel_transforms(pixel_values_pose)
|
315 |
+
|
316 |
+
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
|
317 |
+
pixel_values_ref_img = self.pixel_transforms(pixel_values_ref_img)
|
318 |
+
pixel_values_ref_img = pixel_values_ref_img.squeeze(0)
|
319 |
+
|
320 |
+
pixel_values_ref_pose = pixel_values_ref_pose.unsqueeze(0)
|
321 |
+
pixel_values_ref_pose = self.pixel_transforms(pixel_values_ref_pose)
|
322 |
+
pixel_values_ref_pose = pixel_values_ref_pose.squeeze(0)
|
323 |
+
|
324 |
+
drop_image_embeds = 1 if random.random() < 0.1 else 0
|
325 |
+
|
326 |
+
sample = dict(
|
327 |
+
pixel_values=pixel_values,
|
328 |
+
pixel_values_pose=pixel_values_pose,
|
329 |
+
clip_ref_image=clip_ref_image,
|
330 |
+
pixel_values_ref_img=pixel_values_ref_img,
|
331 |
+
drop_image_embeds=drop_image_embeds,
|
332 |
+
pixel_values_ref_pose=pixel_values_ref_pose,
|
333 |
+
)
|
334 |
+
|
335 |
+
return sample
|
336 |
+
|
337 |
+
def collate_fn(data):
|
338 |
+
pixel_values = torch.stack([example["pixel_values"] for example in data])
|
339 |
+
pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
|
340 |
+
clip_ref_image = torch.cat([example["clip_ref_image"] for example in data])
|
341 |
+
pixel_values_ref_img = torch.stack([example["pixel_values_ref_img"] for example in data])
|
342 |
+
drop_image_embeds = [example["drop_image_embeds"] for example in data]
|
343 |
+
drop_image_embeds = torch.Tensor(drop_image_embeds)
|
344 |
+
pixel_values_ref_pose = torch.stack([example["pixel_values_ref_pose"] for example in data])
|
345 |
+
|
346 |
+
return {
|
347 |
+
"pixel_values": pixel_values,
|
348 |
+
"pixel_values_pose": pixel_values_pose,
|
349 |
+
"clip_ref_image": clip_ref_image,
|
350 |
+
"pixel_values_ref_img": pixel_values_ref_img,
|
351 |
+
"drop_image_embeds": drop_image_embeds,
|
352 |
+
"pixel_values_ref_pose": pixel_values_ref_pose,
|
353 |
+
}
|
354 |
+
|