# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Evaluate CaR on segmentation benchmarks.""" # pylint: disable=g-importing-member import argparse import numpy as np import torch from torch.utils import tensorboard import torch.utils.data from torch.utils.data import Subset import torchvision.transforms as T # pylint: disable=g-bad-import-order from modeling.model.car import CaR from sam.utils import build_sam_config from utils.utils import Config from utils.utils import load_yaml from utils.utils import MetricLogger from utils.utils import SmoothedValue from utils.inference_pipeline import inference_car from utils.merge_mask import merge_masks_simple # Datasets # pylint: disable=g-multiple-import from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset from data.coco import COCO_OBJECT_CLASSES, COCODataset from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset from data.gres import GReferDataset from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset from data.refcoco import ReferDataset from data.voc import VOC_CLASSES, VOCDataset IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512 # set random seed torch.manual_seed(0) np.random.seed(0) def get_dataset(cfg, ds_name, split, transform, data_root=None): """Get dataset.""" data_args = dict(root=data_root) if data_root is not None else {} if 'refcoco' in ds_name: splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc' ds = ReferDataset( dataset=ds_name, splitBy=splitby, split=split, image_transforms=transform, target_transforms=transform, eval_mode=True, prompts_augment=cfg.test.prompts_augment, **data_args, ) elif ds_name == 'gres': ds = GReferDataset(split=split, transform=transform, **data_args) elif ds_name == 'voc': ds = VOCDataset( year='2012', split=split, transform=transform, target_transform=transform, **data_args, ) elif ds_name == 'cocostuff': ds = COCODataset(transform=transform, **data_args) elif ds_name == 'context': ds = CONTEXTDataset( year='2010', transform=transform, split=split, **data_args ) elif ds_name == 'ade': ds = ADEDataset(split=split, transform=transform, **data_args) elif ds_name == 'pascal_459': ds = Pascal459Dataset(split=split, transform=transform, **data_args) elif ds_name == 'ade_847': ds = ADE847Dataset(split=split, transform=transform, **data_args) else: raise ValueError(f'Dataset {ds_name} not implemented') return ds def get_transform(): transforms = [ T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)), T.ToTensor(), ] return T.Compose(transforms) def assign_label( all_masks, scores, stuff_masks=None, stuff_scores=None, id_mapping=None, stuff_id_mapping=None, ): """Assign labels.""" label_preds = np.zeros_like(all_masks[0]).astype(np.int32) if stuff_masks is not None: sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy()) stuff_masks = stuff_masks[sorted_idxs] stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs] for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores): if score > 0: # convert mask to boolean mask = mask > 0.5 # assign label if stuff_id_mapping is not None: label_preds[mask] = stuff_id_mapping[sorted_idx] + 1 else: label_preds[mask] = sorted_idx + 1 sorted_idxs = np.argsort(scores.detach().cpu().numpy()) all_masks = all_masks[sorted_idxs] scores = scores.detach().cpu().numpy()[sorted_idxs] for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores): if score > 0: # convert mask to boolean mask = mask > 0.5 # assign label if id_mapping is not None: label_preds[mask] = id_mapping[sorted_idx] + 1 else: label_preds[mask] = sorted_idx + 1 return label_preds def eval_semantic( label_space, algo, cfg, model, image_path, stuff_label_space=None, sam_pipeline=None, ): """Semantic segmentation evaluation.""" if label_space is None: raise ValueError( 'label_space must be provided for semantic segmentation evaluation' ) if algo == 'car': all_masks, scores = inference_car( cfg, model, image_path, label_space, sam_pipeline=sam_pipeline ) if stuff_label_space is not None: if cfg.test.ds_name == 'context': thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID elif cfg.test.ds_name == 'ade': thing_id_mapping = ADE_THING_CLASS_ID stuff_id_mapping = ADE_STUFF_CLASS_ID elif cfg.test.ds_name == 'pascal_459': thing_id_mapping = PASCAL_459_THING_CLASS_ID stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID elif cfg.test.ds_name == 'ade_847': thing_id_mapping = ADE_847_THING_CLASS_ID stuff_id_mapping = ADE_847_STUFF_CLASS_ID else: raise ValueError(f'Dataset {cfg.test.ds_name} not supported') model.mask_generator.set_bg_cls(label_space) model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type) model.set_bg_factor(cfg.car.stuff_bg_factor) stuff_masks, stuff_scores = inference_car( cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline ) model.mask_generator.set_bg_cls(cfg.car.bg_cls) model.set_visual_prompt_type(cfg.car.visual_prompt_type) model.set_bg_factor(cfg.car.bg_factor) all_masks = all_masks.detach().cpu().numpy() stuff_masks = stuff_masks.detach().cpu().numpy() label_preds = assign_label( all_masks, scores, stuff_masks=stuff_masks, stuff_scores=stuff_scores, id_mapping=thing_id_mapping, stuff_id_mapping=stuff_id_mapping, ) else: all_masks = all_masks.detach().cpu().numpy() label_preds = assign_label(all_masks, scores) return label_preds.squeeze() else: raise NotImplementedError(f'algo {algo} not implemented') def _fast_hist(label_true, label_pred, n_class=21): mask = (label_true >= 0) & (label_true < n_class) hist = np.bincount( n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class**2, ).reshape(n_class, n_class) return hist def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False): """Semantic segmentation IOU.""" hist = np.zeros((n_class, n_class)) for lt, lp in zip(label_trues, label_preds): hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) if ignore_background: hist = hist[1:, 1:] acc = np.diag(hist).sum() / hist.sum() acc_cls = np.diag(hist) / hist.sum(axis=1) acc_cls = np.nanmean(acc_cls) iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) valid = hist.sum(axis=1) > 0 # added if valid.sum() == 0: mean_iu = 0 else: mean_iu = np.nanmean(iu[valid]) freq = hist.sum(axis=1) / hist.sum() fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() if ignore_background: cls_iu = dict(zip(range(1, n_class), iu)) else: cls_iu = dict(zip(range(n_class), iu)) return { 'Pixel Accuracy': acc, 'Mean Accuracy': acc_cls, 'Frequency Weighted IoU': fwavacc, 'mIoU': mean_iu, 'Class IoU': cls_iu, } def evaluate( data_loader, cfg, model, test_cfg, label_space=None, stuff_label_space=None, sam_pipeline=None, ): """Run evaluation.""" if ( test_cfg.ds_name not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847'] and test_cfg.seg_mode == 'semantic' ): raise ValueError(( 'Semantic segmentation evaluation is only implemented for voc, ' 'context, coco object, ade, pascal459, ade847 dataset' )) metric_logger = MetricLogger(delimiter=' ') metric_logger.add_meter( 'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})') ) # evaluation variables cum_i, cum_u = 0, 0 eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_iou = [] header = 'Test:' # all_masks = [] label_preds, label_gts = [], [] print(len(data_loader)) cc = 0 use_tensorboard = False if hasattr(cfg.test, 'use_tensorboard'): use_tensorboard = cfg.test.use_tensorboard if use_tensorboard: writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path) for data in metric_logger.log_every(data_loader, 1, header): _, image_paths, target_list, sentences_list = data # print(type(target_lis)) if not isinstance(target_list, list): target_list, sentences_list = [target_list], [sentences_list] for target, sentences in zip(target_list, sentences_list): image_path = image_paths[0] # print(image_path) if test_cfg.seg_mode == 'refer': all_masks, all_scores = inference_car( cfg, model, image_path, sentences, sam_pipeline=sam_pipeline ) # final_mask = merge_masks(all_masks, *target.shape[1:]) final_mask = merge_masks_simple( all_masks, *target.shape[1:], scores=all_scores ) intersection, union, cur_iou = compute_iou(final_mask, target) # cur_iou = IoU(final_mask, target, 0) metric_logger.update(mIoU=cur_iou) mean_iou.append(cur_iou) if use_tensorboard: writer.add_scalar('Mean IoU', cur_iou, cc) cum_i += intersection cum_u += union for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou seg_total += 1 elif test_cfg.seg_mode == 'semantic': # torch.cuda.empty_cache() label_pred = eval_semantic( label_space, test_cfg.algo, cfg, model, image_path, stuff_label_space, ) label_gt = target.squeeze().cpu().numpy() cur_iou = semantic_iou( [label_gt], [label_pred], n_class=cfg.test.n_class, ignore_background=cfg.test.ignore_background, )['mIoU'] metric_logger.update(mIoU=cur_iou) label_preds.append(label_pred) label_gts.append(label_gt) cc += 1 if test_cfg.seg_mode == 'refer': mean_iou = np.array(mean_iou) m_iou = np.mean(mean_iou) if use_tensorboard: writer.add_scalar('mIoU', m_iou.item(), len(data_loader)) print('Final results:') print('Mean IoU is %.2f\n' % (m_iou * 100.0)) results_str = '' for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % ( str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100.0 / seg_total, ) o_iou = cum_i * 100.0 / cum_u results_str += ' overall IoU = %.2f\n' % o_iou if use_tensorboard: writer.add_scalar('oIoU', o_iou, 0) print(results_str) elif test_cfg.seg_mode == 'semantic': iou_score = semantic_iou( label_gts, label_preds, n_class=cfg.test.n_class, ignore_background=cfg.test.ignore_background, ) if use_tensorboard: writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader)) print(iou_score) if use_tensorboard: writer.close() def compute_iou(pred_seg, gd_seg): """Compute IoU.""" intersection = torch.sum(torch.logical_and(pred_seg, gd_seg)) union = torch.sum(torch.logical_or(pred_seg, gd_seg)) iou = intersection * 1.0 / union if union == 0: iou = 0 return intersection, union, iou def list_of_strings(arg): return [a.strip() for a in arg.split(',')] # pylint: disable=redefined-outer-name def parse_args(): """Parse arguments.""" parser = argparse.ArgumentParser(description='Training') parser.add_argument( '--cfg-path', default='configs/refcoco_test_prompt.yaml', help='path to configuration file.', ) parser.add_argument('--index', default=0, type=int, help='split task') parser.add_argument('--mask_threshold', default=0.0, type=float) parser.add_argument('--confidence_threshold', default=0.0, type=float) parser.add_argument('--clipes_threshold', default=0.0, type=float) parser.add_argument('--stuff_bg_factor', default=0.0, type=float) parser.add_argument('--bg_factor', default=0.0, type=float) parser.add_argument('--output_path', default=None, type=str) parser.add_argument( '--visual_prompt_type', default=None, type=list_of_strings ) parser.add_argument( '--stuff_visual_prompt_type', default=None, type=list_of_strings ) args = parser.parse_args() return args def main(args): cfg = Config(**load_yaml(args.cfg_path)) if args.mask_threshold > 0: cfg.car.mask_threshold = args.mask_threshold if args.confidence_threshold > 0: cfg.car.confidence_threshold = args.confidence_threshold if args.clipes_threshold > 0: cfg.car.clipes_threshold = args.clipes_threshold if args.bg_factor > 0: cfg.car.bg_factor = args.bg_factor if args.stuff_bg_factor > 0: cfg.car.stuff_bg_factor = args.stuff_bg_factor if args.output_path is not None: cfg.test.output_path = args.output_path if args.visual_prompt_type is not None: cfg.car.visual_prompt_type = args.visual_prompt_type if args.stuff_visual_prompt_type is not None: cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type try: data_root = cfg.test.data_root except ValueError: data_root = None dataset_test = get_dataset( cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root ) device = 'cuda' if torch.cuda.is_available() else 'cpu' stuff_label_space = None if cfg.test.ds_name == 'voc': label_space = VOC_CLASSES elif cfg.test.ds_name == 'cocostuff': label_space = COCO_OBJECT_CLASSES elif cfg.test.ds_name == 'context': # label_space = PASCAL_CONTEXT_CLASSES label_space = PASCAL_CONTEXT_THING_CLASS stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS elif cfg.test.ds_name == 'ade': label_space = ADE_THING_CLASS stuff_label_space = ADE_STUFF_CLASS elif cfg.test.ds_name == 'pascal_459': label_space = PASCAL_459_THING_CLASS stuff_label_space = PASCAL_459_STUFF_CLASS elif cfg.test.ds_name == 'ade_847': label_space = ADE_847_THING_CLASS stuff_label_space = ADE_847_STUFF_CLASS else: label_space = None num_chunks, chunk_index = 1, 0 if hasattr(cfg.test, 'num_chunks'): num_chunks = cfg.test.num_chunks if hasattr(cfg.test, 'chunk_index'): chunk_index = cfg.test.chunk_index # Size of each chunk chunk_size = len(dataset_test) // num_chunks # Choose which chunk to load (0-indexed) # Define a subset of the dataset subset_indices = range( chunk_index * chunk_size, (chunk_index + 1) * chunk_size ) subset_dataset = Subset(dataset_test, indices=subset_indices) data_loader_test = torch.utils.data.DataLoader( subset_dataset, batch_size=1, shuffle=False, num_workers=1 ) car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode) car_model = car_model.to(device) if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None: print('Using sam online') # sam_checkpoint, model_type = build_sam_config(cfg) build_sam_config(cfg) evaluate( data_loader_test, cfg, car_model, test_cfg=cfg.test, label_space=label_space, stuff_label_space=stuff_label_space, ) if __name__ == '__main__': args = parse_args() main(args)