File size: 3,900 Bytes
4e18454 13ee8c5 4e18454 13ee8c5 4e18454 13ee8c5 4e18454 acb6152 4e18454 13ee8c5 4e18454 a222e79 4e18454 acb6152 4e18454 acb6152 13ee8c5 4e18454 13ee8c5 acb6152 13ee8c5 4e18454 acb6152 4e18454 acb6152 4e18454 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
from .utils import load_ct_to_numpy, load_pred_volume_to_numpy
from .compute import run_model
from .convert import nifti_to_glb
class WebUI:
def __init__(self, model_name:str = None, class_name:str = None, cwd:str = None):
# global states
self.images = []
self.pred_images = []
# @TODO: This should be dynamically set based on chosen volume size
self.nb_slider_items = 100
self.model_name = model_name
self.class_name = class_name
self.cwd = cwd
# define widgets not to be rendered immediantly, but later on
self.slider = gr.Slider(1, self.nb_slider_items, value=1, step=1, label="Which 2D slice to show")
self.volume_renderer = gr.Model3D(
clear_color=[0.0, 0.0, 0.0, 0.0],
label="3D Model",
visible=True,
elem_id="model-3d",
).style(height=512)
def combine_ct_and_seg(self, img, pred):
return (img, [(pred, self.class_name)])
def upload_file(self, file):
return file.name
def load_mesh(self, mesh_file_name, model_name):
path = mesh_file_name.name
run_model(path, model_name)
nifti_to_glb("prediction-livermask.nii")
self.images = load_ct_to_numpy(path)
self.pred_images = load_pred_volume_to_numpy("./prediction-livermask.nii")
self.slider = self.slider.update(value=2)
return "./prediction.obj"
def get_img_pred_pair(self, k):
k = int(k) - 1
out = [gr.AnnotatedImage.update(visible=False)] * self.nb_slider_items
out[k] = gr.AnnotatedImage.update(self.combine_ct_and_seg(self.images[k], self.pred_images[k]), visible=True)
return out
def run(self):
css="""
#model-3d {
height: 512px;
}
#model-2d {
height: 512px;
margin: auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
file_output = gr.File(
file_types=[".nii", ".nii.nz"],
file_count="single"
).style(full_width=False, size="sm")
file_output.upload(self.upload_file, file_output, file_output)
run_btn = gr.Button("Run analysis").style(full_width=False, size="sm")
run_btn.click(
fn=lambda x: self.load_mesh(x, model_name=self.cwd + self.model_name),
inputs=file_output,
outputs=self.volume_renderer
)
with gr.Row():
gr.Examples(
examples=[self.cwd + "test-volume.nii"],
inputs=file_output,
outputs=file_output,
fn=self.upload_file,
cache_examples=True,
)
with gr.Row():
with gr.Box():
image_boxes = []
for i in range(self.nb_slider_items):
visibility = True if i == 1 else False
t = gr.AnnotatedImage(visible=visibility, elem_id="model-2d")\
.style(color_map={self.class_name: "#ffae00"}, height=512, width=512)
image_boxes.append(t)
self.slider.change(self.get_img_pred_pair, self.slider, image_boxes)
with gr.Box():
self.volume_renderer.render()
with gr.Row():
self.slider.render()
# sharing app publicly -> share=True: https://gradio.app/sharing-your-app/
# inference times > 60 seconds -> need queue(): https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
|