Nadine Rueegg
initial commit with code and data
753fd9a
raw
history blame
4.21 kB
# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py
import torch
import numpy as np
from scipy.ndimage import distance_transform_edt as distance
from skimage import segmentation as skimage_seg
import matplotlib.pyplot as plt
def dice_loss(score, target):
# implemented from paper https://arxiv.org/pdf/1606.04797.pdf
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
class tversky_loss(torch.nn.Module):
# implemented from https://arxiv.org/pdf/1706.05721.pdf
def __init__(self, alpha, beta):
'''
Args:
alpha: coefficient for false positive prediction
beta: coefficient for false negtive prediction
'''
super(tversky_loss, self).__init__()
self.alpha = alpha
self.beta = beta
def __call__(self, score, target):
target = target.float()
smooth = 1e-5
tp = torch.sum(score * target)
fn = torch.sum(target * (1 - score))
fp = torch.sum((1-target) * score)
loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth)
loss = 1 - loss
return loss
def compute_sdf1_1(img_gt, out_shape):
"""
compute the normalized signed distance map of binary mask
input: segmentation, shape = (batch_size, x, y, z)
output: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
normalize sdf to [-1, 1]
"""
img_gt = img_gt.astype(np.uint8)
normalized_sdf = np.zeros(out_shape)
for b in range(out_shape[0]): # batch size
# ignore background
for c in range(1, out_shape[1]):
posmask = img_gt[b]
negmask = 1-posmask
posdis = distance(posmask)
negdis = distance(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
sdf[boundary==1] = 0
normalized_sdf[b][c] = sdf
assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
return normalized_sdf
def compute_sdf(img_gt, out_shape):
"""
compute the signed distance map of binary mask
input: segmentation, shape = (batch_size, x, y, z)
output: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
"""
img_gt = img_gt.astype(np.uint8)
gt_sdf = np.zeros(out_shape)
debug = False
for b in range(out_shape[0]): # batch size
for c in range(0, out_shape[1]):
posmask = img_gt[b]
negmask = 1-posmask
posdis = distance(posmask)
negdis = distance(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
sdf = negdis - posdis
sdf[boundary==1] = 0
gt_sdf[b][c] = sdf
if debug:
plt.figure()
plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar()
plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar()
plt.show()
return gt_sdf
def boundary_loss(output, gt):
"""
compute boundary loss for binary segmentation
input: outputs_soft: softmax results, shape=(b,2,x,y,z)
gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z)
output: boundary_loss; sclar
adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf
"""
multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt)
bd_loss = multipled.mean()
return bd_loss