fadindashfr
Add markdown description in interface
ed80dbe
raw
history blame
2.22 kB
from pathlib import Path
import torch
from monai.bundle import ConfigParser
import gradio as gr
import numpy as np
parser = ConfigParser()
parser.read_config(f="configs/inference.json")
parser.read_meta(f="configs/metadata.json")
inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)
label2color = {0: (0, 0, 0),
1: (225, 24, 69), # RED
2: (135, 233, 17), # GREEN
3: (0, 87, 233), # BLUE
4: (242, 202, 25), # YELLOW
5: (137, 49, 239),} # PURPLE
example_files = list(Path("sample_data").glob("*.png"))
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
data = {"image": img}
batch = preprocess(data)
network.eval()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network)
batch["pred"] = pred
for k,v in batch["pred"].items():
batch["pred"][k] = v.squeeze(dim=0)
batch = postprocess(batch)
result = visualize_instance_seg_mask(batch["type_map"].squeeze())
# Combine image
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
# Solve rotating problem
result = np.fliplr(result)
result = np.rot90(result, k=1)
return result
# load Markdown file
with open('Description.md','r') as file:
markdown_content = file.read()
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="filepath")],
outputs="image",
title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification",
description = markdown_content,
examples=example_files
)
demo.queue(concurrency_count=20).launch()