from typing import Tuple import matplotlib.pyplot as plt import numpy as np import cv2 import torch from classifier import CustomViT from model import get_model def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels): colors = { 0: (255,255,0), 1: (255, 0, 0), 2: (0, 0, 255), 3: (0,128,0), 4: (255,165,0), 5: (230,230,250), 6: (192,192,192) } texts = { 0: 'plastic', 1: 'dangerous', 2: 'carton', 3: 'glass', 4: 'organic', 5: 'rest', 6: 'other' } # Show image boxes = boxes.cpu().detach().numpy().astype(np.int32) fig, ax = plt.subplots(1, 1, figsize=(12, 6)) for i, box in enumerate(boxes): color = colors[labels[i]] [x1, y1, x2, y2] = np.array(box).astype(int) # Si no se hace la copia da error en cv2.rectangle image = np.array(image).copy() pt1 = (x1, y1) pt2 = (x2, y2) cv2.rectangle(image, pt1, pt2, color, thickness=5) cv2.putText(image, texts[labels[i]], (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color) plt.axis('off') ax.imshow(image) fig.savefig("img.png", bbox_inches='tight') def get_models( detection_ckpt : str, classifier_ckpt : str ) -> Tuple[torch.nn.Module, torch.nn.Module]: """ Get the detection and classifier models Args: detection_ckpt (str): Detection model checkpoint classifier_ckpt (str): Classifier model checkpoint Returns: tuple: Tuple containing: - (torch.nn.Module): Detection model - (torch.nn.Module): Classifier model """ print('Loading the detection model') det_model = get_model(detection_ckpt) det_model.eval() print('Loading the classifier model') classifier = CustomViT(target_size=7, pretrained=False) classifier.load_state_dict(torch.load(classifier_ckpt, map_location='cpu')) classifier.eval() return det_model, classifier