File size: 2,221 Bytes
54f43fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50321bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pathlib import Path
import torch
from monai.bundle import ConfigParser
import gradio as gr
import numpy as np

parser = ConfigParser()
parser.read_config(f="configs/inference.json")
parser.read_meta(f="configs/metadata.json")

inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")

state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)

label2color = {0: (0, 0, 0),
             1: (225, 24, 69), # RED
             2: (135, 233, 17), # GREEN
             3: (0, 87, 233), # BLUE
             4: (242, 202, 25), # YELLOW
             5: (137, 49, 239),} # PURPLE

example_files = list(Path("sample_data").glob("*.png"))

def visualize_instance_seg_mask(mask):
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label2color[mask[i, j]]
    image = image / 255
    return image

def query_image(img, progress=gr.Progress(track_tqdm=True)):
    data = {"image": img}
    batch = preprocess(data)
    
    network.eval()
    with torch.no_grad():
        pred = inference(batch['image'].unsqueeze(dim=0), network)
        
    batch["pred"] = pred
    for k,v in batch["pred"].items():
        batch["pred"][k] = v.squeeze(dim=0)
    
    batch = postprocess(batch)
    
    result = visualize_instance_seg_mask(batch["type_map"].squeeze())
    
    # Combine image
    result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
    
    # Solve rotating problem
    result = np.fliplr(result)
    result = np.rot90(result, k=1)
    
    return result
    
demo = gr.Interface(
    query_image, 
    inputs=[gr.Image(type="filepath")], 
    outputs="image",
    title="Medical Image Classification with MONAI - Pathology Nuclei Segmentation Classification",
    description = "Please upload an image to see segmentation capabilities of this model",
    examples=example_files
)

demo.queue(concurrency_count=20).launch()