|
import os |
|
from typing import List, Optional |
|
|
|
import hydra |
|
from omegaconf import DictConfig |
|
from pytorch_lightning import ( |
|
Callback, |
|
LightningDataModule, |
|
LightningModule, |
|
Trainer, |
|
seed_everything, |
|
) |
|
from pytorch_lightning.loggers import LightningLoggerBase |
|
|
|
from src import utils |
|
|
|
log = utils.get_logger(__name__) |
|
|
|
|
|
def train(config: DictConfig) -> Optional[float]: |
|
"""Contains the training pipeline. |
|
Can additionally evaluate model on a testset, using best weights achieved during training. |
|
|
|
Args: |
|
config (DictConfig): Configuration composed by Hydra. |
|
|
|
Returns: |
|
Optional[float]: Metric score for hyperparameter optimization. |
|
""" |
|
|
|
|
|
if config.get("seed"): |
|
seed_everything(config.seed, workers=True) |
|
|
|
|
|
ckpt_path = config.trainer.get("resume_from_checkpoint") |
|
if ckpt_path and not os.path.isabs(ckpt_path): |
|
config.trainer.resume_from_checkpoint = os.path.join( |
|
hydra.utils.get_original_cwd(), ckpt_path |
|
) |
|
|
|
|
|
log.info(f"Instantiating datamodule <{config.datamodule._target_}>") |
|
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) |
|
|
|
|
|
log.info(f"Instantiating model <{config.model._target_}>") |
|
model: LightningModule = hydra.utils.instantiate(config.model) |
|
|
|
|
|
callbacks: List[Callback] = [] |
|
if "callbacks" in config: |
|
for _, cb_conf in config.callbacks.items(): |
|
if "_target_" in cb_conf: |
|
log.info(f"Instantiating callback <{cb_conf._target_}>") |
|
callbacks.append(hydra.utils.instantiate(cb_conf)) |
|
|
|
|
|
logger: List[LightningLoggerBase] = [] |
|
if "logger" in config: |
|
for _, lg_conf in config.logger.items(): |
|
if "_target_" in lg_conf: |
|
log.info(f"Instantiating logger <{lg_conf._target_}>") |
|
logger.append(hydra.utils.instantiate(lg_conf)) |
|
|
|
|
|
log.info(f"Instantiating trainer <{config.trainer._target_}>") |
|
trainer: Trainer = hydra.utils.instantiate( |
|
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" |
|
) |
|
|
|
|
|
log.info("Logging hyperparameters!") |
|
utils.log_hyperparameters( |
|
config=config, |
|
model=model, |
|
datamodule=datamodule, |
|
trainer=trainer, |
|
callbacks=callbacks, |
|
logger=logger, |
|
) |
|
|
|
|
|
if config.get("train"): |
|
log.info("Starting training!") |
|
trainer.fit(model=model, datamodule=datamodule) |
|
|
|
|
|
optimized_metric = config.get("optimized_metric") |
|
if optimized_metric and optimized_metric not in trainer.callback_metrics: |
|
raise Exception( |
|
"Metric for hyperparameter optimization not found! " |
|
"Make sure the `optimized_metric` in `hparams_search` config is correct!" |
|
) |
|
score = trainer.callback_metrics.get(optimized_metric) |
|
|
|
|
|
if config.get("test"): |
|
ckpt_path = "best" |
|
if not config.get("train") or config.trainer.get("fast_dev_run"): |
|
ckpt_path = None |
|
log.info("Starting testing!") |
|
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) |
|
|
|
|
|
log.info("Finalizing!") |
|
utils.finish( |
|
config=config, |
|
model=model, |
|
datamodule=datamodule, |
|
trainer=trainer, |
|
callbacks=callbacks, |
|
logger=logger, |
|
) |
|
|
|
|
|
if not config.trainer.get("fast_dev_run") and config.trainer.get("train"): |
|
log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") |
|
|
|
|
|
return score |
|
|