EdgeTA / data /build_gen /scenario.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import enum
from functools import reduce
from typing import Dict, List, Tuple
import numpy as np
import copy
from utils.common.log import logger
from ..datasets.ab_dataset import ABDataset
from ..datasets.dataset_split import train_val_split
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
from data import get_dataset
class DatasetMetaInfo:
def __init__(self, name,
known_classes_name_idx_map, unknown_class_idx):
assert unknown_class_idx not in known_classes_name_idx_map.keys()
self.name = name
self.unknown_class_idx = unknown_class_idx
self.known_classes_name_idx_map = known_classes_name_idx_map
@property
def num_classes(self):
return len(self.known_classes_idx) + 1
class MergedDataset:
def __init__(self, datasets: List[ABDataset]):
self.datasets = datasets
self.datasets_len = [len(i) for i in self.datasets]
logger.info(f'create MergedDataset: len of datasets {self.datasets_len}')
self.datasets_cum_len = np.cumsum(self.datasets_len)
def __getitem__(self, idx):
for i, cum_len in enumerate(self.datasets_cum_len):
if idx < cum_len:
return self.datasets[i][idx - sum(self.datasets_len[0: i])]
def __len__(self):
return sum(self.datasets_len)
class IndexReturnedDataset:
def __init__(self, dataset: ABDataset):
self.dataset = dataset
def __getitem__(self, idx):
res = self.dataset[idx]
if isinstance(res, (tuple, list)):
return (*res, idx)
else:
return res, idx
def __len__(self):
return len(self.dataset)
# class Scenario:
# def __init__(self, config,
# source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo],
# target_source_map: Dict[str, Dict[str, str]],
# target_domains_order: List[str],
# source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]):
# self.__config = config
# self.__source_datasets_meta_info = source_datasets_meta_info
# self.__target_datasets_meta_info = target_datasets_meta_info
# self.__target_source_map = target_source_map
# self.__target_domains_order = target_domains_order
# self.__source_datasets = source_datasets
# self.__target_datasets = target_datasets
# # 1. basic
# def get_config(self):
# return copy.deepcopy(self.__config)
# def get_task_type(self):
# return list(self.__source_datasets.values())[0]['train'].task_type
# def get_num_classes(self):
# known_classes_idx = []
# unknown_classes_idx = []
# for v in self.__source_datasets_meta_info.values():
# known_classes_idx += list(v.known_classes_name_idx_map.values())
# unknown_classes_idx += [v.unknown_class_idx]
# for v in self.__target_datasets_meta_info.values():
# known_classes_idx += list(v.known_classes_name_idx_map.values())
# unknown_classes_idx += [v.unknown_class_idx]
# unknown_classes_idx = [i for i in unknown_classes_idx if i is not None]
# # print(known_classes_idx, unknown_classes_idx)
# res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx))
# # print(res)
# assert res[0] + res[1] == res[2]
# return res
# def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
# if infinite:
# dataloader = InfiniteDataLoader(
# dataset, None, batch_size, num_workers=num_workers)
# else:
# dataloader = FastDataLoader(
# dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
# return dataloader
# def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
# from ..data.datasets.dataset_split import _SplitDataset
# dataset.dataset = _SplitDataset(dataset.dataset, indexes)
# return dataset
# def build_index_returned_dataset(self, dataset: ABDataset):
# return IndexReturnedDataset(dataset)
# # 2. source
# def get_source_datasets_meta_info(self):
# return self.__source_datasets_meta_info
# def get_source_datasets_name(self):
# return list(self.__source_datasets.keys())
# def get_merged_source_dataset(self, split):
# source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
# return MergedDataset(list(source_train_datasets.values()))
# def get_source_datasets(self, split):
# source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
# return source_train_datasets
# # 3. target **domain**
# # (do we need such API `get_ith_target_domain()`?)
# def get_target_domains_meta_info(self):
# return self.__source_datasets_meta_info
# def get_target_domains_order(self):
# return self.__target_domains_order
# def get_corr_source_datasets_name_of_target_domain(self, target_domain_name):
# return self.__target_source_map[target_domain_name]
# def get_limited_target_train_dataset(self):
# if len(self.__target_domains_order) > 1:
# raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm')
# return list(self.__target_datasets.values())[0]['train']
# def get_target_domains_iterator(self, split):
# for target_domain_index, target_domain_name in enumerate(self.__target_domains_order):
# target_dataset = self.__target_datasets[target_domain_name]
# target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name]
# yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info
# # 4. permission management
# def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits):
# def get_split(dataset, splits):
# res = {}
# for s, d in dataset.items():
# if s in splits:
# res[s] = d
# return res
# return Scenario(
# config=self.__config,
# source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name},
# target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order},
# target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order},
# target_domains_order=target_domains_order,
# source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name},
# target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order}
# )
# def get_only_source_sub_scenario_for_exp_tracker(self):
# return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], [])
# def get_only_source_sub_scenario_for_alg(self):
# return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], [])
# def get_one_da_sub_scenario_for_alg(self, target_domain_name):
# return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name),
# ['train', 'val'], [target_domain_name], ['train'])
# class Scenario:
# def __init__(self, config,
# offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo],
# offline_source_datasets: Dict[str, ABDataset],
# online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]],
# online_datasets: Dict[str, ABDataset],
# target_domains_order: List[str],
# target_source_map: Dict[str, Dict[str, str]],
# num_classes: int):
# self.config = config
# self.offline_source_datasets_meta_info = offline_source_datasets_meta_info
# self.offline_source_datasets = offline_source_datasets
# self.online_datasets_meta_info = online_datasets_meta_info
# self.online_datasets = online_datasets
# self.target_domains_order = target_domains_order
# self.target_source_map = target_source_map
# self.num_classes = num_classes
# def get_offline_source_datasets(self, split):
# return {n: d[split] for n, d in self.offline_source_datasets.items()}
# def get_offline_source_merged_dataset(self, split):
# return MergedDataset([d[split] for d in self.offline_source_datasets.values()])
# def get_online_current_corresponding_source_datasets(self, domain_index, split):
# cur_target_domain_name = self.target_domains_order[domain_index]
# cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
# cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
# return cur_source_datasets
# def get_online_current_corresponding_merged_source_dataset(self, domain_index, split):
# cur_target_domain_name = self.target_domains_order[domain_index]
# cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
# cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
# return MergedDataset([d for d in cur_source_datasets.values()])
# def get_online_current_target_dataset(self, domain_index, split):
# cur_target_domain_name = self.target_domains_order[domain_index]
# return self.online_datasets[cur_target_domain_name][split]
# def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int,
# infinite: bool, shuffle_when_finite: bool, to_iterator: bool):
# if infinite:
# dataloader = InfiniteDataLoader(
# dataset, None, batch_size, num_workers=num_workers)
# else:
# dataloader = FastDataLoader(
# dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
# if to_iterator:
# dataloader = iter(dataloader)
# return dataloader
# def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
# from data.datasets.dataset_split import _SplitDataset
# dataset.dataset = _SplitDataset(dataset.dataset, indexes)
# return dataset
# def build_index_returned_dataset(self, dataset: ABDataset):
# return IndexReturnedDataset(dataset)
# def get_config(self):
# return copy.deepcopy(self.config)
# def get_task_type(self):
# return list(self.online_datasets.values())[0]['train'].task_type
# def get_num_classes(self):
# return self.num_classes
class Scenario:
def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map,
all_datasets_e2e_class_to_idx_map,
num_classes):
self.config = config
self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map
self.all_datasets_idx_map = all_datasets_idx_map
self.target_domains_order = target_domains_order
self.target_source_map = target_source_map
self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map
self.num_classes = num_classes
self.cur_domain_index = 0
logger.info(f'[scenario build] # classes: {num_classes}')
logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}')
def to_json(self):
return dict(
config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map,
all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order,
target_source_map=self.target_source_map,
all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map,
num_classes=self.num_classes
)
def __str__(self):
return f'Scenario({self.to_json()})'
def get_offline_datasets(self, transform=None):
# make source datasets which contains all unioned classes
res_offline_train_source_datasets_map = {}
from .. import get_dataset
data_dirs = self.config['data_dirs']
source_datasets_name = self.config['source_datasets_name']
# ori_datasets_map = {d: get_dataset(d, data_dirs[d], None, None, None, None) for d in source_datasets_name}
# res_source_datasets_map = {k: {split: train_val_split(copy.deepcopy(v), split, rate=0.97) for split in ['train', 'val']} for k, v in ori_datasets_map.items()}
# for ds in res_source_datasets_map.values():
# for k, v in ds.items():
# v.underlying_dataset.dataset.setSplit(k)
res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split,
transform,
self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
for split in ['train', 'val', 'test']}
for d in self.all_datasets_ignore_classes_map.keys() if d in source_datasets_name}
# for source_dataset_name in self.config['source_datasets_name']:
# source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
# # how to merge idx map?
# # 35 79 97
# idx_maps = [d['train'].idx_map for d in source_datasets]
# ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
# union_idx_map = {}
# for idx_map in idx_maps:
# for k, v in idx_map.items():
# if k not in union_idx_map:
# union_idx_map[k] = v
# else:
# assert union_idx_map[k] == v
# union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
# assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
# logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
# d = source_dataset_name
# res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split,
# transform,
# union_ignore_classes, union_idx_map)
# for split in ['train', 'val', 'test']}
return res_source_datasets_map
def get_offline_datasets_args(self):
# make source datasets which contains all unioned classes
res_offline_train_source_datasets_map = {}
from .. import get_dataset
data_dirs = self.config['data_dirs']
source_datasets_name = self.config['source_datasets_name']
res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
None,
self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
for split in ['train', 'val', 'test']}
for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
for source_dataset_name in self.config['source_datasets_name']:
source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
# how to merge idx map?
# 35 79 97
idx_maps = [d['train'].idx_map for d in source_datasets]
ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
union_idx_map = {}
for idx_map in idx_maps:
for k, v in idx_map.items():
if k not in union_idx_map:
union_idx_map[k] = v
else:
assert union_idx_map[k] == v
union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
d = source_dataset_name
res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split,
None,
union_ignore_classes, union_idx_map)
for split in ['train', 'val', 'test']}
return res_offline_train_source_datasets_map
# for d in source_datasets_name:
# source_dataset_with_max_num_classes = None
# for ed_name, ed in res_source_datasets_map.items():
# if not ed_name.startswith(d):
# continue
# if source_dataset_with_max_num_classes is None:
# source_dataset_with_max_num_classes = ed
# res_offline_train_source_datasets_map_names[d] = ed_name
# if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
# source_dataset_with_max_num_classes = ed
# res_offline_train_source_datasets_map_names[d] = ed_name
# res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
# return res_offline_train_source_datasets_map
def get_online_ith_domain_datasets_args_for_inference(self, domain_index):
target_dataset_name = self.target_domains_order[domain_index]
# dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order:
logger.info(f'use val split for inference test (only Det workload)')
split = 'test'
else:
split = 'train'
return dict(dataset_name=target_dataset_name,
root_dir=self.config['data_dirs'][target_dataset_name],
split=split,
transform=None,
ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
idx_map=self.all_datasets_idx_map[target_dataset_name])
def get_online_ith_domain_datasets_args_for_training(self, domain_index):
target_dataset_name = self.target_domains_order[domain_index]
source_datasets_name = list(self.target_source_map[target_dataset_name].keys())
res = {}
# dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name,
root_dir=self.config['data_dirs'][target_dataset_name],
split=split,
transform=None,
ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']}
for d in source_datasets_name:
res[d] = {split: dict(dataset_name=d,
root_dir=self.config['data_dirs'][d],
split=split,
transform=None,
ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name],
idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']}
return res
def get_online_cur_domain_datasets_args_for_inference(self):
return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
def get_online_cur_domain_datasets_args_for_training(self):
return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
def get_online_cur_domain_datasets_for_training(self, transform=None):
res = {}
datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
for dataset_name, dataset_args in datasets_args.items():
res[dataset_name] = {}
for split, args in dataset_args.items():
if transform is not None:
args['transform'] = transform
dataset = get_dataset(**args)
res[dataset_name][split] = dataset
return res
def get_online_cur_domain_datasets_for_inference(self, transform=None):
datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
if transform is not None:
datasets_args['transform'] = transform
return get_dataset(**datasets_args)
def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None):
dataset = self.get_online_cur_domain_datasets_for_training(transform=transform)
dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train']
return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0]
def next_domain(self):
self.cur_domain_index += 1