EdgeTA / data /visualize.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
from .datasets.ab_dataset import ABDataset
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import math
import torch
def visualize_classes_image_classification(dataset: ABDataset, class_to_idx_map, rename_map,
fig_save_path: str, num_imgs_per_class=2, max_num_classes=20,
unknown_class_idx=None):
idx_to_images = {}
idx_to_class = {}
idx_to_original_idx = {}
reach_max_num_class_limit = False
for i, (c, idx) in enumerate(class_to_idx_map.items()):
if unknown_class_idx is not None and idx == unknown_class_idx:
continue
idx_to_images[idx] = []
idx_to_class[idx] = c
idx_to_original_idx[idx] = dataset.raw_classes.index(c)
if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1:
reach_max_num_class_limit = True
break
if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes:
reach_max_num_class_limit = True
break
if unknown_class_idx is not None:
idx_to_images[unknown_class_idx] = []
idx_to_class[unknown_class_idx] = ['(unknown classes)']
full_flags = {k: False for k in idx_to_images.keys()}
i = 0
while True:
x, y = dataset[i]
i += 1
y = int(y)
if full_flags[y]:
continue
idx_to_images[y] += [x]
if len(idx_to_images[y]) == num_imgs_per_class:
full_flags[y] = True
if all(full_flags.values()):
break
shown_num_classes = len(idx_to_images.keys())
if reach_max_num_class_limit:
shown_num_classes += 1
num_cols = 3
num_rows = math.ceil(shown_num_classes / num_cols)
plt.figure(figsize=(6.4, 4.8 * num_rows // 2))
draw_i = 1
for class_idx, imgs in idx_to_images.items():
class_name = idx_to_class[class_idx]
grid = make_grid(imgs, normalize=True)
plt.subplot(num_rows, num_cols, draw_i)
draw_i += 1
plt.axis('off')
img = grid.permute(1, 2, 0).numpy()
plt.imshow(img)
if unknown_class_idx is not None and class_idx == unknown_class_idx:
plt.title(f'(unknown classes)\n'
f'current index: {class_idx}')
else:
class_i = idx_to_original_idx[class_idx]
if class_name in rename_map.keys():
renamed_class = rename_map[class_name]
plt.title(f'{class_i}-th original class\n'
f'"{class_name}" (→ "{renamed_class}")\n'
f'current index: {class_idx}')
else:
plt.title(f'{class_i}-th original class\n'
f'"{class_name}"\n'
f'current index: {class_idx}')
if reach_max_num_class_limit:
plt.subplot(num_rows, num_cols, draw_i)
plt.axis('off')
plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy())
plt.title(f'(Show up to {max_num_classes} classes...)')
plt.tight_layout()
plt.savefig(fig_save_path, dpi=300)
plt.clf()
def visualize_classes_in_object_detection(dataset: ABDataset, class_to_idx_map, rename_map,
fig_save_path: str, num_imgs_per_class=2, max_num_classes=20,
unknown_class_idx=None):
idx_to_images = {}
idx_to_class = {}
idx_to_original_idx = {}
reach_max_num_class_limit = False
for i, (c, idx) in enumerate(class_to_idx_map.items()):
if unknown_class_idx is not None and idx == unknown_class_idx:
continue
idx_to_images[idx] = []
idx_to_class[idx] = c
idx_to_original_idx[idx] = dataset.raw_classes.index(c)
if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1:
reach_max_num_class_limit = True
break
if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes:
reach_max_num_class_limit = True
break
if unknown_class_idx is not None:
idx_to_images[unknown_class_idx] = []
idx_to_class[unknown_class_idx] = ['(unknown classes)']
full_flags = {k: False for k in idx_to_images.keys()}
# print(idx_to_images.keys())
ii = 0
import time
start_time = time.time()
while True:
# print(dataset[i])
x, y = dataset[ii][:2]
ii += 1
cur_map = {}
for label_info in y:
if sum(label_info[1:]) == 0: # pad label
break
ci = label_info[0]
print(f'cur ci: {ci}')
# print(ci, label_info)
if ci in cur_map.keys():
continue # do not visualize multiple objects in an image
if len(idx_to_images[ci]) == num_imgs_per_class:
full_flags[ci] = True
break
idx_to_images[ci] += [(x, label_info[1:])]
print(f'add image, ci: {ci}')
cur_map[ci] = 1
if time.time() - start_time > 40:
break
if sum(list(full_flags.values())) > len(full_flags.values()) * 0.7:
break
shown_num_classes = len(idx_to_images.keys())
if reach_max_num_class_limit:
shown_num_classes += 1
num_cols = 3
num_rows = math.ceil(shown_num_classes / num_cols)
plt.figure(figsize=(6.4, 4.8 * num_rows // 2))
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw
import numpy as np
def draw_bbox(img, bbox):
img = Image.fromarray(np.uint8(img.transpose(1, 2, 0)))
draw = ImageDraw.Draw(img)
draw.rectangle(bbox, outline=(255, 0, 0), width=6)
return np.array(img)
draw_i = 1
for class_idx, imgs in idx_to_images.items():
if len(imgs) == 0:
draw_i += 1
continue
imgs, bboxes = [img[0] for img in imgs], [img[1] for img in imgs]
class_name = idx_to_class[class_idx]
# draw bbox
imgs = [draw_bbox(img, bbox) for img, bbox in zip(imgs, bboxes)]
imgs = [ToTensor()(img) for img in imgs]
grid = make_grid(imgs, normalize=True)
plt.subplot(num_rows, num_cols, draw_i)
draw_i += 1
plt.axis('off')
img = grid.permute(1, 2, 0).numpy()
plt.imshow(img)
if unknown_class_idx is not None and class_idx == unknown_class_idx:
plt.title(f'(unknown classes)\n'
f'current index: {class_idx}')
else:
class_i = idx_to_original_idx[class_idx]
if class_name in rename_map.keys():
renamed_class = rename_map[class_name]
plt.title(f'{class_i}-th original class\n'
f'"{class_name}" (→ "{renamed_class}")\n'
f'current index: {class_idx}')
else:
plt.title(f'{class_i}-th original class\n'
f'"{class_name}"\n'
f'current index: {class_idx}')
if reach_max_num_class_limit:
plt.subplot(num_rows, num_cols, draw_i)
plt.axis('off')
plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy())
plt.title(f'(Show up to {max_num_classes} classes...)')
plt.tight_layout()
plt.savefig(fig_save_path, dpi=300)
plt.clf()