osbm's picture
Upload 9 files
3953219
raw
history blame
16.3 kB
import os
import yaml
import munch
import torch
import ignite
import monai
import shutil
import pandas as pd
from typing import Union, List, Callable
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from monai.handlers import (
CheckpointSaver,
StatsHandler,
TensorBoardStatsHandler,
TensorBoardImageHandler,
ValidationHandler,
from_engine,
MeanDice,
EarlyStopHandler,
MetricLogger,
MetricsSaver
)
from .data import segmentation_dataloaders
from .model import get_model
from .optimizer import get_optimizer
from .loss import get_loss
from .transforms import get_val_post_transforms
from .utils import USE_AMP
def loss_logger(engine):
"write loss and lr of each iteration/epoch to file"
iteration=engine.state.iteration
epoch=engine.state.epoch
loss=[o['loss'] for o in engine.state.output]
loss=sum(loss)/len(loss)
lr=engine.optimizer.param_groups[0]['lr']
log_file=os.path.join(engine.config.log_dir, 'train_logs.csv')
if not os.path.exists(log_file):
with open(log_file, 'w+') as f:
f.write('iteration,epoch,loss,lr\n')
with open(log_file, 'a') as f:
f.write(f'{iteration},{epoch},{loss},{lr}\n')
def metric_logger(engine):
"write `metrics` after each epoch to file"
if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix
metric_names=[k for k in engine.state.metrics.keys()]
metrics=[str(engine.state.metrics[mn]) for mn in metric_names]
log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv')
if not os.path.exists(log_file):
with open(log_file, 'w+') as f:
f.write(','.join(metric_names) + '\n')
with open(log_file, 'a') as f:
f.write(','.join(metrics) + '\n')
def pred_logger(engine):
"save `pred` each time metric improves"
epoch=engine.state.epoch
root = os.path.join(engine.config.out_dir, 'preds')
if not os.path.exists(root):
os.makedirs(root)
torch.save(
engine.state.output[0]['label'],
os.path.join(root, f'label.pt')
)
torch.save(
engine.state.output[0]['image'],
os.path.join(root, f'image.pt')
)
if epoch==engine.state.best_metric_epoch:
torch.save(
engine.state.output[0]['pred'],
os.path.join(root, f'pred_epoch_{epoch}.pt')
)
def get_val_handlers(
network: torch.nn.Module,
config: dict
) -> list:
"""Create default handlers for model validation
Args:
network:
nn.Module subclass, the model to train
Returns:
a list of default handlers for validation: [
StatsHandler:
???
TensorBoardStatsHandler:
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
CheckpointSaver:
Save best model to `config.model_dir`
]
"""
val_handlers=[
StatsHandler(
tag_name="metric_logger",
epoch_print_logger=metric_logger,
output_transform=lambda x: None
),
StatsHandler(
tag_name="pred_logger",
epoch_print_logger=pred_logger,
output_transform=lambda x: None
),
TensorBoardStatsHandler(
log_dir=config.log_dir,
# tag_name="val_mean_dice",
output_transform=lambda x: None
),
TensorBoardImageHandler(
log_dir=config.log_dir,
batch_transform=from_engine(["image", "label"]),
output_transform=from_engine(["pred"]),
),
CheckpointSaver(
save_dir=config.model_dir,
save_dict={f"network_{config.run_id}": network},
save_key_metric=True
),
]
return val_handlers
def get_train_handlers(
evaluator: monai.engines.SupervisedEvaluator,
config: dict
) -> list:
"""Create default handlers for model training
Args:
evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations
every epoch
Returns:
list of default handlers for training: [
ValidationHandler:
Allows model validation every epoch
StatsHandler:
???
TensorBoardStatsHandler:
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
]
"""
train_handlers=[
ValidationHandler(
validator=evaluator,
interval=1,
epoch_level=True
),
StatsHandler(
tag_name="train_loss",
output_transform=from_engine(
["loss"],
first=True
)
),
StatsHandler(
tag_name='loss_logger',
iteration_print_logger=loss_logger
),
TensorBoardStatsHandler(
log_dir=config.log_dir,
tag_name="train_loss",
output_transform=from_engine(
["loss"],
first=True
),
)
]
return train_handlers
def get_evaluator(
config: dict,
device: torch.device ,
network: torch.nn.Module,
val_data_loader: monai.data.dataloader.DataLoader,
val_post_transforms: monai.transforms.compose.Compose,
val_handlers: Union[Callable, List]=get_val_handlers
) -> monai.engines.SupervisedEvaluator:
"""Create default evaluator for training of a segmentation model
Args:
device:
torch.cuda.device for model and engine
network:
nn.Module subclass, the model to train
val_data_loader:
Validation data loader, `monai.data.dataloader.DataLoader` subclass
val_post_transforms:
function to create transforms OR composed transforms
val_handlers:
function to create handerls OR List of handlers
Returns:
default evaluator for segmentation of type `monai.engines.SupervisedEvaluator`
"""
if callable(val_handlers): val_handlers=val_handlers()
evaluator=monai.engines.SupervisedEvaluator(
device=device,
val_data_loader=val_data_loader,
network=network,
inferer=monai.inferers.SlidingWindowInferer(
roi_size=(96, 96, 96),
sw_batch_size=4,
overlap=0.5
),
postprocessing=val_post_transforms,
key_val_metric={
"val_mean_dice": MeanDice(
include_background=False,
output_transform=from_engine(
["pred", "label"]
)
)
},
val_handlers=val_handlers,
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
amp=USE_AMP,
)
evaluator.config=config
return evaluator
class SegmentationTrainer(monai.engines.SupervisedTrainer):
"Default Trainer für supervised segmentation task"
def __init__(self,
config: dict,
progress_bar: bool=True,
early_stopping: bool=True,
metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
save_latest_metrics: bool=True
):
self.config=config
self._prepare_dirs()
self.config.device=torch.device(self.config.device)
train_loader, val_loader=segmentation_dataloaders(
config=config,
train=True,
valid=True,
test=False
)
network=get_model(config=config).to(config.device)
optimizer=get_optimizer(
network,
config=config
)
loss_fn=get_loss(config=config)
val_post_transforms=get_val_post_transforms(config=config)
val_handlers=get_val_handlers(
network,
config=config
)
self.evaluator=get_evaluator(
config=config,
device=config.device,
network=network,
val_data_loader=val_loader,
val_post_transforms=val_post_transforms,
val_handlers=val_handlers,
)
train_handlers=get_train_handlers(
self.evaluator,
config=config
)
super().__init__(
device=config.device,
max_epochs=self.config.training.max_epochs,
train_data_loader=train_loader,
network=network,
optimizer=optimizer,
loss_function=loss_fn,
inferer=monai.inferers.SimpleInferer(),
train_handlers=train_handlers,
amp=USE_AMP,
)
if early_stopping: self._add_early_stopping()
if progress_bar: self._add_progress_bars()
self.schedulers=[]
# add different metrics dynamically
for m in metrics:
getattr(monai.handlers, m)(
include_background=False,
reduction="mean",
output_transform=from_engine(
["pred", "label"]
)
).attach(self.evaluator, m)
self._add_metrics_logger()
# add eval loss to metrics
self._add_eval_loss()
if save_latest_metrics: self._add_metrics_saver()
def _prepare_dirs(self)->None:
# create run_id, copy config file for reproducibility
os.makedirs(self.config.run_id, exist_ok=True)
with open(
os.path.join(
self.config.run_id,
'config.yaml'
), 'w+') as f:
f.write(yaml.safe_dump(self.config))
# delete old log_dir
if os.path.exists(self.config.log_dir):
shutil.rmtree(self.config.log_dir)
def _add_early_stopping(self) -> None:
early_stopping=EarlyStopHandler(
patience=self.config.training.early_stopping_patience,
min_delta=1e-4,
score_function=lambda x: x.state.metrics[x.state.key_metric_name],
trainer=self
)
self.evaluator.add_event_handler(
ignite.engine.Events.COMPLETED,
early_stopping
)
def _add_metrics_logger(self) -> None:
self.metric_logger=MetricLogger(
evaluator=self.evaluator
)
self.metric_logger.attach(self)
def _add_progress_bars(self) -> None:
trainer_pbar=ProgressBar()
evaluator_pbar=ProgressBar(
colour='green'
)
trainer_pbar.attach(
self,
output_transform=lambda output:{
'loss': torch.tensor(
[x['loss'] for x in output]
).mean()
}
)
evaluator_pbar.attach(self.evaluator)
def _add_metrics_saver(self) -> None:
metric_saver=MetricsSaver(
save_dir=self.config.out_dir,
metric_details='*',
batch_transform=self._get_meta_dict,
delimiter=','
)
metric_saver.attach(self.evaluator)
def _add_eval_loss(self)->None:
# TODO improve by adding this to val handlers
eval_loss_handler=ignite.metrics.Loss(
loss_fn=self.loss_function,
output_transform=lambda output: (
output[0]['pred'].unsqueeze(0), # add batch dim
output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim
)
)
eval_loss_handler.attach(self.evaluator, 'eval_loss')
def _get_meta_dict(self, batch) -> list:
"Get dict of metadata from engine. Needed as `batch_transform`"
image_cols=self.config.data.image_cols
image_name=image_cols[0] if isinstance(image_cols, list) else image_cols
key=f'{image_name}_meta_dict'
return [item[key] for item in batch]
def load_checkpoint(self, checkpoint=None):
if not checkpoint:
# get name of last checkpoint
checkpoint = os.path.join(
self.config.model_dir,
f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt"
)
self.network.load_state_dict(
torch.load(checkpoint)
)
def run(self, try_resume_from_checkpoint=True) -> None:
"""Run training, if `try_resume_from_checkpoint` tries to
load previous checkpoint stored at `self.config.model_dir`
"""
if try_resume_from_checkpoint:
checkpoints = [
os.path.join(
self.config.model_dir,
checkpoint_name
) for checkpoint_name in os.listdir(
self.config.model_dir
) if self.config.run_id in checkpoint_name
]
try:
checkpoint = sorted(checkpoints)[-1]
self.load_checkpoint(checkpoint)
print(f"resuming from previous checkpoint at {checkpoint}")
except: pass # train from scratch
# train the model
super().run()
# make metrics and losses more accessible
self.loss={
"iter": [_iter for _iter, _ in self.metric_logger.loss],
"loss": [_loss for _, _loss in self.metric_logger.loss],
"epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss]
}
self.metrics={
k: [item[1] for item in self.metric_logger.metrics[k]] for k in
self.evaluator.state.metric_details.keys()
}
# pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv")
# pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv")
def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None:
"Run training using one-cycle-policy"
assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added"
fit_one_cycle=monai.handlers.LrScheduleHandler(
torch.optim.lr_scheduler.OneCycleLR(
optimizer=self.optimizer,
max_lr=self.optimizer.param_groups[0]['lr'],
steps_per_epoch=self.state.epoch_length,
epochs=self.state.max_epochs
),
epoch_level=False,
name="FitOneCycle"
)
fit_one_cycle.attach(self)
self.schedulers += ["FitOneCycle"]
def reduce_lr_on_plateau(self,
try_resume_from_checkpoint=True,
factor=0.1,
patience=10,
min_lr=1e-10,
verbose=True) -> None:
"Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve"
assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added"
reduce_lr_on_plateau=monai.handlers.LrScheduleHandler(
torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer=self.optimizer,
factor=factor,
patience=patience,
min_lr=min_lr,
verbose=verbose
),
print_lr=True,
name='ReduceLROnPlateau',
epoch_level=True,
step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name],
)
reduce_lr_on_plateau.attach(self.evaluator)
self.schedulers += ["ReduceLROnPlateau"]
def evaluate(self, checkpoint=None, dataloader=None):
"Run evaluation with best saved checkpoint"
self.load_checkpoint(checkpoint)
if dataloader:
self.evaluator.set_data(dataloader)
self.evaluator.state.epoch_length=len(dataloader)
self.evaluator.run()
print(f"metrics saved to {self.config.out_dir}")