File size: 4,003 Bytes
782cec7
3fa54be
cd4c90e
bc3d4e9
 
 
 
cd4c90e
 
 
 
 
bc3d4e9
cd4c90e
3fa54be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f890c24
3fa54be
 
cd4c90e
 
 
 
 
3fa54be
 
cd4c90e
 
3fa54be
cd4c90e
b80c100
cd4c90e
 
 
 
 
 
 
 
3fa54be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd4c90e
 
 
3fa54be
cd4c90e
 
3fa54be
cd4c90e
 
3fa54be
 
 
cd4c90e
 
9fbf078
 
3fa54be
 
 
 
 
 
 
 
 
 
 
 
 
9fbf078
 
 
 
 
3fa54be
 
9fbf078
e5bb367
 
3fa54be
e5bb367
3fa54be
e5bb367
 
3fa54be
 
bc3d4e9
9fbf078
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from io import BytesIO
from typing import Dict, Tuple, Union
from icevision import *
from icevision.models.checkpoint import model_from_checkpoint
from classifier import transform_image
from icevision.models import ross

import PIL
import torch
import numpy as np
import torchvision

MODEL_TYPE = ross.efficientdet

def predict(det_model : torch.nn.Module, image : Union[str, BytesIO],
            detection_threshold : float) -> Dict:
    """
    Make a prediction with the detection model.

    Args:
        det_model (torch.nn.Module): Detection model
        image (Union[str, BytesIO]): Image filepath if the image is one of
            the example images and BytesIO if the image is a custom image
            uploaded by the user.
        detection_threshold (float): Detection threshold

    Returns:
        Dict: Prediction dictionary.
    """        
    img = PIL.Image.open(image)

    # Class map and transforms
    class_map = ClassMap(classes=['Waste'])
    transforms = tfms.A.Adapter([
                    *tfms.A.resize_and_pad(512),
                    tfms.A.Normalize()
                ])
    
    # Single prediction
    pred_dict  = MODEL_TYPE.end2end_detect(img,
                                           transforms, 
                                           det_model,
                                           class_map=class_map,
                                           detection_threshold=detection_threshold,
                                           return_as_pil_img=False,
                                           return_img=True,
                                           display_bbox=False,
                                           display_score=False,
                                           display_label=False)

    return pred_dict

def prepare_prediction(pred_dict : Dict,
                       nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]:
    """
    Get the predictions in a right format.

    Args:
        pred_dict (Dict): Prediction dictionary.
        nms_threshold (float): Threshold for the NMS postprocess.

    Returns:
        Tuple: Tuple containing the following:
            - (torch.Tensor): Bounding boxes
            - (np.ndarray): Image data
    """
    # Convert each box to a tensor and stack them into an unique tensor
    boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
    boxes = torch.stack(boxes)

    # Get the scores and labels as tensor
    scores = torch.as_tensor(pred_dict['detection']['scores'])
    labels = torch.as_tensor(pred_dict['detection']['label_ids'])

    image = np.array(pred_dict['img'])

    # Apply NMS to postprocess the bounding boxes
    fixed_boxes = torchvision.ops.batched_nms(boxes, scores,
                                              labels,nms_threshold)
    boxes = boxes[fixed_boxes, :]

    return boxes, image

def predict_class(classifier : torch.nn.Module, image : np.ndarray,
                  bboxes : torch.Tensor) -> np.ndarray:
    """
    Predict the class of each detected object.

    Args:
        classifier (torch.nn.Module): Classifier model.
        image (np.ndarray): Image data.
        bboxes (torch.Tensor): Bounding boxes.

    Returns:
        np.ndarray: Array containing the predicted class for each object.
    """
    preds = []

    for bbox in bboxes:
        img = image.copy()
        bbox = np.array(bbox).astype(int)

        # Get the bounding box content
        cropped_img = PIL.Image.fromarray(img).crop(bbox)
        cropped_img = np.array(cropped_img)

        # Apply transformations to the cropped image
        tran_image = transform_image(cropped_img, 224)
        # Channels first
        tran_image = tran_image.transpose(2, 0, 1)
        tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)

        # Make prediction
        y_preds = classifier(tran_image)
        preds.append(y_preds.softmax(1).detach().numpy())

    preds = np.concatenate(preds).argmax(1)

    return preds