|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
import glob |
|
|
|
import yaml |
|
import torch |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='average model') |
|
parser.add_argument('--dst_model', required=True, help='averaged model') |
|
parser.add_argument('--src_path', |
|
required=True, |
|
help='src model path for average') |
|
parser.add_argument('--val_best', |
|
action="store_true", |
|
help='averaged model') |
|
parser.add_argument('--num', |
|
default=5, |
|
type=int, |
|
help='nums for averaged model') |
|
|
|
args = parser.parse_args() |
|
print(args) |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
val_scores = [] |
|
if args.val_best: |
|
yamls = glob.glob('{}/*.yaml'.format(args.src_path)) |
|
yamls = [ |
|
f for f in yamls |
|
if not (os.path.basename(f).startswith('train') |
|
or os.path.basename(f).startswith('init')) |
|
] |
|
for y in yamls: |
|
with open(y, 'r') as f: |
|
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) |
|
loss = float(dic_yaml['loss_dict']['loss']) |
|
epoch = int(dic_yaml['epoch']) |
|
step = int(dic_yaml['step']) |
|
tag = dic_yaml['tag'] |
|
val_scores += [[epoch, step, loss, tag]] |
|
sorted_val_scores = sorted(val_scores, |
|
key=lambda x: x[2], |
|
reverse=False) |
|
print("best val (epoch, step, loss, tag) = " + |
|
str(sorted_val_scores[:args.num])) |
|
path_list = [ |
|
args.src_path + '/epoch_{}_whole.pt'.format(score[0]) |
|
for score in sorted_val_scores[:args.num] |
|
] |
|
print(path_list) |
|
avg = {} |
|
num = args.num |
|
assert num == len(path_list) |
|
for path in path_list: |
|
print('Processing {}'.format(path)) |
|
states = torch.load(path, map_location=torch.device('cpu')) |
|
for k in states.keys(): |
|
if k not in avg.keys(): |
|
avg[k] = states[k].clone() |
|
else: |
|
avg[k] += states[k] |
|
|
|
for k in avg.keys(): |
|
if avg[k] is not None: |
|
|
|
avg[k] = torch.true_divide(avg[k], num) |
|
print('Saving to {}'.format(args.dst_model)) |
|
torch.save(avg, args.dst_model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|