|
try: |
|
import detectron2 |
|
except: |
|
import os |
|
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') |
|
|
|
import os |
|
import glob |
|
|
|
import numpy as np |
|
import detectron2 |
|
import torchvision |
|
import cv2 |
|
import torch |
|
|
|
from detectron2 import model_zoo |
|
from detectron2.data import Metadata |
|
from detectron2.structures import BoxMode |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.config import get_cfg |
|
from detectron2.utils.visualizer import ColorMode |
|
from detectron2.modeling import build_model |
|
import detectron2.data.transforms as T |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG_FILE = "fathomnet_config_v2_1280.yaml" |
|
WEIGHTS_FILE = "model_final.pth" |
|
NMS_THRESH = 0.45 |
|
SCORE_THRESH = 0.3 |
|
|
|
|
|
|
|
fathomnet_metadata = Metadata( |
|
name='fathomnet_val', |
|
thing_classes=[ |
|
'Anemone', |
|
'Fish', |
|
'Eel', |
|
'Gastropod', |
|
'Sea star', |
|
'Feather star', |
|
'Sea cucumber', |
|
'Urchin', |
|
'Glass sponge', |
|
'Sea fan', |
|
'Soft coral', |
|
'Sea pen', |
|
'Stony coral', |
|
'Ray', |
|
'Crab', |
|
'Shrimp', |
|
'Squat lobster', |
|
'Flatfish', |
|
'Sea spider', |
|
'Worm'] |
|
) |
|
|
|
|
|
|
|
|
|
base_model_path = "COCO-Detection/retinanet_R_50_FPN_3x.yaml" |
|
|
|
cfg = get_cfg() |
|
cfg.MODEL.DEVICE = 'cpu' |
|
cfg.merge_from_file(model_zoo.get_config_file(base_model_path)) |
|
cfg.merge_from_file(CONFIG_FILE) |
|
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH |
|
cfg.MODEL.WEIGHTS = WEIGHTS_FILE |
|
|
|
|
|
|
|
|
|
|
|
model = build_model(cfg) |
|
checkpointer = DetectionCheckpointer(model) |
|
checkpointer.load(cfg.MODEL.WEIGHTS) |
|
model.eval() |
|
|
|
|
|
aug1 = T.ResizeShortestEdge(short_edge_length=[cfg.INPUT.MIN_SIZE_TEST], |
|
max_size=cfg.INPUT.MAX_SIZE_TEST, |
|
sample_style="choice") |
|
|
|
aug2 = T.ResizeShortestEdge(short_edge_length=[1080], |
|
max_size=1980, |
|
sample_style="choice") |
|
|
|
augmentations = [aug1, aug2] |
|
|
|
|
|
|
|
post_process_nms = torchvision.ops.nms |
|
|
|
|
|
|
|
def run_inference(test_image): |
|
"""This function runs through inference pipeline, taking in a single |
|
image as input. The image will be opened, augmented, ran through the |
|
model, which will output bounding boxes and class categories for each |
|
object detected. These are then passed back to the calling function.""" |
|
|
|
|
|
|
|
|
|
|
|
img = cv2.imread(test_image) |
|
im_height, im_width, _ = img.shape |
|
v_inf = Visualizer(img[:, :, ::-1], |
|
metadata=fathomnet_metadata, |
|
scale=1.0, |
|
instance_mode=ColorMode.IMAGE_BW) |
|
|
|
insts = [] |
|
|
|
|
|
for augmentation in augmentations: |
|
im = augmentation.get_transform(img).apply_image(img) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1)) |
|
model_outputs = model([{"image": im, |
|
"height": im_height, |
|
"width": im_width}])[0] |
|
|
|
|
|
for _ in range(len(model_outputs['instances'])): |
|
insts.append(model_outputs['instances'][_]) |
|
|
|
|
|
|
|
|
|
model_inst = detectron2.structures.instances.Instances([im_height, |
|
im_width]) |
|
|
|
xx = model_inst.cat(insts)[ |
|
post_process_nms(model_inst.cat(insts).pred_boxes.tensor, |
|
model_inst.cat(insts).scores, |
|
NMS_THRESH).to("cpu").tolist()] |
|
|
|
out_inf_raw = v_inf.draw_instance_predictions(xx.to("cpu")) |
|
out_pil = Image.fromarray(out_inf_raw.get_image()).convert('RGB') |
|
|
|
return out_pil |
|
|
|
|
|
def convert_predictions(xx, thing_classes): |
|
"""Helper funtion to post-process the predictions made by Detectron2 |
|
codebase to work with TATOR input requirements.""" |
|
|
|
predictions = [] |
|
|
|
for _ in range(len(xx)): |
|
|
|
|
|
instance = xx.__getitem__(_) |
|
|
|
|
|
x, y, x2, y2 = map(float, instance.pred_boxes.tensor[0]) |
|
w, h = x2 - x, y2 - y |
|
|
|
|
|
class_category = thing_classes[int(instance.pred_classes[0])] |
|
confidence_score = float(instance.scores[0]) |
|
|
|
|
|
prediction = {'x': x, |
|
'y': y, |
|
'width': w, |
|
'height': h, |
|
'class_category': class_category, |
|
'confidence': confidence_score} |
|
|
|
predictions.append(prediction) |
|
|
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = [glob.glob("images/*.png")] |
|
|
|
title = "MBARI Monterey Bay Benthic Supercategory" |
|
description = "Gradio demo for MBARI Monterey Bay Benthic Supercategory: This " \ |
|
"is a RetinaNet model fine-tuned from the Detectron2 object " \ |
|
"detection platform's ResNet backbone to identify 20 benthic " \ |
|
"supercategories drawn from MBARI's remotely operated vehicle " \ |
|
"image data collected in Monterey Bay off the coast of Central " \ |
|
"California. The data is drawn from FathomNet and consists of " \ |
|
"32779 images that contain a total of 80683 localizations. The " \ |
|
"model was trained on an 85/15 train/validation split at the " \ |
|
"image level. DOI: 10.5281/zenodo.5571043. " |
|
|
|
examples = [glob.glob("images/*.png")] |
|
|
|
gr.Interface(inference, inputs=gr.inputs.Image(type="file"), |
|
outputs=gr.outputs.Image(type="pil"), |
|
enable_queue=True, |
|
title=title, |
|
description=description, |
|
examples=examples).launch() |
|
|