File size: 2,421 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import argparse
import os.path as osp
import torch
def main():
"""An auxiliary script for converting a checkpoint file (`checkpoint.pt`) into a support sets (`support_sets.pt`)
and a reconstructor (`reconstructor.pt`) weights files.
Options:
================================================================================================================
--exp : set experiment's wip model dir, as created by `train.py`, i.e., it should contain a sub-directory
`models/` with a checkpoint file (`checkpoint.pt`). Checkpoint file contains the weights of the
support sets and the reconstructor at an intermediate stage of training (after a given iteration).
================================================================================================================
"""
parser = argparse.ArgumentParser(description="Convert a checkpoint file into a support sets and a reconstructor "
"weights files")
parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)")
# Parse given arguments
args = parser.parse_args()
# Check structure of `args.exp`
if not osp.isdir(args.exp):
raise NotADirectoryError("Invalid given directory: {}".format(args.exp))
models_dir = osp.join(args.exp, 'models')
if not osp.isdir(models_dir):
raise NotADirectoryError("Invalid models directory: {}".format(models_dir))
checkpoint_file = osp.join(models_dir, 'checkpoint.pt')
if not osp.isfile(checkpoint_file):
raise FileNotFoundError("Checkpoint file not found: {}".format(checkpoint_file))
print("#. Convert checkpoint file into support sets and reconstructor weight files...")
# Load checkpoint file
checkpoint_dict = torch.load(checkpoint_file)
print(" \\__Checkpoint dictionary: {}".format(checkpoint_dict.keys()))
# Get checkpoint iteration
checkpoint_iter = checkpoint_dict['iter']
print(" \\__Checkpoint iteration: {}".format(checkpoint_iter))
# Save latent support sets (LSS) weights file
print(" \\__Save checkpoint latent support sets LSS weights file...")
torch.save(checkpoint_dict['latent_support_sets'],
osp.join(models_dir, 'latent_support_sets-{:07d}.pt'.format(checkpoint_iter)))
if __name__ == '__main__':
main()
|