|
from typing import Dict, List, Optional, Type, Union |
|
from ..datasets.ab_dataset import ABDataset |
|
|
|
|
|
from ..dataset import get_dataset |
|
import copy |
|
from torchvision.transforms import Compose |
|
from ..datasets.registery import static_dataset_registery |
|
from ..build.scenario import Scenario as DAScenario |
|
from copy import deepcopy |
|
from utils.common.log import logger |
|
import random |
|
from .scenario import _ABDatasetMetaInfo, Scenario |
|
|
|
|
|
def _check(source_datasets_meta_info: List[_ABDatasetMetaInfo], target_datasets_meta_info: List[_ABDatasetMetaInfo]): |
|
|
|
|
|
|
|
source_datasets_class = [i.classes for i in source_datasets_meta_info] |
|
for ci1, c1 in enumerate(source_datasets_class): |
|
for ci2, c2 in enumerate(source_datasets_class): |
|
if ci1 == ci2: |
|
continue |
|
|
|
c1_name = source_datasets_meta_info[ci1].name |
|
c2_name = source_datasets_meta_info[ci2].name |
|
intersection = set(c1).intersection(set(c2)) |
|
assert len(intersection) == 0, f'{c1_name} has intersection with {c2_name}: {intersection}' |
|
|
|
|
|
def build_cl_scenario( |
|
da_scenario: DAScenario, |
|
target_datasets_name: List[str], |
|
num_classes_per_task: int, |
|
max_num_tasks: int, |
|
data_dirs, |
|
sanity_check=False |
|
): |
|
config = deepcopy(locals()) |
|
|
|
source_datasets_idx_map = {} |
|
source_class_idx_max = 0 |
|
|
|
for sd in da_scenario.config['source_datasets_name']: |
|
da_scenario_idx_map = None |
|
for k, v in da_scenario.all_datasets_idx_map.items(): |
|
if k.startswith(sd): |
|
da_scenario_idx_map = v |
|
break |
|
|
|
source_datasets_idx_map[sd] = da_scenario_idx_map |
|
source_class_idx_max = max(source_class_idx_max, max(list(da_scenario_idx_map.values()))) |
|
|
|
|
|
target_class_idx_start = source_class_idx_max + 1 |
|
|
|
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:], None, None) for d in target_datasets_name] |
|
|
|
task_datasets_seq = [] |
|
|
|
num_tasks_per_dataset = {} |
|
|
|
for td_info_i, td_info in enumerate(target_datasets_meta_info): |
|
|
|
if td_info_i >= 1: |
|
for _td_info_i, _td_info in enumerate(target_datasets_meta_info[0: td_info_i]): |
|
if _td_info.name == td_info.name: |
|
|
|
|
|
print(len(task_datasets_seq)) |
|
|
|
task_index_offset = sum([v if __i < _td_info_i else 0 for __i, v in enumerate(num_tasks_per_dataset.values())]) |
|
|
|
task_datasets_seq += task_datasets_seq[task_index_offset: task_index_offset + num_tasks_per_dataset[_td_info_i]] |
|
print(len(task_datasets_seq)) |
|
break |
|
continue |
|
|
|
td_classes = td_info.classes |
|
num_tasks_per_dataset[td_info_i] = 0 |
|
|
|
for ci in range(0, len(td_classes), num_classes_per_task): |
|
task_i = ci // num_classes_per_task |
|
task_datasets_seq += [_ABDatasetMetaInfo( |
|
f'{td_info.name}|task-{task_i}|ci-{ci}-{ci + num_classes_per_task - 1}', |
|
td_classes[ci: ci + num_classes_per_task], |
|
td_info.task_type, |
|
td_info.object_type, |
|
td_info.class_aliases, |
|
td_info.shift_type, |
|
|
|
td_classes[:ci] + td_classes[ci + num_classes_per_task: ], |
|
{cii: cii + target_class_idx_start for cii in range(ci, ci + num_classes_per_task)} |
|
)] |
|
num_tasks_per_dataset[td_info_i] += 1 |
|
|
|
if ci + num_classes_per_task < len(td_classes) - 1: |
|
task_datasets_seq += [_ABDatasetMetaInfo( |
|
f'{td_info.name}-task-{task_i + 1}|ci-{ci}-{ci + num_classes_per_task - 1}', |
|
td_classes[ci: len(td_classes)], |
|
td_info.task_type, |
|
td_info.object_type, |
|
td_info.class_aliases, |
|
td_info.shift_type, |
|
|
|
td_classes[:ci], |
|
{cii: cii + target_class_idx_start for cii in range(ci, len(td_classes))} |
|
)] |
|
num_tasks_per_dataset[td_info_i] += 1 |
|
|
|
target_class_idx_start += len(td_classes) |
|
|
|
if len(task_datasets_seq) < max_num_tasks: |
|
print(len(task_datasets_seq), max_num_tasks) |
|
raise RuntimeError() |
|
|
|
task_datasets_seq = task_datasets_seq[0: max_num_tasks] |
|
target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq]) |
|
|
|
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) |
|
|
|
if sanity_check: |
|
selected_tasks_index = [] |
|
for task_index, _ in enumerate(scenario.target_tasks_order): |
|
cur_datasets = scenario.get_cur_task_train_datasets() |
|
|
|
if len(cur_datasets) < 300: |
|
|
|
|
|
|
|
replaced_task_index = task_index // 2 |
|
assert replaced_task_index != task_index |
|
while replaced_task_index in selected_tasks_index: |
|
replaced_task_index += 1 |
|
|
|
task_datasets_seq[task_index] = deepcopy(task_datasets_seq[replaced_task_index]) |
|
selected_tasks_index += [replaced_task_index] |
|
|
|
logger.warning(f'replace {task_index}-th task with {replaced_task_index}-th task') |
|
|
|
|
|
|
|
scenario.next_task() |
|
|
|
|
|
|
|
if len(selected_tasks_index) > 0: |
|
target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq]) |
|
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) |
|
|
|
for task_index, _ in enumerate(scenario.target_tasks_order): |
|
cur_datasets = scenario.get_cur_task_train_datasets() |
|
logger.info(f'task {task_index}, len {len(cur_datasets)}') |
|
assert len(cur_datasets) > 0 |
|
|
|
scenario.next_task() |
|
|
|
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) |
|
|
|
return scenario |