|
import logging |
|
import os |
|
from functools import partial |
|
from multiprocessing.pool import ThreadPool |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import cv2 |
|
import numpy as np |
|
from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file |
|
from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts |
|
from timm.data.readers.reader import Reader |
|
from tqdm import tqdm |
|
|
|
CROP_ROUND_TOL = 0.3 |
|
MIN_PERSON_SIZE = 100 |
|
MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4 |
|
|
|
_logger = logging.getLogger("ReaderAgeGender") |
|
|
|
|
|
class ReaderAgeGender(Reader): |
|
""" |
|
Reader for almost original imdb-wiki cleaned dataset. |
|
Two changes: |
|
1. Your annotation must be in ./annotation subdir of dataset root |
|
2. Images must be in images subdir |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
images_path, |
|
annotations_path, |
|
split="validation", |
|
target_size=224, |
|
min_size=5, |
|
seed=1234, |
|
with_persons=False, |
|
min_person_size=MIN_PERSON_SIZE, |
|
disable_faces=False, |
|
only_age=False, |
|
min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO, |
|
crop_round_tol=CROP_ROUND_TOL, |
|
): |
|
super().__init__() |
|
|
|
self.with_persons = with_persons |
|
self.disable_faces = disable_faces |
|
self.only_age = only_age |
|
|
|
|
|
self.crop_out_color = (0, 0, 0) |
|
|
|
self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color |
|
self.empty_crop = self.empty_crop.astype(np.uint8) |
|
|
|
self.min_person_size = min_person_size |
|
self.min_person_aftercut_ratio = min_person_aftercut_ratio |
|
self.crop_round_tol = crop_round_tol |
|
|
|
self.split = split |
|
self.min_size = min_size |
|
self.seed = seed |
|
self.target_size = target_size |
|
|
|
|
|
self._ann: Dict[str, List[PictureInfo]] = {} |
|
self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} |
|
self._faces_list: List[Tuple[str, int]] = [] |
|
|
|
self._read_annotations(images_path, annotations_path) |
|
_logger.info(f"Dataset length: {len(self._faces_list)} crops") |
|
|
|
def __getitem__(self, index): |
|
return self._read_img_and_label(index) |
|
|
|
def __len__(self): |
|
return len(self._faces_list) |
|
|
|
def _filename(self, index, basename=False, absolute=False): |
|
img_p = self._faces_list[index][0] |
|
return os.path.basename(img_p) if basename else img_p |
|
|
|
def _read_annotations(self, images_path, csvs_path): |
|
self._ann = {} |
|
self._faces_list = [] |
|
self._associated_objects = {} |
|
|
|
csvs = get_all_files(csvs_path, [".csv"]) |
|
csvs = [c for c in csvs if self.split in os.path.basename(c)] |
|
|
|
|
|
for csv in csvs: |
|
db, ann_type = read_csv_annotation_file(csv, images_path) |
|
if self.with_persons and ann_type != AnnotType.PERSONS: |
|
raise ValueError( |
|
f"Annotation type in file {csv} contains no persons, " |
|
f"but annotations with persons are requested." |
|
) |
|
self._ann.update(db) |
|
|
|
if len(self._ann) == 0: |
|
raise ValueError("Annotations are empty!") |
|
|
|
self._ann, self._associated_objects = self.prepare_annotations() |
|
images_list = list(self._ann.keys()) |
|
|
|
for img_path in images_list: |
|
for index, image_sample_info in enumerate(self._ann[img_path]): |
|
assert image_sample_info.has_gt( |
|
self.only_age |
|
), "Annotations must be checked with self.prepare_annotations() func" |
|
self._faces_list.append((img_path, index)) |
|
|
|
def _read_img_and_label(self, index): |
|
if not isinstance(index, int): |
|
raise TypeError("ReaderAgeGender expected index to be integer") |
|
|
|
img_p, face_index = self._faces_list[index] |
|
ann: PictureInfo = self._ann[img_p][face_index] |
|
img = cv2.imread(img_p) |
|
|
|
face_empty = True |
|
if ann.has_face_bbox and not (self.with_persons and self.disable_faces): |
|
face_crop, face_empty = self._get_crop(ann.bbox, img) |
|
|
|
if not self.with_persons and face_empty: |
|
|
|
raise ValueError("Annotations must be checked with self.prepare_annotations() func") |
|
|
|
if face_empty: |
|
face_crop = self.empty_crop |
|
|
|
person_empty = True |
|
if self.with_persons or self.disable_faces: |
|
if ann.has_person_bbox: |
|
|
|
objects = self._associated_objects[img_p][face_index] |
|
person_crop, person_empty = self._get_crop( |
|
ann.person_bbox, |
|
img, |
|
crop_out_color=self.crop_out_color, |
|
asced_objects=objects, |
|
) |
|
|
|
if face_empty and person_empty: |
|
raise ValueError("Annotations must be checked with self.prepare_annotations() func") |
|
|
|
if person_empty: |
|
person_crop = self.empty_crop |
|
|
|
return (face_crop, person_crop), [ann.age, ann.gender] |
|
|
|
def _get_crop( |
|
self, |
|
bbox, |
|
img, |
|
asced_objects=None, |
|
crop_out_color=(0, 0, 0), |
|
) -> Tuple[np.ndarray, bool]: |
|
|
|
empty_bbox = False |
|
|
|
xmin, ymin, xmax, ymax = bbox |
|
assert not ( |
|
ymax - ymin < self.min_size or xmax - xmin < self.min_size |
|
), "Annotations must be checked with self.prepare_annotations() func" |
|
|
|
crop = img[ymin:ymax, xmin:xmax] |
|
|
|
if asced_objects: |
|
|
|
crop, empty_bbox = _cropout_asced_objs( |
|
asced_objects, |
|
bbox, |
|
crop.copy(), |
|
crop_out_color=crop_out_color, |
|
min_person_size=self.min_person_size, |
|
crop_round_tol=self.crop_round_tol, |
|
min_person_aftercut_ratio=self.min_person_aftercut_ratio, |
|
) |
|
if empty_bbox: |
|
crop = self.empty_crop |
|
|
|
crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color) |
|
return crop, empty_bbox |
|
|
|
def prepare_annotations(self): |
|
|
|
good_anns: Dict[str, List[PictureInfo]] = {} |
|
all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} |
|
|
|
if not self.with_persons: |
|
|
|
for img_path, bboxes in self._ann.items(): |
|
for sample in bboxes: |
|
sample.clear_person_bbox() |
|
|
|
|
|
verify_images_func = partial( |
|
verify_images, |
|
min_size=self.min_size, |
|
min_person_size=self.min_person_size, |
|
with_persons=self.with_persons, |
|
disable_faces=self.disable_faces, |
|
crop_round_tol=self.crop_round_tol, |
|
min_person_aftercut_ratio=self.min_person_aftercut_ratio, |
|
only_age=self.only_age, |
|
) |
|
num_threads = min(8, os.cpu_count()) |
|
|
|
all_msgs = [] |
|
broken = 0 |
|
skipped = 0 |
|
all_skipped_crops = 0 |
|
desc = "Check annotations..." |
|
with ThreadPool(num_threads) as pool: |
|
pbar = tqdm( |
|
pool.imap_unordered(verify_images_func, list(self._ann.items())), |
|
desc=desc, |
|
total=len(self._ann), |
|
) |
|
|
|
for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar: |
|
broken += 1 if is_corrupted else 0 |
|
all_msgs.extend(msgs) |
|
all_skipped_crops += skipped_crops |
|
skipped += 1 if is_empty_annotations else 0 |
|
if img_info is not None: |
|
img_path, img_samples = img_info |
|
good_anns[img_path] = img_samples |
|
all_associated_objects.update({img_path: associated_objects}) |
|
|
|
pbar.desc = ( |
|
f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); " |
|
f"{broken} images corrupted" |
|
) |
|
|
|
pbar.close() |
|
|
|
for msg in all_msgs: |
|
print(msg) |
|
print(f"\nLeft images: {len(good_anns)}") |
|
|
|
return good_anns, all_associated_objects |
|
|
|
|
|
def verify_images( |
|
img_info, |
|
min_size: int, |
|
min_person_size: int, |
|
with_persons: bool, |
|
disable_faces: bool, |
|
crop_round_tol: float, |
|
min_person_aftercut_ratio: float, |
|
only_age: bool, |
|
): |
|
|
|
|
|
|
|
disable_faces = disable_faces and with_persons |
|
kwargs = dict( |
|
min_person_size=min_person_size, |
|
disable_faces=disable_faces, |
|
with_persons=with_persons, |
|
crop_round_tol=crop_round_tol, |
|
min_person_aftercut_ratio=min_person_aftercut_ratio, |
|
only_age=only_age, |
|
) |
|
|
|
def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]: |
|
ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w) |
|
crop_h, crop_w = ymax - ymin, xmax - xmin |
|
if crop_h < min_size or crop_w < min_size: |
|
return False, [-1, -1, -1, -1] |
|
bbox = [xmin, ymin, xmax, ymax] |
|
return True, bbox |
|
|
|
msgs = [] |
|
skipped_crops = 0 |
|
is_corrupted = False |
|
is_empty_annotations = False |
|
|
|
img_path: str = img_info[0] |
|
img_samples: List[PictureInfo] = img_info[1] |
|
try: |
|
im_cv = cv2.imread(img_path) |
|
im_h, im_w = im_cv.shape[:2] |
|
except Exception: |
|
msgs.append(f"Can not load image {img_path}") |
|
is_corrupted = True |
|
return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops |
|
|
|
out_samples: List[PictureInfo] = [] |
|
for sample in img_samples: |
|
|
|
if sample.has_face_bbox: |
|
is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w) |
|
if not is_correct and sample.has_gt(only_age): |
|
msgs.append("Small face. Passing..") |
|
skipped_crops += 1 |
|
|
|
|
|
if sample.has_person_bbox: |
|
is_correct, sample.person_bbox = bbox_correct( |
|
sample.person_bbox, max(min_person_size, min_size), im_h, im_w |
|
) |
|
if not is_correct and sample.has_gt(only_age): |
|
msgs.append(f"Small person {img_path}. Passing..") |
|
skipped_crops += 1 |
|
|
|
if sample.has_face_bbox or sample.has_person_bbox: |
|
out_samples.append(sample) |
|
elif sample.has_gt(only_age): |
|
msgs.append("Sample hs no face and no body. Passing..") |
|
skipped_crops += 1 |
|
|
|
|
|
out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0) |
|
|
|
|
|
associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age) |
|
|
|
out_samples, associated_objects, skipped_crops = filter_bad_samples( |
|
out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs |
|
) |
|
|
|
out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples) |
|
if len(out_samples) == 0: |
|
out_img_info = None |
|
is_empty_annotations = True |
|
|
|
return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops |
|
|
|
|
|
def filter_bad_samples( |
|
out_samples: List[PictureInfo], |
|
associated_objects: dict, |
|
im_cv: np.ndarray, |
|
msgs: List[str], |
|
skipped_crops: int, |
|
**kwargs, |
|
): |
|
with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = ( |
|
kwargs["with_persons"], |
|
kwargs["disable_faces"], |
|
kwargs["min_person_size"], |
|
kwargs["crop_round_tol"], |
|
kwargs["min_person_aftercut_ratio"], |
|
kwargs["only_age"], |
|
) |
|
|
|
|
|
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)] |
|
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
|
if kwargs["disable_faces"]: |
|
|
|
for ind, sample in enumerate(out_samples): |
|
sample.clear_face_bbox() |
|
|
|
|
|
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox] |
|
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
|
if with_persons or disable_faces: |
|
|
|
|
|
|
|
inds = [] |
|
for ind, sample in enumerate(out_samples): |
|
person_empty = True |
|
if sample.has_person_bbox: |
|
xmin, ymin, xmax, ymax = sample.person_bbox |
|
crop = im_cv[ymin:ymax, xmin:xmax] |
|
|
|
_, person_empty = _cropout_asced_objs( |
|
associated_objects[ind], |
|
sample.person_bbox, |
|
crop.copy(), |
|
min_person_size=min_person_size, |
|
crop_round_tol=crop_round_tol, |
|
min_person_aftercut_ratio=min_person_aftercut_ratio, |
|
) |
|
|
|
if person_empty and not sample.has_face_bbox: |
|
msgs.append("Small person after preprocessing. Passing..") |
|
skipped_crops += 1 |
|
else: |
|
inds.append(ind) |
|
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
|
assert len(associated_objects) == len(out_samples) |
|
return out_samples, associated_objects, skipped_crops |
|
|
|
|
|
def _filter_by_ind(out_samples, associated_objects, inds): |
|
_associated_objects = {} |
|
_out_samples = [] |
|
for ind, sample in enumerate(out_samples): |
|
if ind in inds: |
|
_associated_objects[len(_out_samples)] = associated_objects[ind] |
|
_out_samples.append(sample) |
|
|
|
return _out_samples, _associated_objects |
|
|
|
|
|
def find_associated_objects( |
|
image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False |
|
) -> Dict[int, List[List[int]]]: |
|
""" |
|
For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it |
|
""" |
|
associated_objects: Dict[int, List[List[int]]] = {} |
|
|
|
for iindex, image_sample_info in enumerate(image_samples): |
|
|
|
associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else [] |
|
|
|
if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age): |
|
|
|
continue |
|
|
|
iperson_box = image_sample_info.person_bbox |
|
for jindex, other_image_sample in enumerate(image_samples): |
|
if iindex == jindex: |
|
continue |
|
if other_image_sample.has_face_bbox: |
|
jface_bbox = other_image_sample.bbox |
|
iou = _get_iou(jface_bbox, iperson_box) |
|
if iou >= iou_thresh: |
|
associated_objects[iindex].append(jface_bbox) |
|
if other_image_sample.has_person_bbox: |
|
jperson_bbox = other_image_sample.person_bbox |
|
iou = _get_iou(jperson_bbox, iperson_box) |
|
if iou >= iou_thresh: |
|
associated_objects[iindex].append(jperson_bbox) |
|
|
|
return associated_objects |
|
|
|
|
|
def _cropout_asced_objs( |
|
asced_objects, |
|
person_bbox, |
|
crop, |
|
min_person_size, |
|
crop_round_tol, |
|
min_person_aftercut_ratio, |
|
crop_out_color=(0, 0, 0), |
|
): |
|
empty = False |
|
xmin, ymin, xmax, ymax = person_bbox |
|
|
|
for a_obj in asced_objects: |
|
aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj |
|
|
|
aobj_ymin = int(max(aobj_ymin - ymin, 0)) |
|
aobj_xmin = int(max(aobj_xmin - xmin, 0)) |
|
aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin)) |
|
aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin)) |
|
|
|
crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color |
|
|
|
crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol) |
|
if ( |
|
crop.shape[0] < min_person_size or crop.shape[1] < min_person_size |
|
) or cropped_ratio < min_person_aftercut_ratio: |
|
crop = None |
|
empty = True |
|
|
|
return crop, empty |
|
|
|
|
|
def _correct_bbox(bbox, h, w): |
|
xmin, ymin, xmax, ymax = bbox |
|
ymin = min(max(ymin, 0), h) |
|
ymax = min(max(ymax, 0), h) |
|
xmin = min(max(xmin, 0), w) |
|
xmax = min(max(xmax, 0), w) |
|
return ymin, ymax, xmin, xmax |
|
|
|
|
|
def _get_iou(bbox1, bbox2): |
|
xmin1, ymin1, xmax1, ymax1 = bbox1 |
|
xmin2, ymin2, xmax2, ymax2 = bbox2 |
|
iou = IOU( |
|
[ymin1, xmin1, ymax1, xmax1], |
|
[ymin2, xmin2, ymax2, xmax2], |
|
) |
|
return iou |
|
|