EdgeTA / data /build_cl /build.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
6.99 kB
from typing import Dict, List, Optional, Type, Union
from ..datasets.ab_dataset import ABDataset
# from benchmark.data.visualize import visualize_classes_in_object_detection
# from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
from ..dataset import get_dataset
import copy
from torchvision.transforms import Compose
from ..datasets.registery import static_dataset_registery
from ..build.scenario import Scenario as DAScenario
from copy import deepcopy
from utils.common.log import logger
import random
from .scenario import _ABDatasetMetaInfo, Scenario
def _check(source_datasets_meta_info: List[_ABDatasetMetaInfo], target_datasets_meta_info: List[_ABDatasetMetaInfo]):
# requirements for simplity
# 1. no same class in source datasets
source_datasets_class = [i.classes for i in source_datasets_meta_info]
for ci1, c1 in enumerate(source_datasets_class):
for ci2, c2 in enumerate(source_datasets_class):
if ci1 == ci2:
continue
c1_name = source_datasets_meta_info[ci1].name
c2_name = source_datasets_meta_info[ci2].name
intersection = set(c1).intersection(set(c2))
assert len(intersection) == 0, f'{c1_name} has intersection with {c2_name}: {intersection}'
def build_cl_scenario(
da_scenario: DAScenario,
target_datasets_name: List[str],
num_classes_per_task: int,
max_num_tasks: int,
data_dirs,
sanity_check=False
):
config = deepcopy(locals())
source_datasets_idx_map = {}
source_class_idx_max = 0
for sd in da_scenario.config['source_datasets_name']:
da_scenario_idx_map = None
for k, v in da_scenario.all_datasets_idx_map.items():
if k.startswith(sd):
da_scenario_idx_map = v
break
source_datasets_idx_map[sd] = da_scenario_idx_map
source_class_idx_max = max(source_class_idx_max, max(list(da_scenario_idx_map.values())))
target_class_idx_start = source_class_idx_max + 1
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:], None, None) for d in target_datasets_name]
task_datasets_seq = []
num_tasks_per_dataset = {}
for td_info_i, td_info in enumerate(target_datasets_meta_info):
if td_info_i >= 1:
for _td_info_i, _td_info in enumerate(target_datasets_meta_info[0: td_info_i]):
if _td_info.name == td_info.name:
# print(111)
# class_idx_offset = sum([len(t.classes) for t in target_datasets_meta_info[0: td_info_i]])
print(len(task_datasets_seq))
task_index_offset = sum([v if __i < _td_info_i else 0 for __i, v in enumerate(num_tasks_per_dataset.values())])
task_datasets_seq += task_datasets_seq[task_index_offset: task_index_offset + num_tasks_per_dataset[_td_info_i]]
print(len(task_datasets_seq))
break
continue
td_classes = td_info.classes
num_tasks_per_dataset[td_info_i] = 0
for ci in range(0, len(td_classes), num_classes_per_task):
task_i = ci // num_classes_per_task
task_datasets_seq += [_ABDatasetMetaInfo(
f'{td_info.name}|task-{task_i}|ci-{ci}-{ci + num_classes_per_task - 1}',
td_classes[ci: ci + num_classes_per_task],
td_info.task_type,
td_info.object_type,
td_info.class_aliases,
td_info.shift_type,
td_classes[:ci] + td_classes[ci + num_classes_per_task: ],
{cii: cii + target_class_idx_start for cii in range(ci, ci + num_classes_per_task)}
)]
num_tasks_per_dataset[td_info_i] += 1
if ci + num_classes_per_task < len(td_classes) - 1:
task_datasets_seq += [_ABDatasetMetaInfo(
f'{td_info.name}-task-{task_i + 1}|ci-{ci}-{ci + num_classes_per_task - 1}',
td_classes[ci: len(td_classes)],
td_info.task_type,
td_info.object_type,
td_info.class_aliases,
td_info.shift_type,
td_classes[:ci],
{cii: cii + target_class_idx_start for cii in range(ci, len(td_classes))}
)]
num_tasks_per_dataset[td_info_i] += 1
target_class_idx_start += len(td_classes)
if len(task_datasets_seq) < max_num_tasks:
print(len(task_datasets_seq), max_num_tasks)
raise RuntimeError()
task_datasets_seq = task_datasets_seq[0: max_num_tasks]
target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
if sanity_check:
selected_tasks_index = []
for task_index, _ in enumerate(scenario.target_tasks_order):
cur_datasets = scenario.get_cur_task_train_datasets()
if len(cur_datasets) < 300:
# empty_tasks_index += [task_index]
# while True:
# replaced_task_index = random.randint(0, task_index - 1) # ensure no random
replaced_task_index = task_index // 2
assert replaced_task_index != task_index
while replaced_task_index in selected_tasks_index:
replaced_task_index += 1
task_datasets_seq[task_index] = deepcopy(task_datasets_seq[replaced_task_index])
selected_tasks_index += [replaced_task_index]
logger.warning(f'replace {task_index}-th task with {replaced_task_index}-th task')
# print(task_index, [t.name for t in task_datasets_seq])
scenario.next_task()
# print([t.name for t in task_datasets_seq])
if len(selected_tasks_index) > 0:
target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
for task_index, _ in enumerate(scenario.target_tasks_order):
cur_datasets = scenario.get_cur_task_train_datasets()
logger.info(f'task {task_index}, len {len(cur_datasets)}')
assert len(cur_datasets) > 0
scenario.next_task()
scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
return scenario