File size: 3,845 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from re import L
from typing import Dict, List
from collections import Counter

    
def grouping(bondlist):
    # reference: https://blog.csdn.net/YnagShanwen/article/details/111344386
    groups = [] 
    break1 = False    
    while bondlist:
        pair1 = bondlist.pop(0)        
        a = 11111
        b = 10000
        while b != a:
            a = b
            for atomid in pair1:
                for i,pair2 in enumerate(bondlist):            
                    if atomid in pair2:
                        pair1 = pair1 + pair2
                        bondlist.pop(i)
                        if not bondlist:
                            break1 = True
                        break
                if break1:
                    break
            b = len(pair1)
        groups.append(pair1)
    return groups


def build_semantic_class_info(classes: List[str], aliases: List[List[str]]):
    res = []
    for c in classes:
        # print(res)
        if len(aliases) == 0:
            res += [[c]]
        else:
            find_alias = False
            for alias in aliases:
                if c in alias:
                    res += [alias]
                    find_alias = True
                    break
            if not find_alias:
                res += [[c]]
    # print(classes, res)
    return res
    

def merge_the_same_meaning_classes(classes_info_of_all_datasets):
    # print(classes_info_of_all_datasets)

    semantic_classes_of_all_datasets = []
    all_aliases = []
    for classes, aliases in classes_info_of_all_datasets.values():
        all_aliases += aliases
    for classes, aliases in classes_info_of_all_datasets.values():
        semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases)
        
    # print(semantic_classes_of_all_datasets)
    
    grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据
   
    # print(grouped_classes_of_all_datasets)
    
    # final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets]
    # use most common class name; if the same common, use shortest class name!
    final_grouped_classes_of_all_datasets = []
    for c in grouped_classes_of_all_datasets:
        counter = Counter(c).most_common()
        max_times = counter[0][1]
        candidate_class_names = []
        for item, times in counter:
            if times < max_times:
                break
            candidate_class_names += [item]
        candidate_class_names.sort(key=lambda x: len(x))
        
        final_grouped_classes_of_all_datasets += [candidate_class_names[0]]
    res = {}
    res_map = {d: {} for d in classes_info_of_all_datasets.keys()}
    
    for dataset_name, (classes, _) in classes_info_of_all_datasets.items():
        final_classes = []
        for c in classes:
            for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets):
                if c in grouped_names:
                    final_classes += [final_name]
                    if final_name != c:
                        res_map[dataset_name][c] = final_name
                    break
        res[dataset_name] = sorted(set(final_classes), key=final_classes.index)
    return res, res_map


if __name__ == '__main__':
    cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    cifar10_aliases = [['automobile', 'car']]
    stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']

    final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({
        'CIFAR10': (cifar10_classes, cifar10_aliases),
        'STL10': (stl10_classes, [])
    })

    print(final_classes_of_all_datasets, rename_map)