File size: 2,610 Bytes
c75c928 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""
Minimal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic"
Maximal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train
"""
import json
import torch
from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.loggers import WandbLogger
from dataloader import SidewalkSegmentationDataLoader
from model import SidewalkSegmentationModel
def main(
hub_dir: str,
batch_size: int = 32,
learning_rate: float = 6e-5,
model_flavor: int = 0,
seed: int = 42,
split: str = "train",
):
seed_everything(seed)
logger = WandbLogger(project="sidewalk-segmentation")
gpu_value = 1 if torch.cuda.is_available() else 0
id2label_file = json.load(open("id2label.json", "r"))
id2label = {int(k): v for k, v in id2label_file.items()}
num_labels = len(id2label)
model = SidewalkSegmentationModel(
num_labels=num_labels,
id2label=id2label,
model_flavor=model_flavor,
learning_rate=learning_rate,
)
data_module = SidewalkSegmentationDataLoader(
hub_dir=hub_dir,
batch_size=batch_size,
split=split,
)
data_module.setup()
checkpoint_callback = callbacks.ModelCheckpoint(
dirpath="checkpoints",
save_top_k=1,
verbose=True,
monitor="val_mean_iou",
mode="max",
)
early_stopping_callback = callbacks.EarlyStopping(
monitor="val_mean_iou",
patience=5,
verbose=True,
mode="max",
)
trainer = Trainer(
max_epochs=200,
progress_bar_refresh_rate=10,
gpus=gpu_value,
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
deterministic=False,
)
trainer.fit(model, data_module)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hub_dir", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=6e-5)
parser.add_argument("--model_flavor", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--split", type=str, default="train")
args = parser.parse_args()
main(
hub_dir=args.hub_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
model_flavor=args.model_flavor,
seed=args.seed,
split=args.split,
)
|