from typing import Dict, List, Optional, Type, Union from ..datasets.ab_dataset import ABDataset # from benchmark.data.visualize import visualize_classes_in_object_detection # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform from ..dataset import get_dataset import copy from torchvision.transforms import Compose from .merge_alias import merge_the_same_meaning_classes from ..datasets.registery import static_dataset_registery # some legacy aliases of variables: # ignore_classes == discarded classes # private_classes == unknown classes in partial / open-set / universal DA def _merge_the_same_meaning_classes(classes_info_of_all_datasets): final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets) return final_classes_of_all_datasets, rename_map def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] from functools import reduce a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) if set(a_classes) == set(b_classes): # a is equal to b, normal # 1. no ignore classes; 2. match class idx a_ignore_classes, b_ignore_classes = [], [] elif set(a_classes) > set(b_classes): # a contains b, partial a_ignore_classes, b_ignore_classes = [], [] if thres == 3 or thres == 1: # ignore extra classes in a a_ignore_classes = set(a_classes) - set(b_classes) elif set(a_classes) < set(b_classes): # a is contained by b, open set a_ignore_classes, b_ignore_classes = [], [] if thres == 3 or thres == 2: # ignore extra classes in b b_ignore_classes = set(b_classes) - set(a_classes) elif len(set(a_classes) & set(b_classes)) > 0: a_ignore_classes, b_ignore_classes = [], [] if thres == 3: a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) elif thres == 2: b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) elif thres == 1: a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) else: return None # a has no intersection with b, none as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes] return as_ignore_classes, list(b_ignore_classes) def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] from functools import reduce a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) if set(a_classes) == set(b_classes): # a is equal to b, normal # 1. no ignore classes; 2. match class idx a_private_classes, b_private_classes = [], [] elif set(a_classes) > set(b_classes): # a contains b, partial a_private_classes, b_private_classes = [], [] # if thres == 2 or thres == 0: # ignore extra classes in a # a_private_classes = set(a_classes) - set(b_classes) # if thres == 0: # ignore extra classes in a # a_private_classes = set(a_classes) - set(b_classes) elif set(a_classes) < set(b_classes): # a is contained by b, open set a_private_classes, b_private_classes = [], [] if thres == 1 or thres == 0: # ignore extra classes in b b_private_classes = set(b_classes) - set(a_classes) elif len(set(a_classes) & set(b_classes)) > 0: a_private_classes, b_private_classes = [], [] if thres == 0: # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes)) b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) elif thres == 1: b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) elif thres == 2: # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes)) pass else: return None # a has no intersection with b, none return list(b_private_classes) class _ABDatasetMetaInfo: def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type): self.name = name self.classes = classes self.class_aliases = class_aliases self.shift_type = shift_type self.task_type = task_type self.object_type = object_type def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo): if b.shift_type is None: return 'Dataset Shifts' if a.name in b.shift_type.keys(): return b.shift_type[a.name] mid_dataset_name = list(b.shift_type.keys())[0] mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:]) return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0] def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode): # 1. merge the same meaning classes classes_info_of_all_datasets = { d.name: (d.classes, d.class_aliases) for d in source_datasets + target_datasets } final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets) all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets) # print(all_datasets_known_classes) # 2. find ignored classes according to DA mode # source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \ # {d.name: [] for d in target_datasets} # source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \ # {d.name: [] for d in target_datasets} target_source_relationship_map = {td.name: {} for td in target_datasets} # source_target_relationship_map = {sd.name: [] for sd in source_datasets} # 1. construct target_source_relationship_map for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例 for td in target_datasets: sc = all_datasets_classes[sd.name] tc = all_datasets_classes[td.name] if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域 continue target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td) # print(target_source_relationship_map) # exit() source_datasets_ignore_classes = {} for td_name, v1 in target_source_relationship_map.items(): for sd_name, v2 in v1.items(): source_datasets_ignore_classes[sd_name + '|' + td_name] = [] target_datasets_ignore_classes = {d.name: [] for d in target_datasets} target_datasets_private_classes = {d.name: [] for d in target_datasets} # 保证对于每个目标域上的DA都符合给定的label shift # 所以不同目标域就算对应同一个源域,该源域也可能不相同 for td_name, v1 in target_source_relationship_map.items(): sd_names = list(v1.keys()) sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] td_classes = all_datasets_classes[td_name] ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode) for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes): source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes target_datasets_ignore_classes[td_name] = t_ignore_classes target_datasets_private_classes[td_name] = t_private_classes source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()} target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()} target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()} # for k, v in source_datasets_ignore_classes.items(): # print(k, len(v)) # print() # for k, v in target_datasets_ignore_classes.items(): # print(k, len(v)) # print() # for k, v in target_datasets_private_classes.items(): # print(k, len(v)) # print() # print(source_datasets_private_classes, target_datasets_private_classes) # 3. reparse classes idx # 3.1. agg all used classes # all_used_classes = [] # all_datasets_private_class_idx_map = {} # source_datasets_classes_idx_map = {} # for td_name, v1 in target_source_relationship_map.items(): # for sd_name, v2 in v1.items(): # source_datasets_classes_idx_map[sd_name + '|' + td_name] = [] # target_datasets_classes_idx_map = {} global_idx = 0 all_used_classes_idx_map = {} # all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()} for dataset_name, classes in all_datasets_classes.items(): if dataset_name not in target_datasets_ignore_classes.keys(): ignore_classes = [0] * 100000 for sn, sic in source_datasets_ignore_classes.items(): if sn.startswith(dataset_name): if len(sic) < len(ignore_classes): ignore_classes = sic else: ignore_classes = target_datasets_ignore_classes[dataset_name] private_classes = [] \ if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name] for c in classes: if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes: all_used_classes_idx_map[c] = global_idx global_idx += 1 # print(all_used_classes_idx_map) # dataset_private_class_idx_offset = 0 target_private_class_idx = global_idx target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()} for dataset_name, classes in final_classes_of_all_datasets.items(): if dataset_name not in target_datasets_private_classes.keys(): continue # ignore_classes = target_datasets_ignore_classes[dataset_name] private_classes = target_datasets_private_classes[dataset_name] # private_classes = [] \ # if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name] # for c in classes: # if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes: # all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset if len(private_classes) > 0: # all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset # dataset_private_class_idx_offset += 1 # if dataset_name in source_datasets_private_classes.keys(): # if source_private_class_idx is None: # source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1 # all_datasets_private_class_idx[dataset_name] = source_private_class_idx # else: # if target_private_class_idx is None: # target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1 # all_datasets_private_class_idx[dataset_name] = target_private_class_idx target_datasets_private_class_idx[dataset_name] = target_private_class_idx target_private_class_idx += 1 # all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index) # all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)} # print('rename_map', rename_map) # 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map all_datasets_e2e_idx_map = {} all_datasets_e2e_class_to_idx_map = {} for td_name, v1 in target_source_relationship_map.items(): sd_names = list(v1.keys()) sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] td_classes = all_datasets_classes[td_name] for sd_name, sd_classes in zip(sd_names, sds_classes): cur_e2e_idx_map = {} cur_e2e_class_to_idx_map = {} for raw_ci, raw_c in enumerate(sd_classes): renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name] if renamed_c in ignore_classes: continue idx = all_used_classes_idx_map[renamed_c] cur_e2e_idx_map[raw_ci] = idx cur_e2e_class_to_idx_map[raw_c] = idx all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map cur_e2e_idx_map = {} cur_e2e_class_to_idx_map = {} for raw_ci, raw_c in enumerate(td_classes): renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] ignore_classes = target_datasets_ignore_classes[td_name] if renamed_c in ignore_classes: continue if renamed_c in target_datasets_private_classes[td_name]: idx = target_datasets_private_class_idx[td_name] else: idx = all_used_classes_idx_map[renamed_c] cur_e2e_idx_map[raw_ci] = idx cur_e2e_class_to_idx_map[raw_c] = idx all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes} # all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes} classes_idx_set = [] for d, m in all_datasets_e2e_class_to_idx_map.items(): classes_idx_set += list(m.values()) classes_idx_set = set(classes_idx_set) num_classes = len(classes_idx_set) return all_datasets_ignore_classes, target_datasets_private_classes, \ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ target_source_relationship_map, rename_map, num_classes def _build_scenario_info_v2( source_datasets_name: List[str], target_datasets_order: List[str], da_mode: str ): assert da_mode in ['close_set', 'partial', 'open_set', 'universal'] da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode] source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行 target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] all_datasets_ignore_classes, target_datasets_private_classes, \ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ target_source_relationship_map, rename_map, num_classes \ = _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode) return all_datasets_ignore_classes, target_datasets_private_classes, \ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ target_source_relationship_map, rename_map, num_classes def build_scenario_manually_v2( source_datasets_name: List[str], target_datasets_order: List[str], da_mode: str, data_dirs: Dict[str, str], # transforms: Optional[Dict[str, Compose]] = None ): configs = copy.deepcopy(locals())#返回当前局部变量 source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name] target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] all_datasets_ignore_classes, target_datasets_private_classes, \ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ target_source_relationship_map, rename_map, num_classes \ = _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode) # from rich.console import Console # console = Console(width=10000) # def print_obj(_o): # # import pprint # # s = pprint.pformat(_o, width=140, compact=True) # console.print(_o) # console.print('configs:', style='bold red') # print_obj(configs) # console.print('renamed classes:', style='bold red') # print_obj(rename_map) # console.print('discarded classes:', style='bold red') # print_obj(all_datasets_ignore_classes) # console.print('unknown classes:', style='bold red') # print_obj(target_datasets_private_classes) # console.print('class to index map:', style='bold red') # print_obj(all_datasets_e2e_class_to_idx_map) # console.print('index map:', style='bold red') # print_obj(all_datasets_e2e_idx_map) # console = Console() # # console.print('class distribution:', style='bold red') # # class_dist = { # # k: { # # '#known classes': len(all_datasets_known_classes[k]), # # '#unknown classes': len(all_datasets_private_classes[k]), # # '#discarded classes': len(all_datasets_ignore_classes[k]) # # } for k in all_datasets_ignore_classes.keys() # # } # # print_obj(class_dist) # console.print('corresponding sources of each target:', style='bold red') # print_obj(target_source_relationship_map) # return # res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None), # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) # for split in ['train', 'val', 'test']} # for d in source_datasets_name} # res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None), # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]), # num_samples_in_each_target_domain), # 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None), # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) # } # for d in list(set(target_datasets_order))} # res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split, # getattr(transforms, d.split('|')[0], None), # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) # for split in ['train', 'val', 'test']} # for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name} # from functools import reduce # res_offline_train_source_datasets_map = {} # res_offline_train_source_datasets_map_names = {} # for d in source_datasets_name: # source_dataset_with_max_num_classes = None # for ed_name, ed in res_source_datasets_map.items(): # if not ed_name.startswith(d): # continue # if source_dataset_with_max_num_classes is None: # source_dataset_with_max_num_classes = ed # res_offline_train_source_datasets_map_names[d] = ed_name # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes): # source_dataset_with_max_num_classes = ed # res_offline_train_source_datasets_map_names[d] = ed_name # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes # res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None), # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) # for split in ['train', 'val', 'test']} # for d in list(set(target_datasets_order))} from .scenario import Scenario, DatasetMetaInfo # test_scenario = Scenario( # config=configs, # offline_source_datasets_meta_info={ # d: DatasetMetaInfo(d, # {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()}, # None) # for d in source_datasets_name # }, # offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name}, # online_datasets_meta_info=[ # ( # {sd + '|' + d: DatasetMetaInfo(d, # {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()}, # None) # for sd in target_source_relationship_map[d].keys()}, # DatasetMetaInfo(d, # {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]}, # target_datasets_private_class_idx[d]) # ) # for d in target_datasets_order # ], # online_datasets={**res_source_datasets_map, **res_target_datasets_map}, # target_domains_order=target_datasets_order, # target_source_map=target_source_relationship_map, # num_classes=num_classes # ) import os os.environ['_ZQL_NUMC'] = str(num_classes) test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes, all_datasets_idx_map=all_datasets_e2e_idx_map, target_domains_order=target_datasets_order, target_source_map=target_source_relationship_map, all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map, num_classes=num_classes) return test_scenario if __name__ == '__main__': test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'], ['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'], 'close_set') print(test_scenario.num_classes)