Spaces:
Build error
Build error
import argparse | |
from collections import OrderedDict | |
import torch | |
def convert_stem(model_key, model_weight, state_dict, converted_names): | |
new_key = model_key.replace('stem.conv', 'conv1') | |
new_key = new_key.replace('stem.bn', 'bn1') | |
state_dict[new_key] = model_weight | |
converted_names.add(model_key) | |
print(f'Convert {model_key} to {new_key}') | |
def convert_head(model_key, model_weight, state_dict, converted_names): | |
new_key = model_key.replace('head.fc', 'fc') | |
state_dict[new_key] = model_weight | |
converted_names.add(model_key) | |
print(f'Convert {model_key} to {new_key}') | |
def convert_reslayer(model_key, model_weight, state_dict, converted_names): | |
split_keys = model_key.split('.') | |
layer, block, module = split_keys[:3] | |
block_id = int(block[1:]) | |
layer_name = f'layer{int(layer[1:])}' | |
block_name = f'{block_id - 1}' | |
if block_id == 1 and module == 'bn': | |
new_key = f'{layer_name}.{block_name}.downsample.1.{split_keys[-1]}' | |
elif block_id == 1 and module == 'proj': | |
new_key = f'{layer_name}.{block_name}.downsample.0.{split_keys[-1]}' | |
elif module == 'f': | |
if split_keys[3] == 'a_bn': | |
module_name = 'bn1' | |
elif split_keys[3] == 'b_bn': | |
module_name = 'bn2' | |
elif split_keys[3] == 'c_bn': | |
module_name = 'bn3' | |
elif split_keys[3] == 'a': | |
module_name = 'conv1' | |
elif split_keys[3] == 'b': | |
module_name = 'conv2' | |
elif split_keys[3] == 'c': | |
module_name = 'conv3' | |
new_key = f'{layer_name}.{block_name}.{module_name}.{split_keys[-1]}' | |
else: | |
raise ValueError(f'Unsupported conversion of key {model_key}') | |
print(f'Convert {model_key} to {new_key}') | |
state_dict[new_key] = model_weight | |
converted_names.add(model_key) | |
def convert(src, dst): | |
"""Convert keys in pycls pretrained RegNet models to mmdet style.""" | |
# load caffe model | |
regnet_model = torch.load(src) | |
blobs = regnet_model['model_state'] | |
# convert to pytorch style | |
state_dict = OrderedDict() | |
converted_names = set() | |
for key, weight in blobs.items(): | |
if 'stem' in key: | |
convert_stem(key, weight, state_dict, converted_names) | |
elif 'head' in key: | |
convert_head(key, weight, state_dict, converted_names) | |
elif key.startswith('s'): | |
convert_reslayer(key, weight, state_dict, converted_names) | |
# check if all layers are converted | |
for key in blobs: | |
if key not in converted_names: | |
print(f'not converted: {key}') | |
# save checkpoint | |
checkpoint = dict() | |
checkpoint['state_dict'] = state_dict | |
torch.save(checkpoint, dst) | |
def main(): | |
parser = argparse.ArgumentParser(description='Convert model keys') | |
parser.add_argument('src', help='src detectron model path') | |
parser.add_argument('dst', help='save path') | |
args = parser.parse_args() | |
convert(args.src, args.dst) | |
if __name__ == '__main__': | |
main() | |