cell-seg-sribd / compute_metric.py
Lewislou's picture
Upload 24 files
0ca2a11
raw
history blame
6.52 kB
"""
Created on Thu Mar 31 18:10:52 2022
adapted form https://github.com/stardist/stardist/blob/master/stardist/matching.py
Thanks the authors of Stardist for sharing the great code
"""
import argparse
import numpy as np
from numba import jit
from scipy.optimize import linear_sum_assignment
from collections import OrderedDict
import pandas as pd
from skimage import segmentation
import tifffile as tif
import os
join = os.path.join
from tqdm import tqdm
def _intersection_over_union(masks_true, masks_pred):
""" intersection over union of all mask pairs
Parameters
------------
masks_true: ND-array, int
ground truth masks, where 0=NO masks; 1,2... are mask labels
masks_pred: ND-array, int
predicted masks, where 0=NO masks; 1,2... are mask labels
"""
overlap = _label_overlap(masks_true, masks_pred)
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
iou[np.isnan(iou)] = 0.0
return iou
@jit(nopython=True)
def _label_overlap(x, y):
""" fast function to get pixel overlaps between masks in x and y
Parameters
------------
x: ND-array, int
where 0=NO masks; 1,2... are mask labels
y: ND-array, int
where 0=NO masks; 1,2... are mask labels
Returns
------------
overlap: ND-array, int
matrix of pixel overlaps of size [x.max()+1, y.max()+1]
"""
x = x.ravel()
y = y.ravel()
# preallocate a 'contact map' matrix
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
# loop over the labels in x and add to the corresponding
# overlap entry. If label A in x and label B in y share P
# pixels, then the resulting overlap is P
# len(x)=len(y), the number of pixels in the whole image
for i in range(len(x)):
overlap[x[i],y[i]] += 1
return overlap
def _true_positive(iou, th):
""" true positive at threshold th
Parameters
------------
iou: float, ND-array
array of IOU pairs
th: float
threshold on IOU for positive label
Returns
------------
tp: float
number of true positives at threshold
"""
n_min = min(iou.shape[0], iou.shape[1])
costs = -(iou >= th).astype(float) - iou / (2*n_min)
true_ind, pred_ind = linear_sum_assignment(costs)
match_ok = iou[true_ind, pred_ind] >= th
tp = match_ok.sum()
return tp
def eval_tp_fp_fn(masks_true, masks_pred, threshold=0.5):
num_inst_gt = np.max(masks_true)
num_inst_seg = np.max(masks_pred)
if num_inst_seg>0:
iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
# for k,th in enumerate(threshold):
tp = _true_positive(iou, threshold)
fp = num_inst_seg - tp
fn = num_inst_gt - tp
else:
print('No segmentation results!')
tp = 0
fp = 0
fn = 0
return tp, fp, fn
def remove_boundary_cells(mask):
W, H = mask.shape
bd = np.ones((W, H))
bd[2:W-2, 2:H-2] = 0
bd_cells = np.unique(mask*bd)
for i in bd_cells[1:]:
mask[mask==i] = 0
new_label,_,_ = segmentation.relabel_sequential(mask)
return new_label
def main():
parser = argparse.ArgumentParser('Compute F1 score for cell segmentation results', add_help=False)
# Dataset parameters
parser.add_argument('--gt_path', type=str, help='path to ground truth; file names end with _label.tiff', required=True)
parser.add_argument('--seg_path', type=str, help='path to segmentation results; file names are the same as ground truth', required=True)
parser.add_argument('--save_path', default='./', help='path where to save metrics')
args = parser.parse_args()
gt_path = args.gt_path
seg_path = args.seg_path
names = sorted(os.listdir(seg_path))
seg_metric = OrderedDict()
seg_metric['Names'] = []
seg_metric['F1_Score'] = []
for name in tqdm(names):
assert name.endswith('_label.tiff'), 'The suffix of label name should be _label.tiff'
# Load the images for this case
gt = tif.imread(join(gt_path, name))
seg = tif.imread(join(seg_path, name))
# Score the cases
# do not consider cells on the boundaries during evaluation
if np.prod(gt.shape)<25000000:
gt = remove_boundary_cells(gt.astype(np.int32))
seg = remove_boundary_cells(seg.astype(np.int32))
tp, fp, fn = eval_tp_fp_fn(gt, seg, threshold=0.5)
else: # for large images (>5000x5000), the F1 score is computed by a patch-based way
H, W = gt.shape
roi_size = 2000
if H % roi_size != 0:
n_H = H // roi_size + 1
new_H = roi_size * n_H
else:
n_H = H // roi_size
new_H = H
if W % roi_size != 0:
n_W = W // roi_size + 1
new_W = roi_size * n_W
else:
n_W = W // roi_size
new_W = W
gt_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
seg_pad = np.zeros((new_H, new_W), dtype=gt.dtype)
gt_pad[:H, :W] = gt
seg_pad[:H, :W] = seg
tp = 0
fp = 0
fn = 0
for i in range(n_H):
for j in range(n_W):
gt_roi = remove_boundary_cells(gt_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
seg_roi = remove_boundary_cells(seg_pad[roi_size*i:roi_size*(i+1), roi_size*j:roi_size*(j+1)])
tp_i, fp_i, fn_i = eval_tp_fp_fn(gt_roi, seg_roi, threshold=0.5)
tp += tp_i
fp += fp_i
fn += fn_i
if tp == 0:
precision = 0
recall = 0
f1 = 0
else:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2*(precision * recall)/ (precision + recall)
seg_metric['Names'].append(name)
seg_metric['F1_Score'].append(np.round(f1, 4))
seg_metric_df = pd.DataFrame(seg_metric)
seg_metric_df.to_csv(join(args.save_path, 'seg_metric.csv'), index=False)
print('mean F1 Score:', np.mean(seg_metric['F1_Score']))
if __name__ == '__main__':
main()