import gradio as gr import os, requests import numpy as np import torch import cv2 from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference ## local | remote RUN_MODE = "remote" if RUN_MODE != "local": os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/model_best.pth") ## examples os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/1.png") os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/2.png") os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/3.png") os.system("wget https://huggingface.co/xiazhi/LKCell/resolve/main/4.png") ## step 1: set up model device = "cpu" ## pannuke set pannuke_parser = InferenceCellViTParser() pannuke_configurations = pannuke_parser.parse_arguments() pannuke_inf = InferenceCellViT( run_dir=pannuke_configurations["run_dir"], checkpoint_name=pannuke_configurations["checkpoint_name"], gpu=pannuke_configurations["gpu"], magnification=pannuke_configurations["magnification"], ) pannuke_checkpoint = torch.load( pannuke_inf.run_dir / pannuke_inf.checkpoint_name, map_location="cpu" ) pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"]) pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"]) # # put model in eval mode pannuke_model.to(device) pannuke_model.eval() ## monuseg set monuseg_parser = InferenceCellViTMoNuSegParser() monuseg_configurations = monuseg_parser.parse_arguments() monuseg_inf = MoNuSegInference( model_path=monuseg_configurations["model"], dataset_path=monuseg_configurations["dataset"], outdir=monuseg_configurations["outdir"], gpu=monuseg_configurations["gpu"], patching=monuseg_configurations["patching"], magnification=monuseg_configurations["magnification"], overlap=monuseg_configurations["overlap"], ) def click_process(image_input , type_dataset): if type_dataset == "pannuke": if image_input.shape[0] > 512 and image_input.shape[1] > 512: image_input = cv2.resize(image_input, (512,512)) pannuke_inf.run_single_image_inference(pannuke_model,image_input) else: if image_input.shape[0] > 512 and image_input.shape[1] > 512: image_input = cv2.resize(image_input, (512,512)) monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input) image_output = cv2.imread("raw_pred.png") image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB) image_output2 = cv2.imread("pred_img.png") image_output2 = cv2.cvtColor(image_output2, cv2.COLOR_BGR2RGB) return image_output,image_output2 demo = gr.Blocks(title="LkCell") with demo: gr.Markdown(value=""" **Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/hustvl/LKCell) ๐. """) with gr.Row(): with gr.Column(): with gr.Row(): Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480) with gr.Row(): Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke") with gr.Column(): image_output = gr.Image(type="numpy", label="image prediction",height=480,width=480) image_output2 = gr.Image(type="numpy", label="all predictions",height=480) with gr.Row(): Button_run = gr.Button("๐ Submit (ๅ้) ") clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output,image_output2],value="๐งน Clear (ๆธ ้ค)") Button_run.click(fn=click_process, inputs=[Image_input, Type_dataset ], outputs=[image_output,image_output2]) ## guiline gr.Markdown(value=""" ๐**Guideline** 1. Upload your image or select one from the examples. 2. Set up the arguments: "Type_dataset" to enjoy two dataset type's inference 3. Due to the limit of CPU , we resize the input image whose size is larger than (512,512) to (512,512) 4. Run the Submit button to get the output. """) # if RUN_MODE != "local": gr.Examples(examples=[ ['1.png', "pannuke"], ['2.png', "pannuke"], ['3.png', "monuseg"], ['4.png', "monuseg"], ], inputs=[Image_input, Type_dataset], outputs=[image_output,image_output2], label="Examples") gr.HTML(value="""
""") gr.Markdown(value=""" Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco) """) if RUN_MODE == "local": demo.launch(server_name='127.0.0.1',server_port=8003) else: demo.launch()