|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union |
|
|
|
import cv2 |
|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import colorsys |
|
import itertools |
|
import matplotlib.pyplot as plt |
|
from matplotlib import cm |
|
|
|
from monai.data.meta_tensor import MetaTensor |
|
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
|
from monai.transforms import Resize |
|
from monai.utils import ( |
|
BlendMode, |
|
PytorchPadMode, |
|
convert_data_type, |
|
convert_to_dst_type, |
|
ensure_tuple, |
|
fall_back_tuple, |
|
look_up_option, |
|
optional_import, |
|
) |
|
|
|
from scipy import ndimage |
|
from scipy.ndimage.filters import gaussian_filter |
|
from scipy.ndimage.interpolation import affine_transform, map_coordinates |
|
|
|
from skimage import morphology as morph |
|
from scipy.ndimage import filters, measurements |
|
from scipy.ndimage.morphology import ( |
|
binary_dilation, |
|
binary_fill_holes, |
|
distance_transform_cdt, |
|
distance_transform_edt, |
|
) |
|
|
|
from skimage.segmentation import watershed |
|
from skimage.exposure import rescale_intensity |
|
from skimage.filters import sobel_h, sobel_v, gaussian |
|
from skimage.morphology import disk, binary_opening |
|
|
|
tqdm, _ = optional_import("tqdm", name="tqdm") |
|
|
|
__all__ = ["sliding_window_inference"] |
|
|
|
|
|
def normalize(mask, dtype=np.uint8): |
|
return (255 * mask / np.amax(mask)).astype(dtype) |
|
|
|
def fix_mirror_padding(ann): |
|
"""Deal with duplicated instances due to mirroring in interpolation |
|
during shape augmentation (scale, rotation etc.). |
|
|
|
""" |
|
current_max_id = np.amax(ann) |
|
inst_list = list(np.unique(ann)) |
|
if 0 in inst_list: |
|
inst_list.remove(0) |
|
for inst_id in inst_list: |
|
inst_map = np.array(ann == inst_id, np.uint8) |
|
remapped_ids = measurements.label(inst_map)[0] |
|
remapped_ids[remapped_ids > 1] += current_max_id |
|
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] |
|
current_max_id = np.amax(ann) |
|
return ann |
|
|
|
|
|
def get_bounding_box(img): |
|
"""Get bounding box coordinate information.""" |
|
rows = np.any(img, axis=1) |
|
cols = np.any(img, axis=0) |
|
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
|
|
|
|
rmax += 1 |
|
cmax += 1 |
|
return [rmin, rmax, cmin, cmax] |
|
|
|
|
|
|
|
def cropping_center(x, crop_shape, batch=False): |
|
"""Crop an input image at the centre. |
|
|
|
Args: |
|
x: input array |
|
crop_shape: dimensions of cropped array |
|
|
|
Returns: |
|
x: cropped array |
|
|
|
""" |
|
orig_shape = x.shape |
|
if not batch: |
|
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) |
|
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) |
|
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] |
|
else: |
|
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) |
|
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) |
|
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] |
|
return x |
|
|
|
def gen_instance_hv_map(ann, crop_shape): |
|
"""Input annotation must be of original shape. |
|
|
|
The map is calculated only for instances within the crop portion |
|
but based on the original shape in original image. |
|
|
|
Perform following operation: |
|
Obtain the horizontal and vertical distance maps for each |
|
nuclear instance. |
|
|
|
""" |
|
orig_ann = ann.copy() |
|
fixed_ann = fix_mirror_padding(orig_ann) |
|
|
|
crop_ann = cropping_center(fixed_ann, crop_shape) |
|
|
|
crop_ann = morph.remove_small_objects(crop_ann, min_size=30) |
|
|
|
x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) |
|
y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) |
|
|
|
inst_list = list(np.unique(crop_ann)) |
|
if 0 in inst_list: |
|
inst_list.remove(0) |
|
for inst_id in inst_list: |
|
inst_map = np.array(fixed_ann == inst_id, np.uint8) |
|
inst_box = get_bounding_box(inst_map) |
|
|
|
|
|
|
|
|
|
inst_box[0] -= 2 |
|
inst_box[2] -= 2 |
|
inst_box[1] += 2 |
|
inst_box[3] += 2 |
|
|
|
|
|
inst_box[0] = max(inst_box[0], 0) |
|
inst_box[2] = max(inst_box[2], 0) |
|
|
|
|
|
|
|
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] |
|
|
|
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: |
|
print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}') |
|
continue |
|
|
|
|
|
inst_com = list(measurements.center_of_mass(inst_map)) |
|
if np.isnan(measurements.center_of_mass(inst_map)).any(): |
|
print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape) |
|
print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))) |
|
print(inst_map) |
|
print(inst_list) |
|
print(inst_box) |
|
print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8))) |
|
|
|
inst_com[0] = int(inst_com[0] + 0.5) |
|
inst_com[1] = int(inst_com[1] + 0.5) |
|
|
|
inst_x_range = np.arange(1, inst_map.shape[1] + 1) |
|
inst_y_range = np.arange(1, inst_map.shape[0] + 1) |
|
|
|
inst_x_range -= inst_com[1] |
|
inst_y_range -= inst_com[0] |
|
|
|
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) |
|
|
|
|
|
inst_x[inst_map == 0] = 0 |
|
inst_y[inst_map == 0] = 0 |
|
inst_x = inst_x.astype("float32") |
|
inst_y = inst_y.astype("float32") |
|
|
|
|
|
if np.min(inst_x) < 0: |
|
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) |
|
if np.min(inst_y) < 0: |
|
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) |
|
|
|
if np.max(inst_x) > 0: |
|
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) |
|
if np.max(inst_y) > 0: |
|
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) |
|
|
|
|
|
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] |
|
x_map_box[inst_map > 0] = inst_x[inst_map > 0] |
|
|
|
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] |
|
y_map_box[inst_map > 0] = inst_y[inst_map > 0] |
|
|
|
hv_map = np.dstack([x_map, y_map]) |
|
return hv_map |
|
|
|
def remove_small_objects(pred, min_size=64, connectivity=1): |
|
"""Remove connected components smaller than the specified size. |
|
|
|
This function is taken from skimage.morphology.remove_small_objects, but the warning |
|
is removed when a single label is provided. |
|
|
|
Args: |
|
pred: input labelled array |
|
min_size: minimum size of instance in output array |
|
connectivity: The connectivity defining the neighborhood of a pixel. |
|
|
|
Returns: |
|
out: output array with instances removed under min_size |
|
|
|
""" |
|
out = pred |
|
|
|
if min_size == 0: |
|
return out |
|
|
|
if out.dtype == bool: |
|
selem = ndimage.generate_binary_structure(pred.ndim, connectivity) |
|
ccs = np.zeros_like(pred, dtype=np.int32) |
|
ndimage.label(pred, selem, output=ccs) |
|
else: |
|
ccs = out |
|
|
|
try: |
|
component_sizes = np.bincount(ccs.ravel()) |
|
except ValueError: |
|
raise ValueError( |
|
"Negative value labels are not supported. Try " |
|
"relabeling the input with `scipy.ndimage.label` or " |
|
"`skimage.morphology.label`." |
|
) |
|
|
|
too_small = component_sizes < min_size |
|
too_small_mask = too_small[ccs] |
|
out[too_small_mask] = 0 |
|
|
|
return out |
|
|
|
|
|
def gen_targets(ann, crop_shape, **kwargs): |
|
"""Generate the targets for the network.""" |
|
hv_map = gen_instance_hv_map(ann, crop_shape) |
|
np_map = ann.copy() |
|
np_map[np_map > 0] = 1 |
|
|
|
hv_map = cropping_center(hv_map, crop_shape) |
|
np_map = cropping_center(np_map, crop_shape) |
|
|
|
target_dict = { |
|
"hv_map": hv_map, |
|
"np_map": np_map, |
|
} |
|
|
|
return target_dict |
|
|
|
|
|
def xentropy_loss(true, pred, reduction="mean"): |
|
"""Cross entropy loss. Assumes NHWC! |
|
|
|
Args: |
|
pred: prediction array |
|
true: ground truth array |
|
|
|
Returns: |
|
cross entropy loss |
|
|
|
""" |
|
epsilon = 10e-8 |
|
|
|
pred = pred / torch.sum(pred, -1, keepdim=True) |
|
|
|
pred = torch.clamp(pred, epsilon, 1.0 - epsilon) |
|
loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True) |
|
loss = loss.mean() if reduction == "mean" else loss.sum() |
|
return loss |
|
|
|
|
|
|
|
def dice_loss(true, pred, smooth=1e-3): |
|
"""`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC.""" |
|
inse = torch.sum(pred * true, (0, 1, 2)) |
|
l = torch.sum(pred, (0, 1, 2)) |
|
r = torch.sum(true, (0, 1, 2)) |
|
loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) |
|
loss = torch.sum(loss) |
|
return loss |
|
|
|
|
|
|
|
def mse_loss(true, pred): |
|
"""Calculate mean squared error loss. |
|
|
|
Args: |
|
true: ground truth of combined horizontal |
|
and vertical maps |
|
pred: prediction of combined horizontal |
|
and vertical maps |
|
|
|
Returns: |
|
loss: mean squared error |
|
|
|
""" |
|
loss = pred - true |
|
loss = (loss * loss).mean() |
|
return loss |
|
|
|
|
|
|
|
def msge_loss(true, pred, focus): |
|
"""Calculate the mean squared error of the gradients of |
|
horizontal and vertical map predictions. Assumes |
|
channel 0 is Vertical and channel 1 is Horizontal. |
|
|
|
Args: |
|
true: ground truth of combined horizontal |
|
and vertical maps |
|
pred: prediction of combined horizontal |
|
and vertical maps |
|
focus: area where to apply loss (we only calculate |
|
the loss within the nuclei) |
|
|
|
Returns: |
|
loss: mean squared error of gradients |
|
|
|
""" |
|
|
|
def get_sobel_kernel(size): |
|
"""Get sobel kernel with a given size.""" |
|
assert size % 2 == 1, "Must be odd, get size=%d" % size |
|
|
|
h_range = torch.arange( |
|
-size // 2 + 1, |
|
size // 2 + 1, |
|
dtype=torch.float32, |
|
device="cuda", |
|
requires_grad=False, |
|
) |
|
v_range = torch.arange( |
|
-size // 2 + 1, |
|
size // 2 + 1, |
|
dtype=torch.float32, |
|
device="cuda", |
|
requires_grad=False, |
|
) |
|
h, v = torch.meshgrid(h_range, v_range) |
|
kernel_h = h / (h * h + v * v + 1.0e-15) |
|
kernel_v = v / (h * h + v * v + 1.0e-15) |
|
return kernel_h, kernel_v |
|
|
|
|
|
def get_gradient_hv(hv): |
|
"""For calculating gradient.""" |
|
kernel_h, kernel_v = get_sobel_kernel(5) |
|
kernel_h = kernel_h.view(1, 1, 5, 5) |
|
kernel_v = kernel_v.view(1, 1, 5, 5) |
|
|
|
h_ch = hv[..., 0].unsqueeze(1) |
|
v_ch = hv[..., 1].unsqueeze(1) |
|
|
|
|
|
h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) |
|
v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) |
|
dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) |
|
dhv = dhv.permute(0, 2, 3, 1).contiguous() |
|
return dhv |
|
|
|
focus = (focus[..., None]).float() |
|
focus = torch.cat([focus, focus], axis=-1) |
|
true_grad = get_gradient_hv(true) |
|
pred_grad = get_gradient_hv(pred) |
|
loss = pred_grad - true_grad |
|
loss = focus * (loss * loss) |
|
|
|
loss = loss.sum() / (focus.sum() + 1.0e-8) |
|
return loss |
|
|
|
|
|
def __proc_np_hv(pred, np_thres, ksize, overall_thres, obj_size_thres): |
|
"""Process Nuclei Prediction with XY Coordinate Map. |
|
|
|
Args: |
|
pred: prediction output, assuming |
|
channel 0 contain probability map of nuclei |
|
channel 1 containing the regressed X-map |
|
channel 2 containing the regressed Y-map |
|
|
|
""" |
|
pred = np.array(pred, dtype=np.float32) |
|
|
|
blb_raw = pred[..., 0] |
|
h_dir_raw = pred[..., 1] |
|
v_dir_raw = pred[..., 2] |
|
|
|
|
|
blb = np.array(blb_raw >= np_thres, dtype=np.int32) |
|
|
|
blb = measurements.label(blb)[0] |
|
blb = remove_small_objects(blb, min_size=10) |
|
blb[blb > 0] = 1 |
|
|
|
h_dir = cv2.normalize( |
|
h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F |
|
) |
|
v_dir = cv2.normalize( |
|
v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F |
|
) |
|
|
|
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) |
|
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) |
|
|
|
sobelh = 1 - ( |
|
cv2.normalize( |
|
sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F |
|
) |
|
) |
|
sobelv = 1 - ( |
|
cv2.normalize( |
|
sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F |
|
) |
|
) |
|
|
|
overall = np.maximum(sobelh, sobelv) |
|
overall = overall - (1 - blb) |
|
overall[overall < 0] = 0 |
|
|
|
dist = (1.0 - overall) * blb |
|
|
|
dist = -cv2.GaussianBlur(dist, (3, 3), 0) |
|
|
|
overall = np.array(overall >= overall_thres, dtype=np.int32) |
|
|
|
marker = blb - overall |
|
marker[marker < 0] = 0 |
|
marker = binary_fill_holes(marker).astype("uint8") |
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) |
|
marker = measurements.label(marker)[0] |
|
marker = remove_small_objects(marker, min_size=obj_size_thres) |
|
|
|
proced_pred = watershed(dist, markers=marker, mask=blb) |
|
|
|
return proced_pred |
|
|
|
def __proc_np_hv_2(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10): |
|
"""Process Nuclei Prediction with XY Coordinate Map. |
|
|
|
Args: |
|
pred: prediction output, assuming |
|
channel 0 contain probability map of nuclei |
|
channel 1 containing the regressed X-map |
|
channel 2 containing the regressed Y-map |
|
|
|
""" |
|
pred = np.array(pred, dtype=np.float32) |
|
|
|
blb_raw = pred[..., 0] |
|
h_dir_raw = pred[..., 1] |
|
v_dir_raw = pred[..., 2] |
|
|
|
|
|
blb = np.array(blb_raw >= np_thres, dtype=np.int32) |
|
|
|
blb = measurements.label(blb)[0] |
|
blb = remove_small_objects(blb, min_size=10) |
|
blb[blb > 0] = 1 |
|
|
|
h_dir = rescale_intensity(h_dir_raw, out_range=(0, 1)).astype('float32') |
|
v_dir = rescale_intensity(v_dir_raw, out_range=(0, 1)).astype('float32') |
|
|
|
sobelh = sobel_v(h_dir).astype('float64') |
|
sobelv = sobel_h(v_dir).astype('float64') |
|
|
|
sobelh = 1 - rescale_intensity(sobelh, out_range=(0, 1)).astype('float32') |
|
sobelv = 1 - rescale_intensity(sobelv, out_range=(0, 1)).astype('float32') |
|
|
|
overall = np.maximum(sobelh, sobelv) |
|
overall = overall - (1 - blb) |
|
overall[overall < 0] = 0 |
|
|
|
dist = (1.0 - overall) * blb |
|
|
|
dist = - gaussian(dist, sigma=0.8) |
|
|
|
overall = np.array(overall >= overall_thres, dtype=np.int32) |
|
|
|
marker = blb - overall |
|
marker[marker < 0] = 0 |
|
marker = binary_fill_holes(marker).astype("uint8") |
|
kernel = disk(2) |
|
marker = binary_opening(marker, kernel) |
|
marker = measurements.label(marker)[0] |
|
marker = remove_small_objects(marker, min_size=obj_size_thres) |
|
|
|
proced_pred = watershed(dist, markers=marker, mask=blb) |
|
|
|
return proced_pred |
|
|
|
|
|
|
|
def colorize(ch, vmin, vmax): |
|
"""Will clamp value value outside the provided range to vmax and vmin.""" |
|
cmap = plt.get_cmap("jet") |
|
ch = np.squeeze(ch.astype("float32")) |
|
vmin = vmin if vmin is not None else ch.min() |
|
vmax = vmax if vmax is not None else ch.max() |
|
ch[ch > vmax] = vmax |
|
ch[ch < vmin] = vmin |
|
ch = (ch - vmin) / (vmax - vmin + 1.0e-16) |
|
|
|
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") |
|
return ch_cmap |
|
|
|
|
|
|
|
def random_colors(N, bright=True): |
|
"""Generate random colors. |
|
|
|
To get visually distinct colors, generate them in HSV space then |
|
convert to RGB. |
|
""" |
|
brightness = 1.0 if bright else 0.7 |
|
hsv = [(i / N, 1, brightness) for i in range(N)] |
|
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) |
|
random.shuffle(colors) |
|
return colors |
|
|
|
|
|
|
|
def visualize_instances_map( |
|
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 |
|
): |
|
"""Overlays segmentation results on image as contours. |
|
|
|
Args: |
|
input_image: input image |
|
inst_map: instance mask with unique value for every object |
|
type_map: type mask with unique value for every class |
|
type_colour: a dict of {type : colour} , `type` is from 0-N |
|
and `colour` is a tuple of (R, G, B) |
|
line_thickness: line thickness of contours |
|
|
|
Returns: |
|
overlay: output image with segmentation overlay as contours |
|
""" |
|
overlay = np.copy((input_image).astype(np.uint8)) |
|
|
|
inst_list = list(np.unique(inst_map)) |
|
inst_list.remove(0) |
|
|
|
inst_rng_colors = random_colors(len(inst_list)) |
|
inst_rng_colors = np.array(inst_rng_colors) * 255 |
|
inst_rng_colors = inst_rng_colors.astype(np.uint8) |
|
|
|
for inst_idx, inst_id in enumerate(inst_list): |
|
inst_map_mask = np.array(inst_map == inst_id, np.uint8) |
|
y1, y2, x1, x2 = get_bounding_box(inst_map_mask) |
|
y1 = y1 - 2 if y1 - 2 >= 0 else y1 |
|
x1 = x1 - 2 if x1 - 2 >= 0 else x1 |
|
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 |
|
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 |
|
inst_map_crop = inst_map_mask[y1:y2, x1:x2] |
|
contours_crop = cv2.findContours( |
|
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
|
) |
|
|
|
contours_crop = np.squeeze( |
|
contours_crop[0][0].astype("int32") |
|
) |
|
contours_crop += np.asarray([[x1, y1]]) |
|
if type_map is not None: |
|
type_map_crop = type_map[y1:y2, x1:x2] |
|
type_id = np.unique(type_map_crop).max() |
|
inst_colour = type_colour[type_id] |
|
else: |
|
inst_colour = (inst_rng_colors[inst_idx]).tolist() |
|
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) |
|
return overlay |
|
|
|
|
|
def sliding_window_inference( |
|
inputs: torch.Tensor, |
|
roi_size: Union[Sequence[int], int], |
|
sw_batch_size: int, |
|
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], |
|
overlap: float = 0.25, |
|
mode: Union[BlendMode, str] = BlendMode.CONSTANT, |
|
sigma_scale: Union[Sequence[float], float] = 0.125, |
|
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, |
|
cval: float = 0.0, |
|
sw_device: Union[torch.device, str, None] = None, |
|
device: Union[torch.device, str, None] = None, |
|
progress: bool = False, |
|
roi_weight_map: Union[torch.Tensor, None] = None, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: |
|
""" |
|
Sliding window inference on `inputs` with `predictor`. |
|
|
|
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. |
|
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. |
|
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes |
|
could be ([128,64,256], [64,32,128]). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still |
|
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters |
|
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). |
|
|
|
When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
|
To maintain the same spatial sizes, the output image will be cropped to the original input size. |
|
|
|
Args: |
|
inputs: input image to be processed (assuming NCHW[D]) |
|
roi_size: the spatial window size for inferences. |
|
When its components have None or non-positives, the corresponding inputs dimension will be used. |
|
if the components of the `roi_size` are non-positive values, the transform will use the |
|
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
|
to `(32, 64)` if the second spatial dimension size of img is `64`. |
|
sw_batch_size: the batch size to run window slices. |
|
predictor: given input tensor ``patch_data`` in shape NCHW[D], |
|
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary |
|
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; |
|
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, |
|
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), |
|
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen |
|
to ensure the scaled output ROI sizes are still integers. |
|
If the `predictor`'s input and output spatial sizes are different, |
|
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. |
|
overlap: Amount of overlap between scans. |
|
mode: {``"constant"``, ``"gaussian"``} |
|
How to blend output of overlapping windows. Defaults to ``"constant"``. |
|
|
|
- ``"constant``": gives equal weight to all predictions. |
|
- ``"gaussian``": gives less weight to predictions on edges of windows. |
|
|
|
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
|
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
|
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
|
spatial dimensions. |
|
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
|
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
|
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
|
cval: fill value for 'constant' padding mode. Default: 0 |
|
sw_device: device for the window data. |
|
By default the device (and accordingly the memory) of the `inputs` is used. |
|
Normally `sw_device` should be consistent with the device where `predictor` is defined. |
|
device: device for the stitched output prediction. |
|
By default the device (and accordingly the memory) of the `inputs` is used. If for example |
|
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
|
`inputs` and `roi_size`. Output is on the `device`. |
|
progress: whether to print a `tqdm` progress bar. |
|
roi_weight_map: pre-computed (non-negative) weight map for each ROI. |
|
If not given, and ``mode`` is not `constant`, this map will be computed on the fly. |
|
args: optional args to be passed to ``predictor``. |
|
kwargs: optional keyword args to be passed to ``predictor``. |
|
|
|
Note: |
|
- input must be channel-first and have a batch dim, supports N-D sliding window. |
|
|
|
""" |
|
compute_dtype = inputs.dtype |
|
num_spatial_dims = len(inputs.shape) - 2 |
|
if overlap < 0 or overlap >= 1: |
|
raise ValueError("overlap must be >= 0 and < 1.") |
|
|
|
|
|
|
|
batch_size, _, *image_size_ = inputs.shape |
|
|
|
if device is None: |
|
device = inputs.device |
|
if sw_device is None: |
|
sw_device = inputs.device |
|
|
|
roi_size = fall_back_tuple(roi_size, image_size_) |
|
|
|
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
|
pad_size = [] |
|
for k in range(len(inputs.shape) - 1, 1, -1): |
|
diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
|
half = diff // 2 |
|
pad_size.extend([half, diff - half]) |
|
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) |
|
|
|
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
|
|
|
slices = dense_patch_slices(image_size, roi_size, scan_interval) |
|
num_win = len(slices) |
|
total_slices = num_win * batch_size |
|
|
|
|
|
valid_patch_size = get_valid_patch_size(image_size, roi_size) |
|
if valid_patch_size == roi_size and (roi_weight_map is not None): |
|
importance_map = roi_weight_map |
|
else: |
|
try: |
|
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) |
|
except BaseException as e: |
|
raise RuntimeError( |
|
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." |
|
) from e |
|
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] |
|
|
|
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) |
|
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) |
|
|
|
|
|
dict_key, output_image_list, count_map_list = None, [], [] |
|
_initialized_ss = -1 |
|
is_tensor_output = True |
|
|
|
|
|
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): |
|
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
|
unravel_slice = [ |
|
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
|
for idx in slice_range |
|
] |
|
window_data = torch.cat( |
|
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice] |
|
).to(sw_device) |
|
seg_prob_out = predictor(window_data, *args, **kwargs) |
|
|
|
|
|
seg_prob_tuple: Tuple[torch.Tensor, ...] |
|
if isinstance(seg_prob_out, torch.Tensor): |
|
seg_prob_tuple = (seg_prob_out,) |
|
elif isinstance(seg_prob_out, Mapping): |
|
if dict_key is None: |
|
dict_key = sorted(seg_prob_out.keys()) |
|
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) |
|
is_tensor_output = False |
|
else: |
|
seg_prob_tuple = ensure_tuple(seg_prob_out) |
|
is_tensor_output = False |
|
|
|
|
|
for ss, seg_prob in enumerate(seg_prob_tuple): |
|
seg_prob = seg_prob.to(device) |
|
|
|
|
|
zoom_scale = [] |
|
for axis, (img_s_i, out_w_i, in_w_i) in enumerate( |
|
zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) |
|
): |
|
_scale = out_w_i / float(in_w_i) |
|
if not (img_s_i * _scale).is_integer(): |
|
warnings.warn( |
|
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " |
|
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." |
|
) |
|
zoom_scale.append(_scale) |
|
|
|
if _initialized_ss < ss: |
|
|
|
output_classes = seg_prob.shape[1] |
|
output_shape = [batch_size, output_classes] + [ |
|
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) |
|
] |
|
|
|
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device='cpu')) |
|
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device='cpu')) |
|
_initialized_ss += 1 |
|
|
|
|
|
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) |
|
|
|
|
|
for idx, original_idx in zip(slice_range, unravel_slice): |
|
|
|
original_idx_zoom = list(original_idx) |
|
for axis in range(2, len(original_idx_zoom)): |
|
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] |
|
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] |
|
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): |
|
warnings.warn( |
|
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " |
|
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " |
|
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " |
|
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" |
|
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " |
|
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." |
|
) |
|
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) |
|
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) |
|
|
|
|
|
output_image_list[ss][original_idx_zoom] += importance_map_zoom.cpu() * seg_prob[idx - slice_g].cpu() |
|
count_map_list[ss][original_idx_zoom] += ( |
|
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape).cpu() |
|
) |
|
|
|
|
|
for ss in range(len(output_image_list)): |
|
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) |
|
|
|
|
|
for ss, output_i in enumerate(output_image_list): |
|
if torch.isnan(output_i).any() or torch.isinf(output_i).any(): |
|
warnings.warn("Sliding window inference results contain NaN or Inf.") |
|
|
|
zoom_scale = [ |
|
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) |
|
] |
|
|
|
final_slicing: List[slice] = [] |
|
for sp in range(num_spatial_dims): |
|
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
|
slice_dim = slice( |
|
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), |
|
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), |
|
) |
|
final_slicing.insert(0, slice_dim) |
|
while len(final_slicing) < len(output_i.shape): |
|
final_slicing.insert(0, slice(None)) |
|
output_image_list[ss] = output_i[final_slicing] |
|
|
|
if dict_key is not None: |
|
final_output = dict(zip(dict_key, output_image_list)) |
|
else: |
|
final_output = tuple(output_image_list) |
|
final_output = final_output[0] if is_tensor_output else final_output |
|
if isinstance(inputs, MetaTensor): |
|
final_output = convert_to_dst_type(final_output, inputs)[0] |
|
return final_output |
|
|
|
|
|
def _get_scan_interval( |
|
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
|
) -> Tuple[int, ...]: |
|
""" |
|
Compute scan interval according to the image size, roi size and overlap. |
|
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
|
use 1 instead to make sure sliding window works. |
|
|
|
""" |
|
if len(image_size) != num_spatial_dims: |
|
raise ValueError("image coord different from spatial dims.") |
|
if len(roi_size) != num_spatial_dims: |
|
raise ValueError("roi coord different from spatial dims.") |
|
|
|
scan_interval = [] |
|
for i in range(num_spatial_dims): |
|
if roi_size[i] == image_size[i]: |
|
scan_interval.append(int(roi_size[i])) |
|
else: |
|
interval = int(roi_size[i] * (1 - overlap)) |
|
scan_interval.append(interval if interval > 0 else 1) |
|
return tuple(scan_interval) |
|
|