Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import List | |
import cv2 | |
import torch | |
import numpy as np | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from models import YoloV3Lightning | |
from utils import load_model_from_checkpoint | |
import utils | |
import config as cfg | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from grad_cam import YoloGradCAM | |
device = torch.device('cpu') | |
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ | |
(0.2470, 0.2435, 0.2616) | |
model = YoloV3Lightning.YOLOv3LightningModel(num_classes=cfg.NUM_CLASSES, anchors=cfg.ANCHORS, S=cfg.S) | |
ckpt_file = 'ckpt_light2.pth' | |
checkpoint = load_model_from_checkpoint(device, file_name=ckpt_file) | |
model.load_state_dict(checkpoint['model'], strict=False) | |
model.eval() | |
scaled_anchors = ( | |
torch.tensor(cfg.ANCHORS) | |
* torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
).to(model.device) | |
cam = YoloGradCAM(model=model, target_layers=[model.layers[-2]], scaled_anchors=scaled_anchors, use_cuda=False) | |
sample_images = [ | |
['images/000001.jpg'], | |
['images/000002.jpg'], | |
['images/000003.jpg'], | |
['images/000004.jpg'], | |
['images/000005.jpg'], | |
['images/000006.jpg'], | |
['images/000007.jpg'], | |
['images/000008.jpg'], | |
['images/000009.jpg'], | |
['images/000010.jpg'], | |
['images/000011.jpg'], | |
['images/000012.jpg'], | |
['images/000013.jpg'], | |
['images/000014.jpg'], | |
['images/000015.jpg'], | |
['images/000016.jpg'], | |
['images/000017.jpg'], | |
['images/000018.jpg'], | |
['images/000019.jpg'], | |
['images/000020.jpg'], | |
['images/000021.jpg'] | |
] | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# YoloV3 App! | |
## Model is trained on PASCAL-VOC data to predict following classes - | |
""") | |
with gr.Row(): | |
gr.HTML( | |
""" | |
<table> | |
<tr> | |
<th>aeroplane</th> | |
<th>bicycle</th> | |
<th>bird</th> | |
<th>boat</th> | |
<th>bottle</th> | |
<th>bus</th> | |
<th>car</th> | |
<th>cat</th> | |
</tr> | |
<tr> | |
<th>chair</th> | |
<th>cow</th> | |
<th>diningtable</th> | |
<th>dog</th> | |
<th>horse</th> | |
<th>motorbike</th> | |
<th>person</th> | |
<th>pottedplant</th> | |
</tr> | |
<tr> | |
<th>sheep</th> | |
<th>sofa</th> | |
<th>train</th> | |
<th>tvmonitor</th> | |
</tr> | |
</table> | |
<p> | |
<a href='https://github.com/piygr/yolov3/blob/main/models/YoloV3Lightning.py'>Click to see the model architecture / code </a> | |
</p> | |
""" | |
) | |
with gr.Row(visible=True) as pred_cls_col: | |
with gr.Column(): | |
example_images = gr.Gallery(allow_preview=False, label='Select image ', | |
value=[img[0] for img in sample_images], columns=6, rows=2) | |
with gr.Column(): | |
with gr.Row(): | |
pred_image = gr.Image(label='Upload Image or Select from the gallery') | |
with gr.Row(): | |
if_show_grad_cam = gr.Checkbox(value=True, label='Show Class Activation Map (What the model sees)?') | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Row(visible=True) as output_bk: | |
with gr.Column(visible=True) as output_bk: | |
output_img = gr.Image(interactive=False, label='Prediction Output') | |
with gr.Column(visible=True) as output_bk: | |
grad_cam_out = gr.Image(interactive=False, visible=True, label='CAM Outcome') | |
def show_cam_output(input): | |
return { | |
grad_cam_out: gr.update(visible=input) | |
} | |
if_show_grad_cam.change( | |
show_cam_output, | |
if_show_grad_cam, | |
grad_cam_out | |
) | |
def clear_data(): | |
return { | |
pred_image: None, | |
output_img: None, | |
grad_cam_out: None | |
} | |
clear_btn.click(clear_data, None, [pred_image, output_img]) | |
def on_select(evt: gr.SelectData): | |
return { | |
pred_image: sample_images[evt.index][0] | |
} | |
example_images.select(on_select, None, pred_image) | |
def plot_image(image, boxes): | |
"""Plots predicted bounding boxes on the image""" | |
cmap = plt.get_cmap("tab20b") | |
class_labels = cfg.PASCAL_CLASSES | |
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] | |
im = np.array(image) | |
height, width, _ = im.shape | |
# Create figure and axes | |
fig, ax = plt.subplots(1) | |
# Display the image | |
ax.imshow(im) | |
# box[0] is x midpoint, box[2] is width | |
# box[1] is y midpoint, box[3] is height | |
# Create a Rectangle patch | |
for box in boxes: | |
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" | |
class_pred = box[0] | |
box = box[2:] | |
upper_left_x = box[0] - box[2] / 2 | |
upper_left_y = box[1] - box[3] / 2 | |
rect = patches.Rectangle( | |
(upper_left_x * width, upper_left_y * height), | |
box[2] * width, | |
box[3] * height, | |
linewidth=2, | |
edgecolor=colors[int(class_pred)], | |
facecolor="none", | |
) | |
# Add the patch to the Axes | |
ax.add_patch(rect) | |
plt.text( | |
upper_left_x * width, | |
upper_left_y * height, | |
s=class_labels[int(class_pred)], | |
color="white", | |
verticalalignment="top", | |
bbox={"color": colors[int(class_pred)], "pad": 0}, | |
) | |
plt.savefig('output.png') | |
x = plt.show() | |
def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.6, show_cam: bool = False, | |
transparency: float = 0.5) -> List[np.ndarray]: | |
with torch.no_grad(): | |
transformed_image = cfg.grad_cam_transforms(image=image)["image"].unsqueeze(0) | |
output = model(transformed_image) | |
bboxes = [[] for _ in range(1)] | |
for i in range(3): | |
batch_size, A, S, _, _ = output[i].shape | |
anchor = scaled_anchors[i] | |
boxes_scale_i = utils.cells_to_bboxes( | |
output[i], anchor, S=S, is_preds=True | |
) | |
for idx, (box) in enumerate(boxes_scale_i): | |
bboxes[idx] += box | |
nms_boxes = utils.non_max_suppression( | |
bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", | |
) | |
plot_image(image, nms_boxes) | |
plotted_img = 'output.png' | |
if not show_cam: | |
return [plotted_img, None] | |
grayscale_cam = cam(transformed_image) | |
img = np.array(transformed_image[0], np.float16).transpose(1, 2, 0) | |
cam_image = show_cam_on_image(img, grayscale_cam.transpose(1, 2, 0), use_rgb=True, image_weight=transparency) | |
return [plotted_img, cam_image] | |
def img_upload(input_img, if_cam): | |
if input_img is not None: | |
imgs = predict(input_img, show_cam=if_cam) | |
return { | |
output_img: imgs[0], | |
grad_cam_out: imgs[1] | |
} | |
submit_btn.click( | |
img_upload, | |
[pred_image, if_show_grad_cam], | |
[output_img, grad_cam_out] | |
) | |
''' | |
Launch the app | |
''' | |
app.launch() | |