File size: 4,349 Bytes
1d48575 83940f8 1d48575 83940f8 1d48575 83940f8 1d48575 83940f8 1d48575 83940f8 1d48575 83940f8 1d48575 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from yolov5.utils.augmentations import letterbox
from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords
from abc import ABC, abstractmethod
class BaseModel(ABC):
@abstractmethod
def pre_process(self, filename: str):
"""Pre-process the input file and return it as a tensor."""
pass
@abstractmethod
def predict(self, input_data):
"""Run inference on the pre-processed input and return predictions."""
pass
class MegaDetectorModel(BaseModel):
"""
MegaDetectorModel loads the MegaDetector checkpoint from a Hugging Face repository,
preprocesses input images, runs inference, and returns detections (label/confidence).
The repository ID is the only input required. The model filename, class name, and weight file
are all expected to match the repository's base name. For example, if the repository ID is
"nkarthikeyan/MegaDetectorV5", then the model weight file should be "MegaDetectorV5.pt".
"""
def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None):
self.device = torch.device(device)
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.labels = None
if labels_path and os.path.exists(labels_path):
with open(labels_path, "r") as f:
self.labels = [line.strip() for line in f.readlines()]
self.model = None
@classmethod
def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs):
"""
Loads the model checkpoint from the given Hugging Face repository and returns an instance
of MegaDetectorModel ready for inference.
The repository's base name is used to derive the model weight filename. For example, if
repo_id is "nkarthikeyan/MegaDetectorV5", then the weight file is expected to be "MegaDetectorV5.pt".
Args:
repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5").
device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'.
Returns:
MegaDetectorModel: An instance with the model loaded.
"""
instance = cls(device=device, **kwargs)
# Use the repository base name as the weight filename.
model_name = repo_id.split("/")[-1]
weight_filename = f"{model_name}.pt"
model_path = hf_hub_download(repo_id=repo_id, filename=weight_filename)
checkpoint = torch.load(model_path, map_location=instance.device)
instance.model = checkpoint['model'].float().fuse().eval()
if instance.device.type != 'cpu':
instance.model.to(instance.device)
return instance
def pre_process(self, filename: str):
image_bgr = cv2.imread(filename)
if image_bgr is None:
raise ValueError(f"Could not load image from path: {filename}")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
model_stride = int(self.model.stride.max())
processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0]
processed = processed.transpose(2, 0, 1)
processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0
input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
return input_tensor, image_rgb
def predict(self, input_data):
processed_tensor, original_rgb = input_data
with torch.no_grad():
prediction = self.model(processed_tensor)[0]
detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres)
results = []
if detections and detections[0] is not None:
det = detections[0]
det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round()
for *xyxy, conf, cls_idx in det.tolist():
label_idx = int(cls_idx)
confidence = float(conf)
if self.labels and 0 <= label_idx < len(self.labels):
results.append((self.labels[label_idx], confidence))
else:
results.append((label_idx, confidence))
return results
|