# Copyright (c) Meta Platforms, Inc. and affiliates. import copy import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import torch def likelihood_overlay( prob, map_viz=None, p_rgb=0.2, p_alpha=1 / 15, thresh=None, cmap="jet" ): prob = prob / prob.max() cmap = plt.get_cmap(cmap) rgb = cmap(prob**p_rgb) alpha = prob[..., None] ** p_alpha if thresh is not None: alpha[prob <= thresh] = 0 if map_viz is not None: faded = map_viz + (1 - map_viz) * 0.5 rgb = rgb[..., :3] * alpha + faded * (1 - alpha) rgb = np.clip(rgb, 0, 1) else: rgb[..., -1] = alpha.squeeze(-1) return rgb def heatmap2rgb(scores, mask=None, clip_min=0.05, alpha=0.8, cmap="jet"): min_, max_ = np.quantile(scores, [clip_min, 1]) scores = scores.clip(min=min_) rgb = plt.get_cmap(cmap)((scores - min_) / (max_ - min_)) if mask is not None: if alpha == 0: rgb[mask] = np.nan else: rgb[..., -1] = 1 - (1 - 1.0 * mask) * (1 - alpha) return rgb def plot_pose(axs, xy, yaw=None, s=1 / 35, c="r", a=1, w=0.015, dot=True, zorder=10): if yaw is not None: yaw = np.deg2rad(yaw) uv = np.array([np.sin(yaw), -np.cos(yaw)]) xy = np.array(xy) + 0.5 if not isinstance(axs, list): axs = [axs] for ax in axs: if isinstance(ax, int): ax = plt.gcf().axes[ax] if dot: ax.scatter(*xy, c=c, s=70, zorder=zorder, linewidths=0, alpha=a) if yaw is not None: ax.quiver( *xy, *uv, scale=s, scale_units="xy", angles="xy", color=c, zorder=zorder, alpha=a, width=w, ) def plot_dense_rotations( ax, prob, thresh=0.01, skip=10, s=1 / 15, k=3, c="k", w=None, **kwargs ): t = torch.argmax(prob, -1) yaws = t.numpy() / prob.shape[-1] * 360 prob = prob.max(-1).values / prob.max() mask = prob > thresh masked = prob.masked_fill(~mask, 0) max_ = torch.nn.functional.max_pool2d( masked.float()[None, None], k, stride=1, padding=k // 2 ) mask = (max_[0, 0] == masked.float()) & mask indices = np.where(mask.numpy() > 0) plot_pose( ax, indices[::-1], yaws[indices], s=s, c=c, dot=False, zorder=0.1, w=w, **kwargs, ) def copy_image(im, ax): prop = im.properties() prop.pop("children") prop.pop("size") prop.pop("tightbbox") prop.pop("transformed_clip_path_and_affine") prop.pop("window_extent") prop.pop("figure") prop.pop("transform") return ax.imshow(im.get_array(), **prop) def add_circle_inset( ax, center, corner=None, radius_px=10, inset_size=0.4, inset_offset=0.005, color="red", ): data_t_axes = ax.transAxes + ax.transData.inverted() if corner is None: center_axes = np.array(data_t_axes.inverted().transform(center)) corner = 1 - np.round(center_axes).astype(int) corner = np.array(corner) bottom_left = corner * (1 - inset_size - inset_offset) + (1 - corner) * inset_offset axins = ax.inset_axes([*bottom_left, inset_size, inset_size]) if ax.yaxis_inverted(): axins.invert_yaxis() axins.set_axis_off() c = mpl.patches.Circle(center, radius_px, fill=False, color=color) c1 = mpl.patches.Circle(center, radius_px, fill=False, color=color) # ax.add_patch(c) ax.add_patch(c1) # ax.add_patch(c.frozen()) axins.add_patch(c) radius_inset = radius_px + 1 axins.set_xlim([center[0] - radius_inset, center[0] + radius_inset]) ylim = center[1] - radius_inset, center[1] + radius_inset if axins.yaxis_inverted(): ylim = ylim[::-1] axins.set_ylim(ylim) for im in ax.images: im2 = copy_image(im, axins) im2.set_clip_path(c) return axins def plot_bev(bev, uv, yaw, ax=None, zorder=10, **kwargs): if ax is None: ax = plt.gca() h, w = bev.shape[:2] tfm = mpl.transforms.Affine2D().translate(-w / 2, -h) tfm = tfm.rotate_deg(yaw).translate(*uv + 0.5) tfm += plt.gca().transData ax.imshow(bev, transform=tfm, zorder=zorder, **kwargs) ax.plot( [0, w - 1, w / 2, 0], [0, 0, h - 0.5, 0], transform=tfm, c="k", lw=1, zorder=zorder + 1, )