from gradio.outputs import Label from icevision.all import * from icevision.models.checkpoint import * import PIL import gradio as gr import os # Load model checkpoint_path = "models/model_checkpoint.pth" checkpoint_and_model = model_from_checkpoint(checkpoint_path) model = checkpoint_and_model["model"] model_type = checkpoint_and_model["model_type"] class_map = checkpoint_and_model["class_map"] # Transforms img_size = checkpoint_and_model["img_size"] valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]) examples = [['sample_images/IMG_20191212_151351.jpg'],['sample_images/IMG_20191212_153420.jpg'],['sample_images/IMG_20191212_154100.jpg']] def show_preds(input_image): img = PIL.Image.fromarray(input_image, "RGB") pred_dict = model_type.end2end_detect(img, valid_tfms, model, class_map=class_map, detection_threshold=0.5, display_label=False, display_bbox=True, return_img=True, font_size=16, label_color="#FF59D6") return pred_dict["img"], len(pred_dict["detection"]["bboxes"]) gr_interface = gr.Interface( fn=show_preds, inputs=["image"], outputs=[gr.outputs.Image(type="pil", label="RetinaNet Inference"), gr.outputs.Textbox(type="number", label="Microalgae Count")], title="Microalgae Detector with RetinaNet", description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.", article="
", examples=examples, theme="dark-grass", enable_queue=True ) gr_interface.launch()