Waste-Detector / model.py
Hector Lopez
Multiple refactors
3fa54be
from io import BytesIO
from typing import Dict, Tuple, Union
from icevision import *
from icevision.models.checkpoint import model_from_checkpoint
from classifier import transform_image
from icevision.models import ross
import PIL
import torch
import numpy as np
import torchvision
MODEL_TYPE = ross.efficientdet
def predict(det_model : torch.nn.Module, image : Union[str, BytesIO],
detection_threshold : float) -> Dict:
"""
Make a prediction with the detection model.
Args:
det_model (torch.nn.Module): Detection model
image (Union[str, BytesIO]): Image filepath if the image is one of
the example images and BytesIO if the image is a custom image
uploaded by the user.
detection_threshold (float): Detection threshold
Returns:
Dict: Prediction dictionary.
"""
img = PIL.Image.open(image)
# Class map and transforms
class_map = ClassMap(classes=['Waste'])
transforms = tfms.A.Adapter([
*tfms.A.resize_and_pad(512),
tfms.A.Normalize()
])
# Single prediction
pred_dict = MODEL_TYPE.end2end_detect(img,
transforms,
det_model,
class_map=class_map,
detection_threshold=detection_threshold,
return_as_pil_img=False,
return_img=True,
display_bbox=False,
display_score=False,
display_label=False)
return pred_dict
def prepare_prediction(pred_dict : Dict,
nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]:
"""
Get the predictions in a right format.
Args:
pred_dict (Dict): Prediction dictionary.
nms_threshold (float): Threshold for the NMS postprocess.
Returns:
Tuple: Tuple containing the following:
- (torch.Tensor): Bounding boxes
- (np.ndarray): Image data
"""
# Convert each box to a tensor and stack them into an unique tensor
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
boxes = torch.stack(boxes)
# Get the scores and labels as tensor
scores = torch.as_tensor(pred_dict['detection']['scores'])
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
image = np.array(pred_dict['img'])
# Apply NMS to postprocess the bounding boxes
fixed_boxes = torchvision.ops.batched_nms(boxes, scores,
labels,nms_threshold)
boxes = boxes[fixed_boxes, :]
return boxes, image
def predict_class(classifier : torch.nn.Module, image : np.ndarray,
bboxes : torch.Tensor) -> np.ndarray:
"""
Predict the class of each detected object.
Args:
classifier (torch.nn.Module): Classifier model.
image (np.ndarray): Image data.
bboxes (torch.Tensor): Bounding boxes.
Returns:
np.ndarray: Array containing the predicted class for each object.
"""
preds = []
for bbox in bboxes:
img = image.copy()
bbox = np.array(bbox).astype(int)
# Get the bounding box content
cropped_img = PIL.Image.fromarray(img).crop(bbox)
cropped_img = np.array(cropped_img)
# Apply transformations to the cropped image
tran_image = transform_image(cropped_img, 224)
# Channels first
tran_image = tran_image.transpose(2, 0, 1)
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
# Make prediction
y_preds = classifier(tran_image)
preds.append(y_preds.softmax(1).detach().numpy())
preds = np.concatenate(preds).argmax(1)
return preds