curt-park's picture
Refactor code
1615d09
raw
history blame
5.07 kB
from datetime import timedelta
from pathlib import Path
import numpy as np
import torch
from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset,
PascalVocDataset, SBDEvaluationDataset)
from isegm.utils.serialization import load_model
def get_time_metrics(all_ious, elapsed_time):
n_images = len(all_ious)
n_clicks = sum(map(len, all_ious))
mean_spc = elapsed_time / n_clicks
mean_spi = elapsed_time / n_images
return mean_spc, mean_spi
def load_is_model(checkpoint, device, **kwargs):
if isinstance(checkpoint, (str, Path)):
state_dict = torch.load(checkpoint, map_location="cpu")
else:
state_dict = checkpoint
if isinstance(state_dict, list):
model = load_single_is_model(state_dict[0], device, **kwargs)
models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
return model, models
else:
return load_single_is_model(state_dict, device, **kwargs)
def load_single_is_model(state_dict, device, **kwargs):
model = load_model(state_dict["config"], **kwargs)
model.load_state_dict(state_dict["state_dict"], strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def get_dataset(dataset_name, cfg):
if dataset_name == "GrabCut":
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
elif dataset_name == "Berkeley":
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
elif dataset_name == "DAVIS":
dataset = DavisDataset(cfg.DAVIS_PATH)
elif dataset_name == "SBD":
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
elif dataset_name == "SBD_Train":
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train")
elif dataset_name == "PascalVOC":
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test")
elif dataset_name == "COCO_MVal":
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
else:
dataset = None
return dataset
def get_iou(gt_mask, pred_mask, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
obj_gt_mask = gt_mask == 1
intersection = np.logical_and(
np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv
).sum()
union = np.logical_and(
np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv
).sum()
return intersection / union
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
noc_list = []
over_max_list = []
for iou_thr in iou_thrs:
scores_arr = np.array(
[_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int
)
score = scores_arr.mean()
over_max = (scores_arr == max_clicks).sum()
noc_list.append(score)
over_max_list.append(over_max)
return noc_list, over_max_list
def find_checkpoint(weights_folder, checkpoint_name):
weights_folder = Path(weights_folder)
if ":" in checkpoint_name:
model_name, checkpoint_name = checkpoint_name.split(":")
models_candidates = [
x for x in weights_folder.glob(f"{model_name}*") if x.is_dir()
]
assert len(models_candidates) == 1
model_folder = models_candidates[0]
else:
model_folder = weights_folder
if checkpoint_name.endswith(".pth"):
if Path(checkpoint_name).exists():
checkpoint_path = checkpoint_name
else:
checkpoint_path = weights_folder / checkpoint_name
else:
model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth"))
assert len(model_checkpoints) == 1
checkpoint_path = model_checkpoints[0]
return str(checkpoint_path)
def get_results_table(
noc_list,
over_max_list,
brs_type,
dataset_name,
mean_spc,
elapsed_time,
n_clicks=20,
model_name=None,
):
table_header = (
f'|{"BRS Type":^13}|{"Dataset":^11}|'
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
f'{"SPC,s":^7}|{"Time":^9}|'
)
row_width = len(table_header)
header = f"Eval results for model: {model_name}\n" if model_name is not None else ""
header += "-" * row_width + "\n"
header += table_header + "\n" + "-" * row_width
eval_time = str(timedelta(seconds=int(elapsed_time)))
table_row = f"|{brs_type:^13}|{dataset_name:^11}|"
table_row += f"{noc_list[0]:^9.2f}|"
table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|"
return header, table_row