Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# CellViT Inference Method for Patch-Wise Inference on MoNuSeg dataset | |
# | |
# @ Fabian Hörst, [email protected] | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import argparse | |
import inspect | |
import os | |
import sys | |
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
parentdir = os.path.dirname(currentdir) | |
sys.path.insert(0, parentdir) | |
parentdir = os.path.dirname(parentdir) | |
sys.path.insert(0, parentdir) | |
from base_ml.base_experiment import BaseExperiment | |
BaseExperiment.seed_run(1232) | |
from pathlib import Path | |
from typing import List, Union, Tuple | |
import albumentations as A | |
import cv2 as cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import tqdm | |
from einops import rearrange | |
from matplotlib import pyplot as plt | |
from PIL import Image, ImageDraw | |
from skimage.color import rgba2rgb | |
from torch.utils.data import DataLoader | |
from torchmetrics.functional import dice | |
from torchmetrics.functional.classification import binary_jaccard_index | |
from torchvision import transforms | |
from cell_segmentation.datasets.monuseg import MoNuSegDataset | |
from cell_segmentation.inference.cell_detection import ( | |
CellPostProcessor, | |
get_cell_position, | |
get_cell_position_marging, | |
get_edge_patch, | |
) | |
from cell_segmentation.utils.metrics import ( | |
cell_detection_scores, | |
get_fast_pq, | |
remap_label, | |
) | |
from cell_segmentation.utils.post_proc_cellvit import calculate_instances | |
from cell_segmentation.utils.tools import pair_coordinates | |
from models.segmentation.cell_segmentation.cellvit import CellViT | |
from utils.logger import Logger | |
from utils.tools import unflatten_dict | |
class MoNuSegInference: | |
def __init__( | |
self, | |
model_path: Union[Path, str], | |
dataset_path: Union[Path, str], | |
outdir: Union[Path, str], | |
gpu: int, | |
patching: bool = False, | |
overlap: int = 0, | |
magnification: int = 40, | |
) -> None: | |
"""Cell Segmentation Inference class for MoNuSeg dataset | |
Args: | |
model_path (Union[Path, str]): Path to model checkpoint | |
dataset_path (Union[Path, str]): Path to dataset | |
outdir (Union[Path, str]): Output directory | |
gpu (int): CUDA GPU id to use | |
patching (bool, optional): If dataset should be pacthed to 256px. Defaults to False. | |
overlap (int, optional): If overlap should be used. Recommed (next to no overlap) is 64 px. Overlap in px. | |
If overlap is used, patching must be True. Defaults to 0. | |
magnification (int, optional): Dataset magnification. Defaults to 40. | |
""" | |
self.model_path = Path(model_path) | |
self.device = "cpu" | |
self.magnification = magnification | |
self.overlap = overlap | |
self.patching = patching | |
if overlap > 0: | |
assert patching, "Patching must be activated" | |
# self.__instantiate_logger() | |
self.__load_model() | |
# self.__load_inference_transforms() | |
self.__setup_amp() | |
def __instantiate_logger(self) -> None: | |
"""Instantiate logger | |
Logger is using no formatters. Logs are stored in the run directory under the filename: inference.log | |
""" | |
logger = Logger( | |
level="INFO", | |
log_dir=self.outdir, | |
comment="inference_monuseg", | |
use_timestamp=False, | |
formatter="%(message)s", | |
) | |
self.logger = logger.create_logger() | |
def __load_model(self) -> None: | |
"""Load model and checkpoint and load the state_dict""" | |
model_checkpoint = torch.load(self.model_path, map_location="cpu") | |
# unpack checkpoint | |
self.run_conf = unflatten_dict(model_checkpoint["config"], ".") | |
self.model = self.__get_model(model_type=model_checkpoint["arch"]) | |
self.model.load_state_dict(model_checkpoint["model_state_dict"]) | |
self.model.eval() | |
self.model.to(self.device) | |
def __get_model( | |
self, model_type: str | |
) -> Union[ | |
CellViT, | |
]: | |
"""Return the trained model for inference | |
Args: | |
model_type (str): Name of the model. Must either be one of: | |
CellViT, CellViTShared, CellViT256, CellViT256Shared, CellViTSAM, CellViTSAMShared | |
Returns: | |
Union[CellViT, CellViTShared, CellViT256, CellViTShared, CellViTSAM, CellViTSAMShared]: Model | |
""" | |
implemented_models = [ | |
"CellViT", | |
] | |
if model_type not in implemented_models: | |
raise NotImplementedError( | |
f"Unknown model type. Please select one of {implemented_models}" | |
) | |
if model_type in ["CellViT", "CellViTShared"]: | |
if model_type == "CellViT": | |
model_class = CellViT | |
model = model_class( | |
model256_path=self.run_conf["model"].get("pretrained_encoder"), | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
#embed_dim=self.run_conf["model"]["embed_dim"], | |
in_channels=self.run_conf["model"].get("input_channels", 3), | |
#depth=self.run_conf["model"]["depth"], | |
#num_heads=self.run_conf["model"]["num_heads"], | |
#extract_layers=self.run_conf["model"]["extract_layers"], | |
#regression_loss=self.run_conf["model"].get("regression_loss", False), | |
) | |
return model | |
def __load_inference_transforms(self) -> None: | |
"""Load the inference transformations from the run_configuration""" | |
self.logger.info("Loading inference transformations") | |
transform_settings = self.run_conf["transformations"] | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
self.inference_transforms = A.Compose([A.Normalize(mean=mean, std=std)]) | |
def __setup_amp(self) -> None: | |
"""Setup automated mixed precision (amp) for inference.""" | |
self.mixed_precision = self.run_conf["training"].get("mixed_precision", False) | |
def run_inference(self, generate_plots: bool = False) -> None: | |
"""Run inference | |
Args: | |
generate_plots (bool, optional): If plots should be generated. Defaults to False. | |
""" | |
self.model.eval() | |
# setup score tracker | |
image_names = [] # image names as str | |
binary_dice_scores = [] # binary dice scores per image | |
binary_jaccard_scores = [] # binary jaccard scores per image | |
pq_scores = [] # pq-scores per image | |
dq_scores = [] # dq-scores per image | |
sq_scores = [] # sq-scores per image | |
f1_ds = [] # f1-scores per image | |
prec_ds = [] # precision per image | |
rec_ds = [] # recall per image | |
inference_loop = tqdm.tqdm( | |
enumerate(self.inference_dataloader), total=len(self.inference_dataloader) | |
) | |
with torch.no_grad(): | |
for image_idx, batch in inference_loop: | |
image_metrics = self.inference_step( | |
model=self.model, batch=batch, generate_plots=generate_plots | |
) | |
image_names.append(image_metrics["image_name"]) | |
binary_dice_scores.append(image_metrics["binary_dice_score"]) | |
binary_jaccard_scores.append(image_metrics["binary_jaccard_score"]) | |
pq_scores.append(image_metrics["pq_score"]) | |
dq_scores.append(image_metrics["dq_score"]) | |
sq_scores.append(image_metrics["sq_score"]) | |
f1_ds.append(image_metrics["f1_d"]) | |
prec_ds.append(image_metrics["prec_d"]) | |
rec_ds.append(image_metrics["rec_d"]) | |
# average metrics for dataset | |
binary_dice_scores = np.array(binary_dice_scores) | |
binary_jaccard_scores = np.array(binary_jaccard_scores) | |
pq_scores = np.array(pq_scores) | |
dq_scores = np.array(dq_scores) | |
sq_scores = np.array(sq_scores) | |
f1_ds = np.array(f1_ds) | |
prec_ds = np.array(prec_ds) | |
rec_ds = np.array(rec_ds) | |
dataset_metrics = { | |
"Binary-Cell-Dice-Mean": float(np.nanmean(binary_dice_scores)), | |
"Binary-Cell-Jacard-Mean": float(np.nanmean(binary_jaccard_scores)), | |
"bPQ": float(np.nanmean(pq_scores)), | |
"bDQ": float(np.nanmean(dq_scores)), | |
"bSQ": float(np.nanmean(sq_scores)), | |
"f1_detection": float(np.nanmean(f1_ds)), | |
"precision_detection": float(np.nanmean(prec_ds)), | |
"recall_detection": float(np.nanmean(rec_ds)), | |
} | |
self.logger.info(f"{20*'*'} Binary Dataset metrics {20*'*'}") | |
[self.logger.info(f"{f'{k}:': <25} {v}") for k, v in dataset_metrics.items()] | |
def inference_step( | |
self, model: nn.Module, batch: object, generate_plots: bool = False | |
) -> dict: | |
"""Inference step | |
Args: | |
model (nn.Module): Training model, must return "nuclei_binary_map", "nuclei_type_map", "tissue_type" and "hv_map" | |
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3]) | |
generate_plots (bool, optional): If plots should be generated. Defaults to False. | |
Returns: | |
Dict: Image_metrics with keys: | |
""" | |
img = batch[0].to(self.device) | |
if len(img.shape) > 4: | |
img = img[0] | |
img = rearrange(img, "c i j w h -> (i j) c w h") | |
mask = batch[1] | |
image_name = list(batch[2]) | |
mask["instance_types"] = calculate_instances( | |
torch.unsqueeze(mask["nuclei_binary_map"], dim=0), mask["instance_map"] | |
) | |
model.zero_grad() | |
if self.mixed_precision: | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
predictions_ = model.forward(img) | |
else: | |
predictions_ = model.forward(img) | |
if self.overlap == 0: | |
if self.patching: | |
predictions_ = self.post_process_patching(predictions_) | |
predictions = self.get_cell_predictions(predictions_) | |
image_metrics = self.calculate_step_metric( | |
predictions=predictions, gt=mask, image_name=image_name | |
) | |
elif self.patching and self.overlap != 0: | |
cell_list = self.post_process_patching_overlap( | |
predictions_, overlap=self.overlap | |
) | |
image_metrics, predictions = self.calculate_step_metric_overlap( | |
cell_list=cell_list, gt=mask, image_name=image_name | |
) | |
scores = [ | |
float(image_metrics["binary_dice_score"].detach().cpu()), | |
float(image_metrics["binary_jaccard_score"].detach().cpu()), | |
image_metrics["pq_score"], | |
] | |
if generate_plots: | |
if self.overlap == 0 and self.patching: | |
batch_size = img.shape[0] | |
num_elems = int(np.sqrt(batch_size)) | |
img = torch.permute(img, (0, 2, 3, 1)) | |
img = rearrange( | |
img, "(i j) h w c -> (i h) (j w) c", i=num_elems, j=num_elems | |
) | |
img = torch.unsqueeze(img, dim=0) | |
img = torch.permute(img, (0, 3, 1, 2)) | |
elif self.overlap != 0 and self.patching: | |
h, w = mask["nuclei_binary_map"].shape[1:] | |
total_img = torch.zeros((3, h, w)) | |
decomposed_patch_num = int(np.sqrt(img.shape[0])) | |
for i in range(decomposed_patch_num): | |
for j in range(decomposed_patch_num): | |
x_global = i * 256 - i * self.overlap | |
y_global = j * 256 - j * self.overlap | |
total_img[ | |
:, x_global : x_global + 256, y_global : y_global + 256 | |
] = img[i * decomposed_patch_num + j] | |
img = total_img | |
img = img[None, :, :, :] | |
self.plot_results( | |
img=img, | |
predictions=predictions, | |
ground_truth=mask, | |
img_name=image_name[0], | |
scores=scores, | |
) | |
return image_metrics | |
def run_single_image_inference(self, model: nn.Module, image: np.ndarray, generate_plots: bool = True, | |
) -> dict: | |
"""Inference step | |
Args: | |
model (nn.Module): Training model, must return "nuclei_binary_map", "nuclei_type_map", "tissue_type" and "hv_map" | |
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3]) | |
generate_plots (bool, optional): If plots should be generated. Defaults to False. | |
Returns: | |
Dict: Image_metrics with keys: | |
""" | |
# set image transforms | |
transform_settings = self.run_conf["transformations"] | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
transforms = A.Compose([A.Normalize(mean=mean, std=std)]) | |
transformed_img = transforms(image=image)["image"] | |
image = torch.from_numpy(transformed_img).permute(2, 0, 1).unsqueeze(0).float() | |
img = image.to(self.device) | |
model.zero_grad() | |
predictions_ = model.forward(img) | |
if self.overlap == 0: | |
if self.patching: | |
predictions_ = self.post_process_patching(predictions_) | |
predictions = self.get_cell_predictions(predictions_) | |
image_output = self.plot_results( | |
img=img, | |
predictions=predictions | |
) | |
return image_output | |
def calculate_step_metric( | |
self, predictions: dict, gt: dict, image_name: List[str] | |
) -> dict: | |
"""Calculate step metric for one MoNuSeg image. | |
Args: | |
predictions (dict): Necssary keys: | |
* instance_map: Pixel-wise nuclear instance segmentation. | |
Each instance has its own integer, starting from 1. Shape: (1, H, W) | |
* nuclei_binary_map: Softmax output for binary nuclei branch. Shape: (1, 2, H, W) | |
* instance_types: Instance type prediction list. | |
Each list entry stands for one image. Each list entry is a dictionary with the following structure: | |
Main Key is the nuclei instance number (int), with a dict as value. | |
For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), | |
contour, type_prob (probability), type (nuclei type). Actually just one list entry, as we expecting batch-size=1 (one image) | |
gt (dict): Necessary keys: | |
* instance_map | |
* nuclei_binary_map | |
* instance_types | |
image_name (List[str]): Name of the image, list with [str]. List is necessary for backward compatibility | |
Returns: | |
dict: Image metrics for one MoNuSeg image. Keys are: | |
* image_name | |
* binary_dice_score | |
* binary_jaccard_score | |
* pq_score | |
* dq_score | |
* sq_score | |
* f1_d | |
* prec_d | |
* rec_d | |
""" | |
predictions["instance_map"] = predictions["instance_map"].detach().cpu() | |
instance_maps_gt = gt["instance_map"].detach().cpu() | |
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"], dim=1) | |
target_binary_map = gt["nuclei_binary_map"].to(self.device) | |
cell_dice = ( | |
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0) | |
.detach() | |
.cpu() | |
) | |
cell_jaccard = ( | |
binary_jaccard_index( | |
preds=pred_binary_map, | |
target=target_binary_map, | |
) | |
.detach() | |
.cpu() | |
) | |
remapped_instance_pred = remap_label(predictions["instance_map"]) | |
remapped_gt = remap_label(instance_maps_gt) | |
[dq, sq, pq], _ = get_fast_pq(true=remapped_gt, pred=remapped_instance_pred) | |
# detection scores | |
true_centroids = np.array( | |
[v["centroid"] for k, v in gt["instance_types"][0].items()] | |
) | |
pred_centroids = np.array( | |
[v["centroid"] for k, v in predictions["instance_types"].items()] | |
) | |
if true_centroids.shape[0] == 0: | |
true_centroids = np.array([[0, 0]]) | |
if pred_centroids.shape[0] == 0: | |
pred_centroids = np.array([[0, 0]]) | |
if self.magnification == 40: | |
pairing_radius = 12 | |
else: | |
pairing_radius = 6 | |
paired, unpaired_true, unpaired_pred = pair_coordinates( | |
true_centroids, pred_centroids, pairing_radius | |
) | |
f1_d, prec_d, rec_d = cell_detection_scores( | |
paired_true=paired[:, 0], | |
paired_pred=paired[:, 1], | |
unpaired_true=unpaired_true, | |
unpaired_pred=unpaired_pred, | |
) | |
image_metrics = { | |
"image_name": image_name, | |
"binary_dice_score": cell_dice, | |
"binary_jaccard_score": cell_jaccard, | |
"pq_score": pq, | |
"dq_score": dq, | |
"sq_score": sq, | |
"f1_d": f1_d, | |
"prec_d": prec_d, | |
"rec_d": rec_d, | |
} | |
return image_metrics | |
def convert_binary_type(self, instance_types: dict) -> dict: | |
"""Clean nuclei detection from type prediction to binary prediction | |
Args: | |
instance_types (dict): Dictionary with cells | |
Returns: | |
dict: Cleaned with just one class | |
""" | |
cleaned_instance_types = {} | |
for key, elem in instance_types.items(): | |
if elem["type"] == 0: | |
continue | |
else: | |
elem["type"] = 0 | |
cleaned_instance_types[key] = elem | |
return cleaned_instance_types | |
def get_cell_predictions(self, predictions: dict) -> dict: | |
"""Reshaping predictions and calculating instance maps and instance types | |
Args: | |
predictions (dict): Dictionary with the following keys: | |
* tissue_types: Logit tissue prediction output. Shape: (B, num_tissue_classes) | |
* nuclei_binary_map: Logit output for binary nuclei prediction branch. Shape: (B, H, W, 2) | |
* hv_map: Logit output for hv-prediction. Shape: (B, 2, H, W) | |
* nuclei_type_map: Logit output for nuclei instance-prediction. Shape: (B, num_nuclei_classes, H, W) | |
Returns: | |
dict: | |
* nuclei_binary_map: Softmax binary prediction. Shape: (B, 2, H, W | |
* nuclei_type_map: Softmax nuclei type map. Shape: (B, num_nuclei_classes, H, W) | |
* hv_map: Logit output for hv-prediction. Shape: (B, 2, H, W) | |
* tissue_types: Logit tissue prediction output. Shape: (B, num_tissue_classes) | |
* instance_map: Instance map, each instance has one integer. Shape: (B, H, W) | |
* instance_types: Instance type dict, cleaned. Keys: | |
'bbox', 'centroid', 'contour', 'type_prob', 'type' | |
""" | |
predictions["nuclei_binary_map"] = F.softmax( | |
predictions["nuclei_binary_map"], dim=1 | |
) | |
predictions["nuclei_type_map"] = F.softmax( | |
predictions["nuclei_type_map"], dim=1 | |
) | |
( | |
predictions["instance_map"], | |
predictions["instance_types"], | |
) = self.model.calculate_instance_map( | |
predictions, magnification=self.magnification | |
) | |
predictions["instance_types"] = self.convert_binary_type( | |
predictions["instance_types"][0] | |
) | |
return predictions | |
def post_process_patching(self, predictions: dict) -> dict: | |
"""Post-process patching by reassamble (without overlap) stitched predictions to one big image prediction | |
Args: | |
predictions (dict): Necessary keys: | |
* nuclei_binary_map: Logit binary prediction. Shape: (B, 2, 256, 256) | |
* hv_map: Logit output for hv-prediction. Shape: (B, 2, H, W) | |
* nuclei_type_map: Logit output for nuclei instance-prediction. Shape: (B, num_nuclei_classes, 256, 256) | |
Returns: | |
dict: Return elements that have been changed: | |
* nuclei_binary_map: Shape: (1, 2, H, W) | |
* hv_map: Shape: (1, 2, H, W) | |
* nuclei_type_map: (1, num_nuclei_classes, H, W) | |
""" | |
batch_size = predictions["nuclei_binary_map"].shape[0] | |
num_elems = int(np.sqrt(batch_size)) | |
predictions["nuclei_binary_map"] = rearrange( | |
predictions["nuclei_binary_map"], | |
"(i j) d w h ->d (i w) (j h)", | |
i=num_elems, | |
j=num_elems, | |
) | |
predictions["hv_map"] = rearrange( | |
predictions["hv_map"], | |
"(i j) d w h -> d (i w) (j h)", | |
i=num_elems, | |
j=num_elems, | |
) | |
predictions["nuclei_type_map"] = rearrange( | |
predictions["nuclei_type_map"], | |
"(i j) d w h -> d (i w) (j h)", | |
i=num_elems, | |
j=num_elems, | |
) | |
predictions["nuclei_binary_map"] = torch.unsqueeze( | |
predictions["nuclei_binary_map"], dim=0 | |
) | |
predictions["hv_map"] = torch.unsqueeze(predictions["hv_map"], dim=0) | |
predictions["nuclei_type_map"] = torch.unsqueeze( | |
predictions["nuclei_type_map"], dim=0 | |
) | |
return predictions | |
def post_process_patching_overlap(self, predictions: dict, overlap: int) -> List: | |
"""Post processing overlapping cells by merging overlap. Use same merging strategy as for our | |
Args: | |
predictions (dict): Predictions with necessary keys: | |
* nuclei_binary_map: Binary nuclei prediction, Shape: (B, 2, H, W) | |
* nuclei_type_map: Nuclei type prediction, Shape: (B, num_nuclei_classes, H, W) | |
* hv_map: Binary HV Map predictions. Shape: (B, 2, H, W) | |
overlap (int): Used overlap as integer | |
Returns: | |
List: Cleaned (merged) cell list with each entry beeing one detected cell with dictionary as entries. | |
""" | |
predictions["nuclei_binary_map"] = F.softmax( | |
predictions["nuclei_binary_map"], dim=1 | |
) | |
predictions["nuclei_type_map"] = F.softmax( | |
predictions["nuclei_type_map"], dim=1 | |
) | |
( | |
predictions["instance_map"], | |
predictions["instance_types"], | |
) = self.model.calculate_instance_map( | |
predictions, magnification=self.magnification | |
) | |
predictions = self.merge_predictions(predictions, overlap) | |
return predictions | |
def merge_predictions(self, predictions: dict, overlap: int) -> list: | |
"""Merge overlapping cell predictions | |
Args: | |
predictions (dict): Predictions with necessary keys: | |
* nuclei_binary_map: Binary nuclei prediction, Shape: (B, 2, H, W) | |
* instance_types: Instance type dictionary with cell entries | |
overlap (int): Used overlap as integer | |
Returns: | |
list: Cleaned (merged) cell list with each entry beeing one detected cell with dictionary as entries. | |
""" | |
cell_list = [] | |
decomposed_patch_num = int(np.sqrt(predictions["nuclei_binary_map"].shape[0])) | |
for i in range(decomposed_patch_num): | |
for j in range(decomposed_patch_num): | |
x_global = i * 256 - i * overlap | |
y_global = j * 256 - j * overlap | |
patch_instance_types = predictions["instance_types"][ | |
i * decomposed_patch_num + j | |
] | |
for cell in patch_instance_types.values(): | |
if cell["type"] == 0: | |
continue | |
offset_global = np.array([x_global, y_global]) | |
centroid_global = cell["centroid"] + np.flip(offset_global) | |
contour_global = cell["contour"] + np.flip(offset_global) | |
bbox_global = cell["bbox"] + offset_global | |
cell_dict = { | |
"bbox": bbox_global.tolist(), | |
"centroid": centroid_global.tolist(), | |
"contour": contour_global.tolist(), | |
"type_prob": cell["type_prob"], | |
"type": cell["type"], | |
"patch_coordinates": [ | |
i, # row | |
j, # col | |
], | |
"cell_status": get_cell_position_marging(cell["bbox"], 256, 64), | |
"offset_global": offset_global.tolist(), | |
} | |
if np.max(cell["bbox"]) == 256 or np.min(cell["bbox"]) == 0: | |
position = get_cell_position(cell["bbox"], 256) | |
cell_dict["edge_position"] = True | |
cell_dict["edge_information"] = {} | |
cell_dict["edge_information"]["position"] = position | |
cell_dict["edge_information"]["edge_patches"] = get_edge_patch( | |
position, i, j # row, col | |
) | |
else: | |
cell_dict["edge_position"] = False | |
cell_list.append(cell_dict) | |
self.logger.info(f"Detected cells before cleaning: {len(cell_list)}") | |
cell_processor = CellPostProcessor(cell_list, self.logger) | |
cleaned_cells = cell_processor.post_process_cells() | |
cell_list = [cell_list[idx_c] for idx_c in cleaned_cells.index.values] | |
self.logger.info(f"Detected cells after cleaning: {len(cell_list)}") | |
return cell_list | |
def calculate_step_metric_overlap( | |
self, cell_list: List[dict], gt: dict, image_name: List[str] | |
) -> Tuple[dict, dict]: | |
"""Calculate step metric and return merged predictions for plotting | |
Args: | |
cell_list (List[dict]): List with cell dicts | |
gt (dict): Ground-Truth dictionary | |
image_name (List[str]): Image Name as list with just one entry | |
Returns: | |
Tuple[dict, dict]: | |
dict: Image metrics for one MoNuSeg image. Keys are: | |
* image_name | |
* binary_dice_score | |
* binary_jaccard_score | |
* pq_score | |
* dq_score | |
* sq_score | |
* f1_d | |
* prec_d | |
* rec_d | |
dict: Predictions, reshaped for one image and for plotting | |
* nuclei_binary_map: Shape (1, 2, 1024, 1024) or (1, 2, 1024, 1024) | |
* instance_map: Shape (1, 1024, 1024) or or (1, 2, 512, 512) | |
* instance_types: Dict for each nuclei | |
""" | |
predictions = {} | |
h, w = gt["nuclei_binary_map"].shape[1:] | |
instance_type_map = np.zeros((h, w), dtype=np.int32) | |
for instance, cell in enumerate(cell_list): | |
contour = np.array(cell["contour"])[None, :, :] | |
cv2.fillPoly(instance_type_map, contour, instance) | |
predictions["instance_map"] = torch.Tensor(instance_type_map) | |
instance_maps_gt = gt["instance_map"].detach().cpu() | |
pred_arr = np.clip(instance_type_map, 0, 1) | |
target_binary_map = gt["nuclei_binary_map"].to(self.device).squeeze() | |
predictions["nuclei_binary_map"] = pred_arr | |
predictions["instance_types"] = cell_list | |
cell_dice = ( | |
dice( | |
preds=torch.Tensor(pred_arr).to(self.device), | |
target=target_binary_map, | |
ignore_index=0, | |
) | |
.detach() | |
.cpu() | |
) | |
cell_jaccard = ( | |
binary_jaccard_index( | |
preds=torch.Tensor(pred_arr).to(self.device), | |
target=target_binary_map, | |
) | |
.detach() | |
.cpu() | |
) | |
remapped_instance_pred = remap_label(predictions["instance_map"])[None, :, :] | |
remapped_gt = remap_label(instance_maps_gt) | |
[dq, sq, pq], _ = get_fast_pq(true=remapped_gt, pred=remapped_instance_pred) | |
# detection scores | |
true_centroids = np.array( | |
[v["centroid"] for k, v in gt["instance_types"][0].items()] | |
) | |
pred_centroids = np.array([v["centroid"] for v in cell_list]) | |
if true_centroids.shape[0] == 0: | |
true_centroids = np.array([[0, 0]]) | |
if pred_centroids.shape[0] == 0: | |
pred_centroids = np.array([[0, 0]]) | |
if self.magnification == 40: | |
pairing_radius = 12 | |
else: | |
pairing_radius = 6 | |
paired, unpaired_true, unpaired_pred = pair_coordinates( | |
true_centroids, pred_centroids, pairing_radius | |
) | |
f1_d, prec_d, rec_d = cell_detection_scores( | |
paired_true=paired[:, 0], | |
paired_pred=paired[:, 1], | |
unpaired_true=unpaired_true, | |
unpaired_pred=unpaired_pred, | |
) | |
image_metrics = { | |
"image_name": image_name, | |
"binary_dice_score": cell_dice, | |
"binary_jaccard_score": cell_jaccard, | |
"pq_score": pq, | |
"dq_score": dq, | |
"sq_score": sq, | |
"f1_d": f1_d, | |
"prec_d": prec_d, | |
"rec_d": rec_d, | |
} | |
# align to common shapes | |
cleaned_instance_types = { | |
k + 1: v for k, v in enumerate(predictions["instance_types"]) | |
} | |
for cell, results in cleaned_instance_types.items(): | |
results["contour"] = np.array(results["contour"]) | |
cleaned_instance_types[cell] = results | |
predictions["instance_types"] = cleaned_instance_types | |
predictions["instance_map"] = predictions["instance_map"][None, :, :] | |
predictions["nuclei_binary_map"] = F.one_hot( | |
torch.Tensor(predictions["nuclei_binary_map"]).type(torch.int64), | |
num_classes=2, | |
).permute(2, 0, 1)[None, :, :, :] | |
return image_metrics, predictions | |
def plot_results( | |
self, | |
img: torch.Tensor, | |
predictions: dict, | |
) -> None: | |
"""Plot MoNuSeg results | |
Args: | |
img (torch.Tensor): Image as torch.Tensor, with Shape (1, 3, 1024, 1024) or (1, 3, 512, 512) | |
predictions (dict): Prediction dictionary. Necessary keys: | |
* nuclei_binary_map: Shape (1, 2, 1024, 1024) or (1, 2, 512, 512) | |
* instance_map: Shape (1, 1024, 1024) or (1, 512, 512) | |
* instance_types: List[dict], but just one entry in list | |
ground_truth (dict): Ground-Truth dictionary. Necessary keys: | |
* nuclei_binary_map: (1, 1024, 1024) or or (1, 512, 512) | |
* instance_map: (1, 1024, 1024) or or (1, 512, 512) | |
* instance_types: List[dict], but just one entry in list | |
img_name (str): Image name as string | |
outdir (Path): Output directory for storing | |
scores (List[float]): Scores as list [Dice, Jaccard, bPQ] | |
""" | |
predictions["nuclei_binary_map"] = predictions["nuclei_binary_map"].permute( | |
0, 2, 3, 1 | |
) | |
h = predictions["instance_map"].shape[1] | |
w = predictions["instance_map"].shape[2] | |
# process image and other maps | |
sample_image = img.permute(0, 2, 3, 1).contiguous().cpu().numpy() | |
pred_sample_binary_map = ( | |
predictions["nuclei_binary_map"][:, :, :, 1].detach().cpu().numpy() | |
)[0] | |
pred_sample_instance_maps = ( | |
predictions["instance_map"].detach().cpu().numpy()[0] | |
) | |
binary_cmap = plt.get_cmap("Greys_r") | |
instance_map = plt.get_cmap("viridis") | |
# invert the normalization of the sample images | |
transform_settings = self.run_conf["transformations"] | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
inv_normalize = transforms.Normalize( | |
mean=[-0.5 / mean[0], -0.5 / mean[1], -0.5 / mean[2]], | |
std=[1 / std[0], 1 / std[1], 1 / std[2]], | |
) | |
inv_samples = inv_normalize(torch.tensor(sample_image).permute(0, 3, 1, 2)) | |
sample_image = inv_samples.permute(0, 2, 3, 1).detach().cpu().numpy()[0] | |
# start overlaying on image | |
placeholder = np.zeros(( h, 4 * w, 3)) | |
# orig image | |
placeholder[:h, :w, :3] = sample_image | |
# binary prediction | |
placeholder[:h, w : 2 * w, :3] = rgba2rgb( | |
binary_cmap(pred_sample_binary_map * 255) | |
) | |
# instance_predictions | |
placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb( | |
instance_map( | |
(pred_sample_instance_maps - np.min(pred_sample_instance_maps)) | |
/ ( | |
np.max(pred_sample_instance_maps) | |
- np.min(pred_sample_instance_maps + 1e-10) | |
) | |
) | |
) | |
# pred | |
pred_contours_polygon = [ | |
v["contour"] for v in predictions["instance_types"].values() | |
] | |
pred_contours_polygon = [ | |
list(zip(poly[:, 0], poly[:, 1])) for poly in pred_contours_polygon | |
] | |
pred_contour_colors_polygon = [ | |
"#70c6ff" for i in range(len(pred_contours_polygon)) | |
] | |
pred_cell_image = Image.fromarray( | |
(sample_image * 255).astype(np.uint8) | |
).convert("RGB") | |
pred_drawing = ImageDraw.Draw(pred_cell_image) | |
add_patch = lambda poly, color: pred_drawing.polygon( | |
poly, outline=color, width=2 | |
) | |
[ | |
add_patch(poly, c) | |
for poly, c in zip(pred_contours_polygon, pred_contour_colors_polygon) | |
] | |
placeholder[: h, 3 * w : 4 * w, :3] = np.asarray(pred_cell_image) / 255 | |
pred_cell_image.save("raw_pred.png") | |
# plotting | |
fig, axs = plt.subplots(figsize=(3, 2), dpi=1200) | |
axs.imshow(placeholder) | |
axs.set_xticks(np.arange(w / 2, 4 * w, w)) | |
axs.set_xticklabels( | |
[ | |
"Image", | |
"Binary-Cells", | |
"Instances", | |
"Countours", | |
], | |
fontsize=6, | |
) | |
axs.xaxis.tick_top() | |
axs.set_yticks([h / 2]) | |
axs.set_yticklabels([ "Pred."], fontsize=6) | |
axs.tick_params(axis="both", which="both", length=0) | |
grid_x = np.arange(w, 3 * w, w) | |
grid_y = np.arange(h, 2 * h, h) | |
for x_seg in grid_x: | |
axs.axvline(x_seg, color="black") | |
for y_seg in grid_y: | |
axs.axhline(y_seg, color="black") | |
fig.suptitle(f"All Predictions for input image", fontsize=6) | |
fig.tight_layout() | |
fig.savefig("pred_img.png") | |
plt.close() | |
# CLI | |
class InferenceCellViTMoNuSegParser: | |
def __init__(self) -> None: | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
description="Perform CellViT inference for MoNuSeg dataset", | |
) | |
parser.add_argument( | |
"--model", | |
type=str, | |
help="Model checkpoint file that is used for inference", | |
default="./model_best.pth", | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
help="Path to MoNuSeg dataset.", | |
default="/data/lunbinzeng/datasets/monuseg/testing/", | |
) | |
parser.add_argument( | |
"--outdir", | |
type=str, | |
help="Path to output directory to store results.", | |
default="/data/lunbinzeng/results/lkcell/small/2024-04-22T232903_CellViT-unireplknet-fold1-final/monuseg/inference/", | |
) | |
parser.add_argument( | |
"--gpu", type=int, help="Cuda-GPU ID for inference. Default: 0", default=0 | |
) | |
parser.add_argument( | |
"--magnification", | |
type=int, | |
help="Dataset Magnification. Either 20 or 40. Default: 40", | |
choices=[20, 40], | |
default=20, | |
) | |
parser.add_argument( | |
"--patching", | |
type=bool, | |
help="Patch to 256px images. Default: False", | |
default=False, | |
) | |
parser.add_argument( | |
"--overlap", | |
type=int, | |
help="Patch overlap, just valid for patching", | |
default=0, | |
) | |
parser.add_argument( | |
"--plots", | |
type=bool, | |
help="Generate result plots. Default: False", | |
default=True, | |
) | |
self.parser = parser | |
def parse_arguments(self) -> dict: | |
opt = self.parser.parse_args() | |
return vars(opt) | |
if __name__ == "__main__": | |
configuration_parser = InferenceCellViTMoNuSegParser() | |
configuration = configuration_parser.parse_arguments() | |
print(configuration) | |
inf = MoNuSegInference( | |
model_path=configuration["model"], | |
dataset_path=configuration["dataset"], | |
outdir=configuration["outdir"], | |
gpu=configuration["gpu"], | |
patching=configuration["patching"], | |
magnification=configuration["magnification"], | |
overlap=configuration["overlap"], | |
) | |
inf.run_inference(generate_plots=configuration["plots"]) | |