File size: 4,349 Bytes
1d48575
 
 
 
 
 
 
 
 
 
 
 
83940f8
1d48575
 
 
 
83940f8
1d48575
 
 
 
83940f8
1d48575
83940f8
 
 
 
1d48575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83940f8
 
 
 
 
1d48575
 
 
 
 
 
 
 
 
83940f8
 
 
 
1d48575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from yolov5.utils.augmentations import letterbox
from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords
from abc import ABC, abstractmethod

class BaseModel(ABC):
    @abstractmethod
    def pre_process(self, filename: str):
        """Pre-process the input file and return it as a tensor."""
        pass

    @abstractmethod
    def predict(self, input_data):
        """Run inference on the pre-processed input and return predictions."""
        pass

class MegaDetectorModel(BaseModel):
    """
    MegaDetectorModel loads the MegaDetector checkpoint from a Hugging Face repository,
    preprocesses input images, runs inference, and returns detections (label/confidence).

    The repository ID is the only input required. The model filename, class name, and weight file
    are all expected to match the repository's base name. For example, if the repository ID is
    "nkarthikeyan/MegaDetectorV5", then the model weight file should be "MegaDetectorV5.pt".
    """

    def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
        self.device = torch.device(device)
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres
        self.labels = None
        if labels_path and os.path.exists(labels_path):
            with open(labels_path, "r") as f:
                self.labels = [line.strip() for line in f.readlines()]
        self.model = None

    @classmethod
    def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
        """
        Loads the model checkpoint from the given Hugging Face repository and returns an instance
        of MegaDetectorModel ready for inference.

        The repository's base name is used to derive the model weight filename. For example, if
        repo_id is "nkarthikeyan/MegaDetectorV5", then the weight file is expected to be "MegaDetectorV5.pt".

        Args:
            repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
            device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'.

        Returns:
            MegaDetectorModel: An instance with the model loaded.
        """
        instance = cls(device=device, **kwargs)
        # Use the repository base name as the weight filename.
        model_name = repo_id.split("/")[-1]
        weight_filename = f"{model_name}.pt"
        model_path = hf_hub_download(repo_id=repo_id, filename=weight_filename)
        checkpoint = torch.load(model_path, map_location=instance.device)
        instance.model = checkpoint['model'].float().fuse().eval()
        if instance.device.type != 'cpu':
            instance.model.to(instance.device)
        return instance

    def pre_process(self, filename: str):
        image_bgr = cv2.imread(filename)
        if image_bgr is None:
            raise ValueError(f"Could not load image from path: {filename}")
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        model_stride = int(self.model.stride.max())
        processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0]
        processed = processed.transpose(2, 0, 1)
        processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0
        input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
        return input_tensor, image_rgb

    def predict(self, input_data):
        processed_tensor, original_rgb = input_data
        with torch.no_grad():
            prediction = self.model(processed_tensor)[0]
        detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres)
        results = []
        if detections and detections[0] is not None:
            det = detections[0]
            det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round()
            for *xyxy, conf, cls_idx in det.tolist():
                label_idx = int(cls_idx)
                confidence = float(conf)
                if self.labels and 0 <= label_idx < len(self.labels):
                    results.append((self.labels[label_idx], confidence))
                else:
                    results.append((label_idx, confidence))
        return results