|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Metrics for evaluating the performance of the model.""" |
|
|
|
import torch |
|
|
|
|
|
def IoU(mask1, mask2, threshold=0.5): |
|
"""Calculate Intersection over Union (IoU) between prediction and GT masks. |
|
|
|
Args: |
|
mask1: A torch.Tensor denoting the prediction, shape (N, H, W), where N is |
|
the number of masks. |
|
mask2: A torch.Tensor denoting the ground truth, shape (N, H, W), where N |
|
is the number of masks. |
|
threshold: The threshold to binarize masks. |
|
Returns: |
|
IoU of `mask1` and `mask2`. |
|
""" |
|
if threshold > 0: |
|
mask1, mask2 = (mask1 > threshold).to(torch.bool), (mask2 > threshold).to( |
|
torch.bool |
|
) |
|
intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze() |
|
union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze() |
|
if union.sum() == 0: |
|
return 0 |
|
return (intersection.to(torch.float) / union).mean().item() |
|
|
|
|
|
def IoM(pred, target, min_pred_threshold=0.2): |
|
"""Calculate Intersection over the area of gt Mask and pred Mask (IoM). |
|
|
|
between prediction and each ground truth masks. |
|
Precaution: |
|
this function works for prediction and target that are binary masks, |
|
where 1 represents the mask and 0 represents the background. |
|
Args: |
|
pred: A torch.Tensor denoting the prediction, shape (N, H, W), where N is |
|
the number of masks. |
|
target: A torch.Tensor denoting the ground truth, shape (N, H, W), where N |
|
is the number of masks. |
|
min_pred_threshold: prediction threshold. |
|
|
|
Returns: |
|
ious: A torch.Tensor denoting the IoU, shape (N,). |
|
""" |
|
|
|
intersection = torch.einsum("mij,nij->mn", pred.to(target.device), target) |
|
area_pred = torch.einsum("mij->m", pred) |
|
area_target = torch.einsum("nij->n", target) |
|
|
|
iom_target = torch.einsum("mn,n->mn", intersection, 1 / area_target) |
|
iom_pred = torch.einsum("mn,m->mn", intersection, 1 / area_pred) |
|
|
|
|
|
iom_target[iom_pred < min_pred_threshold] = 0 |
|
|
|
|
|
iom = torch.max(iom_target, iom_pred) |
|
iom = iom.max(dim=0)[0] |
|
return iom |
|
|