# Ultralytics YOLO 🚀, AGPL-3.0 license from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr try: # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674 from torch.utils.tensorboard import SummaryWriter assert not TESTS_RUNNING # do not log pytest assert SETTINGS["tensorboard"] is True # verify integration is enabled WRITER = None # TensorBoard SummaryWriter instance PREFIX = colorstr("TensorBoard: ") # Imports below only required if TensorBoard enabled import warnings from copy import deepcopy from ultralytics.utils.torch_utils import de_parallel, torch except (ImportError, AssertionError, TypeError, AttributeError): # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed SummaryWriter = None def _log_scalars(scalars, step=0): """Logs scalar values to TensorBoard.""" if WRITER: for k, v in scalars.items(): WRITER.add_scalar(k, v, step) def _log_tensorboard_graph(trainer): """Log model graph to TensorBoard.""" # Input image imgsz = trainer.args.imgsz imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz p = next(trainer.model.parameters()) # for device, type im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning # Try simple method first (YOLO) try: trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) LOGGER.info(f"{PREFIX}model graph visualization added ✅") return except: # noqa E722 # Fallback to TorchScript export steps (RTDETR) try: model = deepcopy(de_parallel(trainer.model)) model.eval() model = model.fuse(verbose=False) for m in model.modules(): if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class) m.export = True m.format = "torchscript" model(im) # dry run WRITER.add_graph(torch.jit.trace(model, im, strict=False), []) LOGGER.info(f"{PREFIX}model graph visualization added ✅") except Exception as e: LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}") def on_pretrain_routine_start(trainer): """Initialize TensorBoard logging with SummaryWriter.""" if SummaryWriter: try: global WRITER WRITER = SummaryWriter(str(trainer.save_dir)) LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") except Exception as e: LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}") def on_train_start(trainer): """Log TensorBoard graph.""" if WRITER: _log_tensorboard_graph(trainer) def on_train_epoch_end(trainer): """Logs scalar statistics at the end of a training epoch.""" _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.lr, trainer.epoch + 1) def on_fit_epoch_end(trainer): """Logs epoch metrics at end of training epoch.""" _log_scalars(trainer.metrics, trainer.epoch + 1) callbacks = ( { "on_pretrain_routine_start": on_pretrain_routine_start, "on_train_start": on_train_start, "on_fit_epoch_end": on_fit_epoch_end, "on_train_epoch_end": on_train_epoch_end, } if SummaryWriter else {} )