fadindashfr
update app.py device = 'cpu'
d9d6b8b
raw
history blame
3.3 kB
import torch
from monai.bundle import ConfigParser
import gradio as gr
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
parser.read_config(f="configs/inference.json") # read the config from specified JSON file
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
inference = parser.get_parsed_content("inferer")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
state_dict = torch.load("models/model.pt", map_location=torch.device('cpu'))
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
class_names = {
0: "Other",
1: "Inflammatory",
2: "Epithelial",
3: "Spindle-Shaped",
}
def classify_image(image_file, label_file):
data = {"image":image_file, "label":label_file}
batch = preprocess(data)
batch['image'] = batch['image']
network.eval()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
prob = pred.softmax(-1).detach().cpu().numpy()[0]
confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))}
return confidences
example_files1 = [
['sample_data/Images/test_11_2_0628.png',
'sample_data/Labels/test_11_2_0628.png'],
['sample_data/Images/test_9_4_0149.png',
'sample_data/Labels/test_9_4_0149.png'],
['sample_data/Images/test_12_3_0292.png',
'sample_data/Labels/test_12_3_0292.png'],
['sample_data/Images/test_9_4_0019.png',
'sample_data/Labels/test_9_4_0019.png']
]
example_files2 = [
['sample_data/Images/test_14_3_0433.png',
'sample_data/Labels/test_14_3_0433.png'],
['sample_data/Images/test_14_4_0544.png',
'sample_data/Labels/test_14_4_0544.png'],
['sample_data/Images/train_1_1_0095.png',
'sample_data/Labels/train_1_1_0095.png'],
['sample_data/Images/train_1_3_0020.png',
'sample_data/Labels/train_1_3_0020.png'],
]
with open('Description.md','r') as file:
markdown_content = file.read()
with gr.Blocks() as app:
gr.Markdown("# Pathology Nuclei Classification")
gr.Markdown(markdown_content)
with gr.Row():
with gr.Column():
with gr.Row():
inp_img = gr.Image(type="filepath", image_mode="RGB")
label_img = gr.Image(type="filepath", image_mode="L")
with gr.Row():
process_btn = gr.Button(value="Process")
clear_btn = gr.Button(value="Clear")
out_txt = gr.Label(label="Probabilities", num_top_classes=4)
process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt)
clear_btn.click(lambda:(
gr.update(value=None),
gr.update(value=None),
gr.update(value=None)
),
inputs=None,
outputs=[inp_img, label_img,out_txt]
)
gr.Markdown("## Image Examples")
with gr.Row():
for file in example_files1:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
with gr.Row():
for file in example_files2:
gr.Examples(
[file], inputs=[inp_img, label_img]
)
app.launch()