EdgeTA / data /build_cl /scenario.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import enum
from functools import reduce
from typing import Dict, List, Tuple
import numpy as np
import copy
from utils.common.log import logger
from ..datasets.ab_dataset import ABDataset
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
from data import get_dataset, MergedDataset, Scenario as DAScenario
class _ABDatasetMetaInfo:
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
self.name = name
self.classes = classes
self.class_aliases = class_aliases
self.shift_type = shift_type
self.task_type = task_type
self.object_type = object_type
self.ignore_classes = ignore_classes
self.idx_map = idx_map
def __repr__(self) -> str:
return f'({self.name}, {self.classes}, {self.idx_map})'
class Scenario:
def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
self.config = config
self.target_datasets_info = target_datasets_info
self.num_classes = num_classes
self.cur_task_index = 0
self.num_source_classes = num_source_classes
self.cur_class_offset = num_source_classes
self.data_dirs = data_dirs
self.target_tasks_order = [i.name for i in self.target_datasets_info]
self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])
logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
f'# classes per task: {config["num_classes_per_task"]}')
def to_json(self):
config = copy.deepcopy(self.config)
config['da_scenario'] = config['da_scenario'].to_json()
target_datasets_info = [str(i) for i in self.target_datasets_info]
return dict(
config=config, target_datasets_info=target_datasets_info,
num_classes=self.num_classes
)
def __str__(self):
return f'Scenario({self.to_json()})'
def get_cur_class_offset(self):
return self.cur_class_offset
def get_cur_num_class(self):
return len(self.target_datasets_info[self.cur_task_index].classes)
def get_nc_per_task(self):
return len(self.target_datasets_info[0].classes)
def next_task(self):
self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
self.cur_task_index += 1
print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
def get_cur_task_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res ={ **{split: get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map) for split in ['train']},
**{split: MergedDataset([get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=di.ignore_classes,
idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]])
for split in ['val', 'test']}
}
# if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
# return None
if len(res['train']) < 1000:
res['train'] = MergedDataset([res['train']] * 5)
logger.info('aug train dataset')
if len(res['val']) < 1000:
res['val'] = MergedDataset(res['val'].datasets * 5)
logger.info('aug val dataset')
if len(res['test']) < 1000:
res['test'] = MergedDataset(res['test'].datasets * 5)
logger.info('aug test dataset')
# da_scenario: DAScenario = self.config['da_scenario']
# offline_datasets = da_scenario.get_offline_datasets()
for k, v in res.items():
logger.info(f'{k} dataset: {len(v)}')
# new_val_datasets = [
# *[d['val'] for d in offline_datasets.values()],
# res['val']
# ]
# res['val'] = MergedDataset(new_val_datasets)
# new_test_datasets = [
# *[d['test'] for d in offline_datasets.values()],
# res['test']
# ]
# res['test'] = MergedDataset(new_test_datasets)
return res
def get_cur_task_train_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res = get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split='train',
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map)
return res
def get_online_cur_task_samples_for_training(self, num_samples):
dataset = self.get_cur_task_datasets()
dataset = dataset['train']
return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0]