File size: 3,559 Bytes
f977726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import cv2
import torch
import random
from PIL import ImageDraw
import torchvision.transforms as T

# COCO Classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# Standard PyTorch mean-std Input Image Normalization
transform = T.Compose([
    T.Resize(500),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# For Output Bounding Box Post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


# Pre-processing on Image
def image_processing(im, model, transform, confidence=0.9):
    # im = Image.open(image_path)
    img = transform(im).unsqueeze(0)

    outputs = model(img)
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > confidence
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

    return probas[keep], bboxes_scaled


# Helper Functions for Plotting BBoxes
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    width, height = img.size
    tl = line_thickness or round(0.002 * (width + height) / 2) + 1  # line/font thickness
    color = color or (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    img_draw = ImageDraw.Draw(img)
    img_draw.rectangle((c1[0], c1[1], c2[0], c2[1]), outline=color, width=tl)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 5, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] + t_size[1] - 3
        img_draw.rectangle((c1, c2), fill=color)
        img_draw.text((c1[0], c1[1] - 2), label, fill=(255, 255, 255))


# Ploting Bounding Box on img
def add_bboxes(pil_img, prob, bboxes):
    for p, coord in zip(prob, bboxes.tolist()):
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]: 0.2f}'
        plot_one_box(x=coord, img=pil_img, label=text)

    return pil_img


def detect(im, confidence):
    # Load model
    model = torch.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
    model.eval()

    scores, boxes = image_processing(im, model, transform, confidence / 100)
    # im = cv2.imread(image_path)
    # im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = add_bboxes(im, scores, boxes)

    return im