File size: 2,428 Bytes
782cec7
cd4c90e
 
 
 
 
 
 
b237467
 
 
cd4c90e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7cb0f8
cd4c90e
 
 
 
 
 
 
 
 
b80c100
 
 
284fe0a
9c94b10
 
cd4c90e
 
 
 
 
 
 
 
 
 
b80c100
cd4c90e
 
 
 
 
 
 
 
b80c100
cd4c90e
 
 
 
 
 
 
b80c100
cd4c90e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from io import BytesIO
from icevision import *
import collections
import PIL
import torch
import numpy as np
import torchvision

import icevision.models.ross.efficientdet

MODEL_TYPE = icevision.models.ross.efficientdet

def get_model(checkpoint_path):
    extra_args = {}
    backbone = MODEL_TYPE.backbones.d0
    # The efficientdet model requires an img_size parameter
    extra_args['img_size'] = 512

    model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
                             num_classes=2, 
                             **extra_args)
    
    ckpt = get_checkpoint(checkpoint_path)
    model.load_state_dict(ckpt)

    return model

def get_checkpoint(checkpoint_path):
    ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu'))

    fixed_state_dict = collections.OrderedDict()

    for k, v in ckpt['state_dict'].items():
        new_k = k[6:]
        fixed_state_dict[new_k] = v

    return fixed_state_dict

def predict(model, image, detection_threshold):
    img = PIL.Image.open(image)
    #img = PIL.Image.open(BytesIO(image))
    img = np.array(img)
    img = PIL.Image.fromarray(img)
    
    class_map = ClassMap(classes=['Waste'])
    transforms = tfms.A.Adapter([
                    *tfms.A.resize_and_pad(512),
                    tfms.A.Normalize()
                ])

    pred_dict  = MODEL_TYPE.end2end_detect(img,
                                           transforms, 
                                           model,
                                           class_map=class_map,
                                           detection_threshold=detection_threshold,
                                           return_as_pil_img=False,
                                           return_img=True,
                                           display_bbox=False,
                                           display_score=False,
                                           display_label=False)

    return pred_dict

def prepare_prediction(pred_dict, threshold):
    boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
    boxes = torch.stack(boxes)

    scores = torch.as_tensor(pred_dict['detection']['scores'])
    labels = torch.as_tensor(pred_dict['detection']['label_ids'])
    image = np.array(pred_dict['img'])

    fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
    boxes = boxes[fixed_boxes, :]

    return boxes, image