Spaces:
Runtime error
Runtime error
| import unittest | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('agg') | |
| import torch | |
| import cv2 | |
| from pytorch_grad_cam import EigenCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import config | |
| from model import YOLOv3 | |
| from utils import ( | |
| cells_to_bboxes, | |
| non_max_suppression, | |
| plot_image | |
| ) | |
| def yolov3_reshape_transform(tensor, ): | |
| return tensor[0] | |
| # fname = 'epoch=38-step=20202.ckpt' | |
| # checkpoint = torch.load(fname, map_location=torch.device('cpu')) | |
| # model_state_dict = checkpoint['state_dict'] | |
| model = YOLOv3(num_classes=20) | |
| # model.load_state_dict(model_state_dict) | |
| # torch.save(model.state_dict(), 'yolov3.pth') | |
| fname = 'yolov3.pth' | |
| model.load_state_dict(torch.load(fname)) | |
| IMAGE_SIZE = config.IMAGE_SIZE | |
| S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
| anchors = ( torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1)\ | |
| .unsqueeze(1).repeat(1, 3, 2) | |
| ) | |
| def object_detector(input_image, thresh = 0.8, iou_thresh = 0.5): | |
| input_img = config.test_transforms(image=input_image)['image'] | |
| input_img = input_img.unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(input_img) | |
| bboxes = [] | |
| for i in range(3): | |
| _, _, S, _, _ = out[i].shape | |
| anchor = anchors[i] | |
| bboxes += cells_to_bboxes( | |
| out[i], anchor, S=S, is_preds=True | |
| )[0] | |
| nms_boxes = non_max_suppression( | |
| bboxes, iou_threshold=iou_thresh, | |
| threshold=thresh, box_format="midpoint", | |
| ) | |
| fig = plot_image(input_img.squeeze(0).permute(1,2,0).detach().cpu(), | |
| nms_boxes, | |
| return_fig=True) | |
| plt.gca().set(xticks=[], yticks=[], xticklabels=[], yticklabels=[]) | |
| plt.axis('off') | |
| image_path = "plot.png" | |
| fig.savefig(image_path) | |
| plt.close() | |
| # target_layers = [model.layers[21]] | |
| # cam = EigenCAM(model, target_layers, use_cuda=False, | |
| # reshape_transform=yolov3_reshape_transform, | |
| # ) | |
| # grayscale_cam = cam(input_img, target_layers)[0][0, :, :] | |
| # cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) | |
| return gr.update(value=image_path, visible=True),\ | |
| gr.update(value=image_path, visible=True) | |
| # Define the input and output components for Gradio | |
| input_image = gr.Image(label="Input image") | |
| confidence_level = gr.Slider(0.5, 1, value=0.6, step=0.01, | |
| label="confidence level") | |
| iou_level = gr.Slider(0.5, 1, value=0.6, step=0.01, | |
| label="Interference over union level") | |
| output_box = gr.Image(label="Output image", visible=False,)\ | |
| .style(width=428, height=428) | |
| cam_output = gr.Image(label="cam output", visible=False)\ | |
| .style(width=428, height=428) | |
| images_path = "examples/" | |
| gr_interface = gr.Interface( | |
| fn=object_detector, | |
| inputs=[input_image, confidence_level, iou_level], | |
| outputs=[output_box, cam_output], | |
| examples=[[images_path + "000015.jpg"], | |
| [images_path + "000017.jpg"], | |
| [images_path + "000030.jpg"], | |
| [images_path + "000069.jpg"], | |
| [images_path + "000071.jpg"], | |
| [images_path + "000084.jpg"], | |
| [images_path + "000086.jpg"], | |
| [images_path + "000088.jpg"], | |
| [images_path + "000100.jpg"], | |
| ], | |
| cache_examples=False | |
| ) | |
| gr_interface.launch() | |
| # class TestGradioInterfaceInput(unittest.TestCase): | |
| # def test_valid_image_input(self): | |
| # # Create a valid input image | |
| # input_image = images_path + "000015.jpg" | |
| # | |
| # # Pass the image through the interface | |
| # output = gr_interface(input_image) | |
| # | |
| # # Assert the output matches the expected result | |
| # self.assertEqual(output[0].shape, (3, 416, 416)) | |
| # | |
| # if __name__ == '__main__': | |
| # unittest.main() | |
| # | |
| # Create the Gradio interface | |
| # gr.Interface(fn=object_detector, inputs=input_image, | |
| # outputs=[output_box, cam_output], | |
| # examples=[[images_path + "000015.jpg"], | |
| # [images_path + "000017.jpg"], | |
| # [images_path + "000030.jpg"], | |
| # [images_path + "000069.jpg"], | |
| # [images_path + "000071.jpg"], | |
| # [images_path + "000084.jpg"], | |
| # [images_path + "000086.jpg"], | |
| # [images_path + "000088.jpg"], | |
| # [images_path + "000095.jpg"], | |
| # [images_path + "000100.jpg"], | |
| # ], | |
| # ).launch() | |
| # | |