EdgeTA / data /build_gen /build.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
24.3 kB
from typing import Dict, List, Optional, Type, Union
from ..datasets.ab_dataset import ABDataset
# from benchmark.data.visualize import visualize_classes_in_object_detection
# from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
from ..dataset import get_dataset
import copy
from torchvision.transforms import Compose
from .merge_alias import merge_the_same_meaning_classes
from ..datasets.registery import static_dataset_registery
# some legacy aliases of variables:
# ignore_classes == discarded classes
# private_classes == unknown classes in partial / open-set / universal DA
def _merge_the_same_meaning_classes(classes_info_of_all_datasets):
final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets)
return final_classes_of_all_datasets, rename_map
def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
from functools import reduce
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
if set(a_classes) == set(b_classes):
# a is equal to b, normal
# 1. no ignore classes; 2. match class idx
a_ignore_classes, b_ignore_classes = [], []
elif set(a_classes) > set(b_classes):
# a contains b, partial
a_ignore_classes, b_ignore_classes = [], []
if thres == 3 or thres == 1: # ignore extra classes in a
a_ignore_classes = set(a_classes) - set(b_classes)
elif set(a_classes) < set(b_classes):
# a is contained by b, open set
a_ignore_classes, b_ignore_classes = [], []
if thres == 3 or thres == 2: # ignore extra classes in b
b_ignore_classes = set(b_classes) - set(a_classes)
elif len(set(a_classes) & set(b_classes)) > 0:
a_ignore_classes, b_ignore_classes = [], []
if thres == 3:
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
elif thres == 2:
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
elif thres == 1:
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
else:
return None # a has no intersection with b, none
as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes]
return as_ignore_classes, list(b_ignore_classes)
def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
from functools import reduce
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
if set(a_classes) == set(b_classes):
# a is equal to b, normal
# 1. no ignore classes; 2. match class idx
a_private_classes, b_private_classes = [], []
elif set(a_classes) > set(b_classes):
# a contains b, partial
a_private_classes, b_private_classes = [], []
# if thres == 2 or thres == 0: # ignore extra classes in a
# a_private_classes = set(a_classes) - set(b_classes)
# if thres == 0: # ignore extra classes in a
# a_private_classes = set(a_classes) - set(b_classes)
elif set(a_classes) < set(b_classes):
# a is contained by b, open set
a_private_classes, b_private_classes = [], []
if thres == 1 or thres == 0: # ignore extra classes in b
b_private_classes = set(b_classes) - set(a_classes)
elif len(set(a_classes) & set(b_classes)) > 0:
a_private_classes, b_private_classes = [], []
if thres == 0:
# a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
elif thres == 1:
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
elif thres == 2:
# a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
pass
else:
return None # a has no intersection with b, none
return list(b_private_classes)
class _ABDatasetMetaInfo:
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type):
self.name = name
self.classes = classes
self.class_aliases = class_aliases
self.shift_type = shift_type
self.task_type = task_type
self.object_type = object_type
def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo):
if b.shift_type is None:
return 'Dataset Shifts'
if a.name in b.shift_type.keys():
return b.shift_type[a.name]
mid_dataset_name = list(b.shift_type.keys())[0]
mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:])
return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0]
def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode):
# 1. merge the same meaning classes
classes_info_of_all_datasets = {
d.name: (d.classes, d.class_aliases)
for d in source_datasets + target_datasets
}
final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets)
all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets)
# print(all_datasets_known_classes)
# 2. find ignored classes according to DA mode
# source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \
# {d.name: [] for d in target_datasets}
# source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \
# {d.name: [] for d in target_datasets}
target_source_relationship_map = {td.name: {} for td in target_datasets}
# source_target_relationship_map = {sd.name: [] for sd in source_datasets}
# 1. construct target_source_relationship_map
for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例
for td in target_datasets:
sc = all_datasets_classes[sd.name]
tc = all_datasets_classes[td.name]
if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域
continue
target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td)
# print(target_source_relationship_map)
# exit()
source_datasets_ignore_classes = {}
for td_name, v1 in target_source_relationship_map.items():
for sd_name, v2 in v1.items():
source_datasets_ignore_classes[sd_name + '|' + td_name] = []
target_datasets_ignore_classes = {d.name: [] for d in target_datasets}
target_datasets_private_classes = {d.name: [] for d in target_datasets}
# 保证对于每个目标域上的DA都符合给定的label shift
# 所以不同目标域就算对应同一个源域,该源域也可能不相同
for td_name, v1 in target_source_relationship_map.items():
sd_names = list(v1.keys())
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
td_classes = all_datasets_classes[td_name]
ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes
t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)
for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes):
source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes
target_datasets_ignore_classes[td_name] = t_ignore_classes
target_datasets_private_classes[td_name] = t_private_classes
source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()}
target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()}
target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()}
# for k, v in source_datasets_ignore_classes.items():
# print(k, len(v))
# print()
# for k, v in target_datasets_ignore_classes.items():
# print(k, len(v))
# print()
# for k, v in target_datasets_private_classes.items():
# print(k, len(v))
# print()
# print(source_datasets_private_classes, target_datasets_private_classes)
# 3. reparse classes idx
# 3.1. agg all used classes
# all_used_classes = []
# all_datasets_private_class_idx_map = {}
# source_datasets_classes_idx_map = {}
# for td_name, v1 in target_source_relationship_map.items():
# for sd_name, v2 in v1.items():
# source_datasets_classes_idx_map[sd_name + '|' + td_name] = []
# target_datasets_classes_idx_map = {}
global_idx = 0
all_used_classes_idx_map = {}
# all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()}
for dataset_name, classes in all_datasets_classes.items():
if dataset_name not in target_datasets_ignore_classes.keys():
ignore_classes = [0] * 100000
for sn, sic in source_datasets_ignore_classes.items():
if sn.startswith(dataset_name):
if len(sic) < len(ignore_classes):
ignore_classes = sic
else:
ignore_classes = target_datasets_ignore_classes[dataset_name]
private_classes = [] \
if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name]
for c in classes:
if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes:
all_used_classes_idx_map[c] = global_idx
global_idx += 1
# print(all_used_classes_idx_map)
# dataset_private_class_idx_offset = 0
target_private_class_idx = global_idx
target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()}
for dataset_name, classes in final_classes_of_all_datasets.items():
if dataset_name not in target_datasets_private_classes.keys():
continue
# ignore_classes = target_datasets_ignore_classes[dataset_name]
private_classes = target_datasets_private_classes[dataset_name]
# private_classes = [] \
# if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name]
# for c in classes:
# if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes:
# all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset
if len(private_classes) > 0:
# all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset
# dataset_private_class_idx_offset += 1
# if dataset_name in source_datasets_private_classes.keys():
# if source_private_class_idx is None:
# source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1
# all_datasets_private_class_idx[dataset_name] = source_private_class_idx
# else:
# if target_private_class_idx is None:
# target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1
# all_datasets_private_class_idx[dataset_name] = target_private_class_idx
target_datasets_private_class_idx[dataset_name] = target_private_class_idx
target_private_class_idx += 1
# all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index)
# all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)}
# print('rename_map', rename_map)
# 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map
all_datasets_e2e_idx_map = {}
all_datasets_e2e_class_to_idx_map = {}
for td_name, v1 in target_source_relationship_map.items():
sd_names = list(v1.keys())
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
td_classes = all_datasets_classes[td_name]
for sd_name, sd_classes in zip(sd_names, sds_classes):
cur_e2e_idx_map = {}
cur_e2e_class_to_idx_map = {}
for raw_ci, raw_c in enumerate(sd_classes):
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name]
if renamed_c in ignore_classes:
continue
idx = all_used_classes_idx_map[renamed_c]
cur_e2e_idx_map[raw_ci] = idx
cur_e2e_class_to_idx_map[raw_c] = idx
all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map
all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map
cur_e2e_idx_map = {}
cur_e2e_class_to_idx_map = {}
for raw_ci, raw_c in enumerate(td_classes):
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
ignore_classes = target_datasets_ignore_classes[td_name]
if renamed_c in ignore_classes:
continue
if renamed_c in target_datasets_private_classes[td_name]:
idx = target_datasets_private_class_idx[td_name]
else:
idx = all_used_classes_idx_map[renamed_c]
cur_e2e_idx_map[raw_ci] = idx
cur_e2e_class_to_idx_map[raw_c] = idx
all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map
all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map
all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes}
# all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes}
classes_idx_set = []
for d, m in all_datasets_e2e_class_to_idx_map.items():
classes_idx_set += list(m.values())
classes_idx_set = set(classes_idx_set)
num_classes = len(classes_idx_set)
return all_datasets_ignore_classes, target_datasets_private_classes, \
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
target_source_relationship_map, rename_map, num_classes
def _build_scenario_info_v2(
source_datasets_name: List[str],
target_datasets_order: List[str],
da_mode: str
):
assert da_mode in ['close_set', 'partial', 'open_set', 'universal']
da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode]
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
all_datasets_ignore_classes, target_datasets_private_classes, \
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
target_source_relationship_map, rename_map, num_classes \
= _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode)
return all_datasets_ignore_classes, target_datasets_private_classes, \
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
target_source_relationship_map, rename_map, num_classes
def build_scenario_manually_v2(
source_datasets_name: List[str],
target_datasets_order: List[str],
da_mode: str,
data_dirs: Dict[str, str],
# transforms: Optional[Dict[str, Compose]] = None
):
configs = copy.deepcopy(locals())#返回当前局部变量
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
all_datasets_ignore_classes, target_datasets_private_classes, \
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
target_source_relationship_map, rename_map, num_classes \
= _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode)
# from rich.console import Console
# console = Console(width=10000)
# def print_obj(_o):
# # import pprint
# # s = pprint.pformat(_o, width=140, compact=True)
# console.print(_o)
# console.print('configs:', style='bold red')
# print_obj(configs)
# console.print('renamed classes:', style='bold red')
# print_obj(rename_map)
# console.print('discarded classes:', style='bold red')
# print_obj(all_datasets_ignore_classes)
# console.print('unknown classes:', style='bold red')
# print_obj(target_datasets_private_classes)
# console.print('class to index map:', style='bold red')
# print_obj(all_datasets_e2e_class_to_idx_map)
# console.print('index map:', style='bold red')
# print_obj(all_datasets_e2e_idx_map)
# console = Console()
# # console.print('class distribution:', style='bold red')
# # class_dist = {
# # k: {
# # '#known classes': len(all_datasets_known_classes[k]),
# # '#unknown classes': len(all_datasets_private_classes[k]),
# # '#discarded classes': len(all_datasets_ignore_classes[k])
# # } for k in all_datasets_ignore_classes.keys()
# # }
# # print_obj(class_dist)
# console.print('corresponding sources of each target:', style='bold red')
# print_obj(target_source_relationship_map)
# return
# res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
# for split in ['train', 'val', 'test']}
# for d in source_datasets_name}
# res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]),
# num_samples_in_each_target_domain),
# 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
# }
# for d in list(set(target_datasets_order))}
# res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
# getattr(transforms, d.split('|')[0], None),
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
# for split in ['train', 'val', 'test']}
# for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name}
# from functools import reduce
# res_offline_train_source_datasets_map = {}
# res_offline_train_source_datasets_map_names = {}
# for d in source_datasets_name:
# source_dataset_with_max_num_classes = None
# for ed_name, ed in res_source_datasets_map.items():
# if not ed_name.startswith(d):
# continue
# if source_dataset_with_max_num_classes is None:
# source_dataset_with_max_num_classes = ed
# res_offline_train_source_datasets_map_names[d] = ed_name
# if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
# source_dataset_with_max_num_classes = ed
# res_offline_train_source_datasets_map_names[d] = ed_name
# res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
# res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
# for split in ['train', 'val', 'test']}
# for d in list(set(target_datasets_order))}
from .scenario import Scenario, DatasetMetaInfo
# test_scenario = Scenario(
# config=configs,
# offline_source_datasets_meta_info={
# d: DatasetMetaInfo(d,
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()},
# None)
# for d in source_datasets_name
# },
# offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name},
# online_datasets_meta_info=[
# (
# {sd + '|' + d: DatasetMetaInfo(d,
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()},
# None)
# for sd in target_source_relationship_map[d].keys()},
# DatasetMetaInfo(d,
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]},
# target_datasets_private_class_idx[d])
# )
# for d in target_datasets_order
# ],
# online_datasets={**res_source_datasets_map, **res_target_datasets_map},
# target_domains_order=target_datasets_order,
# target_source_map=target_source_relationship_map,
# num_classes=num_classes
# )
import os
os.environ['_ZQL_NUMC'] = str(num_classes)
test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes,
all_datasets_idx_map=all_datasets_e2e_idx_map,
target_domains_order=target_datasets_order,
target_source_map=target_source_relationship_map,
all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map,
num_classes=num_classes)
return test_scenario
if __name__ == '__main__':
test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'],
['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'],
'close_set')
print(test_scenario.num_classes)