Spaces:
Runtime error
Runtime error
import os | |
import yaml | |
import munch | |
import torch | |
import ignite | |
import monai | |
import shutil | |
import pandas as pd | |
from typing import Union, List, Callable | |
from ignite.contrib.handlers.tqdm_logger import ProgressBar | |
from monai.handlers import ( | |
CheckpointSaver, | |
StatsHandler, | |
TensorBoardStatsHandler, | |
TensorBoardImageHandler, | |
ValidationHandler, | |
from_engine, | |
MeanDice, | |
EarlyStopHandler, | |
MetricLogger, | |
MetricsSaver | |
) | |
from .data import segmentation_dataloaders | |
from .model import get_model | |
from .optimizer import get_optimizer | |
from .loss import get_loss | |
from .transforms import get_val_post_transforms | |
from .utils import USE_AMP | |
def loss_logger(engine): | |
"write loss and lr of each iteration/epoch to file" | |
iteration=engine.state.iteration | |
epoch=engine.state.epoch | |
loss=[o['loss'] for o in engine.state.output] | |
loss=sum(loss)/len(loss) | |
lr=engine.optimizer.param_groups[0]['lr'] | |
log_file=os.path.join(engine.config.log_dir, 'train_logs.csv') | |
if not os.path.exists(log_file): | |
with open(log_file, 'w+') as f: | |
f.write('iteration,epoch,loss,lr\n') | |
with open(log_file, 'a') as f: | |
f.write(f'{iteration},{epoch},{loss},{lr}\n') | |
def metric_logger(engine): | |
"write `metrics` after each epoch to file" | |
if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix | |
metric_names=[k for k in engine.state.metrics.keys()] | |
metrics=[str(engine.state.metrics[mn]) for mn in metric_names] | |
log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv') | |
if not os.path.exists(log_file): | |
with open(log_file, 'w+') as f: | |
f.write(','.join(metric_names) + '\n') | |
with open(log_file, 'a') as f: | |
f.write(','.join(metrics) + '\n') | |
def pred_logger(engine): | |
"save `pred` each time metric improves" | |
epoch=engine.state.epoch | |
root = os.path.join(engine.config.out_dir, 'preds') | |
if not os.path.exists(root): | |
os.makedirs(root) | |
torch.save( | |
engine.state.output[0]['label'], | |
os.path.join(root, f'label.pt') | |
) | |
torch.save( | |
engine.state.output[0]['image'], | |
os.path.join(root, f'image.pt') | |
) | |
if epoch==engine.state.best_metric_epoch: | |
torch.save( | |
engine.state.output[0]['pred'], | |
os.path.join(root, f'pred_epoch_{epoch}.pt') | |
) | |
def get_val_handlers( | |
network: torch.nn.Module, | |
config: dict | |
) -> list: | |
"""Create default handlers for model validation | |
Args: | |
network: | |
nn.Module subclass, the model to train | |
Returns: | |
a list of default handlers for validation: [ | |
StatsHandler: | |
??? | |
TensorBoardStatsHandler: | |
Save loss from validation to `config.log_dir`, allow logging with TensorBoard | |
CheckpointSaver: | |
Save best model to `config.model_dir` | |
] | |
""" | |
val_handlers=[ | |
StatsHandler( | |
tag_name="metric_logger", | |
epoch_print_logger=metric_logger, | |
output_transform=lambda x: None | |
), | |
StatsHandler( | |
tag_name="pred_logger", | |
epoch_print_logger=pred_logger, | |
output_transform=lambda x: None | |
), | |
TensorBoardStatsHandler( | |
log_dir=config.log_dir, | |
# tag_name="val_mean_dice", | |
output_transform=lambda x: None | |
), | |
TensorBoardImageHandler( | |
log_dir=config.log_dir, | |
batch_transform=from_engine(["image", "label"]), | |
output_transform=from_engine(["pred"]), | |
), | |
CheckpointSaver( | |
save_dir=config.model_dir, | |
save_dict={f"network_{config.run_id}": network}, | |
save_key_metric=True | |
), | |
] | |
return val_handlers | |
def get_train_handlers( | |
evaluator: monai.engines.SupervisedEvaluator, | |
config: dict | |
) -> list: | |
"""Create default handlers for model training | |
Args: | |
evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations | |
every epoch | |
Returns: | |
list of default handlers for training: [ | |
ValidationHandler: | |
Allows model validation every epoch | |
StatsHandler: | |
??? | |
TensorBoardStatsHandler: | |
Save loss from validation to `config.log_dir`, allow logging with TensorBoard | |
] | |
""" | |
train_handlers=[ | |
ValidationHandler( | |
validator=evaluator, | |
interval=1, | |
epoch_level=True | |
), | |
StatsHandler( | |
tag_name="train_loss", | |
output_transform=from_engine( | |
["loss"], | |
first=True | |
) | |
), | |
StatsHandler( | |
tag_name='loss_logger', | |
iteration_print_logger=loss_logger | |
), | |
TensorBoardStatsHandler( | |
log_dir=config.log_dir, | |
tag_name="train_loss", | |
output_transform=from_engine( | |
["loss"], | |
first=True | |
), | |
) | |
] | |
return train_handlers | |
def get_evaluator( | |
config: dict, | |
device: torch.device , | |
network: torch.nn.Module, | |
val_data_loader: monai.data.dataloader.DataLoader, | |
val_post_transforms: monai.transforms.compose.Compose, | |
val_handlers: Union[Callable, List]=get_val_handlers | |
) -> monai.engines.SupervisedEvaluator: | |
"""Create default evaluator for training of a segmentation model | |
Args: | |
device: | |
torch.cuda.device for model and engine | |
network: | |
nn.Module subclass, the model to train | |
val_data_loader: | |
Validation data loader, `monai.data.dataloader.DataLoader` subclass | |
val_post_transforms: | |
function to create transforms OR composed transforms | |
val_handlers: | |
function to create handerls OR List of handlers | |
Returns: | |
default evaluator for segmentation of type `monai.engines.SupervisedEvaluator` | |
""" | |
if callable(val_handlers): val_handlers=val_handlers() | |
evaluator=monai.engines.SupervisedEvaluator( | |
device=device, | |
val_data_loader=val_data_loader, | |
network=network, | |
inferer=monai.inferers.SlidingWindowInferer( | |
roi_size=(96, 96, 96), | |
sw_batch_size=4, | |
overlap=0.5 | |
), | |
postprocessing=val_post_transforms, | |
key_val_metric={ | |
"val_mean_dice": MeanDice( | |
include_background=False, | |
output_transform=from_engine( | |
["pred", "label"] | |
) | |
) | |
}, | |
val_handlers=val_handlers, | |
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation | |
amp=USE_AMP, | |
) | |
evaluator.config=config | |
return evaluator | |
class SegmentationTrainer(monai.engines.SupervisedTrainer): | |
"Default Trainer für supervised segmentation task" | |
def __init__(self, | |
config: dict, | |
progress_bar: bool=True, | |
early_stopping: bool=True, | |
metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"], | |
save_latest_metrics: bool=True | |
): | |
self.config=config | |
self._prepare_dirs() | |
self.config.device=torch.device(self.config.device) | |
train_loader, val_loader=segmentation_dataloaders( | |
config=config, | |
train=True, | |
valid=True, | |
test=False | |
) | |
network=get_model(config=config).to(config.device) | |
optimizer=get_optimizer( | |
network, | |
config=config | |
) | |
loss_fn=get_loss(config=config) | |
val_post_transforms=get_val_post_transforms(config=config) | |
val_handlers=get_val_handlers( | |
network, | |
config=config | |
) | |
self.evaluator=get_evaluator( | |
config=config, | |
device=config.device, | |
network=network, | |
val_data_loader=val_loader, | |
val_post_transforms=val_post_transforms, | |
val_handlers=val_handlers, | |
) | |
train_handlers=get_train_handlers( | |
self.evaluator, | |
config=config | |
) | |
super().__init__( | |
device=config.device, | |
max_epochs=self.config.training.max_epochs, | |
train_data_loader=train_loader, | |
network=network, | |
optimizer=optimizer, | |
loss_function=loss_fn, | |
inferer=monai.inferers.SimpleInferer(), | |
train_handlers=train_handlers, | |
amp=USE_AMP, | |
) | |
if early_stopping: self._add_early_stopping() | |
if progress_bar: self._add_progress_bars() | |
self.schedulers=[] | |
# add different metrics dynamically | |
for m in metrics: | |
getattr(monai.handlers, m)( | |
include_background=False, | |
reduction="mean", | |
output_transform=from_engine( | |
["pred", "label"] | |
) | |
).attach(self.evaluator, m) | |
self._add_metrics_logger() | |
# add eval loss to metrics | |
self._add_eval_loss() | |
if save_latest_metrics: self._add_metrics_saver() | |
def _prepare_dirs(self)->None: | |
# create run_id, copy config file for reproducibility | |
os.makedirs(self.config.run_id, exist_ok=True) | |
with open( | |
os.path.join( | |
self.config.run_id, | |
'config.yaml' | |
), 'w+') as f: | |
f.write(yaml.safe_dump(self.config)) | |
# delete old log_dir | |
if os.path.exists(self.config.log_dir): | |
shutil.rmtree(self.config.log_dir) | |
def _add_early_stopping(self) -> None: | |
early_stopping=EarlyStopHandler( | |
patience=self.config.training.early_stopping_patience, | |
min_delta=1e-4, | |
score_function=lambda x: x.state.metrics[x.state.key_metric_name], | |
trainer=self | |
) | |
self.evaluator.add_event_handler( | |
ignite.engine.Events.COMPLETED, | |
early_stopping | |
) | |
def _add_metrics_logger(self) -> None: | |
self.metric_logger=MetricLogger( | |
evaluator=self.evaluator | |
) | |
self.metric_logger.attach(self) | |
def _add_progress_bars(self) -> None: | |
trainer_pbar=ProgressBar() | |
evaluator_pbar=ProgressBar( | |
colour='green' | |
) | |
trainer_pbar.attach( | |
self, | |
output_transform=lambda output:{ | |
'loss': torch.tensor( | |
[x['loss'] for x in output] | |
).mean() | |
} | |
) | |
evaluator_pbar.attach(self.evaluator) | |
def _add_metrics_saver(self) -> None: | |
metric_saver=MetricsSaver( | |
save_dir=self.config.out_dir, | |
metric_details='*', | |
batch_transform=self._get_meta_dict, | |
delimiter=',' | |
) | |
metric_saver.attach(self.evaluator) | |
def _add_eval_loss(self)->None: | |
# TODO improve by adding this to val handlers | |
eval_loss_handler=ignite.metrics.Loss( | |
loss_fn=self.loss_function, | |
output_transform=lambda output: ( | |
output[0]['pred'].unsqueeze(0), # add batch dim | |
output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim | |
) | |
) | |
eval_loss_handler.attach(self.evaluator, 'eval_loss') | |
def _get_meta_dict(self, batch) -> list: | |
"Get dict of metadata from engine. Needed as `batch_transform`" | |
image_cols=self.config.data.image_cols | |
image_name=image_cols[0] if isinstance(image_cols, list) else image_cols | |
key=f'{image_name}_meta_dict' | |
return [item[key] for item in batch] | |
def load_checkpoint(self, checkpoint=None): | |
if not checkpoint: | |
# get name of last checkpoint | |
checkpoint = os.path.join( | |
self.config.model_dir, | |
f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt" | |
) | |
self.network.load_state_dict( | |
torch.load(checkpoint) | |
) | |
def run(self, try_resume_from_checkpoint=True) -> None: | |
"""Run training, if `try_resume_from_checkpoint` tries to | |
load previous checkpoint stored at `self.config.model_dir` | |
""" | |
if try_resume_from_checkpoint: | |
checkpoints = [ | |
os.path.join( | |
self.config.model_dir, | |
checkpoint_name | |
) for checkpoint_name in os.listdir( | |
self.config.model_dir | |
) if self.config.run_id in checkpoint_name | |
] | |
try: | |
checkpoint = sorted(checkpoints)[-1] | |
self.load_checkpoint(checkpoint) | |
print(f"resuming from previous checkpoint at {checkpoint}") | |
except: pass # train from scratch | |
# train the model | |
super().run() | |
# make metrics and losses more accessible | |
self.loss={ | |
"iter": [_iter for _iter, _ in self.metric_logger.loss], | |
"loss": [_loss for _, _loss in self.metric_logger.loss], | |
"epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss] | |
} | |
self.metrics={ | |
k: [item[1] for item in self.metric_logger.metrics[k]] for k in | |
self.evaluator.state.metric_details.keys() | |
} | |
# pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv") | |
# pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv") | |
def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None: | |
"Run training using one-cycle-policy" | |
assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added" | |
fit_one_cycle=monai.handlers.LrScheduleHandler( | |
torch.optim.lr_scheduler.OneCycleLR( | |
optimizer=self.optimizer, | |
max_lr=self.optimizer.param_groups[0]['lr'], | |
steps_per_epoch=self.state.epoch_length, | |
epochs=self.state.max_epochs | |
), | |
epoch_level=False, | |
name="FitOneCycle" | |
) | |
fit_one_cycle.attach(self) | |
self.schedulers += ["FitOneCycle"] | |
def reduce_lr_on_plateau(self, | |
try_resume_from_checkpoint=True, | |
factor=0.1, | |
patience=10, | |
min_lr=1e-10, | |
verbose=True) -> None: | |
"Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve" | |
assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added" | |
reduce_lr_on_plateau=monai.handlers.LrScheduleHandler( | |
torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer=self.optimizer, | |
factor=factor, | |
patience=patience, | |
min_lr=min_lr, | |
verbose=verbose | |
), | |
print_lr=True, | |
name='ReduceLROnPlateau', | |
epoch_level=True, | |
step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name], | |
) | |
reduce_lr_on_plateau.attach(self.evaluator) | |
self.schedulers += ["ReduceLROnPlateau"] | |
def evaluate(self, checkpoint=None, dataloader=None): | |
"Run evaluation with best saved checkpoint" | |
self.load_checkpoint(checkpoint) | |
if dataloader: | |
self.evaluator.set_data(dataloader) | |
self.evaluator.state.epoch_length=len(dataloader) | |
self.evaluator.run() | |
print(f"metrics saved to {self.config.out_dir}") |