import torch from model import get_model from torchvision.transforms import ToTensor from PIL import Image import io import os # Constants NUM_CLASSES = 4 CONFIDENCE_THRESHOLD = 0.5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the handler: load the model. """ # Load the model self.model_weights_path = os.path.join(path, "model.pt") self.model = get_model(NUM_CLASSES).to(DEVICE) checkpoint = torch.load(self.model_weights_path, map_location=DEVICE) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.eval() # Preprocessing function self.preprocess = ToTensor() # Class labels self.label_map = {1: "yellow", 2: "red", 3: "blue"} def preprocess_frame(self, image_bytes): """ Convert raw binary image data to a tensor. """ # Load image from binary data image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE) return image_tensor def __call__(self, data): """ Process incoming raw binary image data. """ try: if "body" not in data: return {"error": "No image data provided in request."} image_bytes = data["body"] image_tensor = self.preprocess_frame(image_bytes) # Perform inference with torch.no_grad(): predictions = self.model(image_tensor) # Extract predictions boxes = predictions[0]["boxes"].cpu().tolist() labels = predictions[0]["labels"].cpu().tolist() scores = predictions[0]["scores"].cpu().tolist() # Filter predictions by confidence threshold results = [] for box, label, score in zip(boxes, labels, scores): if score >= CONFIDENCE_THRESHOLD: x1, y1, x2, y2 = map(int, box) label_text = self.label_map.get(label, "unknown") results.append({ "box": [x1, y1, x2, y2], "label": label_text, "score": round(score, 2) }) return {"predictions": results} except Exception as e: return {"error": str(e)}