kimjy0411 commited on
Commit
028ba1c
·
verified ·
1 Parent(s): d00cb50

Upload src/dataset/dataset_face.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/dataset/dataset_face.py +354 -0
src/dataset/dataset_face.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random, pdb
2
+ import cv2
3
+ import numpy as np
4
+ import json
5
+ from PIL import Image
6
+ from einops import rearrange
7
+
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from torch.utils.data.dataset import Dataset
11
+ from transformers import CLIPImageProcessor
12
+ import torch.distributed as dist
13
+
14
+
15
+ from src.utils.draw_util import FaceMeshVisualizer
16
+
17
+ def zero_rank_print(s):
18
+ if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
19
+
20
+
21
+
22
+ class FaceDatasetValid(Dataset):
23
+ def __init__(
24
+ self,
25
+ json_path,
26
+ extra_json_path=None,
27
+ sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
28
+ is_image=False,
29
+ sample_stride_aug=False
30
+ ):
31
+ zero_rank_print(f"loading annotations from {json_path} ...")
32
+ self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
33
+
34
+ self.length = len(self.data_dic_name_list)
35
+ zero_rank_print(f"data scale: {self.length}")
36
+
37
+ self.sample_stride = sample_stride
38
+ self.sample_n_frames = sample_n_frames
39
+
40
+ self.sample_stride_aug = sample_stride_aug
41
+
42
+ self.sample_size = sample_size
43
+ self.resize = transforms.Resize((sample_size[0], sample_size[1]))
44
+
45
+
46
+ self.pixel_transforms = transforms.Compose([
47
+ transforms.Resize([sample_size[1], sample_size[0]]),
48
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
49
+ ])
50
+
51
+ self.visualizer = FaceMeshVisualizer(forehead_edge=False)
52
+ self.clip_image_processor = CLIPImageProcessor()
53
+ self.is_image = is_image
54
+
55
+ def get_data(self, json_name, extra_json_name, augment_num=1):
56
+ zero_rank_print(f"start loading data: {json_name}")
57
+ with open(json_name,'r') as f:
58
+ data_dic = json.load(f)
59
+
60
+ data_dic_name_list = []
61
+ for augment_index in range(augment_num):
62
+ for video_name in data_dic.keys():
63
+ data_dic_name_list.append(video_name)
64
+
65
+ invalid_video_name_list = []
66
+ for video_name in data_dic_name_list:
67
+ video_clip_num = len(data_dic[video_name]['clip_data_list'])
68
+ if video_clip_num < 1:
69
+ invalid_video_name_list.append(video_name)
70
+ for name in invalid_video_name_list:
71
+ data_dic_name_list.remove(name)
72
+
73
+
74
+ if extra_json_name is not None:
75
+ zero_rank_print(f"start loading data: {extra_json_name}")
76
+ with open(extra_json_name,'r') as f:
77
+ extra_data_dic = json.load(f)
78
+ data_dic.update(extra_data_dic)
79
+ for augment_index in range(3*augment_num):
80
+ for video_name in extra_data_dic.keys():
81
+ data_dic_name_list.append(video_name)
82
+ random.shuffle(data_dic_name_list)
83
+ zero_rank_print("finish loading")
84
+ return data_dic_name_list, data_dic
85
+
86
+ def __len__(self):
87
+ return len(self.data_dic_name_list)
88
+
89
+ def get_batch_wo_pose(self, index):
90
+ video_name = self.data_dic_name_list[index]
91
+ video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
92
+
93
+ source_anchor = random.sample(range(video_clip_num), 1)[0]
94
+ source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
95
+ source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
96
+
97
+ video_length = len(source_image_path_list)
98
+
99
+ if self.sample_stride_aug:
100
+ tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
101
+ else:
102
+ tmp_sample_stride = self.sample_stride
103
+
104
+ if not self.is_image:
105
+ clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
106
+ start_idx = random.randint(0, video_length - clip_length)
107
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
108
+ else:
109
+ batch_index = [random.randint(0, video_length - 1)]
110
+
111
+ ref_img_idx = random.randint(0, video_length - 1)
112
+
113
+ ref_img = cv2.imread(source_image_path_list[ref_img_idx])
114
+ ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
115
+ ref_img = self.contrast_normalization(ref_img)
116
+
117
+ ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
118
+ ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
119
+
120
+ images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
121
+ images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
122
+ image_np = np.array([self.contrast_normalization(img) for img in images])
123
+
124
+ pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
125
+ pixel_values = pixel_values / 255.
126
+
127
+ mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
128
+
129
+ pixel_values_pose = []
130
+ for frame_id in range(mesh2d_clip.shape[0]):
131
+ normed_mesh2d = mesh2d_clip[frame_id]
132
+
133
+ pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
134
+ pixel_values_pose.append(pose_image)
135
+ pixel_values_pose = np.array(pixel_values_pose)
136
+
137
+ if self.is_image:
138
+ pixel_values = pixel_values[0]
139
+ pixel_values_pose = pixel_values_pose[0]
140
+ image_np = image_np[0]
141
+
142
+ return ref_img, pixel_values_pose, image_np, ref_pose_image
143
+
144
+ def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
145
+ # convert input image to float32
146
+ image = image.astype(np.float32)
147
+
148
+ # normalize the image
149
+ normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
150
+
151
+ # convert to uint8
152
+ normalized_image = normalized_image.astype(np.uint8)
153
+
154
+ return normalized_image
155
+
156
+ def __getitem__(self, idx):
157
+ ref_img, pixel_values_pose, tar_gt, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
158
+
159
+ sample = dict(
160
+ pixel_values_pose=pixel_values_pose,
161
+ ref_img=ref_img,
162
+ tar_gt=tar_gt,
163
+ pixel_values_ref_pose=pixel_values_ref_pose,
164
+ )
165
+
166
+ return sample
167
+
168
+
169
+
170
+ class FaceDataset(Dataset):
171
+ def __init__(
172
+ self,
173
+ json_path,
174
+ extra_json_path=None,
175
+ sample_size=[512, 512], sample_stride=4, sample_n_frames=16,
176
+ is_image=False,
177
+ sample_stride_aug=False
178
+ ):
179
+ zero_rank_print(f"loading annotations from {json_path} ...")
180
+ self.data_dic_name_list, self.data_dic = self.get_data(json_path, extra_json_path)
181
+
182
+ self.length = len(self.data_dic_name_list)
183
+ zero_rank_print(f"data scale: {self.length}")
184
+
185
+ self.sample_stride = sample_stride
186
+ self.sample_n_frames = sample_n_frames
187
+
188
+ self.sample_stride_aug = sample_stride_aug
189
+
190
+ self.sample_size = sample_size
191
+ self.resize = transforms.Resize((sample_size[0], sample_size[1]))
192
+
193
+
194
+ self.pixel_transforms = transforms.Compose([
195
+ transforms.Resize([sample_size[1], sample_size[0]]),
196
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
197
+ ])
198
+
199
+ self.visualizer = FaceMeshVisualizer(forehead_edge=False)
200
+ self.clip_image_processor = CLIPImageProcessor()
201
+ self.is_image = is_image
202
+
203
+ def get_data(self, json_name, extra_json_name, augment_num=1):
204
+ zero_rank_print(f"start loading data: {json_name}")
205
+ with open(json_name,'r') as f:
206
+ data_dic = json.load(f)
207
+
208
+ data_dic_name_list = []
209
+ for augment_index in range(augment_num):
210
+ for video_name in data_dic.keys():
211
+ data_dic_name_list.append(video_name)
212
+
213
+ invalid_video_name_list = []
214
+ for video_name in data_dic_name_list:
215
+ video_clip_num = len(data_dic[video_name]['clip_data_list'])
216
+ if video_clip_num < 1:
217
+ invalid_video_name_list.append(video_name)
218
+ for name in invalid_video_name_list:
219
+ data_dic_name_list.remove(name)
220
+
221
+
222
+ if extra_json_name is not None:
223
+ zero_rank_print(f"start loading data: {extra_json_name}")
224
+ with open(extra_json_name,'r') as f:
225
+ extra_data_dic = json.load(f)
226
+ data_dic.update(extra_data_dic)
227
+ for augment_index in range(3*augment_num):
228
+ for video_name in extra_data_dic.keys():
229
+ data_dic_name_list.append(video_name)
230
+ random.shuffle(data_dic_name_list)
231
+ zero_rank_print("finish loading")
232
+ return data_dic_name_list, data_dic
233
+
234
+ def __len__(self):
235
+ return len(self.data_dic_name_list)
236
+
237
+
238
+ def get_batch_wo_pose(self, index):
239
+ video_name = self.data_dic_name_list[index]
240
+ video_clip_num = len(self.data_dic[video_name]['clip_data_list'])
241
+
242
+ source_anchor = random.sample(range(video_clip_num), 1)[0]
243
+ source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list']
244
+ source_mesh2d_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['lmks_list']
245
+
246
+ video_length = len(source_image_path_list)
247
+
248
+ if self.sample_stride_aug:
249
+ tmp_sample_stride = self.sample_stride if random.random() > 0.5 else 4
250
+ else:
251
+ tmp_sample_stride = self.sample_stride
252
+
253
+ if not self.is_image:
254
+ clip_length = min(video_length, (self.sample_n_frames - 1) * tmp_sample_stride + 1)
255
+ start_idx = random.randint(0, video_length - clip_length)
256
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
257
+ else:
258
+ batch_index = [random.randint(0, video_length - 1)]
259
+
260
+ ref_img_idx = random.randint(0, video_length - 1)
261
+ ref_img = cv2.imread(source_image_path_list[ref_img_idx])
262
+ ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
263
+ ref_img = self.contrast_normalization(ref_img)
264
+ ref_img_pil = Image.fromarray(ref_img)
265
+
266
+ clip_ref_image = self.clip_image_processor(images=ref_img_pil, return_tensors="pt").pixel_values
267
+
268
+ pixel_values_ref_img = torch.from_numpy(ref_img).permute(2, 0, 1).contiguous()
269
+ pixel_values_ref_img = pixel_values_ref_img / 255.
270
+
271
+ ref_mesh2d_clip = np.load(source_mesh2d_path_list[ref_img_idx]).astype(float)
272
+ ref_pose_image = self.visualizer.draw_landmarks(self.sample_size, ref_mesh2d_clip, normed=True)
273
+ pixel_values_ref_pose = torch.from_numpy(ref_pose_image).permute(2, 0, 1).contiguous()
274
+ pixel_values_ref_pose = pixel_values_ref_pose / 255.
275
+
276
+ images = [cv2.imread(source_image_path_list[idx]) for idx in batch_index]
277
+ images = [cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) for bgr_image in images]
278
+ image_np = np.array([self.contrast_normalization(img) for img in images])
279
+
280
+ pixel_values = torch.from_numpy(image_np).permute(0, 3, 1, 2).contiguous()
281
+ pixel_values = pixel_values / 255.
282
+
283
+ mesh2d_clip = np.array([np.load(source_mesh2d_path_list[idx]).astype(float) for idx in batch_index])
284
+
285
+ pixel_values_pose = []
286
+ for frame_id in range(mesh2d_clip.shape[0]):
287
+ normed_mesh2d = mesh2d_clip[frame_id]
288
+
289
+ pose_image = self.visualizer.draw_landmarks(self.sample_size, normed_mesh2d, normed=True)
290
+
291
+ pixel_values_pose.append(pose_image)
292
+
293
+ pixel_values_pose = np.array(pixel_values_pose)
294
+ pixel_values_pose = torch.from_numpy(pixel_values_pose).permute(0, 3, 1, 2).contiguous()
295
+ pixel_values_pose = pixel_values_pose / 255.
296
+
297
+ if self.is_image:
298
+ pixel_values = pixel_values[0]
299
+ pixel_values_pose = pixel_values_pose[0]
300
+
301
+ return pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose
302
+
303
+ def contrast_normalization(self, image, lower_bound=0, upper_bound=255):
304
+ image = image.astype(np.float32)
305
+ normalized_image = image * (upper_bound - lower_bound) / 255 + lower_bound
306
+ normalized_image = normalized_image.astype(np.uint8)
307
+
308
+ return normalized_image
309
+
310
+ def __getitem__(self, idx):
311
+ pixel_values, pixel_values_pose, clip_ref_image, pixel_values_ref_img, pixel_values_ref_pose = self.get_batch_wo_pose(idx)
312
+
313
+ pixel_values = self.pixel_transforms(pixel_values)
314
+ pixel_values_pose = self.pixel_transforms(pixel_values_pose)
315
+
316
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
317
+ pixel_values_ref_img = self.pixel_transforms(pixel_values_ref_img)
318
+ pixel_values_ref_img = pixel_values_ref_img.squeeze(0)
319
+
320
+ pixel_values_ref_pose = pixel_values_ref_pose.unsqueeze(0)
321
+ pixel_values_ref_pose = self.pixel_transforms(pixel_values_ref_pose)
322
+ pixel_values_ref_pose = pixel_values_ref_pose.squeeze(0)
323
+
324
+ drop_image_embeds = 1 if random.random() < 0.1 else 0
325
+
326
+ sample = dict(
327
+ pixel_values=pixel_values,
328
+ pixel_values_pose=pixel_values_pose,
329
+ clip_ref_image=clip_ref_image,
330
+ pixel_values_ref_img=pixel_values_ref_img,
331
+ drop_image_embeds=drop_image_embeds,
332
+ pixel_values_ref_pose=pixel_values_ref_pose,
333
+ )
334
+
335
+ return sample
336
+
337
+ def collate_fn(data):
338
+ pixel_values = torch.stack([example["pixel_values"] for example in data])
339
+ pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
340
+ clip_ref_image = torch.cat([example["clip_ref_image"] for example in data])
341
+ pixel_values_ref_img = torch.stack([example["pixel_values_ref_img"] for example in data])
342
+ drop_image_embeds = [example["drop_image_embeds"] for example in data]
343
+ drop_image_embeds = torch.Tensor(drop_image_embeds)
344
+ pixel_values_ref_pose = torch.stack([example["pixel_values_ref_pose"] for example in data])
345
+
346
+ return {
347
+ "pixel_values": pixel_values,
348
+ "pixel_values_pose": pixel_values_pose,
349
+ "clip_ref_image": clip_ref_image,
350
+ "pixel_values_ref_img": pixel_values_ref_img,
351
+ "drop_image_embeds": drop_image_embeds,
352
+ "pixel_values_ref_pose": pixel_values_ref_pose,
353
+ }
354
+