lulu5131's picture
Upload 8 files
4df19d3 verified
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import gradio as gr
import torch
import cv2
import pandas as pd
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine import DefaultPredictor
cfg = get_cfg()
cfg.merge_from_file("resnet.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.DEVICE='cpu'
predictor = DefaultPredictor(cfg)
def inference(image_path):
# load image
image = read_image(image_path, format="BGR")
# Make inference on the image
outputs = predictor(image)
metadata = MetadataCatalog.get("custom_dataset_train")
MetadataCatalog.get("custom_dataset_train").set(thing_classes=['100km/h', '120km/h', '20km/h', '30km/h', '40km/h', '50km/h', '60km/h', '70km/h', '80km/h'])
v = Visualizer(image, metadata, scale=1)
out = v.draw_instance_predictions(outputs['instances'])
# Detection summary table
cls_idxs = outputs['instances'].pred_classes.numpy()
thing_classes=['100km/h', '120km/h', '20km/h', '30km/h', '40km/h', '50km/h', '60km/h', '70km/h', '80km/h']
# get labels from class indices
labels = [thing_classes[i] for i in cls_idxs]
scores = outputs['instances'].scores.numpy()
df = pd.DataFrame({'Detected speed limit': labels, 'Confidence score': scores})
# Return the visualization as an RGB image
return out.get_image()[:, :, ::-1], df
examples = ["examples/1.jpg", "examples/2.jpg", "examples/3.jpg"]
with gr.Blocks(theme='gradio/monochrome') as demo:
gr.Markdown("# Speed Limit Detection demo")
gr.Markdown("**Author**: *Lu CHEN*")
gr.Markdown(
"""This interactive demo is based on the Faster R-CNN model for object detection. The model is
trained using the [Detectron2](https://github.com/facebookresearch/detectron2) library with a custom
dataset that I created by combining images from [Tsinghua-Tencent100K](https://cg.cs.tsinghua.edu.cn/traffic-sign/) and [GTSDB](https://benchmark.ini.rub.de/), both of which provide real-world traffic signs captured within the autonomous driving domain.
To use the demo, simply upload an image and click on *"Infer"* to view the following results:
- **Detection**: outputs of Object Detector
- **Detection summary**: a summary of the detection outputs
You can also select an image from the cached **Examples** to quickly try out. Without clicking *"Infer"*, the cached outputs will be loaded automatically.
In case the output image seems too small, simply right-click on the image, and choose “Open image in new tab” to visualize it in full size.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(type="filepath")
button = gr.Button("Infer")
with gr.Column():
detection = gr.Image(label="Output")
detection_summary = gr.DataFrame(label="Detection summary")
examples_block = gr.Examples(inputs=image, examples=examples, fn=inference, outputs=[detection, detection_summary], cache_examples=True)
button.click(fn=inference, inputs=image, outputs=[detection, detection_summary])
demo.launch()