MiVOLO / mivolo /data /dataset /__init__.py
admin
sync
319d3b5
raw
history blame
1.88 kB
from typing import Tuple
import torch
from mivolo.model.mi_volo import MiVOLO
from .age_gender_dataset import AgeGenderDataset
from .age_gender_loader import create_loader
from .classification_dataset import AdienceDataset, FairFaceDataset
DATASET_CLASS_MAP = {
"utk": AgeGenderDataset,
"lagenda": AgeGenderDataset,
"imdb": AgeGenderDataset,
"adience": AdienceDataset,
"fairface": FairFaceDataset,
}
def build(
name: str,
images_path: str,
annotations_path: str,
split: str,
mivolo_model: MiVOLO,
workers: int,
batch_size: int,
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
dataset_class = DATASET_CLASS_MAP[name]
dataset: torch.utils.data.Dataset = dataset_class(
images_path=images_path,
annotations_path=annotations_path,
name=name,
split=split,
target_size=mivolo_model.input_size,
max_age=mivolo_model.meta.max_age,
min_age=mivolo_model.meta.min_age,
model_with_persons=mivolo_model.meta.with_persons_model,
use_persons=mivolo_model.meta.use_persons,
disable_faces=mivolo_model.meta.disable_faces,
only_age=mivolo_model.meta.only_age,
)
data_config = mivolo_model.data_config
in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
dataset_loader: torch.utils.data.DataLoader = create_loader(
dataset,
input_size=input_size,
batch_size=batch_size,
mean=data_config["mean"],
std=data_config["std"],
num_workers=workers,
crop_pct=data_config["crop_pct"],
crop_mode=data_config["crop_mode"],
pin_memory=False,
device=mivolo_model.device,
target_type=dataset.target_dtype,
)
return dataset, dataset_loader