File size: 1,949 Bytes
d2e7940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import os
from typing import List
import hydra
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import LightningLoggerBase
from src import utils
log = utils.get_logger(__name__)
def test(config: DictConfig) -> None:
"""Contains minimal example of the testing pipeline.
Evaluates given checkpoint on a testset.
Args:
config (DictConfig): Configuration composed by Hydra.
Returns:
None
"""
# Set seed for random number generators in pytorch, numpy and python.random
if config.get("seed"):
seed_everything(config.seed, workers=True)
# Convert relative ckpt path to absolute path if necessary
if not os.path.isabs(config.ckpt_path):
config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path)
# Init lightning datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
# Init lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)
# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=logger)
# Log hyperparameters
trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path})
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=config.ckpt_path)
|