|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from stardist.big import _grid_divisible, BlockND, OBJECT_KEYS |
|
from stardist.matching import relabel_sequential |
|
from stardist import dist_to_coord, non_maximum_suppression, polygons_to_label |
|
from stardist import random_label_cmap,ray_angles |
|
from stardist import star_dist,edt_prob |
|
from monai.data.meta_tensor import MetaTensor |
|
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
|
from monai.transforms import Resize |
|
from monai.utils import ( |
|
BlendMode, |
|
PytorchPadMode, |
|
convert_data_type, |
|
convert_to_dst_type, |
|
ensure_tuple, |
|
fall_back_tuple, |
|
look_up_option, |
|
optional_import, |
|
) |
|
|
|
tqdm, _ = optional_import("tqdm", name="tqdm") |
|
|
|
__all__ = ["sliding_window_inference"] |
|
|
|
|
|
def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device): |
|
|
|
h,w = inputs.shape[0],inputs.shape[1] |
|
if h < 5000 or w < 5000: |
|
test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) |
|
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor) |
|
prob = output_prob[0][0].cpu().numpy() |
|
dist = output_dist[0].cpu().numpy() |
|
dist = np.transpose(dist,(1,2,0)) |
|
dist = np.maximum(1e-3, dist) |
|
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) |
|
|
|
coord = dist_to_coord(disti,points) |
|
|
|
labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape) |
|
else: |
|
n = inputs.ndim |
|
axes = 'YXC' |
|
grid = (1,1,1) |
|
if np.isscalar(block_size): block_size = n*[block_size] |
|
if np.isscalar(min_overlap): min_overlap = n*[min_overlap] |
|
if np.isscalar(context): context = n*[context] |
|
shape_out = (inputs.shape[0],inputs.shape[1]) |
|
labels_out = np.zeros(shape_out, dtype=np.uint64) |
|
|
|
block_size[2] = inputs.shape[2] |
|
min_overlap[2] = context[2] = 0 |
|
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes)) |
|
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes)) |
|
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes)) |
|
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True) |
|
blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context) |
|
label_offset = 1 |
|
blocks = tqdm(blocks) |
|
for block in blocks: |
|
image = block.read(inputs, axes=axes) |
|
test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) |
|
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor) |
|
prob = output_prob[0][0].cpu().numpy() |
|
dist = output_dist[0].cpu().numpy() |
|
dist = np.transpose(dist,(1,2,0)) |
|
dist = np.maximum(1e-3, dist) |
|
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) |
|
|
|
coord = dist_to_coord(disti,points) |
|
polys = dict(coord=coord, points=points, prob=probi) |
|
labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape) |
|
labels = block.crop_context(labels, axes='YX') |
|
labels, polys = block.filter_objects(labels, polys, axes='YX') |
|
labels = relabel_sequential(labels, label_offset)[0] |
|
if labels_out is not None: |
|
block.write(labels_out, labels, axes='YX') |
|
|
|
|
|
label_offset += len(polys['prob']) |
|
del labels |
|
|
|
return labels_out |
|
def sliding_window_inference( |
|
inputs: torch.Tensor, |
|
roi_size: Union[Sequence[int], int], |
|
sw_batch_size: int, |
|
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], |
|
overlap: float = 0.25, |
|
mode: Union[BlendMode, str] = BlendMode.CONSTANT, |
|
sigma_scale: Union[Sequence[float], float] = 0.125, |
|
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, |
|
cval: float = 0.0, |
|
sw_device: Union[torch.device, str, None] = None, |
|
device: Union[torch.device, str, None] = None, |
|
progress: bool = False, |
|
roi_weight_map: Union[torch.Tensor, None] = None, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: |
|
""" |
|
Sliding window inference on `inputs` with `predictor`. |
|
|
|
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. |
|
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. |
|
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes |
|
could be ([128,64,256], [64,32,128]). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still |
|
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters |
|
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). |
|
|
|
When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
|
To maintain the same spatial sizes, the output image will be cropped to the original input size. |
|
|
|
Args: |
|
inputs: input image to be processed (assuming NCHW[D]) |
|
roi_size: the spatial window size for inferences. |
|
When its components have None or non-positives, the corresponding inputs dimension will be used. |
|
if the components of the `roi_size` are non-positive values, the transform will use the |
|
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
|
to `(32, 64)` if the second spatial dimension size of img is `64`. |
|
sw_batch_size: the batch size to run window slices. |
|
predictor: given input tensor ``patch_data`` in shape NCHW[D], |
|
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary |
|
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; |
|
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, |
|
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), |
|
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen |
|
to ensure the scaled output ROI sizes are still integers. |
|
If the `predictor`'s input and output spatial sizes are different, |
|
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. |
|
overlap: Amount of overlap between scans. |
|
mode: {``"constant"``, ``"gaussian"``} |
|
How to blend output of overlapping windows. Defaults to ``"constant"``. |
|
|
|
- ``"constant``": gives equal weight to all predictions. |
|
- ``"gaussian``": gives less weight to predictions on edges of windows. |
|
|
|
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
|
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
|
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
|
spatial dimensions. |
|
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
|
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
|
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
|
cval: fill value for 'constant' padding mode. Default: 0 |
|
sw_device: device for the window data. |
|
By default the device (and accordingly the memory) of the `inputs` is used. |
|
Normally `sw_device` should be consistent with the device where `predictor` is defined. |
|
device: device for the stitched output prediction. |
|
By default the device (and accordingly the memory) of the `inputs` is used. If for example |
|
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
|
`inputs` and `roi_size`. Output is on the `device`. |
|
progress: whether to print a `tqdm` progress bar. |
|
roi_weight_map: pre-computed (non-negative) weight map for each ROI. |
|
If not given, and ``mode`` is not `constant`, this map will be computed on the fly. |
|
args: optional args to be passed to ``predictor``. |
|
kwargs: optional keyword args to be passed to ``predictor``. |
|
|
|
Note: |
|
- input must be channel-first and have a batch dim, supports N-D sliding window. |
|
|
|
""" |
|
compute_dtype = inputs.dtype |
|
num_spatial_dims = len(inputs.shape) - 2 |
|
if overlap < 0 or overlap >= 1: |
|
raise ValueError("overlap must be >= 0 and < 1.") |
|
|
|
|
|
|
|
batch_size, _, *image_size_ = inputs.shape |
|
|
|
if device is None: |
|
device = inputs.device |
|
if sw_device is None: |
|
sw_device = inputs.device |
|
|
|
roi_size = fall_back_tuple(roi_size, image_size_) |
|
|
|
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
|
pad_size = [] |
|
for k in range(len(inputs.shape) - 1, 1, -1): |
|
diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
|
half = diff // 2 |
|
pad_size.extend([half, diff - half]) |
|
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) |
|
|
|
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
|
|
|
slices = dense_patch_slices(image_size, roi_size, scan_interval) |
|
num_win = len(slices) |
|
total_slices = num_win * batch_size |
|
|
|
|
|
valid_patch_size = get_valid_patch_size(image_size, roi_size) |
|
if valid_patch_size == roi_size and (roi_weight_map is not None): |
|
importance_map = roi_weight_map |
|
else: |
|
try: |
|
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) |
|
except BaseException as e: |
|
raise RuntimeError( |
|
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." |
|
) from e |
|
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] |
|
|
|
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) |
|
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) |
|
|
|
|
|
dict_key, output_image_list, count_map_list = None, [], [] |
|
_initialized_ss = -1 |
|
is_tensor_output = True |
|
|
|
|
|
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): |
|
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
|
unravel_slice = [ |
|
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
|
for idx in slice_range |
|
] |
|
window_data = torch.cat( |
|
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice] |
|
).to(sw_device) |
|
seg_prob_out = predictor(window_data, *args, **kwargs) |
|
|
|
|
|
seg_prob_tuple: Tuple[torch.Tensor, ...] |
|
if isinstance(seg_prob_out, torch.Tensor): |
|
seg_prob_tuple = (seg_prob_out,) |
|
elif isinstance(seg_prob_out, Mapping): |
|
if dict_key is None: |
|
dict_key = sorted(seg_prob_out.keys()) |
|
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) |
|
is_tensor_output = False |
|
else: |
|
seg_prob_tuple = ensure_tuple(seg_prob_out) |
|
is_tensor_output = False |
|
|
|
|
|
for ss, seg_prob in enumerate(seg_prob_tuple): |
|
seg_prob = seg_prob.to(device) |
|
|
|
|
|
zoom_scale = [] |
|
for axis, (img_s_i, out_w_i, in_w_i) in enumerate( |
|
zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) |
|
): |
|
_scale = out_w_i / float(in_w_i) |
|
if not (img_s_i * _scale).is_integer(): |
|
warnings.warn( |
|
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " |
|
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." |
|
) |
|
zoom_scale.append(_scale) |
|
|
|
if _initialized_ss < ss: |
|
|
|
output_classes = seg_prob.shape[1] |
|
output_shape = [batch_size, output_classes] + [ |
|
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) |
|
] |
|
|
|
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) |
|
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) |
|
_initialized_ss += 1 |
|
|
|
|
|
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) |
|
|
|
|
|
for idx, original_idx in zip(slice_range, unravel_slice): |
|
|
|
original_idx_zoom = list(original_idx) |
|
for axis in range(2, len(original_idx_zoom)): |
|
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] |
|
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] |
|
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): |
|
warnings.warn( |
|
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " |
|
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " |
|
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " |
|
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" |
|
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " |
|
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." |
|
) |
|
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) |
|
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) |
|
|
|
output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] |
|
count_map_list[ss][original_idx_zoom] += ( |
|
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) |
|
) |
|
|
|
|
|
for ss in range(len(output_image_list)): |
|
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) |
|
|
|
|
|
for ss, output_i in enumerate(output_image_list): |
|
if torch.isnan(output_i).any() or torch.isinf(output_i).any(): |
|
warnings.warn("Sliding window inference results contain NaN or Inf.") |
|
|
|
zoom_scale = [ |
|
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) |
|
] |
|
|
|
final_slicing: List[slice] = [] |
|
for sp in range(num_spatial_dims): |
|
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
|
slice_dim = slice( |
|
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), |
|
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), |
|
) |
|
final_slicing.insert(0, slice_dim) |
|
while len(final_slicing) < len(output_i.shape): |
|
final_slicing.insert(0, slice(None)) |
|
output_image_list[ss] = output_i[final_slicing] |
|
|
|
if dict_key is not None: |
|
final_output = dict(zip(dict_key, output_image_list)) |
|
else: |
|
final_output = tuple(output_image_list) |
|
final_output = final_output[0] if is_tensor_output else final_output |
|
|
|
if isinstance(inputs, MetaTensor): |
|
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] |
|
return final_output |
|
|
|
|
|
def _get_scan_interval( |
|
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
|
) -> Tuple[int, ...]: |
|
""" |
|
Compute scan interval according to the image size, roi size and overlap. |
|
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
|
use 1 instead to make sure sliding window works. |
|
|
|
""" |
|
if len(image_size) != num_spatial_dims: |
|
raise ValueError("image coord different from spatial dims.") |
|
if len(roi_size) != num_spatial_dims: |
|
raise ValueError("roi coord different from spatial dims.") |
|
|
|
scan_interval = [] |
|
for i in range(num_spatial_dims): |
|
if roi_size[i] == image_size[i]: |
|
scan_interval.append(int(roi_size[i])) |
|
else: |
|
interval = int(roi_size[i] * (1 - overlap)) |
|
scan_interval.append(interval if interval > 0 else 1) |
|
return tuple(scan_interval) |
|
|