Spaces:
Sleeping
Sleeping
File size: 1,400 Bytes
14c9181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from mmengine import Config
from mmengine.registry import init_default_scope
from mmocr.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[640, 640],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
h = w = args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape
else:
raise ValueError('invalid input shape, please use --shape h w')
input_shape = (1, 3, h, w)
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmocr'))
model = MODELS.build(cfg.model)
flops = FlopCountAnalysis(model, torch.ones(input_shape))
# params = parameter_count_table(model)
flops_data = flop_count_table(flops)
print(flops_data)
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
|