Spaces:
Runtime error
Runtime error
File size: 2,428 Bytes
782cec7 cd4c90e b237467 cd4c90e f7cb0f8 cd4c90e b80c100 284fe0a 9c94b10 cd4c90e b80c100 cd4c90e b80c100 cd4c90e b80c100 cd4c90e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from io import BytesIO
from icevision import *
import collections
import PIL
import torch
import numpy as np
import torchvision
import icevision.models.ross.efficientdet
MODEL_TYPE = icevision.models.ross.efficientdet
def get_model(checkpoint_path):
extra_args = {}
backbone = MODEL_TYPE.backbones.d0
# The efficientdet model requires an img_size parameter
extra_args['img_size'] = 512
model = MODEL_TYPE.model(backbone=backbone(pretrained=True),
num_classes=2,
**extra_args)
ckpt = get_checkpoint(checkpoint_path)
model.load_state_dict(ckpt)
return model
def get_checkpoint(checkpoint_path):
ckpt = torch.load('checkpoint.ckpt', map_location=torch.device('cpu'))
fixed_state_dict = collections.OrderedDict()
for k, v in ckpt['state_dict'].items():
new_k = k[6:]
fixed_state_dict[new_k] = v
return fixed_state_dict
def predict(model, image, detection_threshold):
img = PIL.Image.open(image)
#img = PIL.Image.open(BytesIO(image))
img = np.array(img)
img = PIL.Image.fromarray(img)
class_map = ClassMap(classes=['Waste'])
transforms = tfms.A.Adapter([
*tfms.A.resize_and_pad(512),
tfms.A.Normalize()
])
pred_dict = MODEL_TYPE.end2end_detect(img,
transforms,
model,
class_map=class_map,
detection_threshold=detection_threshold,
return_as_pil_img=False,
return_img=True,
display_bbox=False,
display_score=False,
display_label=False)
return pred_dict
def prepare_prediction(pred_dict, threshold):
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
boxes = torch.stack(boxes)
scores = torch.as_tensor(pred_dict['detection']['scores'])
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
image = np.array(pred_dict['img'])
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
boxes = boxes[fixed_boxes, :]
return boxes, image |