Spaces:
Runtime error
Runtime error
# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py | |
import torch | |
import numpy as np | |
from scipy.ndimage import distance_transform_edt as distance | |
from skimage import segmentation as skimage_seg | |
import matplotlib.pyplot as plt | |
def dice_loss(score, target): | |
# implemented from paper https://arxiv.org/pdf/1606.04797.pdf | |
target = target.float() | |
smooth = 1e-5 | |
intersect = torch.sum(score * target) | |
y_sum = torch.sum(target * target) | |
z_sum = torch.sum(score * score) | |
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) | |
loss = 1 - loss | |
return loss | |
class tversky_loss(torch.nn.Module): | |
# implemented from https://arxiv.org/pdf/1706.05721.pdf | |
def __init__(self, alpha, beta): | |
''' | |
Args: | |
alpha: coefficient for false positive prediction | |
beta: coefficient for false negtive prediction | |
''' | |
super(tversky_loss, self).__init__() | |
self.alpha = alpha | |
self.beta = beta | |
def __call__(self, score, target): | |
target = target.float() | |
smooth = 1e-5 | |
tp = torch.sum(score * target) | |
fn = torch.sum(target * (1 - score)) | |
fp = torch.sum((1-target) * score) | |
loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth) | |
loss = 1 - loss | |
return loss | |
def compute_sdf1_1(img_gt, out_shape): | |
""" | |
compute the normalized signed distance map of binary mask | |
input: segmentation, shape = (batch_size, x, y, z) | |
output: the Signed Distance Map (SDM) | |
sdf(x) = 0; x in segmentation boundary | |
-inf|x-y|; x in segmentation | |
+inf|x-y|; x out of segmentation | |
normalize sdf to [-1, 1] | |
""" | |
img_gt = img_gt.astype(np.uint8) | |
normalized_sdf = np.zeros(out_shape) | |
for b in range(out_shape[0]): # batch size | |
# ignore background | |
for c in range(1, out_shape[1]): | |
posmask = img_gt[b] | |
negmask = 1-posmask | |
posdis = distance(posmask) | |
negdis = distance(negmask) | |
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) | |
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) | |
sdf[boundary==1] = 0 | |
normalized_sdf[b][c] = sdf | |
assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) | |
assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) | |
return normalized_sdf | |
def compute_sdf(img_gt, out_shape): | |
""" | |
compute the signed distance map of binary mask | |
input: segmentation, shape = (batch_size, x, y, z) | |
output: the Signed Distance Map (SDM) | |
sdf(x) = 0; x in segmentation boundary | |
-inf|x-y|; x in segmentation | |
+inf|x-y|; x out of segmentation | |
""" | |
img_gt = img_gt.astype(np.uint8) | |
gt_sdf = np.zeros(out_shape) | |
debug = False | |
for b in range(out_shape[0]): # batch size | |
for c in range(0, out_shape[1]): | |
posmask = img_gt[b] | |
negmask = 1-posmask | |
posdis = distance(posmask) | |
negdis = distance(negmask) | |
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) | |
sdf = negdis - posdis | |
sdf[boundary==1] = 0 | |
gt_sdf[b][c] = sdf | |
if debug: | |
plt.figure() | |
plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar() | |
plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar() | |
plt.show() | |
return gt_sdf | |
def boundary_loss(output, gt): | |
""" | |
compute boundary loss for binary segmentation | |
input: outputs_soft: softmax results, shape=(b,2,x,y,z) | |
gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z) | |
output: boundary_loss; sclar | |
adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf | |
""" | |
multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt) | |
bd_loss = multipled.mean() | |
return bd_loss |