File size: 3,974 Bytes
14c9181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os.path as osp

from mmengine.fileio import load
from tabulate import tabulate


class BaseWeightList:
    """Class for generating model list in markdown format.

    Args:
        dataset_list (list[str]): List of dataset names.
        table_header (list[str]): List of table header.
        msg (str): Message to be displayed.
        task_abbr (str): Abbreviation of task name.
        metric_name (str): Metric name.
    """

    base_url: str = 'https://github.com/open-mmlab/mmocr/blob/1.x/'
    table_cfg: dict = dict(
        tablefmt='pipe', floatfmt='.2f', numalign='right', stralign='center')
    dataset_list: list
    table_header: list
    msg: str
    task_abbr: str
    metric_name: str

    def __init__(self):
        data = (d + f' ({self.metric_name})' for d in self.dataset_list)
        self.table_header = ['Model', 'README', *data]

    def _get_model_info(self, task_name: str):
        meta_indexes = load('../../model-index.yml')
        for meta_path in meta_indexes['Import']:
            meta_path = osp.join('../../', meta_path)
            metainfo = load(meta_path)
            collection2md = {}
            for item in metainfo['Collections']:
                url = self.base_url + item['README']
                collection2md[item['Name']] = f'[link]({url})'
            for item in metainfo['Models']:
                if task_name not in item['Config']:
                    continue
                name = f'`{item["Name"]}`'
                if item.get('Alias', None):
                    if isinstance(item['Alias'], str):
                        item['Alias'] = [item['Alias']]
                    aliases = [f'`{alias}`' for alias in item['Alias']]
                    aliases.append(name)
                    name = ' / '.join(aliases)
                readme = collection2md[item['In Collection']]
                eval_res = self._get_eval_res(item)
                yield (name, readme, *eval_res)

    def _get_eval_res(self, item):
        eval_res = {k: '-' for k in self.dataset_list}
        for res in item['Results']:
            if res['Dataset'] in self.dataset_list:
                eval_res[res['Dataset']] = res['Metrics'][self.metric_name]
        return (eval_res[k] for k in self.dataset_list)

    def gen_model_list(self):
        content = f'\n{self.msg}\n'
        content += '```{table}\n:class: model-summary nowrap field-list '
        content += 'table table-hover\n'
        content += tabulate(
            self._get_model_info(self.task_abbr), self.table_header,
            **self.table_cfg)
        content += '\n```\n'
        return content


class TextDetWeightList(BaseWeightList):

    dataset_list = ['ICDAR2015', 'CTW1500', 'Totaltext']
    msg = '### Text Detection'
    task_abbr = 'textdet'
    metric_name = 'hmean-iou'


class TextRecWeightList(BaseWeightList):

    dataset_list = [
        'Avg', 'IIIT5K', 'SVT', 'ICDAR2013', 'ICDAR2015', 'SVTP', 'CT80'
    ]
    msg = ('### Text Recognition\n'
           '```{note}\n'
           'Avg is the average on IIIT5K, SVT, ICDAR2013, ICDAR2015, SVTP,'
           ' CT80.\n```\n')
    task_abbr = 'textrecog'
    metric_name = 'word_acc'

    def _get_eval_res(self, item):
        eval_res = {k: '-' for k in self.dataset_list}
        avg = []
        for res in item['Results']:
            if res['Dataset'] in self.dataset_list:
                eval_res[res['Dataset']] = res['Metrics'][self.metric_name]
                avg.append(res['Metrics'][self.metric_name])
        eval_res['Avg'] = sum(avg) / len(avg)
        return (eval_res[k] for k in self.dataset_list)


class KIEWeightList(BaseWeightList):

    dataset_list = ['wildreceipt']
    msg = '### Key Information Extraction'
    task_abbr = 'kie'
    metric_name = 'macro_f1'


def gen_weight_list():
    content = TextDetWeightList().gen_model_list()
    content += TextRecWeightList().gen_model_list()
    content += KIEWeightList().gen_model_list()
    return content