|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from typing import List, Union |
|
|
|
import torch |
|
from monai.apps.vista3d.inferer import point_based_window_inferer |
|
from monai.inferers import Inferer, SlidingWindowInfererAdapt |
|
from torch import Tensor |
|
|
|
|
|
class Vista3dInferer(Inferer): |
|
""" |
|
Vista3D Inferer |
|
|
|
Args: |
|
roi_size: the sliding window patch size. |
|
overlap: sliding window overlap ratio. |
|
""" |
|
|
|
def __init__( |
|
self, roi_size, overlap, use_point_window=False, sw_batch_size=1 |
|
) -> None: |
|
Inferer.__init__(self) |
|
self.roi_size = roi_size |
|
self.overlap = overlap |
|
self.sw_batch_size = sw_batch_size |
|
self.use_point_window = use_point_window |
|
|
|
def __call__( |
|
self, |
|
inputs: Union[List[Tensor], Tensor], |
|
network, |
|
point_coords, |
|
point_labels, |
|
class_vector, |
|
labels=None, |
|
label_set=None, |
|
prev_mask=None, |
|
): |
|
""" |
|
Unified callable function API of Inferers. |
|
Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1. |
|
It only used in interactive segmentation. |
|
|
|
Args: |
|
inputs: input tensor images. |
|
network: vista3d model. |
|
point_coords: point click coordinates. [B, N, 3]. |
|
point_labels: point click labels (0 for negative, 1 for positive) [B, N]. |
|
class_vector: class vector of length B. |
|
labels: groundtruth labels. Used for sampling validation points. |
|
label_set: [0,1,2,3,...,output_classes]. |
|
prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! |
|
|
|
""" |
|
prompt_class = copy.deepcopy(class_vector) |
|
if class_vector is not None: |
|
|
|
if hasattr(network, "point_head"): |
|
point_head = network.point_head |
|
elif hasattr(network, "module") and hasattr(network.module, "point_head"): |
|
point_head = network.module.point_head |
|
else: |
|
raise AttributeError("Network does not have attribute 'point_head'.") |
|
|
|
if torch.any(class_vector > point_head.last_supported): |
|
class_vector = None |
|
val_outputs = None |
|
torch.cuda.empty_cache() |
|
if self.use_point_window and point_coords is not None: |
|
if isinstance(inputs, list): |
|
device = inputs[0].device |
|
else: |
|
device = inputs.device |
|
val_outputs = point_based_window_inferer( |
|
inputs=inputs, |
|
roi_size=self.roi_size, |
|
sw_batch_size=self.sw_batch_size, |
|
transpose=True, |
|
with_coord=True, |
|
predictor=network, |
|
mode="gaussian", |
|
sw_device=device, |
|
device=device, |
|
overlap=self.overlap, |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
class_vector=class_vector, |
|
prompt_class=prompt_class, |
|
prev_mask=prev_mask, |
|
labels=labels, |
|
label_set=label_set, |
|
) |
|
else: |
|
val_outputs = SlidingWindowInfererAdapt( |
|
roi_size=self.roi_size, |
|
sw_batch_size=self.sw_batch_size, |
|
with_coord=True, |
|
padding_mode="replicate", |
|
)( |
|
inputs, |
|
network, |
|
transpose=True, |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
class_vector=class_vector, |
|
prompt_class=prompt_class, |
|
prev_mask=prev_mask, |
|
labels=labels, |
|
label_set=label_set, |
|
) |
|
return val_outputs |
|
|