VISTA3D-HF / scripts /inferer.py
BinLiunls's picture
init version
08efd84
raw
history blame
4.45 kB
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Union
import torch
from monai.apps.vista3d.inferer import point_based_window_inferer
from monai.inferers import Inferer, SlidingWindowInfererAdapt
from torch import Tensor
class Vista3dInferer(Inferer):
"""
Vista3D Inferer
Args:
roi_size: the sliding window patch size.
overlap: sliding window overlap ratio.
"""
def __init__(
self, roi_size, overlap, use_point_window=False, sw_batch_size=1
) -> None:
Inferer.__init__(self)
self.roi_size = roi_size
self.overlap = overlap
self.sw_batch_size = sw_batch_size
self.use_point_window = use_point_window
def __call__(
self,
inputs: Union[List[Tensor], Tensor],
network,
point_coords,
point_labels,
class_vector,
labels=None,
label_set=None,
prev_mask=None,
):
"""
Unified callable function API of Inferers.
Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1.
It only used in interactive segmentation.
Args:
inputs: input tensor images.
network: vista3d model.
point_coords: point click coordinates. [B, N, 3].
point_labels: point click labels (0 for negative, 1 for positive) [B, N].
class_vector: class vector of length B.
labels: groundtruth labels. Used for sampling validation points.
label_set: [0,1,2,3,...,output_classes].
prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID!
"""
prompt_class = copy.deepcopy(class_vector)
if class_vector is not None:
# Check if network has attribute 'point_head' directly or within its 'module'
if hasattr(network, "point_head"):
point_head = network.point_head
elif hasattr(network, "module") and hasattr(network.module, "point_head"):
point_head = network.module.point_head
else:
raise AttributeError("Network does not have attribute 'point_head'.")
if torch.any(class_vector > point_head.last_supported):
class_vector = None
val_outputs = None
torch.cuda.empty_cache()
if self.use_point_window and point_coords is not None:
if isinstance(inputs, list):
device = inputs[0].device
else:
device = inputs.device
val_outputs = point_based_window_inferer(
inputs=inputs,
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
transpose=True,
with_coord=True,
predictor=network,
mode="gaussian",
sw_device=device,
device=device,
overlap=self.overlap,
point_coords=point_coords,
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
prev_mask=prev_mask,
labels=labels,
label_set=label_set,
)
else:
val_outputs = SlidingWindowInfererAdapt(
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
with_coord=True,
padding_mode="replicate",
)(
inputs,
network,
transpose=True,
point_coords=point_coords,
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
prev_mask=prev_mask,
labels=labels,
label_set=label_set,
)
return val_outputs