DeepLabCutModelZoo-SuperAnimals / detection_utils.py
shaokaiye's picture
Update detection_utils.py
e554d61
raw
history blame
3.13 kB
from tkinter import W
import gradio as gr
from matplotlib import cm
import torch
import torchvision
import matplotlib
import PIL
from PIL import Image, ImageColor, ImageFont, ImageDraw
import numpy as np
import math
import yaml
import pdb
############################################
# Predict detections with MegaDetector v5a model
def predict_md(im,
megadetector_model, #Megadet_Models[mega_model_input]
size=640):
# resize image
g = (size / max(im.size)) # multipl factor to make max size of the image equal to input size
im = im.resize((int(x * g) for x in im.size),
PIL.Image.Resampling.LANCZOS) # resize
# device
if torch.cuda.is_available():
md_device = torch.device('cuda')
else:
md_device = torch.device('cpu')
# megadetector
MD_model = torch.hub.load('ultralytics/yolov5', # repo_or_dir
'custom', #model
megadetector_model, # args for callable model
force_reload=True,
device=md_device)
# send model to gpu if possible
if (md_device == torch.device('cuda')):
print('Sending model to GPU')
MD_model.to(md_device)
## detect objects
results = MD_model(im) # inference # vars(results).keys()= dict_keys(['imgs', 'pred', 'names', 'files', 'times', 'xyxy', 'xywh', 'xyxyn', 'xywhn', 'n', 't', 's'])
return results
##########################################
def crop_animal_detections(img_in,
yolo_results,
likelihood_th):
## Extract animal crops
list_labels_as_str = [i for i in yolo_results.names.values()] # ['animal', 'person', 'vehicle']
list_np_animal_crops = []
# image to crop (scale as input for megadetector)
img_in = img_in.resize((yolo_results.ims[0].shape[1],
yolo_results.ims[0].shape[0]))
# for every detection in the img
for det_array in yolo_results.xyxy:
# for every detection
for j in range(det_array.shape[0]):
# compute coords around bbox rounded to the nearest integer (for pasting later)
xmin_rd = int(math.floor(det_array[j,0])) # int() should suffice?
ymin_rd = int(math.floor(det_array[j,1]))
xmax_rd = int(math.ceil(det_array[j,2]))
ymax_rd = int(math.ceil(det_array[j,3]))
pred_llk = det_array[j,4]
pred_label = det_array[j,5]
# keep animal crops above threshold
if (pred_label == list_labels_as_str.index('animal')) and \
(pred_llk >= likelihood_th):
area = (xmin_rd, ymin_rd, xmax_rd, ymax_rd)
#pdb.set_trace()
crop = img_in.crop(area) #Image.fromarray(img_in).crop(area)
crop_np = np.asarray(crop)
# add to list
list_np_animal_crops.append(crop_np)
return list_np_animal_crops