File size: 2,432 Bytes
86694c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse

from models.launch import train_model
from models.spectrogram_cnn import get_model as get_spectrogram


def standard_run_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Setup and train a model, storing the output"
    )
    parser.add_argument(
        "--model",
        dest="model_name",
        type=str,
        choices=["C1", "C2", "C3", "C4", "C5", "C6", "C6XL", "e2e"],
        default="e2e",
        help="Model architecture to run",
    )
    parser.add_argument(
        "--dataset_name",
        default="InverSynth",
        help='Name of the dataset to use - other filenames are generated from this. If you have a file "modelname_data.hdf5", put in "modelname"',
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="How many epochs to run"
    )
    parser.add_argument(
        "--dataset_dir",
        default="test_datasets",
        help="Directory full of datasets to use",
    )
    parser.add_argument(
        "--output_dir",
        default="output",
        help="Directory to store the final model and history",
    )
    parser.add_argument(
        "--dataset_file", default=None, help="Specify an exact dataset file to use"
    )
    parser.add_argument(
        "--parameters_file",
        default=None,
        help="Specify an exact parameters file to use",
    )
    parser.add_argument(
        "--data_format",
        type=str,
        choices=["channels_last", "channels_first"],
        default="channels_last",
        help="Image data format for Keras. If CPU only, has to be channels_last",
    )
    parser.add_argument(
        "--run_name",
        type=str,
        dest="run_name",
        help="Name to save the output under. Defaults to dataset_name + model",
    )
    parser.add_argument(
        "--resume",
        dest="resume",
        action="store_const",
        const=True,
        default=False,
        help="Look for a checkpoint file to resume from",
    )
    return parser


if __name__ == "__main__":

    print("Starting model runner")
    # Get a standard parser, and the arguments out of it
    parser = standard_run_parser()
    args = parser.parse_args()
    setup = vars(args)

    print("Parsed arguments")
    # Figure out the model callback
    model_callback = get_spectrogram

    # Actually train the model
    train_model(model_callback=model_callback, **setup)