File size: 5,750 Bytes
e26e560 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import argparse
import glob
import json
import os.path as osp
import shutil
import subprocess
import mmcv
import torch
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
sha = subprocess.check_output(['sha256sum', out_file]).decode()
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
subprocess.Popen(['mv', out_file, final_file])
return final_file
def get_final_epoch(config):
cfg = mmcv.Config.fromfile('./configs/' + config)
return cfg.total_epochs
def get_final_results(log_json_path, epoch, results_lut):
result_dict = dict()
with open(log_json_path, 'r') as f:
for line in f.readlines():
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue
if log_line['mode'] == 'train' and log_line['epoch'] == epoch:
result_dict['memory'] = log_line['memory']
if log_line['mode'] == 'val' and log_line['epoch'] == epoch:
result_dict.update({
key: log_line[key]
for key in results_lut if key in log_line
})
return result_dict
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
parser.add_argument(
'out', type=str, help='output path of gathered models to be stored')
args = parser.parse_args()
return args
def main():
args = parse_args()
models_root = args.root
models_out = args.out
mmcv.mkdir_or_exist(models_out)
# find all models in the root directory to be gathered
raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True))
# filter configs that is not trained in the experiments dir
used_configs = []
for raw_config in raw_configs:
if osp.exists(osp.join(models_root, raw_config)):
used_configs.append(raw_config)
print(f'Find {len(used_configs)} models to be gathered')
# find final_ckpt and log file for trained each config
# and parse the best performance
model_infos = []
for used_config in used_configs:
exp_dir = osp.join(models_root, used_config)
# check whether the exps is finished
final_epoch = get_final_epoch(used_config)
final_model = 'epoch_{}.pth'.format(final_epoch)
model_path = osp.join(exp_dir, final_model)
# skip if the model is still training
if not osp.exists(model_path):
continue
# get the latest logs
log_json_path = list(
sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1]
log_txt_path = list(sorted(glob.glob(osp.join(exp_dir, '*.log'))))[-1]
cfg = mmcv.Config.fromfile('./configs/' + used_config)
results_lut = cfg.evaluation.metric
if not isinstance(results_lut, list):
results_lut = [results_lut]
# case when using VOC, the evaluation key is only 'mAP'
results_lut = [key + '_mAP' for key in results_lut if 'mAP' not in key]
model_performance = get_final_results(log_json_path, final_epoch,
results_lut)
if model_performance is None:
continue
model_time = osp.split(log_txt_path)[-1].split('.')[0]
model_infos.append(
dict(
config=used_config,
results=model_performance,
epochs=final_epoch,
model_time=model_time,
log_json_path=osp.split(log_json_path)[-1]))
# publish model for each checkpoint
publish_model_infos = []
for model in model_infos:
model_publish_dir = osp.join(models_out, model['config'].rstrip('.py'))
mmcv.mkdir_or_exist(model_publish_dir)
model_name = osp.split(model['config'])[-1].split('.')[0]
model_name += '_' + model['model_time']
publish_model_path = osp.join(model_publish_dir, model_name)
trained_model_path = osp.join(models_root, model['config'],
'epoch_{}.pth'.format(model['epochs']))
# convert model
final_model_path = process_checkpoint(trained_model_path,
publish_model_path)
# copy log
shutil.copy(
osp.join(models_root, model['config'], model['log_json_path']),
osp.join(model_publish_dir, f'{model_name}.log.json'))
shutil.copy(
osp.join(models_root, model['config'],
model['log_json_path'].rstrip('.json')),
osp.join(model_publish_dir, f'{model_name}.log'))
# copy config to guarantee reproducibility
config_path = model['config']
config_path = osp.join(
'configs',
config_path) if 'configs' not in config_path else config_path
target_cconfig_path = osp.split(config_path)[-1]
shutil.copy(config_path,
osp.join(model_publish_dir, target_cconfig_path))
model['model_path'] = final_model_path
publish_model_infos.append(model)
models = dict(models=publish_model_infos)
print(f'Totally gathered {len(publish_model_infos)} models')
mmcv.dump(models, osp.join(models_out, 'model_info.json'))
if __name__ == '__main__':
main()
|