File size: 2,483 Bytes
59c3137
f53d612
 
59c3137
 
28a32c3
59c3137
f53d612
 
 
 
639e661
 
 
 
f53d612
639e661
f53d612
 
 
 
639e661
 
 
f53d612
 
 
 
 
 
 
 
 
 
 
 
 
 
639e661
 
 
f53d612
639e661
 
f53d612
387dfb8
639e661
f53d612
 
387dfb8
f53d612
639e661
f53d612
639e661
f53d612
387dfb8
 
 
639e661
f53d612
 
 
 
 
 
 
 
 
 
 
639e661
387dfb8
639e661
387dfb8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from model import get_model 
from torchvision.transforms import ToTensor
from PIL import Image
import io
import os

# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler: load the model.
        """
        # Load the model
        self.model_weights_path = os.path.join(path, "model.pt")
        self.model = get_model(NUM_CLASSES).to(DEVICE)
        checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.eval()

        # Preprocessing function
        self.preprocess = ToTensor()

        # Class labels
        self.label_map = {1: "yellow", 2: "red", 3: "blue"}

    def preprocess_frame(self, image_bytes):
        """
        Convert raw binary image data to a tensor.
        """
        # Load image from binary data
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
        return image_tensor

    def __call__(self, data):
        """
        Process incoming raw binary image data.
        """
        try:
            if "body" not in data:
                return {"error": "No image data provided in request."}

            image_bytes = data["body"]  
            image_tensor = self.preprocess_frame(image_bytes)

            # Perform inference
            with torch.no_grad():
                predictions = self.model(image_tensor)

            # Extract predictions
            boxes = predictions[0]["boxes"].cpu().tolist()
            labels = predictions[0]["labels"].cpu().tolist()
            scores = predictions[0]["scores"].cpu().tolist()

            # Filter predictions by confidence threshold
            results = []
            for box, label, score in zip(boxes, labels, scores):
                if score >= CONFIDENCE_THRESHOLD:
                    x1, y1, x2, y2 = map(int, box)
                    label_text = self.label_map.get(label, "unknown")
                    results.append({
                        "box": [x1, y1, x2, y2],
                        "label": label_text,
                        "score": round(score, 2)
                    })

            return {"predictions": results}
        except Exception as e:
            return {"error": str(e)}