|
import glob |
|
import numpy as np |
|
import torch |
|
import yolov5 |
|
from typing import Union, List, Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = "models/deepsea-detector.pt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
class YOLO: |
|
"""Wrapper class for loading and running YOLO model""" |
|
|
|
def __init__(self, model_path: str, device: Optional[str] = None): |
|
|
|
|
|
self.model = yolov5.load(model_path, device=device) |
|
|
|
def __call__( |
|
self, |
|
img: Union[str, np.ndarray], |
|
conf_threshold: float = 0.25, |
|
iou_threshold: float = 0.45, |
|
image_size: int = 720, |
|
classes: Optional[List[int]] = None) -> torch.Tensor: |
|
self.model.conf = conf_threshold |
|
self.model.iou = iou_threshold |
|
|
|
if classes is not None: |
|
self.model.classes = classes |
|
|
|
|
|
detections = self.model(img, size=image_size) |
|
|
|
return detections |
|
|
|
|
|
def run_inference(image_path): |
|
"""Helper function to execute the inference.""" |
|
|
|
predictions = model(image_path) |
|
|
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
model = YOLO(model_path, device='cpu') |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
test_images = glob.glob("images/*.png") |
|
|
|
for test_image in test_images: |
|
predictions = run_inference(test_image) |
|
|
|
print("Done.") |
|
|