Spaces:
Running
on
T4
Running
on
T4
# -*- coding: utf-8 -*- | |
""" | |
@File : visualizer.py | |
@Time : 2022/04/05 11:39:33 | |
@Author : Shilong Liu | |
@Contact : [email protected] | |
""" | |
import datetime | |
import os | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from matplotlib import transforms | |
from matplotlib.collections import PatchCollection | |
from matplotlib.patches import Polygon | |
from pycocotools import mask as maskUtils | |
def renorm( | |
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) -> torch.FloatTensor: | |
# img: tensor(3,H,W) or tensor(B,3,H,W) | |
# return: same as img | |
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() | |
if img.dim() == 3: | |
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( | |
img.size(0), | |
str(img.size()), | |
) | |
img_perm = img.permute(1, 2, 0) | |
mean = torch.Tensor(mean) | |
std = torch.Tensor(std) | |
img_res = img_perm * std + mean | |
return img_res.permute(2, 0, 1) | |
else: # img.dim() == 4 | |
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( | |
img.size(1), | |
str(img.size()), | |
) | |
img_perm = img.permute(0, 2, 3, 1) | |
mean = torch.Tensor(mean) | |
std = torch.Tensor(std) | |
img_res = img_perm * std + mean | |
return img_res.permute(0, 3, 1, 2) | |
class ColorMap: | |
def __init__(self, basergb=[255, 255, 0]): | |
self.basergb = np.array(basergb) | |
def __call__(self, attnmap): | |
# attnmap: h, w. np.uint8. | |
# return: h, w, 4. np.uint8. | |
assert attnmap.dtype == np.uint8 | |
h, w = attnmap.shape | |
res = self.basergb.copy() | |
res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3 | |
attn1 = attnmap.copy()[..., None] # h, w, 1 | |
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) | |
return res | |
def rainbow_text(x, y, ls, lc, **kw): | |
""" | |
Take a list of strings ``ls`` and colors ``lc`` and place them next to each | |
other, with text ls[i] being shown in color lc[i]. | |
This example shows how to do both vertical and horizontal text, and will | |
pass all keyword arguments to plt.text, so you can set the font size, | |
family, etc. | |
""" | |
t = plt.gca().transData | |
fig = plt.gcf() | |
plt.show() | |
# horizontal version | |
for s, c in zip(ls, lc): | |
text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw) | |
text.draw(fig.canvas.get_renderer()) | |
ex = text.get_window_extent() | |
t = transforms.offset_copy(text._transform, x=ex.width, units="dots") | |
# #vertical version | |
# for s,c in zip(ls,lc): | |
# text = plt.text(x,y," "+s+" ",color=c, transform=t, | |
# rotation=90,va='bottom',ha='center',**kw) | |
# text.draw(fig.canvas.get_renderer()) | |
# ex = text.get_window_extent() | |
# t = transforms.offset_copy(text._transform, y=ex.height, units='dots') | |
class COCOVisualizer: | |
def __init__(self, coco=None, tokenlizer=None) -> None: | |
self.coco = coco | |
def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"): | |
""" | |
img: tensor(3, H, W) | |
tgt: make sure they are all on cpu. | |
must have items: 'image_id', 'boxes', 'size' | |
""" | |
plt.figure(dpi=dpi) | |
plt.rcParams["font.size"] = "5" | |
ax = plt.gca() | |
img = renorm(img).permute(1, 2, 0) | |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
# import ipdb; ipdb.set_trace() | |
ax.imshow(img) | |
self.addtgt(tgt) | |
if tgt is None: | |
image_id = 0 | |
elif "image_id" not in tgt: | |
image_id = 0 | |
else: | |
image_id = tgt["image_id"] | |
if caption is None: | |
savename = "{}/{}-{}.png".format( | |
savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-") | |
) | |
else: | |
savename = "{}/{}-{}-{}.png".format( | |
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") | |
) | |
print("savename: {}".format(savename)) | |
os.makedirs(os.path.dirname(savename), exist_ok=True) | |
plt.savefig(savename) | |
plt.close() | |
def addtgt(self, tgt): | |
""" """ | |
if tgt is None or not "boxes" in tgt: | |
ax = plt.gca() | |
if "caption" in tgt: | |
ax.set_title(tgt["caption"], wrap=True) | |
ax.set_axis_off() | |
return | |
ax = plt.gca() | |
H, W = tgt["size"] | |
numbox = tgt["boxes"].shape[0] | |
color = [] | |
polygons = [] | |
boxes = [] | |
for box in tgt["boxes"].cpu(): | |
unnormbbox = box * torch.Tensor([W, H, W, H]) | |
unnormbbox[:2] -= unnormbbox[2:] / 2 | |
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() | |
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) | |
poly = [ | |
[bbox_x, bbox_y], | |
[bbox_x, bbox_y + bbox_h], | |
[bbox_x + bbox_w, bbox_y + bbox_h], | |
[bbox_x + bbox_w, bbox_y], | |
] | |
np_poly = np.array(poly).reshape((4, 2)) | |
polygons.append(Polygon(np_poly)) | |
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] | |
color.append(c) | |
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) | |
ax.add_collection(p) | |
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) | |
ax.add_collection(p) | |
if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0: | |
assert ( | |
len(tgt["strings_positive"]) == numbox | |
), f"{len(tgt['strings_positive'])} = {numbox}, " | |
for idx, strlist in enumerate(tgt["strings_positive"]): | |
cate_id = int(tgt["labels"][idx]) | |
_string = str(cate_id) + ":" + " ".join(strlist) | |
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] | |
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) | |
ax.text( | |
bbox_x, | |
bbox_y, | |
_string, | |
color="black", | |
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, | |
) | |
if "box_label" in tgt: | |
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " | |
for idx, bl in enumerate(tgt["box_label"]): | |
_string = str(bl) | |
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] | |
# ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) | |
ax.text( | |
bbox_x, | |
bbox_y, | |
_string, | |
color="black", | |
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, | |
) | |
if "caption" in tgt: | |
ax.set_title(tgt["caption"], wrap=True) | |
# plt.figure() | |
# rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(), | |
# ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black']) | |
if "attn" in tgt: | |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
# import ipdb; ipdb.set_trace() | |
if isinstance(tgt["attn"], tuple): | |
tgt["attn"] = [tgt["attn"]] | |
for item in tgt["attn"]: | |
attn_map, basergb = item | |
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) | |
attn_map = (attn_map * 255).astype(np.uint8) | |
cm = ColorMap(basergb) | |
heatmap = cm(attn_map) | |
ax.imshow(heatmap) | |
ax.set_axis_off() | |
def showAnns(self, anns, draw_bbox=False): | |
""" | |
Display the specified annotations. | |
:param anns (array of object): annotations to display | |
:return: None | |
""" | |
if len(anns) == 0: | |
return 0 | |
if "segmentation" in anns[0] or "keypoints" in anns[0]: | |
datasetType = "instances" | |
elif "caption" in anns[0]: | |
datasetType = "captions" | |
else: | |
raise Exception("datasetType not supported") | |
if datasetType == "instances": | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
polygons = [] | |
color = [] | |
for ann in anns: | |
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] | |
if "segmentation" in ann: | |
if type(ann["segmentation"]) == list: | |
# polygon | |
for seg in ann["segmentation"]: | |
poly = np.array(seg).reshape((int(len(seg) / 2), 2)) | |
polygons.append(Polygon(poly)) | |
color.append(c) | |
else: | |
# mask | |
t = self.imgs[ann["image_id"]] | |
if type(ann["segmentation"]["counts"]) == list: | |
rle = maskUtils.frPyObjects( | |
[ann["segmentation"]], t["height"], t["width"] | |
) | |
else: | |
rle = [ann["segmentation"]] | |
m = maskUtils.decode(rle) | |
img = np.ones((m.shape[0], m.shape[1], 3)) | |
if ann["iscrowd"] == 1: | |
color_mask = np.array([2.0, 166.0, 101.0]) / 255 | |
if ann["iscrowd"] == 0: | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
img[:, :, i] = color_mask[i] | |
ax.imshow(np.dstack((img, m * 0.5))) | |
if "keypoints" in ann and type(ann["keypoints"]) == list: | |
# turn skeleton into zero-based index | |
sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1 | |
kp = np.array(ann["keypoints"]) | |
x = kp[0::3] | |
y = kp[1::3] | |
v = kp[2::3] | |
for sk in sks: | |
if np.all(v[sk] > 0): | |
plt.plot(x[sk], y[sk], linewidth=3, color=c) | |
plt.plot( | |
x[v > 0], | |
y[v > 0], | |
"o", | |
markersize=8, | |
markerfacecolor=c, | |
markeredgecolor="k", | |
markeredgewidth=2, | |
) | |
plt.plot( | |
x[v > 1], | |
y[v > 1], | |
"o", | |
markersize=8, | |
markerfacecolor=c, | |
markeredgecolor=c, | |
markeredgewidth=2, | |
) | |
if draw_bbox: | |
[bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] | |
poly = [ | |
[bbox_x, bbox_y], | |
[bbox_x, bbox_y + bbox_h], | |
[bbox_x + bbox_w, bbox_y + bbox_h], | |
[bbox_x + bbox_w, bbox_y], | |
] | |
np_poly = np.array(poly).reshape((4, 2)) | |
polygons.append(Polygon(np_poly)) | |
color.append(c) | |
# p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) | |
# ax.add_collection(p) | |
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) | |
ax.add_collection(p) | |
elif datasetType == "captions": | |
for ann in anns: | |
print(ann["caption"]) | |