|
from pathlib import Path |
|
import torch |
|
from monai.bundle import ConfigParser |
|
import gradio as gr |
|
import numpy as np |
|
|
|
parser = ConfigParser() |
|
parser.read_config(f="configs/inference.json") |
|
parser.read_meta(f="configs/metadata.json") |
|
|
|
inference = parser.get_parsed_content("inferer") |
|
|
|
network = parser.get_parsed_content("network_def") |
|
preprocess = parser.get_parsed_content("preprocessing") |
|
postprocess = parser.get_parsed_content("postprocessing") |
|
|
|
state_dict = torch.load("models/model.pt") |
|
network.load_state_dict(state_dict, strict=True) |
|
|
|
label2color = {0: (0, 0, 0), |
|
1: (225, 24, 69), |
|
2: (135, 233, 17), |
|
3: (0, 87, 233), |
|
4: (242, 202, 25), |
|
5: (137, 49, 239),} |
|
|
|
example_files = list(Path("sample_data").glob("*.png")) |
|
|
|
def visualize_instance_seg_mask(mask): |
|
image = np.zeros((mask.shape[0], mask.shape[1], 3)) |
|
labels = np.unique(mask) |
|
for i in range(image.shape[0]): |
|
for j in range(image.shape[1]): |
|
image[i, j, :] = label2color[mask[i, j]] |
|
image = image / 255 |
|
return image |
|
|
|
def query_image(img): |
|
data = {"image": img} |
|
batch = preprocess(data) |
|
|
|
network.eval() |
|
with torch.no_grad(): |
|
pred = inference(batch['image'].unsqueeze(dim=0), network) |
|
|
|
batch["pred"] = pred |
|
for k,v in batch["pred"].items(): |
|
batch["pred"][k] = v.squeeze(dim=0) |
|
|
|
batch = postprocess(batch) |
|
|
|
result = visualize_instance_seg_mask(batch["type_map"].squeeze()) |
|
|
|
|
|
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5 |
|
|
|
|
|
result = np.fliplr(result) |
|
result = np.rot90(result, k=1) |
|
|
|
return result |
|
|
|
|
|
with open('Description.md','r') as file: |
|
markdown_content = file.read() |
|
|
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image(type="filepath")], |
|
outputs="image", |
|
title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification", |
|
description = markdown_content, |
|
examples=example_files |
|
) |
|
|
|
demo.queue(concurrency_count=20).launch() |
|
|