File size: 1,789 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
from collections import OrderedDict

import torch
from mmcv import Config


def parse_config(config_strings):
    temp_file = tempfile.NamedTemporaryFile()
    config_path = f'{temp_file.name}.py'
    with open(config_path, 'w') as f:
        f.write(config_strings)

    config = Config.fromfile(config_path)
    # check whether it is SSD
    if config.model.bbox_head.type != 'SSDHead':
        raise AssertionError('This is not a SSD model.')


def convert(in_file, out_file):
    checkpoint = torch.load(in_file)
    in_state_dict = checkpoint.pop('state_dict')
    out_state_dict = OrderedDict()
    meta_info = checkpoint['meta']
    parse_config('#' + meta_info['config'])
    for key, value in in_state_dict.items():
        if 'extra' in key:
            layer_idx = int(key.split('.')[2])
            new_key = 'neck.extra_layers.{}.{}.conv.'.format(
                layer_idx // 2, layer_idx % 2) + key.split('.')[-1]
        elif 'l2_norm' in key:
            new_key = 'neck.l2_norm.weight'
        elif 'bbox_head' in key:
            new_key = key[:21] + '.0' + key[21:]
        else:
            new_key = key
        out_state_dict[new_key] = value
    checkpoint['state_dict'] = out_state_dict

    if torch.__version__ >= '1.6':
        torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
    else:
        torch.save(checkpoint, out_file)


def main():
    parser = argparse.ArgumentParser(description='Upgrade SSD version')
    parser.add_argument('in_file', help='input checkpoint file')
    parser.add_argument('out_file', help='output checkpoint file')

    args = parser.parse_args()
    convert(args.in_file, args.out_file)


if __name__ == '__main__':
    main()