File size: 1,592 Bytes
319d3b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from typing import Any, List, Optional
import torch
from .age_gender_dataset import AgeGenderDataset
class ClassificationDataset(AgeGenderDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_dtype = torch.int32
def set_age_classes(self) -> Optional[List[str]]:
raise NotImplementedError
def parse_target(self, age: str, gender: str) -> List[Any]:
assert self.age_classes is not None
if age != "-1":
assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
age_ind = self.age_classes.index(age)
else:
age_ind = -1
target: List[int] = [age_ind, int(self.parse_gender(gender))]
return target
class FairFaceDataset(ClassificationDataset):
def set_age_classes(self) -> Optional[List[str]]:
age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
# a[i-1] <= v < a[i] => age_classes[i-1]
self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
return age_classes
class AdienceDataset(ClassificationDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_dtype = torch.int32
def set_age_classes(self) -> Optional[List[str]]:
age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
# a[i-1] <= v < a[i] => age_classes[i-1]
self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
return age_classes
|