|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluate CaR on segmentation benchmarks.""" |
|
|
|
import argparse |
|
import numpy as np |
|
import torch |
|
from torch.utils import tensorboard |
|
import torch.utils.data |
|
from torch.utils.data import Subset |
|
import torchvision.transforms as T |
|
|
|
|
|
from modeling.model.car import CaR |
|
from sam.utils import build_sam_config |
|
from utils.utils import Config |
|
from utils.utils import load_yaml |
|
from utils.utils import MetricLogger |
|
from utils.utils import SmoothedValue |
|
from utils.inference_pipeline import inference_car |
|
from utils.merge_mask import merge_masks_simple |
|
|
|
|
|
|
|
from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset |
|
from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset |
|
from data.coco import COCO_OBJECT_CLASSES, COCODataset |
|
from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset |
|
from data.gres import GReferDataset |
|
from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset |
|
from data.refcoco import ReferDataset |
|
from data.voc import VOC_CLASSES, VOCDataset |
|
|
|
|
|
IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512 |
|
|
|
|
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
|
|
|
|
def get_dataset(cfg, ds_name, split, transform, data_root=None): |
|
"""Get dataset.""" |
|
data_args = dict(root=data_root) if data_root is not None else {} |
|
if 'refcoco' in ds_name: |
|
splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc' |
|
ds = ReferDataset( |
|
dataset=ds_name, |
|
splitBy=splitby, |
|
split=split, |
|
image_transforms=transform, |
|
target_transforms=transform, |
|
eval_mode=True, |
|
prompts_augment=cfg.test.prompts_augment, |
|
**data_args, |
|
) |
|
elif ds_name == 'gres': |
|
ds = GReferDataset(split=split, transform=transform, **data_args) |
|
elif ds_name == 'voc': |
|
ds = VOCDataset( |
|
year='2012', |
|
split=split, |
|
transform=transform, |
|
target_transform=transform, |
|
**data_args, |
|
) |
|
|
|
elif ds_name == 'cocostuff': |
|
ds = COCODataset(transform=transform, **data_args) |
|
|
|
elif ds_name == 'context': |
|
ds = CONTEXTDataset( |
|
year='2010', transform=transform, split=split, **data_args |
|
) |
|
elif ds_name == 'ade': |
|
ds = ADEDataset(split=split, transform=transform, **data_args) |
|
elif ds_name == 'pascal_459': |
|
ds = Pascal459Dataset(split=split, transform=transform, **data_args) |
|
elif ds_name == 'ade_847': |
|
ds = ADE847Dataset(split=split, transform=transform, **data_args) |
|
else: |
|
raise ValueError(f'Dataset {ds_name} not implemented') |
|
return ds |
|
|
|
|
|
def get_transform(): |
|
transforms = [ |
|
T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)), |
|
T.ToTensor(), |
|
] |
|
|
|
return T.Compose(transforms) |
|
|
|
|
|
def assign_label( |
|
all_masks, |
|
scores, |
|
stuff_masks=None, |
|
stuff_scores=None, |
|
id_mapping=None, |
|
stuff_id_mapping=None, |
|
): |
|
"""Assign labels.""" |
|
label_preds = np.zeros_like(all_masks[0]).astype(np.int32) |
|
if stuff_masks is not None: |
|
sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy()) |
|
stuff_masks = stuff_masks[sorted_idxs] |
|
stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs] |
|
for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores): |
|
if score > 0: |
|
|
|
mask = mask > 0.5 |
|
|
|
if stuff_id_mapping is not None: |
|
label_preds[mask] = stuff_id_mapping[sorted_idx] + 1 |
|
else: |
|
label_preds[mask] = sorted_idx + 1 |
|
sorted_idxs = np.argsort(scores.detach().cpu().numpy()) |
|
all_masks = all_masks[sorted_idxs] |
|
scores = scores.detach().cpu().numpy()[sorted_idxs] |
|
for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores): |
|
if score > 0: |
|
|
|
mask = mask > 0.5 |
|
|
|
if id_mapping is not None: |
|
label_preds[mask] = id_mapping[sorted_idx] + 1 |
|
else: |
|
label_preds[mask] = sorted_idx + 1 |
|
|
|
return label_preds |
|
|
|
|
|
def eval_semantic( |
|
label_space, |
|
algo, |
|
cfg, |
|
model, |
|
image_path, |
|
stuff_label_space=None, |
|
sam_pipeline=None, |
|
): |
|
"""Semantic segmentation evaluation.""" |
|
|
|
if label_space is None: |
|
raise ValueError( |
|
'label_space must be provided for semantic segmentation evaluation' |
|
) |
|
if algo == 'car': |
|
all_masks, scores = inference_car( |
|
cfg, model, image_path, label_space, sam_pipeline=sam_pipeline |
|
) |
|
if stuff_label_space is not None: |
|
if cfg.test.ds_name == 'context': |
|
thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID |
|
stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID |
|
elif cfg.test.ds_name == 'ade': |
|
thing_id_mapping = ADE_THING_CLASS_ID |
|
stuff_id_mapping = ADE_STUFF_CLASS_ID |
|
elif cfg.test.ds_name == 'pascal_459': |
|
thing_id_mapping = PASCAL_459_THING_CLASS_ID |
|
stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID |
|
elif cfg.test.ds_name == 'ade_847': |
|
thing_id_mapping = ADE_847_THING_CLASS_ID |
|
stuff_id_mapping = ADE_847_STUFF_CLASS_ID |
|
else: |
|
raise ValueError(f'Dataset {cfg.test.ds_name} not supported') |
|
|
|
model.mask_generator.set_bg_cls(label_space) |
|
model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type) |
|
model.set_bg_factor(cfg.car.stuff_bg_factor) |
|
stuff_masks, stuff_scores = inference_car( |
|
cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline |
|
) |
|
model.mask_generator.set_bg_cls(cfg.car.bg_cls) |
|
model.set_visual_prompt_type(cfg.car.visual_prompt_type) |
|
model.set_bg_factor(cfg.car.bg_factor) |
|
all_masks = all_masks.detach().cpu().numpy() |
|
stuff_masks = stuff_masks.detach().cpu().numpy() |
|
label_preds = assign_label( |
|
all_masks, |
|
scores, |
|
stuff_masks=stuff_masks, |
|
stuff_scores=stuff_scores, |
|
id_mapping=thing_id_mapping, |
|
stuff_id_mapping=stuff_id_mapping, |
|
) |
|
else: |
|
all_masks = all_masks.detach().cpu().numpy() |
|
label_preds = assign_label(all_masks, scores) |
|
return label_preds.squeeze() |
|
else: |
|
raise NotImplementedError(f'algo {algo} not implemented') |
|
|
|
|
|
def _fast_hist(label_true, label_pred, n_class=21): |
|
mask = (label_true >= 0) & (label_true < n_class) |
|
hist = np.bincount( |
|
n_class * label_true[mask].astype(int) + label_pred[mask], |
|
minlength=n_class**2, |
|
).reshape(n_class, n_class) |
|
return hist |
|
|
|
|
|
def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False): |
|
"""Semantic segmentation IOU.""" |
|
|
|
hist = np.zeros((n_class, n_class)) |
|
for lt, lp in zip(label_trues, label_preds): |
|
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) |
|
if ignore_background: |
|
hist = hist[1:, 1:] |
|
acc = np.diag(hist).sum() / hist.sum() |
|
acc_cls = np.diag(hist) / hist.sum(axis=1) |
|
acc_cls = np.nanmean(acc_cls) |
|
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) |
|
valid = hist.sum(axis=1) > 0 |
|
if valid.sum() == 0: |
|
mean_iu = 0 |
|
else: |
|
mean_iu = np.nanmean(iu[valid]) |
|
freq = hist.sum(axis=1) / hist.sum() |
|
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() |
|
if ignore_background: |
|
cls_iu = dict(zip(range(1, n_class), iu)) |
|
else: |
|
cls_iu = dict(zip(range(n_class), iu)) |
|
|
|
return { |
|
'Pixel Accuracy': acc, |
|
'Mean Accuracy': acc_cls, |
|
'Frequency Weighted IoU': fwavacc, |
|
'mIoU': mean_iu, |
|
'Class IoU': cls_iu, |
|
} |
|
|
|
|
|
def evaluate( |
|
data_loader, |
|
cfg, |
|
model, |
|
test_cfg, |
|
label_space=None, |
|
stuff_label_space=None, |
|
sam_pipeline=None, |
|
): |
|
"""Run evaluation.""" |
|
|
|
if ( |
|
test_cfg.ds_name |
|
not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847'] |
|
and test_cfg.seg_mode == 'semantic' |
|
): |
|
raise ValueError(( |
|
'Semantic segmentation evaluation is only implemented for voc, ' |
|
'context, coco object, ade, pascal459, ade847 dataset' |
|
)) |
|
|
|
metric_logger = MetricLogger(delimiter=' ') |
|
metric_logger.add_meter( |
|
'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})') |
|
) |
|
|
|
cum_i, cum_u = 0, 0 |
|
eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9] |
|
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
|
seg_total = 0 |
|
mean_iou = [] |
|
header = 'Test:' |
|
|
|
|
|
label_preds, label_gts = [], [] |
|
print(len(data_loader)) |
|
cc = 0 |
|
use_tensorboard = False |
|
if hasattr(cfg.test, 'use_tensorboard'): |
|
use_tensorboard = cfg.test.use_tensorboard |
|
|
|
if use_tensorboard: |
|
writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path) |
|
for data in metric_logger.log_every(data_loader, 1, header): |
|
_, image_paths, target_list, sentences_list = data |
|
|
|
|
|
if not isinstance(target_list, list): |
|
target_list, sentences_list = [target_list], [sentences_list] |
|
for target, sentences in zip(target_list, sentences_list): |
|
image_path = image_paths[0] |
|
|
|
if test_cfg.seg_mode == 'refer': |
|
all_masks, all_scores = inference_car( |
|
cfg, model, image_path, sentences, sam_pipeline=sam_pipeline |
|
) |
|
|
|
final_mask = merge_masks_simple( |
|
all_masks, *target.shape[1:], scores=all_scores |
|
) |
|
intersection, union, cur_iou = compute_iou(final_mask, target) |
|
|
|
metric_logger.update(mIoU=cur_iou) |
|
mean_iou.append(cur_iou) |
|
if use_tensorboard: |
|
writer.add_scalar('Mean IoU', cur_iou, cc) |
|
cum_i += intersection |
|
cum_u += union |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
|
seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou |
|
seg_total += 1 |
|
elif test_cfg.seg_mode == 'semantic': |
|
|
|
label_pred = eval_semantic( |
|
label_space, |
|
test_cfg.algo, |
|
cfg, |
|
model, |
|
image_path, |
|
stuff_label_space, |
|
) |
|
label_gt = target.squeeze().cpu().numpy() |
|
cur_iou = semantic_iou( |
|
[label_gt], |
|
[label_pred], |
|
n_class=cfg.test.n_class, |
|
ignore_background=cfg.test.ignore_background, |
|
)['mIoU'] |
|
metric_logger.update(mIoU=cur_iou) |
|
label_preds.append(label_pred) |
|
label_gts.append(label_gt) |
|
|
|
cc += 1 |
|
|
|
if test_cfg.seg_mode == 'refer': |
|
mean_iou = np.array(mean_iou) |
|
m_iou = np.mean(mean_iou) |
|
if use_tensorboard: |
|
writer.add_scalar('mIoU', m_iou.item(), len(data_loader)) |
|
print('Final results:') |
|
print('Mean IoU is %.2f\n' % (m_iou * 100.0)) |
|
results_str = '' |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
results_str += ' precision@%s = %.2f\n' % ( |
|
str(eval_seg_iou_list[n_eval_iou]), |
|
seg_correct[n_eval_iou] * 100.0 / seg_total, |
|
) |
|
o_iou = cum_i * 100.0 / cum_u |
|
results_str += ' overall IoU = %.2f\n' % o_iou |
|
if use_tensorboard: |
|
writer.add_scalar('oIoU', o_iou, 0) |
|
print(results_str) |
|
elif test_cfg.seg_mode == 'semantic': |
|
iou_score = semantic_iou( |
|
label_gts, |
|
label_preds, |
|
n_class=cfg.test.n_class, |
|
ignore_background=cfg.test.ignore_background, |
|
) |
|
if use_tensorboard: |
|
writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader)) |
|
|
|
print(iou_score) |
|
if use_tensorboard: |
|
writer.close() |
|
|
|
|
|
def compute_iou(pred_seg, gd_seg): |
|
"""Compute IoU.""" |
|
intersection = torch.sum(torch.logical_and(pred_seg, gd_seg)) |
|
union = torch.sum(torch.logical_or(pred_seg, gd_seg)) |
|
iou = intersection * 1.0 / union |
|
if union == 0: |
|
iou = 0 |
|
return intersection, union, iou |
|
|
|
|
|
def list_of_strings(arg): |
|
return [a.strip() for a in arg.split(',')] |
|
|
|
|
|
|
|
def parse_args(): |
|
"""Parse arguments.""" |
|
parser = argparse.ArgumentParser(description='Training') |
|
parser.add_argument( |
|
'--cfg-path', |
|
default='configs/refcoco_test_prompt.yaml', |
|
help='path to configuration file.', |
|
) |
|
parser.add_argument('--index', default=0, type=int, help='split task') |
|
parser.add_argument('--mask_threshold', default=0.0, type=float) |
|
parser.add_argument('--confidence_threshold', default=0.0, type=float) |
|
parser.add_argument('--clipes_threshold', default=0.0, type=float) |
|
parser.add_argument('--stuff_bg_factor', default=0.0, type=float) |
|
parser.add_argument('--bg_factor', default=0.0, type=float) |
|
parser.add_argument('--output_path', default=None, type=str) |
|
parser.add_argument( |
|
'--visual_prompt_type', default=None, type=list_of_strings |
|
) |
|
parser.add_argument( |
|
'--stuff_visual_prompt_type', default=None, type=list_of_strings |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def main(args): |
|
cfg = Config(**load_yaml(args.cfg_path)) |
|
if args.mask_threshold > 0: |
|
cfg.car.mask_threshold = args.mask_threshold |
|
if args.confidence_threshold > 0: |
|
cfg.car.confidence_threshold = args.confidence_threshold |
|
if args.clipes_threshold > 0: |
|
cfg.car.clipes_threshold = args.clipes_threshold |
|
if args.bg_factor > 0: |
|
cfg.car.bg_factor = args.bg_factor |
|
if args.stuff_bg_factor > 0: |
|
cfg.car.stuff_bg_factor = args.stuff_bg_factor |
|
if args.output_path is not None: |
|
cfg.test.output_path = args.output_path |
|
if args.visual_prompt_type is not None: |
|
cfg.car.visual_prompt_type = args.visual_prompt_type |
|
if args.stuff_visual_prompt_type is not None: |
|
cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type |
|
|
|
try: |
|
data_root = cfg.test.data_root |
|
except ValueError: |
|
data_root = None |
|
|
|
dataset_test = get_dataset( |
|
cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root |
|
) |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
stuff_label_space = None |
|
if cfg.test.ds_name == 'voc': |
|
label_space = VOC_CLASSES |
|
elif cfg.test.ds_name == 'cocostuff': |
|
label_space = COCO_OBJECT_CLASSES |
|
elif cfg.test.ds_name == 'context': |
|
|
|
label_space = PASCAL_CONTEXT_THING_CLASS |
|
stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS |
|
elif cfg.test.ds_name == 'ade': |
|
label_space = ADE_THING_CLASS |
|
stuff_label_space = ADE_STUFF_CLASS |
|
elif cfg.test.ds_name == 'pascal_459': |
|
label_space = PASCAL_459_THING_CLASS |
|
stuff_label_space = PASCAL_459_STUFF_CLASS |
|
elif cfg.test.ds_name == 'ade_847': |
|
label_space = ADE_847_THING_CLASS |
|
stuff_label_space = ADE_847_STUFF_CLASS |
|
else: |
|
label_space = None |
|
|
|
num_chunks, chunk_index = 1, 0 |
|
if hasattr(cfg.test, 'num_chunks'): |
|
num_chunks = cfg.test.num_chunks |
|
if hasattr(cfg.test, 'chunk_index'): |
|
chunk_index = cfg.test.chunk_index |
|
|
|
chunk_size = len(dataset_test) // num_chunks |
|
|
|
|
|
subset_indices = range( |
|
chunk_index * chunk_size, (chunk_index + 1) * chunk_size |
|
) |
|
subset_dataset = Subset(dataset_test, indices=subset_indices) |
|
|
|
data_loader_test = torch.utils.data.DataLoader( |
|
subset_dataset, batch_size=1, shuffle=False, num_workers=1 |
|
) |
|
|
|
car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode) |
|
|
|
car_model = car_model.to(device) |
|
|
|
if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None: |
|
print('Using sam online') |
|
|
|
build_sam_config(cfg) |
|
|
|
evaluate( |
|
data_loader_test, |
|
cfg, |
|
car_model, |
|
test_cfg=cfg.test, |
|
label_space=label_space, |
|
stuff_label_space=stuff_label_space, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args) |
|
|