import sys from functools import partial from typing import Callable from typing import Dict from typing import Tuple from typing import Union from argparse import Namespace sys.path.append("vision/references/segmentation") import presets import torch import torch.utils.data import torchvision import utils from torch import nn from common import flops_calculation_function from common import NanSafeConfusionMatrix as ConfusionMatrix from common import get_coco def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]: def sbd(*args, **kwargs): kwargs.pop("use_v2") return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) def voc(*args, **kwargs): kwargs.pop("use_v2") return torchvision.datasets.VOCSegmentation(*args, **kwargs) paths = { "voc": (args.data_path, voc, 21), "voc_aug": (args.data_path, sbd, 21), "coco": (args.data_path, get_coco, 21), "coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81) } p, ds_fn, num_classes = paths["coco_orig"] if transform is None: transform = presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) image_set = "train" if is_train else "val" ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2) return ds, num_classes def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor: losses = {} for name, x in inputs.items(): losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) if len(losses) == 1: return losses["out"] return losses["out"] + 0.5 * losses["aux"] def evaluate( model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: Union[str, torch.device], num_classes: int, criterion: Callable, ) -> Tuple[ConfusionMatrix, float]: model.eval() confmat = ConfusionMatrix(num_classes) metric_logger = utils.MetricLogger(delimiter=" ") header = "Test:" num_processed_samples = 0 with torch.inference_mode(): for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)): image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) output = output["out"] confmat.update(target.flatten(), output.argmax(1).flatten()) # FIXME need to take into account that the datasets # could have been padded in distributed setup num_processed_samples += image.shape[0] metric_logger.update(loss=loss.item()) confmat.reduce_from_all_processes() return confmat, metric_logger.loss.global_avg def main(args): if args.backend.lower() != "pil" and not args.use_v2: # TODO: Support tensor backend in V1? raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.") if args.use_v2: raise ValueError("v2 is only supported for coco dataset for now.") print(args) device = torch.device(args.device) if args.use_deterministic_algorithms: torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True) else: torch.backends.cudnn.benchmark = True dataset_test, num_classes = get_dataset(args, is_train=False) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) checkpoint = torch.load(args.model_path) model = checkpoint["model"] model.to(device) model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device)) print(f"Model Flops: {model_flops}M") # We disable the cudnn benchmarking because it can noticeably affect the accuracy torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True confmat, loss = evaluate( model=model, data_loader=data_loader_test, device=device, num_classes=num_classes, criterion=criterion, ) print(confmat) return def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" ) parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run") parser.add_argument( "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" ) parser.add_argument( "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." ) # distributed training parameters parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") parser.add_argument("--model-path", default=None, help="Path to model checkpoint.") return parser if __name__ == "__main__": args = get_args_parser().parse_args() main(args)