muchlisinadi's picture
try to remove progress bar
8f845d2
raw
history blame
2.18 kB
from pathlib import Path
import torch
from monai.bundle import ConfigParser
import gradio as gr
import numpy as np
parser = ConfigParser()
parser.read_config(f="configs/inference.json")
parser.read_meta(f="configs/metadata.json")
inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)
label2color = {0: (0, 0, 0),
1: (225, 24, 69), # RED
2: (135, 233, 17), # GREEN
3: (0, 87, 233), # BLUE
4: (242, 202, 25), # YELLOW
5: (137, 49, 239),} # PURPLE
example_files = list(Path("sample_data").glob("*.png"))
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
data = {"image": img}
batch = preprocess(data)
network.eval()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network)
batch["pred"] = pred
for k,v in batch["pred"].items():
batch["pred"][k] = v.squeeze(dim=0)
batch = postprocess(batch)
result = visualize_instance_seg_mask(batch["type_map"].squeeze())
# Combine image
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
# Solve rotating problem
result = np.fliplr(result)
result = np.rot90(result, k=1)
return result
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="filepath")],
outputs="image",
title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification",
description = "Please upload an image to see segmentation capabilities of this model",
examples=example_files
)
demo.queue(concurrency_count=20).launch()