yolov3 / app.py
piyushgrover's picture
Update app.py
26ba27b verified
import gradio as gr
from typing import List
import cv2
import torch
import numpy as np
from pytorch_grad_cam.utils.image import show_cam_on_image
from models import YoloV3Lightning
from utils import load_model_from_checkpoint
import utils
import config as cfg
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from grad_cam import YoloGradCAM
device = torch.device('cpu')
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \
(0.2470, 0.2435, 0.2616)
model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S)
ckpt_file = 'ckpt_light2.pth'
checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file)
model.load_state_dict(checkpoint['model'], strict=False)
scaled_anchors = (
* torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False)
sample_images = [
with gr.Blocks() as app:
with gr.Row():
# YoloV3 App!
## Model is trained on PASCAL-VOC data to predict following classes -
with gr.Row():
<a href='https://github.com/piygr/yolov3/blob/main/models/YoloV3Lightning.py'>Click to see the model architecture / code </a>
with gr.Row(visible=True) as pred_cls_col:
with gr.Column():
example_images = gr.Gallery(allow_preview=False, label='Select image ',
value=[img[0] for img in sample_images], columns=6, rows=2)
with gr.Column():
with gr.Row():
pred_image = gr.Image(label='Upload Image or Select from the gallery')
with gr.Row():
if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?')
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Row(visible=True) as output_bk:
with gr.Column(visible=True) as output_bk:
output_img = gr.Image(interactive=False, label='Prediction Output')
with gr.Column(visible=True) as output_bk:
grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome')
def show_cam_output(input):
return {
grad_cam_out: gr.update(visible=input)
def clear_data():
return {
pred_image: None,
output_img: None,
grad_cam_out: None
clear_btn.click(clear_data, None, [pred_image, output_img])
def on_select(evt: gr.SelectData):
return {
pred_image: sample_images[evt.index][0]
example_images.select(on_select, None, pred_image)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
cmap = plt.get_cmap("tab20b")
class_labels = cfg.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
# box[0] is x midpoint, box[2] is width
# box[1] is y midpoint, box[3] is height
# Create a Rectangle patch
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
# Add the patch to the Axes
upper_left_x * width,
upper_left_y * height,
bbox={"color": colors[int(class_pred)], "pad": 0},
x = plt.show()
def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.6, show_cam: bool = False,
transparency: float = 0.5) -> List[np.ndarray]:
with torch.no_grad():
transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0)
output = model(transformed_image)
bboxes = [[] for _ in range(1)]
for i in range(3):
batch_size, A, S, _, _ = output[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = utils.cells_to_bboxes(
output[i], anchor, S=S, is_preds=True
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = utils.non_max_suppression(
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
plot_image(image, nms_boxes)
plotted_img = 'output.png'
if not show_cam:
return [plotted_img, None]
grayscale_cam = cam(transformed_image)
img = np.array(transformed_image[0], np.float16).transpose(1, 2, 0)
cam_image = show_cam_on_image(img, grayscale_cam.transpose(1, 2, 0), use_rgb=True, image_weight=transparency)
return [plotted_img, cam_image]
def img_upload(input_img, if_cam):
if input_img is not None:
imgs = predict(input_img, show_cam=if_cam)
return {
output_img: imgs[0],
grad_cam_out: imgs[1]
[pred_image, if_show_grad_cam],
[output_img, grad_cam_out]
Launch the app