File size: 2,483 Bytes
59c3137 f53d612 59c3137 28a32c3 59c3137 f53d612 639e661 f53d612 639e661 f53d612 639e661 f53d612 639e661 f53d612 639e661 f53d612 387dfb8 639e661 f53d612 387dfb8 f53d612 639e661 f53d612 639e661 f53d612 387dfb8 639e661 f53d612 639e661 387dfb8 639e661 387dfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
from model import get_model
from torchvision.transforms import ToTensor
from PIL import Image
import io
import os
# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler: load the model.
"""
# Load the model
self.model_weights_path = os.path.join(path, "model.pt")
self.model = get_model(NUM_CLASSES).to(DEVICE)
checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
# Preprocessing function
self.preprocess = ToTensor()
# Class labels
self.label_map = {1: "yellow", 2: "red", 3: "blue"}
def preprocess_frame(self, image_bytes):
"""
Convert raw binary image data to a tensor.
"""
# Load image from binary data
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
return image_tensor
def __call__(self, data):
"""
Process incoming raw binary image data.
"""
try:
if "body" not in data:
return {"error": "No image data provided in request."}
image_bytes = data["body"]
image_tensor = self.preprocess_frame(image_bytes)
# Perform inference
with torch.no_grad():
predictions = self.model(image_tensor)
# Extract predictions
boxes = predictions[0]["boxes"].cpu().tolist()
labels = predictions[0]["labels"].cpu().tolist()
scores = predictions[0]["scores"].cpu().tolist()
# Filter predictions by confidence threshold
results = []
for box, label, score in zip(boxes, labels, scores):
if score >= CONFIDENCE_THRESHOLD:
x1, y1, x2, y2 = map(int, box)
label_text = self.label_map.get(label, "unknown")
results.append({
"box": [x1, y1, x2, y2],
"label": label_text,
"score": round(score, 2)
})
return {"predictions": results}
except Exception as e:
return {"error": str(e)} |