nkarthikeyan commited on
Commit
1d48575
·
verified ·
1 Parent(s): 70565aa

Upload mega_detector.py

Browse files
Files changed (1) hide show
  1. mega_detector.py +91 -0
mega_detector.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from yolov5.utils.augmentations import letterbox
7
+ from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords
8
+ from abc import ABC, abstractmethod
9
+
10
+ class BaseModel(ABC):
11
+ @abstractmethod
12
+ def pre_process(self, filename: str):
13
+ """
14
+ Pre-process the input file and return it as a tensor.
15
+ """
16
+ pass
17
+
18
+ @abstractmethod
19
+ def predict(self, input_data):
20
+ """
21
+ Run inference on the pre-processed input and return predictions.
22
+ """
23
+ pass
24
+
25
+ class MegaDetectorModel(BaseModel):
26
+ """
27
+ MegaDetectorModel loads the MegaDetector V5 checkpoint from Hugging Face,
28
+ preprocesses input images, runs inference, and returns detections (label/confidence).
29
+ """
30
+
31
+ def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
32
+ self.device = torch.device(device)
33
+ self.conf_thres = conf_thres
34
+ self.iou_thres = iou_thres
35
+ self.labels = None
36
+ if labels_path and os.path.exists(labels_path):
37
+ with open(labels_path, "r") as f:
38
+ self.labels = [line.strip() for line in f.readlines()]
39
+ self.model = None
40
+
41
+ @classmethod
42
+ def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
43
+ """
44
+ Loads the model checkpoint from the given Hugging Face repository and returns
45
+ an instance of MegaDetectorModel ready for inference.
46
+
47
+ Args:
48
+ repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
49
+ device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'.
50
+
51
+ Returns:
52
+ MegaDetectorModel: An instance with the model loaded.
53
+ """
54
+ instance = cls(device=device, **kwargs)
55
+ # Download the model checkpoint (assumes the file is named 'model.pt')
56
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
57
+ checkpoint = torch.load(model_path, map_location=instance.device)
58
+ instance.model = checkpoint['model'].float().fuse().eval()
59
+ if instance.device.type != 'cpu':
60
+ instance.model.to(instance.device)
61
+ return instance
62
+
63
+ def pre_process(self, filename: str):
64
+ image_bgr = cv2.imread(filename)
65
+ if image_bgr is None:
66
+ raise ValueError(f"Could not load image from path: {filename}")
67
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
68
+ model_stride = int(self.model.stride.max())
69
+ processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0]
70
+ processed = processed.transpose(2, 0, 1)
71
+ processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0
72
+ input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
73
+ return input_tensor, image_rgb
74
+
75
+ def predict(self, input_data):
76
+ processed_tensor, original_rgb = input_data
77
+ with torch.no_grad():
78
+ prediction = self.model(processed_tensor)[0]
79
+ detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres)
80
+ results = []
81
+ if detections and detections[0] is not None:
82
+ det = detections[0]
83
+ det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round()
84
+ for *xyxy, conf, cls_idx in det.tolist():
85
+ label_idx = int(cls_idx)
86
+ confidence = float(conf)
87
+ if self.labels and 0 <= label_idx < len(self.labels):
88
+ results.append((self.labels[label_idx], confidence))
89
+ else:
90
+ results.append((label_idx, confidence))
91
+ return results