|
""" |
|
Minimal command: |
|
python training_loop.py --hub_dir "segments/sidewalk-semantic" |
|
|
|
Maximal command: |
|
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train |
|
""" |
|
|
|
import json |
|
import torch |
|
|
|
from pytorch_lightning import Trainer, callbacks, seed_everything |
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
from dataloader import SidewalkSegmentationDataLoader |
|
from model import SidewalkSegmentationModel |
|
|
|
|
|
def main( |
|
hub_dir: str, |
|
batch_size: int = 32, |
|
learning_rate: float = 6e-5, |
|
model_flavor: int = 0, |
|
seed: int = 42, |
|
split: str = "train", |
|
): |
|
seed_everything(seed) |
|
logger = WandbLogger(project="sidewalk-segmentation") |
|
gpu_value = 1 if torch.cuda.is_available() else 0 |
|
|
|
id2label_file = json.load(open("id2label.json", "r")) |
|
id2label = {int(k): v for k, v in id2label_file.items()} |
|
num_labels = len(id2label) |
|
|
|
model = SidewalkSegmentationModel( |
|
num_labels=num_labels, |
|
id2label=id2label, |
|
model_flavor=model_flavor, |
|
learning_rate=learning_rate, |
|
) |
|
data_module = SidewalkSegmentationDataLoader( |
|
hub_dir=hub_dir, |
|
batch_size=batch_size, |
|
split=split, |
|
) |
|
data_module.setup() |
|
|
|
checkpoint_callback = callbacks.ModelCheckpoint( |
|
dirpath="checkpoints", |
|
save_top_k=1, |
|
verbose=True, |
|
monitor="val_mean_iou", |
|
mode="max", |
|
) |
|
early_stopping_callback = callbacks.EarlyStopping( |
|
monitor="val_mean_iou", |
|
patience=5, |
|
verbose=True, |
|
mode="max", |
|
) |
|
|
|
trainer = Trainer( |
|
max_epochs=200, |
|
progress_bar_refresh_rate=10, |
|
gpus=gpu_value, |
|
logger=logger, |
|
callbacks=[checkpoint_callback, early_stopping_callback], |
|
deterministic=False, |
|
) |
|
trainer.fit(model, data_module) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--hub_dir", type=str, required=True) |
|
parser.add_argument("--batch_size", type=int, default=32) |
|
parser.add_argument("--learning_rate", type=float, default=6e-5) |
|
parser.add_argument("--model_flavor", type=int, default=0) |
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("--split", type=str, default="train") |
|
args = parser.parse_args() |
|
|
|
main( |
|
hub_dir=args.hub_dir, |
|
batch_size=args.batch_size, |
|
learning_rate=args.learning_rate, |
|
model_flavor=args.model_flavor, |
|
seed=args.seed, |
|
split=args.split, |
|
) |
|
|
|
|