import os import cv2 import numpy as np import torch from huggingface_hub import hf_hub_download from yolov5.utils.augmentations import letterbox from yolov5.utils.general import non_max_suppression, scale_boxes as scale_coords from abc import ABC, abstractmethod class BaseModel(ABC): @abstractmethod def pre_process(self, filename: str): """Pre-process the input file and return it as a tensor.""" pass @abstractmethod def predict(self, input_data): """Run inference on the pre-processed input and return predictions.""" pass class MegaDetectorModel(BaseModel): """ MegaDetectorModel loads the MegaDetector checkpoint from a Hugging Face repository, preprocesses input images, runs inference, and returns detections (label/confidence). The repository ID is the only input required. The model filename, class name, and weight file are all expected to match the repository's base name. For example, if the repository ID is "nkarthikeyan/MegaDetectorV5", then the model weight file should be "MegaDetectorV5.pt". """ def __init__(self, device='cpu', conf_thres=0.25, iou_thres=0.45, labels_path=None): self.device = torch.device(device) self.conf_thres = conf_thres self.iou_thres = iou_thres self.labels = None if labels_path and os.path.exists(labels_path): with open(labels_path, "r") as f: self.labels = [line.strip() for line in f.readlines()] self.model = None @classmethod def from_pretrained(cls, repo_id: str, device: str = 'cpu', **kwargs): """ Loads the model checkpoint from the given Hugging Face repository and returns an instance of MegaDetectorModel ready for inference. The repository's base name is used to derive the model weight filename. For example, if repo_id is "nkarthikeyan/MegaDetectorV5", then the weight file is expected to be "MegaDetectorV5.pt". Args: repo_id (str): The Hugging Face repository ID (e.g. "nkarthikeyan/MegaDetectorV5"). device (str, optional): Device to run the model on ('cpu' or 'cuda'). Default is 'cpu'. Returns: MegaDetectorModel: An instance with the model loaded. """ instance = cls(device=device, **kwargs) # Use the repository base name as the weight filename. model_name = repo_id.split("/")[-1] weight_filename = f"{model_name}.pt" model_path = hf_hub_download(repo_id=repo_id, filename=weight_filename) checkpoint = torch.load(model_path, map_location=instance.device) instance.model = checkpoint['model'].float().fuse().eval() if instance.device.type != 'cpu': instance.model.to(instance.device) return instance def pre_process(self, filename: str): image_bgr = cv2.imread(filename) if image_bgr is None: raise ValueError(f"Could not load image from path: {filename}") image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) model_stride = int(self.model.stride.max()) processed = letterbox(image_rgb, new_shape=640, stride=model_stride, auto=False)[0] processed = processed.transpose(2, 0, 1) processed = np.ascontiguousarray(processed, dtype=np.float32) / 255.0 input_tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device) return input_tensor, image_rgb def predict(self, input_data): processed_tensor, original_rgb = input_data with torch.no_grad(): prediction = self.model(processed_tensor)[0] detections = non_max_suppression(prediction, conf_thres=self.conf_thres, iou_thres=self.iou_thres) results = [] if detections and detections[0] is not None: det = detections[0] det[:, :4] = scale_coords(processed_tensor.shape[2:], det[:, :4], original_rgb.shape).round() for *xyxy, conf, cls_idx in det.tolist(): label_idx = int(cls_idx) confidence = float(conf) if self.labels and 0 <= label_idx < len(self.labels): results.append((self.labels[label_idx], confidence)) else: results.append((label_idx, confidence)) return results