File size: 2,995 Bytes
a1f1417 54f43fd a1f1417 54f43fd a1f1417 9acc5c1 a1f1417 54f43fd a1f1417 54f43fd a1f1417 54f43fd a1f1417 54f43fd 8f845d2 54f43fd a1f1417 4adfcec 54f43fd 4adfcec 54f43fd 4adfcec 54f43fd 4adfcec 54f43fd 4adfcec 54f43fd 4adfcec 54f43fd 4adfcec 54f43fd ed80dbe 9acc5c1 ed80dbe 54f43fd a1f1417 4adfcec 54f43fd 9acc5c1 4adfcec 54f43fd 50321bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from monai.bundle import ConfigParser
from utils import page_utils
with open("configs/inference.json") as f:
inference_config = json.load(f)
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
# * NOTE: device must be hardcoded, config file won't affect the device selection
inference_config["device"] = device
parser = ConfigParser()
parser.read_config(f=inference_config)
parser.read_meta(f="configs/metadata.json")
inference = parser.get_parsed_content("inferer")
# loader = parser.get_parsed_content("dataloader")
network = parser.get_parsed_content("network_def")
preprocess = parser.get_parsed_content("preprocessing")
postprocess = parser.get_parsed_content("postprocessing")
use_fp16 = os.environ.get('USE_FP16', False)
state_dict = torch.load("models/model.pt")
network.load_state_dict(state_dict, strict=True)
network = network.to(device)
network.eval()
if use_fp16 and torch.cuda.is_available():
network = network.half()
label2color = {0: (0, 0, 0),
1: (225, 24, 69), # RED
2: (135, 233, 17), # GREEN
3: (0, 87, 233), # BLUE
4: (242, 202, 25), # YELLOW
5: (137, 49, 239),} # PURPLE
example_files = list(Path("sample_data").glob("*.png"))
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
data = {"image": img}
batch = preprocess(data)
batch['image'] = batch['image'].to(device)
if use_fp16 and torch.cuda.is_available():
batch['image'] = batch['image'].half()
with torch.no_grad():
pred = inference(batch['image'].unsqueeze(dim=0), network)
batch["pred"] = pred
for k,v in batch["pred"].items():
batch["pred"][k] = v.squeeze(dim=0)
batch = postprocess(batch)
result = visualize_instance_seg_mask(batch["type_map"].squeeze())
# Combine image
result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5
# Solve rotating problem
result = np.fliplr(result)
result = np.rot90(result, k=1)
return result
# load Markdown file
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="filepath")],
outputs="image",
theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
button_primary_text_color="white",
),
description = html_content,
examples=example_files,
)
demo.queue(concurrency_count=20).launch()
|