File size: 2,421 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
import os.path as osp
import torch


def main():
    """An auxiliary script for converting a checkpoint file (`checkpoint.pt`) into a support sets (`support_sets.pt`)
    and a reconstructor (`reconstructor.pt`) weights files.

    Options:
        ================================================================================================================
        --exp : set experiment's wip model dir, as created by `train.py`, i.e., it should contain a sub-directory
                `models/` with a checkpoint file (`checkpoint.pt`). Checkpoint file contains the weights of the
                 support sets and the reconstructor at an intermediate stage of training (after a given iteration).
        ================================================================================================================
    """
    parser = argparse.ArgumentParser(description="Convert a checkpoint file into a support sets and a reconstructor "
                                                 "weights files")
    parser.add_argument('--exp', type=str, required=True, help="set experiment's model dir (created by `train.py`)")

    # Parse given arguments
    args = parser.parse_args()

    # Check structure of `args.exp`
    if not osp.isdir(args.exp):
        raise NotADirectoryError("Invalid given directory: {}".format(args.exp))
    models_dir = osp.join(args.exp, 'models')
    if not osp.isdir(models_dir):
        raise NotADirectoryError("Invalid models directory: {}".format(models_dir))
    checkpoint_file = osp.join(models_dir, 'checkpoint.pt')
    if not osp.isfile(checkpoint_file):
        raise FileNotFoundError("Checkpoint file not found: {}".format(checkpoint_file))

    print("#. Convert checkpoint file into support sets and reconstructor weight files...")

    # Load checkpoint file
    checkpoint_dict = torch.load(checkpoint_file)
    print("  \\__Checkpoint dictionary: {}".format(checkpoint_dict.keys()))

    # Get checkpoint iteration
    checkpoint_iter = checkpoint_dict['iter']
    print("  \\__Checkpoint iteration: {}".format(checkpoint_iter))

    # Save latent support sets (LSS) weights file
    print("  \\__Save checkpoint latent support sets LSS weights file...")
    torch.save(checkpoint_dict['latent_support_sets'],
               osp.join(models_dir, 'latent_support_sets-{:07d}.pt'.format(checkpoint_iter)))


if __name__ == '__main__':
    main()