|
""" |
|
Created on Thu Mar 31 18:10:52 2022 |
|
adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py |
|
Thanks the authors of Stardist for sharing the great code |
|
|
|
""" |
|
|
|
import argparse |
|
import numpy as np |
|
from numba import jit |
|
from scipy.optimize import linear_sum_assignment |
|
from collections import OrderedDict |
|
import pandas as pd |
|
from skimage import segmentation |
|
import tifffile as tif |
|
import os |
|
join = os.path.join |
|
from tqdm import tqdm |
|
|
|
def _intersection_over_union(masks_true, masks_pred): |
|
""" intersection over union of all mask pairs |
|
|
|
Parameters |
|
------------ |
|
|
|
masks_true: ND-array, int |
|
ground truth masks, where 0=NO masks; 1,2... are mask labels |
|
masks_pred: ND-array, int |
|
predicted masks, where 0=NO masks; 1,2... are mask labels |
|
""" |
|
overlap = _label_overlap(masks_true, masks_pred) |
|
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) |
|
n_pixels_true = np.sum(overlap, axis=1, keepdims=True) |
|
iou = overlap / (n_pixels_pred + n_pixels_true - overlap) |
|
iou[np.isnan(iou)] = 0.0 |
|
return iou |
|
|
|
@jit(nopython=True) |
|
def _label_overlap(x, y): |
|
""" fast function to get pixel overlaps between masks in x and y |
|
|
|
Parameters |
|
------------ |
|
|
|
x: ND-array, int |
|
where 0=NO masks; 1,2... are mask labels |
|
y: ND-array, int |
|
where 0=NO masks; 1,2... are mask labels |
|
|
|
Returns |
|
------------ |
|
|
|
overlap: ND-array, int |
|
matrix of pixel overlaps of size [x.max()+1, y.max()+1] |
|
|
|
""" |
|
x = x.ravel() |
|
y = y.ravel() |
|
|
|
|
|
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(x)): |
|
overlap[x[i],y[i]] += 1 |
|
return overlap |
|
|
|
def _true_positive(iou, th): |
|
""" true positive at threshold th |
|
|
|
Parameters |
|
------------ |
|
|
|
iou: float, ND-array |
|
array of IOU pairs |
|
th: float |
|
threshold on IOU for positive label |
|
|
|
Returns |
|
------------ |
|
|
|
tp: float |
|
number of true positives at threshold |
|
""" |
|
n_min = min(iou.shape[0], iou.shape[1]) |
|
costs = -(iou >= th).astype(float) - iou / (2*n_min) |
|
true_ind, pred_ind = linear_sum_assignment(costs) |
|
match_ok = iou[true_ind, pred_ind] >= th |
|
tp = match_ok.sum() |
|
return tp |
|
|
|
def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5): |
|
num_inst_gt = np.max(masks_true) |
|
num_inst_seg = np.max(masks_pred) |
|
if num_inst_seg>0: |
|
iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:] |
|
|
|
tp = _true_positive(iou, threshold) |
|
fp = num_inst_seg - tp |
|
fn = num_inst_gt - tp |
|
else: |
|
print('No segmentation results!') |
|
tp = 0 |
|
fp = 0 |
|
fn = 0 |
|
|
|
return tp, fp, fn |
|
|
|
def remove_boundary_cells(mask): |
|
W, H = mask.shape |
|
bd = np.ones((W, H)) |
|
bd[2:W-2, 2:H-2] = 0 |
|
bd_cells = np.unique(mask*bd) |
|
for i in bd_cells[1:]: |
|
mask[mask==i] = 0 |
|
new_label,_,_ = segmentation.relabel_sequential(mask) |
|
return new_label |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False) |
|
|
|
parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True) |
|
parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True) |
|
parser.add_argument('--save_path', default='./', help='path where to save metrics') |
|
args = parser.parse_args() |
|
|
|
gt_path = args.gt_path |
|
seg_path = args.seg_path |
|
names = sorted(os.listdir(seg_path)) |
|
seg_metric = OrderedDict() |
|
seg_metric['Names'] = [] |
|
seg_metric['F1_Score'] = [] |
|
for name in tqdm(names): |
|
assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff' |
|
|
|
|
|
gt = tif.imread(join(gt_path, name)) |
|
seg = tif.imread(join(seg_path, name)) |
|
|
|
|
|
|
|
if np.prod(gt.shape)<25000000: |
|
gt = remove_boundary_cells(gt.astype(np.int32)) |
|
seg = remove_boundary_cells(seg.astype(np.int32)) |
|
tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5) |
|
else: |
|
H, W = gt.shape |
|
roi_size = 2000 |
|
|
|
if H % roi_size != 0: |
|
n_H = H // roi_size + 1 |
|
new_H = roi_size * n_H |
|
else: |
|
n_H = H // roi_size |
|
new_H = H |
|
|
|
if W % roi_size != 0: |
|
n_W = W // roi_size + 1 |
|
new_W = roi_size * n_W |
|
else: |
|
n_W = W // roi_size |
|
new_W = W |
|
|
|
gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype) |
|
seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype) |
|
gt_pad[:H, :W] = gt |
|
seg_pad[:H, :W] = seg |
|
|
|
tp = 0 |
|
fp = 0 |
|
fn = 0 |
|
for i in range(n_H): |
|
for j in range(n_W): |
|
gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) |
|
seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)]) |
|
tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5) |
|
tp += tp_i |
|
fp += fp_i |
|
fn += fn_i |
|
|
|
if tp == 0: |
|
precision = 0 |
|
recall = 0 |
|
f1 = 0 |
|
else: |
|
precision = tp / (tp + fp) |
|
recall = tp / (tp + fn) |
|
f1 = 2*(precision * recall)/ (precision + recall) |
|
seg_metric['Names'].append(name) |
|
seg_metric['F1_Score'].append(np.round(f1, 4)) |
|
|
|
|
|
seg_metric_df = pd.DataFrame(seg_metric) |
|
seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False) |
|
print('mean F1 Score:', np.mean(seg_metric['F1_Score'])) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|