import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np import seaborn as sns import torch import torchvision from matplotlib import colors def get_part_color(n_parts): colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen', 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid', 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson', 'honeydew', 'thistle', 'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen', 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid', 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson', 'honeydew', 'thistle')[:n_parts] part_color = [] for i in range(n_parts): part_color.append(colors.to_rgb(colormap[i])) part_color = np.array(part_color) return part_color def denormalize(img): mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1) std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1) img = img * std + mean img = torch.clamp(img, min=0, max=1) return img def draw_matrix(mat): fig = plt.figure() sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu") ncols, nrows = fig.canvas.get_width_height() fig.canvas.draw() plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) plt.close(fig) return plot def draw_kp_grid(img, kp): kp_color = get_part_color(kp.shape[1]) img = img[:64].permute(0, 2, 3, 1).detach().cpu() kp = kp.detach().cpu()[:64] fig = plt.figure(figsize=(8, 8)) gs = gridspec.GridSpec(8, 8) gs.update(wspace=0, hspace=0) for i, sample in enumerate(img): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.imshow(sample, vmin=0, vmax=1) ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+') ncols, nrows = fig.canvas.get_width_height() fig.canvas.draw() plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) plt.close(fig) return plot def draw_kp_grid_unnorm(img, kp): kp_color = get_part_color(kp.shape[1]) img = img[:64].permute(0, 2, 3, 1).detach().cpu() kp = kp.detach().cpu()[:64] fig = plt.figure(figsize=(8, 8)) gs = gridspec.GridSpec(8, 8) gs.update(wspace=0, hspace=0) for i, sample in enumerate(img): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.imshow(sample) ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+') ncols, nrows = fig.canvas.get_width_height() fig.canvas.draw() plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) plt.close(fig) return plot def draw_img_grid(img): img = img[:64].detach().cpu() nrow = min(8, img.shape[0]) img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0) return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8)