File size: 1,907 Bytes
a1d71d0 5a88aec a1d71d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import glob
import numpy as np
import torch
import yolov5
from typing import Union, List, Optional
# -----------------------------------------------------------------------------
# Configs
# -----------------------------------------------------------------------------
model_path = "models/deepsea-detector.pt"
# -----------------------------------------------------------------------------
# YOLOv5 class
# -----------------------------------------------------------------------------
class YOLO:
"""Wrapper class for loading and running YOLO model"""
def __init__(self, model_path: str, device: Optional[str] = None):
# load model
self.model = yolov5.load(model_path, device=device)
def __call__(
self,
img: Union[str, np.ndarray],
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
classes: Optional[List[int]] = None) -> torch.Tensor:
self.model.conf = conf_threshold
self.model.iou = iou_threshold
if classes is not None:
self.model.classes = classes
# pylint: disable=not-callable
detections = self.model(img, size=image_size)
return detections
def run_inference(image_path):
"""Helper function to execute the inference."""
predictions = model(image_path)
return predictions
# -----------------------------------------------------------------------------
# Model Creation
# -----------------------------------------------------------------------------
model = YOLO(model_path, device='cpu')
if __name__ == "__main__":
# For demo purposes: run through a couple of test
# images and then output the predictions in a folder.
test_images = glob.glob("images/*.png")
for test_image in test_images:
predictions = run_inference(test_image)
print("Done.")
|