Spaces:
Runtime error
Runtime error
from io import BytesIO | |
from typing import Dict, Tuple, Union | |
from icevision import * | |
from icevision.models.checkpoint import model_from_checkpoint | |
from classifier import transform_image | |
from icevision.models import ross | |
import PIL | |
import torch | |
import numpy as np | |
import torchvision | |
MODEL_TYPE = ross.efficientdet | |
def predict(det_model : torch.nn.Module, image : Union[str, BytesIO], | |
detection_threshold : float) -> Dict: | |
""" | |
Make a prediction with the detection model. | |
Args: | |
det_model (torch.nn.Module): Detection model | |
image (Union[str, BytesIO]): Image filepath if the image is one of | |
the example images and BytesIO if the image is a custom image | |
uploaded by the user. | |
detection_threshold (float): Detection threshold | |
Returns: | |
Dict: Prediction dictionary. | |
""" | |
img = PIL.Image.open(image) | |
# Class map and transforms | |
class_map = ClassMap(classes=['Waste']) | |
transforms = tfms.A.Adapter([ | |
*tfms.A.resize_and_pad(512), | |
tfms.A.Normalize() | |
]) | |
# Single prediction | |
pred_dict = MODEL_TYPE.end2end_detect(img, | |
transforms, | |
det_model, | |
class_map=class_map, | |
detection_threshold=detection_threshold, | |
return_as_pil_img=False, | |
return_img=True, | |
display_bbox=False, | |
display_score=False, | |
display_label=False) | |
return pred_dict | |
def prepare_prediction(pred_dict : Dict, | |
nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]: | |
""" | |
Get the predictions in a right format. | |
Args: | |
pred_dict (Dict): Prediction dictionary. | |
nms_threshold (float): Threshold for the NMS postprocess. | |
Returns: | |
Tuple: Tuple containing the following: | |
- (torch.Tensor): Bounding boxes | |
- (np.ndarray): Image data | |
""" | |
# Convert each box to a tensor and stack them into an unique tensor | |
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']] | |
boxes = torch.stack(boxes) | |
# Get the scores and labels as tensor | |
scores = torch.as_tensor(pred_dict['detection']['scores']) | |
labels = torch.as_tensor(pred_dict['detection']['label_ids']) | |
image = np.array(pred_dict['img']) | |
# Apply NMS to postprocess the bounding boxes | |
fixed_boxes = torchvision.ops.batched_nms(boxes, scores, | |
labels,nms_threshold) | |
boxes = boxes[fixed_boxes, :] | |
return boxes, image | |
def predict_class(classifier : torch.nn.Module, image : np.ndarray, | |
bboxes : torch.Tensor) -> np.ndarray: | |
""" | |
Predict the class of each detected object. | |
Args: | |
classifier (torch.nn.Module): Classifier model. | |
image (np.ndarray): Image data. | |
bboxes (torch.Tensor): Bounding boxes. | |
Returns: | |
np.ndarray: Array containing the predicted class for each object. | |
""" | |
preds = [] | |
for bbox in bboxes: | |
img = image.copy() | |
bbox = np.array(bbox).astype(int) | |
# Get the bounding box content | |
cropped_img = PIL.Image.fromarray(img).crop(bbox) | |
cropped_img = np.array(cropped_img) | |
# Apply transformations to the cropped image | |
tran_image = transform_image(cropped_img, 224) | |
# Channels first | |
tran_image = tran_image.transpose(2, 0, 1) | |
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0) | |
# Make prediction | |
y_preds = classifier(tran_image) | |
preds.append(y_preds.softmax(1).detach().numpy()) | |
preds = np.concatenate(preds).argmax(1) | |
return preds |