File size: 6,068 Bytes
319d3b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import logging
from typing import Any, List, Optional, Set
import cv2
import numpy as np
import torch
from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
from PIL import Image
from torchvision import transforms
_logger = logging.getLogger("AgeGenderDataset")
class AgeGenderDataset(torch.utils.data.Dataset):
def __init__(
self,
images_path,
annotations_path,
name=None,
split="train",
load_bytes=False,
img_mode="RGB",
transform=None,
is_training=False,
seed=1234,
target_size=224,
min_age=None,
max_age=None,
model_with_persons=False,
use_persons=False,
disable_faces=False,
only_age=False,
):
reader = ReaderAgeGender(
images_path,
annotations_path,
split=split,
seed=seed,
target_size=target_size,
with_persons=use_persons,
disable_faces=disable_faces,
only_age=only_age,
)
self.name = name
self.model_with_persons = model_with_persons
self.reader = reader
self.load_bytes = load_bytes
self.img_mode = img_mode
self.transform = transform
self._consecutive_errors = 0
self.is_training = is_training
self.random_flip = 0.0
# Setting up classes.
# If min and max classes are passed - use them to have the same preprocessing for validation
self.max_age: float = None
self.min_age: float = None
self.avg_age: float = None
self.set_ages_min_max(min_age, max_age)
self.genders = ["M", "F"]
self.num_classes_gender = len(self.genders)
self.age_classes: Optional[List[str]] = self.set_age_classes()
self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
self.num_classes: int = self.num_classes_age + self.num_classes_gender
self.target_dtype = torch.float32
def set_age_classes(self) -> Optional[List[str]]:
return None # for regression dataset
def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
assert all(age is None for age in [min_age, max_age]) or all(
age is not None for age in [min_age, max_age]
), "Both min and max age must be passed or none of them"
if max_age is not None and min_age is not None:
_logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
self.max_age = max_age
self.min_age = min_age
else:
# collect statistics from loaded dataset
all_ages_set: Set[int] = set()
for img_path, image_samples in self.reader._ann.items():
for image_sample_info in image_samples:
if image_sample_info.age == "-1":
continue
age = round(float(image_sample_info.age))
all_ages_set.add(age)
self.max_age = max(all_ages_set)
self.min_age = min(all_ages_set)
self.avg_age = (self.max_age + self.min_age) / 2.0
def _norm_age(self, age):
return (age - self.avg_age) / (self.max_age - self.min_age)
def parse_gender(self, _gender: str) -> float:
if _gender != "-1":
gender = float(0 if _gender == "M" or _gender == "0" else 1)
else:
gender = -1
return gender
def parse_target(self, _age: str, gender: str) -> List[Any]:
if _age != "-1":
age = round(float(_age))
age = self._norm_age(float(age))
else:
age = -1
target: List[float] = [age, self.parse_gender(gender)]
return target
@property
def transform(self):
return self._transform
@transform.setter
def transform(self, transform):
# Disable pretrained monkey-patched transforms
if not transform:
return
_trans = []
for trans in transform.transforms:
if "Resize" in str(trans):
continue
if "Crop" in str(trans):
continue
_trans.append(trans)
self._transform = transforms.Compose(_trans)
def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
if image is None:
return None
if self.transform is None:
return image
image = convert_to_pil(image, self.img_mode)
for trans in self.transform.transforms:
image = trans(image)
return image
def __getitem__(self, index):
# get preprocessed face and person crops (np.ndarray)
# resize + pad, for person crops: cut off other bboxes
images, target = self.reader[index]
target = self.parse_target(*target)
if self.model_with_persons:
face_image, person_image = images
person_image: np.ndarray = self.apply_tranforms(person_image)
else:
face_image = images[0]
person_image = None
face_image: np.ndarray = self.apply_tranforms(face_image)
if person_image is not None:
img = np.concatenate([face_image, person_image], axis=0)
else:
img = face_image
return img, target
def __len__(self):
return len(self.reader)
def filename(self, index, basename=False, absolute=False):
return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.reader.filenames(basename, absolute)
def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
if cv_im is None:
return None
if img_mode == "RGB":
cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
else:
raise Exception("Incorrect image mode has been passed!")
cv_im = np.ascontiguousarray(cv_im)
pil_image = Image.fromarray(cv_im)
return pil_image
|