Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import yaml | |
from yolov5.models.common import DetectMultiBackend | |
from yolov5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams | |
from yolov5.utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, | |
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) | |
from yolov5.utils.plots import Annotator, colors, save_one_box | |
from yolov5.utils.torch_utils import select_device, time_sync | |
from yolov5.utils.augmentations import letterbox | |
device = 'cpu' | |
half = False | |
weights = 'yolov5/joint_all_multi.pt' | |
model = DetectMultiBackend(weights, device=device, dnn=False, data=None) | |
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | |
bs = 1 | |
imgsz = (640, 640) | |
conf_thres = 0.1 | |
iou_thres = 0.1 | |
hide_labels = False | |
hide_conf = True | |
line_thickness = 1 | |
with open('yolov5/joint_all_multi.yaml', 'r') as f: | |
LABELS = yaml.safe_load(f)['names'] | |
def joint_detection(img0): | |
global imgsz | |
img = letterbox(img0, 640, stride=stride, auto=pt)[0] | |
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |
img = np.ascontiguousarray(img) | |
im = torch.from_numpy(img).to(device) | |
im = im.half() if half else im.float() # uint8 to fp16/32 | |
im /= 255 # 0 - 255 to 0.0 - 1.0 | |
if len(im.shape) == 3: | |
im = im[None] # expand for batch dim | |
# Padded resize | |
# Convert | |
imgsz = check_img_size(imgsz, s=stride) # check image size | |
# Inference | |
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz), half=half) # warmup | |
pred = model(im, augment=False, visualize=False) | |
t3 = time_sync() | |
# NMS | |
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000) | |
# Second-stage classifier (optional) | |
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) | |
# Process predictions | |
for i, det in enumerate(pred): # per image | |
im0 = img0.copy() | |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh | |
annotator = Annotator(im0, line_width=line_thickness, example=str(names)) | |
imc = im0 | |
if len(det): | |
# Rescale boxes from img_size to im0 size | |
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() | |
# Write results | |
for *xyxy, conf, cls in reversed(det): | |
c = int(cls) # integer class | |
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') | |
annotator.box_label(xyxy, label, color=colors(c, True)) | |
# save as text | |
# Write results | |
content = {} | |
for *xyxy, conf, cls in reversed(det): | |
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh | |
x, y, width, height = xywh | |
current_label = LABELS[int(cls.item())] | |
if content.get(current_label, None) is None: | |
content[current_label] = [] | |
current_dict = {'x': x, 'y': y, 'width': width, 'height': height} | |
content[current_label].append(current_dict) # label format | |
# Stream results | |
im0 = annotator.result() | |
return im0, content | |
# if view_img: | |
# cv2.imshow(str(p), im0) | |
# cv2.waitKey(1) # 1 millisecond | |
# | |
# # Save results (image with detections) | |
# if save_img: | |
# if dataset.mode == 'image': | |
# cv2.imwrite(save_path, im0) | |
# else: # 'video' or 'stream' | |
# if vid_path[i] != save_path: # new video | |
# vid_path[i] = save_path | |
# if isinstance(vid_writer[i], cv2.VideoWriter): | |
# vid_writer[i].release() # release previous video writer | |
# if vid_cap: # video | |
# fps = vid_cap.get(cv2.CAP_PROP_FPS) | |
# w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
# h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# else: # stream | |
# fps, w, h = 30, im0.shape[1], im0.shape[0] | |
# save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos | |
# vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) | |
# vid_writer[i].write(im0) | |
# Print time (inference-only) | |
iface = gr.Interface(fn=joint_detection, inputs="image", outputs=["image", "json"]) | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |