segformer-sidewalk / training_loop.py
chainyo's picture
create training loop
c75c928
raw
history blame
2.61 kB
"""
Minimal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic"
Maximal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train
"""
import json
import torch
from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.loggers import WandbLogger
from dataloader import SidewalkSegmentationDataLoader
from model import SidewalkSegmentationModel
def main(
hub_dir: str,
batch_size: int = 32,
learning_rate: float = 6e-5,
model_flavor: int = 0,
seed: int = 42,
split: str = "train",
):
seed_everything(seed)
logger = WandbLogger(project="sidewalk-segmentation")
gpu_value = 1 if torch.cuda.is_available() else 0
id2label_file = json.load(open("id2label.json", "r"))
id2label = {int(k): v for k, v in id2label_file.items()}
num_labels = len(id2label)
model = SidewalkSegmentationModel(
num_labels=num_labels,
id2label=id2label,
model_flavor=model_flavor,
learning_rate=learning_rate,
)
data_module = SidewalkSegmentationDataLoader(
hub_dir=hub_dir,
batch_size=batch_size,
split=split,
)
data_module.setup()
checkpoint_callback = callbacks.ModelCheckpoint(
dirpath="checkpoints",
save_top_k=1,
verbose=True,
monitor="val_mean_iou",
mode="max",
)
early_stopping_callback = callbacks.EarlyStopping(
monitor="val_mean_iou",
patience=5,
verbose=True,
mode="max",
)
trainer = Trainer(
max_epochs=200,
progress_bar_refresh_rate=10,
gpus=gpu_value,
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
deterministic=False,
)
trainer.fit(model, data_module)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hub_dir", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=6e-5)
parser.add_argument("--model_flavor", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--split", type=str, default="train")
args = parser.parse_args()
main(
hub_dir=args.hub_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
model_flavor=args.model_flavor,
seed=args.seed,
split=args.split,
)