import argparse import os import torch from model.detector import * from model.backbone import * from model.loss import * from model.data import Therin import datetime from model.detector.fasterRCNN import FasterRCNN from model.backbone.densenet import DenseNet from model.utils.engine import * from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision import transforms as T from PIL import Image, ImageDraw device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def label_to_text_en(l): d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"} return d[l] def label_to_text_ja(l): d = {0: "しのびこんでいる", 1: "這っている", 2: "かがんでいる", 3: "よじ登っている", 4: "その他"} return d[l] def show_bb(img, x, y, w, h, text, textcolor, bbcolor): draw = ImageDraw.Draw(img) text_w, text_h = draw.textsize(text) label_y = y if y <= text_h else y - text_h draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor) draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor) draw.text((x, label_y), text, fill=textcolor) def postprocess(true_image, o): copy_im = true_image.copy() data = o[0] boxes = data["boxes"] labels = data["labels"].tolist() scores = data["scores"].tolist() selected_labels = [] selected_scores = [] selected_indices = [] thresh = 0.30 for i, box in enumerate(boxes.tolist()): # if scores[i] > thresh: if i == scores.index(max(scores)): show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh selected_labels.append(label_to_text_ja(labels[i])) selected_scores.append( '{:.3f}'.format(scores[i])) selected_indices.append(i) copy_im.show() copy_im.save("img/detected.png") return selected_labels, selected_scores, selected_indices def inference(image_pil): num_classes = 5 backbone = resnet_fpn_backbone('resnet18', False) model = FasterRCNN(backbone, num_classes) model.eval() state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth') model.load_state_dict(state_dict["model"]) _transform = T.Compose([T.ToTensor()]) image = image_pil.convert("RGB") image = _transform(image) output = model([image]) res = postprocess(image_pil, output) return output, res