File size: 3,573 Bytes
4c5329d
 
 
 
befdaa8
 
4c5329d
 
 
 
 
 
 
fbe0f24
4c5329d
fbe0f24
4c5329d
 
 
 
 
 
 
 
 
 
fbe0f24
4c5329d
 
 
 
 
 
 
 
fbe0f24
 
 
 
 
 
 
 
4c5329d
 
 
fbe0f24
 
 
 
 
 
 
 
4c5329d
 
befdaa8
 
 
 
 
 
 
 
 
4c5329d
 
 
 
 
 
 
befdaa8
4c5329d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from monai.bundle import ConfigParser
import gradio as gr

from utils import page_utils

parser = ConfigParser() #  load configuration files that specify various parameters for running the MONAI workflow.
parser.read_config(f="configs/inference.json") # read the config from specified JSON file
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file

inference = parser.get_parsed_content("inferer")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
state_dict = torch.load("models/model.pt", map_location=torch.device('cpu'))
network.load_state_dict(state_dict, strict=True) #  Loads a model’s parameter dictionary

class_names = {
    0: "Other",
    1: "Inflammatory",
    2: "Epithelial",
    3: "Spindle-Shaped",
}

def classify_image(image_file, label_file):
    data = {"image":image_file, "label":label_file}
    batch = preprocess(data)
    batch['image'] = batch['image']
    network.eval()
    with torch.no_grad():
        pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input  (3 RGB, 1 Label mask)
    prob = pred.softmax(-1).detach().cpu().numpy()[0]
    confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))}
    return confidences

example_files1 = [
    ['sample_data/Images/test_11_2_0628.png',
    'sample_data/Labels/test_11_2_0628.png'],
    ['sample_data/Images/test_9_4_0149.png',
    'sample_data/Labels/test_9_4_0149.png'],
    ['sample_data/Images/test_12_3_0292.png',
    'sample_data/Labels/test_12_3_0292.png'],
    ['sample_data/Images/test_9_4_0019.png',
    'sample_data/Labels/test_9_4_0019.png']
]

example_files2 = [
    ['sample_data/Images/test_14_3_0433.png',
    'sample_data/Labels/test_14_3_0433.png'],
    ['sample_data/Images/test_14_4_0544.png',
    'sample_data/Labels/test_14_4_0544.png'],
    ['sample_data/Images/train_1_1_0095.png',
    'sample_data/Labels/train_1_1_0095.png'],
    ['sample_data/Images/train_1_3_0020.png',
    'sample_data/Labels/train_1_3_0020.png'],
]

with open('index.html', encoding='utf-8') as file:
   html_content = file.read()

with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
        button_primary_background_fill='*primary_600',
        button_primary_background_fill_hover='*primary_500',
        button_primary_text_color='white',
    )) as app:
    gr.HTML(html_content)
    with gr.Row():
        with gr.Column():
            with gr.Row():
                inp_img = gr.Image(type="filepath", image_mode="RGB")
                label_img = gr.Image(type="filepath", image_mode="L")
            with gr.Row():
                clear_btn = gr.Button(value="Clear")
                process_btn = gr.Button(value="Process", variant="primary")
        out_txt = gr.Label(label="Probabilities", num_top_classes=4)

    process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt)
    clear_btn.click(lambda:(
        gr.update(value=None),
        gr.update(value=None),
        gr.update(value=None)
        ),
        inputs=None,
        outputs=[inp_img, label_img,out_txt]
        )

    gr.Markdown("## Image Examples")
    with gr.Row():
        for file in example_files1:
            gr.Examples(
                [file], inputs=[inp_img, label_img]
            )
    with gr.Row():
        for file in example_files2:
            gr.Examples(
                [file], inputs=[inp_img, label_img]
            )
app.launch()