from .datasets.ab_dataset import ABDataset import matplotlib.pyplot as plt from torchvision.utils import make_grid import math import torch def visualize_classes_image_classification(dataset: ABDataset, class_to_idx_map, rename_map, fig_save_path: str, num_imgs_per_class=2, max_num_classes=20, unknown_class_idx=None): idx_to_images = {} idx_to_class = {} idx_to_original_idx = {} reach_max_num_class_limit = False for i, (c, idx) in enumerate(class_to_idx_map.items()): if unknown_class_idx is not None and idx == unknown_class_idx: continue idx_to_images[idx] = [] idx_to_class[idx] = c idx_to_original_idx[idx] = dataset.raw_classes.index(c) if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1: reach_max_num_class_limit = True break if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes: reach_max_num_class_limit = True break if unknown_class_idx is not None: idx_to_images[unknown_class_idx] = [] idx_to_class[unknown_class_idx] = ['(unknown classes)'] full_flags = {k: False for k in idx_to_images.keys()} i = 0 while True: x, y = dataset[i] i += 1 y = int(y) if full_flags[y]: continue idx_to_images[y] += [x] if len(idx_to_images[y]) == num_imgs_per_class: full_flags[y] = True if all(full_flags.values()): break shown_num_classes = len(idx_to_images.keys()) if reach_max_num_class_limit: shown_num_classes += 1 num_cols = 3 num_rows = math.ceil(shown_num_classes / num_cols) plt.figure(figsize=(6.4, 4.8 * num_rows // 2)) draw_i = 1 for class_idx, imgs in idx_to_images.items(): class_name = idx_to_class[class_idx] grid = make_grid(imgs, normalize=True) plt.subplot(num_rows, num_cols, draw_i) draw_i += 1 plt.axis('off') img = grid.permute(1, 2, 0).numpy() plt.imshow(img) if unknown_class_idx is not None and class_idx == unknown_class_idx: plt.title(f'(unknown classes)\n' f'current index: {class_idx}') else: class_i = idx_to_original_idx[class_idx] if class_name in rename_map.keys(): renamed_class = rename_map[class_name] plt.title(f'{class_i}-th original class\n' f'"{class_name}" (→ "{renamed_class}")\n' f'current index: {class_idx}') else: plt.title(f'{class_i}-th original class\n' f'"{class_name}"\n' f'current index: {class_idx}') if reach_max_num_class_limit: plt.subplot(num_rows, num_cols, draw_i) plt.axis('off') plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy()) plt.title(f'(Show up to {max_num_classes} classes...)') plt.tight_layout() plt.savefig(fig_save_path, dpi=300) plt.clf() def visualize_classes_in_object_detection(dataset: ABDataset, class_to_idx_map, rename_map, fig_save_path: str, num_imgs_per_class=2, max_num_classes=20, unknown_class_idx=None): idx_to_images = {} idx_to_class = {} idx_to_original_idx = {} reach_max_num_class_limit = False for i, (c, idx) in enumerate(class_to_idx_map.items()): if unknown_class_idx is not None and idx == unknown_class_idx: continue idx_to_images[idx] = [] idx_to_class[idx] = c idx_to_original_idx[idx] = dataset.raw_classes.index(c) if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1: reach_max_num_class_limit = True break if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes: reach_max_num_class_limit = True break if unknown_class_idx is not None: idx_to_images[unknown_class_idx] = [] idx_to_class[unknown_class_idx] = ['(unknown classes)'] full_flags = {k: False for k in idx_to_images.keys()} # print(idx_to_images.keys()) ii = 0 import time start_time = time.time() while True: # print(dataset[i]) x, y = dataset[ii][:2] ii += 1 cur_map = {} for label_info in y: if sum(label_info[1:]) == 0: # pad label break ci = label_info[0] print(f'cur ci: {ci}') # print(ci, label_info) if ci in cur_map.keys(): continue # do not visualize multiple objects in an image if len(idx_to_images[ci]) == num_imgs_per_class: full_flags[ci] = True break idx_to_images[ci] += [(x, label_info[1:])] print(f'add image, ci: {ci}') cur_map[ci] = 1 if time.time() - start_time > 40: break if sum(list(full_flags.values())) > len(full_flags.values()) * 0.7: break shown_num_classes = len(idx_to_images.keys()) if reach_max_num_class_limit: shown_num_classes += 1 num_cols = 3 num_rows = math.ceil(shown_num_classes / num_cols) plt.figure(figsize=(6.4, 4.8 * num_rows // 2)) from torchvision.transforms import ToTensor from PIL import Image, ImageDraw import numpy as np def draw_bbox(img, bbox): img = Image.fromarray(np.uint8(img.transpose(1, 2, 0))) draw = ImageDraw.Draw(img) draw.rectangle(bbox, outline=(255, 0, 0), width=6) return np.array(img) draw_i = 1 for class_idx, imgs in idx_to_images.items(): if len(imgs) == 0: draw_i += 1 continue imgs, bboxes = [img[0] for img in imgs], [img[1] for img in imgs] class_name = idx_to_class[class_idx] # draw bbox imgs = [draw_bbox(img, bbox) for img, bbox in zip(imgs, bboxes)] imgs = [ToTensor()(img) for img in imgs] grid = make_grid(imgs, normalize=True) plt.subplot(num_rows, num_cols, draw_i) draw_i += 1 plt.axis('off') img = grid.permute(1, 2, 0).numpy() plt.imshow(img) if unknown_class_idx is not None and class_idx == unknown_class_idx: plt.title(f'(unknown classes)\n' f'current index: {class_idx}') else: class_i = idx_to_original_idx[class_idx] if class_name in rename_map.keys(): renamed_class = rename_map[class_name] plt.title(f'{class_i}-th original class\n' f'"{class_name}" (→ "{renamed_class}")\n' f'current index: {class_idx}') else: plt.title(f'{class_i}-th original class\n' f'"{class_name}"\n' f'current index: {class_idx}') if reach_max_num_class_limit: plt.subplot(num_rows, num_cols, draw_i) plt.axis('off') plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy()) plt.title(f'(Show up to {max_num_classes} classes...)') plt.tight_layout() plt.savefig(fig_save_path, dpi=300) plt.clf()