Spaces:
Sleeping
Sleeping
import argparse | |
from models.launch import train_model | |
from models.spectrogram_cnn import get_model as get_spectrogram | |
def standard_run_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser( | |
description="Setup and train a model, storing the output" | |
) | |
parser.add_argument( | |
"--model", | |
dest="model_name", | |
type=str, | |
choices=["C1", "C2", "C3", "C4", "C5", "C6", "C6XL", "e2e"], | |
default="e2e", | |
help="Model architecture to run", | |
) | |
parser.add_argument( | |
"--dataset_name", | |
default="InverSynth", | |
help='Name of the dataset to use - other filenames are generated from this. If you have a file "modelname_data.hdf5", put in "modelname"', | |
) | |
parser.add_argument( | |
"--epochs", type=int, default=100, help="How many epochs to run" | |
) | |
parser.add_argument( | |
"--dataset_dir", | |
default="test_datasets", | |
help="Directory full of datasets to use", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default="output", | |
help="Directory to store the final model and history", | |
) | |
parser.add_argument( | |
"--dataset_file", default=None, help="Specify an exact dataset file to use" | |
) | |
parser.add_argument( | |
"--parameters_file", | |
default=None, | |
help="Specify an exact parameters file to use", | |
) | |
parser.add_argument( | |
"--data_format", | |
type=str, | |
choices=["channels_last", "channels_first"], | |
default="channels_last", | |
help="Image data format for Keras. If CPU only, has to be channels_last", | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
dest="run_name", | |
help="Name to save the output under. Defaults to dataset_name + model", | |
) | |
parser.add_argument( | |
"--resume", | |
dest="resume", | |
action="store_const", | |
const=True, | |
default=False, | |
help="Look for a checkpoint file to resume from", | |
) | |
return parser | |
if __name__ == "__main__": | |
print("Starting model runner") | |
# Get a standard parser, and the arguments out of it | |
parser = standard_run_parser() | |
args = parser.parse_args() | |
setup = vars(args) | |
print("Parsed arguments") | |
# Figure out the model callback | |
model_callback = get_spectrogram | |
# Actually train the model | |
train_model(model_callback=model_callback, **setup) | |