Upload mega_detector.py
Browse files- mega_detector.py +91 -0
mega_detector.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from yolov5.utils.augmentations import letterbox
|
7 |
+
from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
|
10 |
+
class BaseModel(ABC):
|
11 |
+
@abstractmethod
|
12 |
+
def pre_process(self, filename: str):
|
13 |
+
"""
|
14 |
+
Pre-process the input file and return it as a tensor.
|
15 |
+
"""
|
16 |
+
pass
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def predict(self, input_data):
|
20 |
+
"""
|
21 |
+
Run inference on the pre-processed input and return predictions.
|
22 |
+
"""
|
23 |
+
pass
|
24 |
+
|
25 |
+
class MegaDetectorModel(BaseModel):
|
26 |
+
"""
|
27 |
+
MegaDetectorModel loads the MegaDetector V5 checkpoint from Hugging Face,
|
28 |
+
preprocesses input images, runs inference, and returns detections (label/confidence).
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
|
32 |
+
self.device = torch.device(device)
|
33 |
+
self.conf_thres = conf_thres
|
34 |
+
self.iou_thres = iou_thres
|
35 |
+
self.labels = None
|
36 |
+
if labels_path and os.path.exists(labels_path):
|
37 |
+
with open(labels_path, "r") as f:
|
38 |
+
self.labels = [line.strip() for line in f.readlines()]
|
39 |
+
self.model = None
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
|
43 |
+
"""
|
44 |
+
Loads the model checkpoint from the given Hugging Face repository and returns
|
45 |
+
an instance of MegaDetectorModel ready for inference.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
|
49 |
+
device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
MegaDetectorModel: An instance with the model loaded.
|
53 |
+
"""
|
54 |
+
instance = cls(device=device, **kwargs)
|
55 |
+
# Download the model checkpoint (assumes the file is named 'model.pt')
|
56 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
|
57 |
+
checkpoint = torch.load(model_path, map_location=instance.device)
|
58 |
+
instance.model = checkpoint['model'].float().fuse().eval()
|
59 |
+
if instance.device.type != 'cpu':
|
60 |
+
instance.model.to(instance.device)
|
61 |
+
return instance
|
62 |
+
|
63 |
+
def pre_process(self, filename: str):
|
64 |
+
image_bgr = cv2.imread(filename)
|
65 |
+
if image_bgr is None:
|
66 |
+
raise ValueError(f"Could not load image from path: {filename}")
|
67 |
+
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
68 |
+
model_stride = int(self.model.stride.max())
|
69 |
+
processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0]
|
70 |
+
processed = processed.transpose(2, 0, 1)
|
71 |
+
processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0
|
72 |
+
input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
|
73 |
+
return input_tensor, image_rgb
|
74 |
+
|
75 |
+
def predict(self, input_data):
|
76 |
+
processed_tensor, original_rgb = input_data
|
77 |
+
with torch.no_grad():
|
78 |
+
prediction = self.model(processed_tensor)[0]
|
79 |
+
detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres)
|
80 |
+
results = []
|
81 |
+
if detections and detections[0] is not None:
|
82 |
+
det = detections[0]
|
83 |
+
det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round()
|
84 |
+
for *xyxy, conf, cls_idx in det.tolist():
|
85 |
+
label_idx = int(cls_idx)
|
86 |
+
confidence = float(conf)
|
87 |
+
if self.labels and 0 <= label_idx < len(self.labels):
|
88 |
+
results.append((self.labels[label_idx], confidence))
|
89 |
+
else:
|
90 |
+
results.append((label_idx, confidence))
|
91 |
+
return results
|