|
from typing import Tuple |
|
|
|
import torch |
|
from mivolo.model.mi_volo import MiVOLO |
|
|
|
from .age_gender_dataset import AgeGenderDataset |
|
from .age_gender_loader import create_loader |
|
from .classification_dataset import AdienceDataset, FairFaceDataset |
|
|
|
DATASET_CLASS_MAP = { |
|
"utk": AgeGenderDataset, |
|
"lagenda": AgeGenderDataset, |
|
"imdb": AgeGenderDataset, |
|
"adience": AdienceDataset, |
|
"fairface": FairFaceDataset, |
|
} |
|
|
|
|
|
def build( |
|
name: str, |
|
images_path: str, |
|
annotations_path: str, |
|
split: str, |
|
mivolo_model: MiVOLO, |
|
workers: int, |
|
batch_size: int, |
|
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]: |
|
|
|
dataset_class = DATASET_CLASS_MAP[name] |
|
|
|
dataset: torch.utils.data.Dataset = dataset_class( |
|
images_path=images_path, |
|
annotations_path=annotations_path, |
|
name=name, |
|
split=split, |
|
target_size=mivolo_model.input_size, |
|
max_age=mivolo_model.meta.max_age, |
|
min_age=mivolo_model.meta.min_age, |
|
model_with_persons=mivolo_model.meta.with_persons_model, |
|
use_persons=mivolo_model.meta.use_persons, |
|
disable_faces=mivolo_model.meta.disable_faces, |
|
only_age=mivolo_model.meta.only_age, |
|
) |
|
|
|
data_config = mivolo_model.data_config |
|
|
|
in_chans = 3 if not mivolo_model.meta.with_persons_model else 6 |
|
input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size) |
|
|
|
dataset_loader: torch.utils.data.DataLoader = create_loader( |
|
dataset, |
|
input_size=input_size, |
|
batch_size=batch_size, |
|
mean=data_config["mean"], |
|
std=data_config["std"], |
|
num_workers=workers, |
|
crop_pct=data_config["crop_pct"], |
|
crop_mode=data_config["crop_mode"], |
|
pin_memory=False, |
|
device=mivolo_model.device, |
|
target_type=dataset.target_dtype, |
|
) |
|
|
|
return dataset, dataset_loader |
|
|