CLIP_as_RNN / evaluate.py
Kevin Sun
init commit
6cd90b7
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate CaR on segmentation benchmarks."""
# pylint: disable=g-importing-member
import argparse
import numpy as np
import torch
from torch.utils import tensorboard
import torch.utils.data
from torch.utils.data import Subset
import torchvision.transforms as T
# pylint: disable=g-bad-import-order
from modeling.model.car import CaR
from sam.utils import build_sam_config
from utils.utils import Config
from utils.utils import load_yaml
from utils.utils import MetricLogger
from utils.utils import SmoothedValue
from utils.inference_pipeline import inference_car
from utils.merge_mask import merge_masks_simple
# Datasets
# pylint: disable=g-multiple-import
from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset
from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset
from data.coco import COCO_OBJECT_CLASSES, COCODataset
from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset
from data.gres import GReferDataset
from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset
from data.refcoco import ReferDataset
from data.voc import VOC_CLASSES, VOCDataset
IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512
# set random seed
torch.manual_seed(0)
np.random.seed(0)
def get_dataset(cfg, ds_name, split, transform, data_root=None):
"""Get dataset."""
data_args = dict(root=data_root) if data_root is not None else {}
if 'refcoco' in ds_name:
splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc'
ds = ReferDataset(
dataset=ds_name,
splitBy=splitby,
split=split,
image_transforms=transform,
target_transforms=transform,
eval_mode=True,
prompts_augment=cfg.test.prompts_augment,
**data_args,
)
elif ds_name == 'gres':
ds = GReferDataset(split=split, transform=transform, **data_args)
elif ds_name == 'voc':
ds = VOCDataset(
year='2012',
split=split,
transform=transform,
target_transform=transform,
**data_args,
)
elif ds_name == 'cocostuff':
ds = COCODataset(transform=transform, **data_args)
elif ds_name == 'context':
ds = CONTEXTDataset(
year='2010', transform=transform, split=split, **data_args
)
elif ds_name == 'ade':
ds = ADEDataset(split=split, transform=transform, **data_args)
elif ds_name == 'pascal_459':
ds = Pascal459Dataset(split=split, transform=transform, **data_args)
elif ds_name == 'ade_847':
ds = ADE847Dataset(split=split, transform=transform, **data_args)
else:
raise ValueError(f'Dataset {ds_name} not implemented')
return ds
def get_transform():
transforms = [
T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
T.ToTensor(),
]
return T.Compose(transforms)
def assign_label(
all_masks,
scores,
stuff_masks=None,
stuff_scores=None,
id_mapping=None,
stuff_id_mapping=None,
):
"""Assign labels."""
label_preds = np.zeros_like(all_masks[0]).astype(np.int32)
if stuff_masks is not None:
sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy())
stuff_masks = stuff_masks[sorted_idxs]
stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs]
for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores):
if score > 0:
# convert mask to boolean
mask = mask > 0.5
# assign label
if stuff_id_mapping is not None:
label_preds[mask] = stuff_id_mapping[sorted_idx] + 1
else:
label_preds[mask] = sorted_idx + 1
sorted_idxs = np.argsort(scores.detach().cpu().numpy())
all_masks = all_masks[sorted_idxs]
scores = scores.detach().cpu().numpy()[sorted_idxs]
for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores):
if score > 0:
# convert mask to boolean
mask = mask > 0.5
# assign label
if id_mapping is not None:
label_preds[mask] = id_mapping[sorted_idx] + 1
else:
label_preds[mask] = sorted_idx + 1
return label_preds
def eval_semantic(
label_space,
algo,
cfg,
model,
image_path,
stuff_label_space=None,
sam_pipeline=None,
):
"""Semantic segmentation evaluation."""
if label_space is None:
raise ValueError(
'label_space must be provided for semantic segmentation evaluation'
)
if algo == 'car':
all_masks, scores = inference_car(
cfg, model, image_path, label_space, sam_pipeline=sam_pipeline
)
if stuff_label_space is not None:
if cfg.test.ds_name == 'context':
thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID
stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID
elif cfg.test.ds_name == 'ade':
thing_id_mapping = ADE_THING_CLASS_ID
stuff_id_mapping = ADE_STUFF_CLASS_ID
elif cfg.test.ds_name == 'pascal_459':
thing_id_mapping = PASCAL_459_THING_CLASS_ID
stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID
elif cfg.test.ds_name == 'ade_847':
thing_id_mapping = ADE_847_THING_CLASS_ID
stuff_id_mapping = ADE_847_STUFF_CLASS_ID
else:
raise ValueError(f'Dataset {cfg.test.ds_name} not supported')
model.mask_generator.set_bg_cls(label_space)
model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type)
model.set_bg_factor(cfg.car.stuff_bg_factor)
stuff_masks, stuff_scores = inference_car(
cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline
)
model.mask_generator.set_bg_cls(cfg.car.bg_cls)
model.set_visual_prompt_type(cfg.car.visual_prompt_type)
model.set_bg_factor(cfg.car.bg_factor)
all_masks = all_masks.detach().cpu().numpy()
stuff_masks = stuff_masks.detach().cpu().numpy()
label_preds = assign_label(
all_masks,
scores,
stuff_masks=stuff_masks,
stuff_scores=stuff_scores,
id_mapping=thing_id_mapping,
stuff_id_mapping=stuff_id_mapping,
)
else:
all_masks = all_masks.detach().cpu().numpy()
label_preds = assign_label(all_masks, scores)
return label_preds.squeeze()
else:
raise NotImplementedError(f'algo {algo} not implemented')
def _fast_hist(label_true, label_pred, n_class=21):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) + label_pred[mask],
minlength=n_class**2,
).reshape(n_class, n_class)
return hist
def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False):
"""Semantic segmentation IOU."""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
if ignore_background:
hist = hist[1:, 1:]
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
valid = hist.sum(axis=1) > 0 # added
if valid.sum() == 0:
mean_iu = 0
else:
mean_iu = np.nanmean(iu[valid])
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
if ignore_background:
cls_iu = dict(zip(range(1, n_class), iu))
else:
cls_iu = dict(zip(range(n_class), iu))
return {
'Pixel Accuracy': acc,
'Mean Accuracy': acc_cls,
'Frequency Weighted IoU': fwavacc,
'mIoU': mean_iu,
'Class IoU': cls_iu,
}
def evaluate(
data_loader,
cfg,
model,
test_cfg,
label_space=None,
stuff_label_space=None,
sam_pipeline=None,
):
"""Run evaluation."""
if (
test_cfg.ds_name
not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847']
and test_cfg.seg_mode == 'semantic'
):
raise ValueError((
'Semantic segmentation evaluation is only implemented for voc, '
'context, coco object, ade, pascal459, ade847 dataset'
))
metric_logger = MetricLogger(delimiter=' ')
metric_logger.add_meter(
'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})')
)
# evaluation variables
cum_i, cum_u = 0, 0
eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9]
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
seg_total = 0
mean_iou = []
header = 'Test:'
# all_masks = []
label_preds, label_gts = [], []
print(len(data_loader))
cc = 0
use_tensorboard = False
if hasattr(cfg.test, 'use_tensorboard'):
use_tensorboard = cfg.test.use_tensorboard
if use_tensorboard:
writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path)
for data in metric_logger.log_every(data_loader, 1, header):
_, image_paths, target_list, sentences_list = data
# print(type(target_lis))
if not isinstance(target_list, list):
target_list, sentences_list = [target_list], [sentences_list]
for target, sentences in zip(target_list, sentences_list):
image_path = image_paths[0]
# print(image_path)
if test_cfg.seg_mode == 'refer':
all_masks, all_scores = inference_car(
cfg, model, image_path, sentences, sam_pipeline=sam_pipeline
)
# final_mask = merge_masks(all_masks, *target.shape[1:])
final_mask = merge_masks_simple(
all_masks, *target.shape[1:], scores=all_scores
)
intersection, union, cur_iou = compute_iou(final_mask, target)
# cur_iou = IoU(final_mask, target, 0)
metric_logger.update(mIoU=cur_iou)
mean_iou.append(cur_iou)
if use_tensorboard:
writer.add_scalar('Mean IoU', cur_iou, cc)
cum_i += intersection
cum_u += union
for n_eval_iou in range(len(eval_seg_iou_list)):
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou
seg_total += 1
elif test_cfg.seg_mode == 'semantic':
# torch.cuda.empty_cache()
label_pred = eval_semantic(
label_space,
test_cfg.algo,
cfg,
model,
image_path,
stuff_label_space,
)
label_gt = target.squeeze().cpu().numpy()
cur_iou = semantic_iou(
[label_gt],
[label_pred],
n_class=cfg.test.n_class,
ignore_background=cfg.test.ignore_background,
)['mIoU']
metric_logger.update(mIoU=cur_iou)
label_preds.append(label_pred)
label_gts.append(label_gt)
cc += 1
if test_cfg.seg_mode == 'refer':
mean_iou = np.array(mean_iou)
m_iou = np.mean(mean_iou)
if use_tensorboard:
writer.add_scalar('mIoU', m_iou.item(), len(data_loader))
print('Final results:')
print('Mean IoU is %.2f\n' % (m_iou * 100.0))
results_str = ''
for n_eval_iou in range(len(eval_seg_iou_list)):
results_str += ' precision@%s = %.2f\n' % (
str(eval_seg_iou_list[n_eval_iou]),
seg_correct[n_eval_iou] * 100.0 / seg_total,
)
o_iou = cum_i * 100.0 / cum_u
results_str += ' overall IoU = %.2f\n' % o_iou
if use_tensorboard:
writer.add_scalar('oIoU', o_iou, 0)
print(results_str)
elif test_cfg.seg_mode == 'semantic':
iou_score = semantic_iou(
label_gts,
label_preds,
n_class=cfg.test.n_class,
ignore_background=cfg.test.ignore_background,
)
if use_tensorboard:
writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader))
print(iou_score)
if use_tensorboard:
writer.close()
def compute_iou(pred_seg, gd_seg):
"""Compute IoU."""
intersection = torch.sum(torch.logical_and(pred_seg, gd_seg))
union = torch.sum(torch.logical_or(pred_seg, gd_seg))
iou = intersection * 1.0 / union
if union == 0:
iou = 0
return intersection, union, iou
def list_of_strings(arg):
return [a.strip() for a in arg.split(',')]
# pylint: disable=redefined-outer-name
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(description='Training')
parser.add_argument(
'--cfg-path',
default='configs/refcoco_test_prompt.yaml',
help='path to configuration file.',
)
parser.add_argument('--index', default=0, type=int, help='split task')
parser.add_argument('--mask_threshold', default=0.0, type=float)
parser.add_argument('--confidence_threshold', default=0.0, type=float)
parser.add_argument('--clipes_threshold', default=0.0, type=float)
parser.add_argument('--stuff_bg_factor', default=0.0, type=float)
parser.add_argument('--bg_factor', default=0.0, type=float)
parser.add_argument('--output_path', default=None, type=str)
parser.add_argument(
'--visual_prompt_type', default=None, type=list_of_strings
)
parser.add_argument(
'--stuff_visual_prompt_type', default=None, type=list_of_strings
)
args = parser.parse_args()
return args
def main(args):
cfg = Config(**load_yaml(args.cfg_path))
if args.mask_threshold > 0:
cfg.car.mask_threshold = args.mask_threshold
if args.confidence_threshold > 0:
cfg.car.confidence_threshold = args.confidence_threshold
if args.clipes_threshold > 0:
cfg.car.clipes_threshold = args.clipes_threshold
if args.bg_factor > 0:
cfg.car.bg_factor = args.bg_factor
if args.stuff_bg_factor > 0:
cfg.car.stuff_bg_factor = args.stuff_bg_factor
if args.output_path is not None:
cfg.test.output_path = args.output_path
if args.visual_prompt_type is not None:
cfg.car.visual_prompt_type = args.visual_prompt_type
if args.stuff_visual_prompt_type is not None:
cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type
try:
data_root = cfg.test.data_root
except ValueError:
data_root = None
dataset_test = get_dataset(
cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
stuff_label_space = None
if cfg.test.ds_name == 'voc':
label_space = VOC_CLASSES
elif cfg.test.ds_name == 'cocostuff':
label_space = COCO_OBJECT_CLASSES
elif cfg.test.ds_name == 'context':
# label_space = PASCAL_CONTEXT_CLASSES
label_space = PASCAL_CONTEXT_THING_CLASS
stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS
elif cfg.test.ds_name == 'ade':
label_space = ADE_THING_CLASS
stuff_label_space = ADE_STUFF_CLASS
elif cfg.test.ds_name == 'pascal_459':
label_space = PASCAL_459_THING_CLASS
stuff_label_space = PASCAL_459_STUFF_CLASS
elif cfg.test.ds_name == 'ade_847':
label_space = ADE_847_THING_CLASS
stuff_label_space = ADE_847_STUFF_CLASS
else:
label_space = None
num_chunks, chunk_index = 1, 0
if hasattr(cfg.test, 'num_chunks'):
num_chunks = cfg.test.num_chunks
if hasattr(cfg.test, 'chunk_index'):
chunk_index = cfg.test.chunk_index
# Size of each chunk
chunk_size = len(dataset_test) // num_chunks
# Choose which chunk to load (0-indexed)
# Define a subset of the dataset
subset_indices = range(
chunk_index * chunk_size, (chunk_index + 1) * chunk_size
)
subset_dataset = Subset(dataset_test, indices=subset_indices)
data_loader_test = torch.utils.data.DataLoader(
subset_dataset, batch_size=1, shuffle=False, num_workers=1
)
car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode)
car_model = car_model.to(device)
if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None:
print('Using sam online')
# sam_checkpoint, model_type = build_sam_config(cfg)
build_sam_config(cfg)
evaluate(
data_loader_test,
cfg,
car_model,
test_cfg=cfg.test,
label_space=label_space,
stuff_label_space=stuff_label_space,
)
if __name__ == '__main__':
args = parse_args()
main(args)