|
import enum |
|
from functools import reduce |
|
from typing import Dict, List, Tuple |
|
import numpy as np |
|
import copy |
|
from utils.common.log import logger |
|
from ..datasets.ab_dataset import ABDataset |
|
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader |
|
from data import get_dataset, MergedDataset, Scenario as DAScenario |
|
|
|
|
|
class _ABDatasetMetaInfo: |
|
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map): |
|
self.name = name |
|
self.classes = classes |
|
self.class_aliases = class_aliases |
|
self.shift_type = shift_type |
|
self.task_type = task_type |
|
self.object_type = object_type |
|
|
|
self.ignore_classes = ignore_classes |
|
self.idx_map = idx_map |
|
|
|
def __repr__(self) -> str: |
|
return f'({self.name}, {self.classes}, {self.idx_map})' |
|
|
|
|
|
class Scenario: |
|
def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs): |
|
self.config = config |
|
self.target_datasets_info = target_datasets_info |
|
self.num_classes = num_classes |
|
self.cur_task_index = 0 |
|
self.num_source_classes = num_source_classes |
|
self.cur_class_offset = num_source_classes |
|
self.data_dirs = data_dirs |
|
|
|
self.target_tasks_order = [i.name for i in self.target_datasets_info] |
|
self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info]) |
|
|
|
logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, ' |
|
f'# classes per task: {config["num_classes_per_task"]}') |
|
|
|
def to_json(self): |
|
config = copy.deepcopy(self.config) |
|
config['da_scenario'] = config['da_scenario'].to_json() |
|
target_datasets_info = [str(i) for i in self.target_datasets_info] |
|
return dict( |
|
config=config, target_datasets_info=target_datasets_info, |
|
num_classes=self.num_classes |
|
) |
|
|
|
def __str__(self): |
|
return f'Scenario({self.to_json()})' |
|
|
|
def get_cur_class_offset(self): |
|
return self.cur_class_offset |
|
|
|
def get_cur_num_class(self): |
|
return len(self.target_datasets_info[self.cur_task_index].classes) |
|
|
|
def get_nc_per_task(self): |
|
return len(self.target_datasets_info[0].classes) |
|
|
|
def next_task(self): |
|
self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes) |
|
self.cur_task_index += 1 |
|
|
|
print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}') |
|
|
|
def get_cur_task_datasets(self): |
|
dataset_info = self.target_datasets_info[self.cur_task_index] |
|
dataset_name = dataset_info.name.split('|')[0] |
|
|
|
|
|
|
|
|
|
res ={ **{split: get_dataset(dataset_name=dataset_name, |
|
root_dir=self.data_dirs[dataset_name], |
|
split=split, |
|
transform=None, |
|
ignore_classes=dataset_info.ignore_classes, |
|
idx_map=dataset_info.idx_map) for split in ['train']}, |
|
|
|
**{split: MergedDataset([get_dataset(dataset_name=dataset_name, |
|
root_dir=self.data_dirs[dataset_name], |
|
split=split, |
|
transform=None, |
|
ignore_classes=di.ignore_classes, |
|
idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]]) |
|
for split in ['val', 'test']} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if len(res['train']) < 1000: |
|
res['train'] = MergedDataset([res['train']] * 5) |
|
logger.info('aug train dataset') |
|
if len(res['val']) < 1000: |
|
res['val'] = MergedDataset(res['val'].datasets * 5) |
|
logger.info('aug val dataset') |
|
if len(res['test']) < 1000: |
|
res['test'] = MergedDataset(res['test'].datasets * 5) |
|
logger.info('aug test dataset') |
|
|
|
|
|
|
|
for k, v in res.items(): |
|
logger.info(f'{k} dataset: {len(v)}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res |
|
|
|
def get_cur_task_train_datasets(self): |
|
dataset_info = self.target_datasets_info[self.cur_task_index] |
|
dataset_name = dataset_info.name.split('|')[0] |
|
|
|
|
|
|
|
|
|
res = get_dataset(dataset_name=dataset_name, |
|
root_dir=self.data_dirs[dataset_name], |
|
split='train', |
|
transform=None, |
|
ignore_classes=dataset_info.ignore_classes, |
|
idx_map=dataset_info.idx_map) |
|
|
|
return res |
|
|
|
def get_online_cur_task_samples_for_training(self, num_samples): |
|
dataset = self.get_cur_task_datasets() |
|
dataset = dataset['train'] |
|
return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0] |