Spaces:
Runtime error
Runtime error
""" | |
Dataset object for Panoptic Narrative Grounding. | |
Paper: https://openaccess.thecvf.com/content/ICCV2021/papers/Gonzalez_Panoptic_Narrative_Grounding_ICCV_2021_paper.pdf | |
""" | |
import os | |
from os.path import join, isdir, exists | |
import torch | |
from torch.utils.data import Dataset | |
import cv2 | |
from PIL import Image | |
from skimage import io | |
import numpy as np | |
import textwrap | |
import matplotlib.pyplot as plt | |
from matplotlib import transforms | |
from imgaug.augmentables.segmaps import SegmentationMapsOnImage | |
import matplotlib.colors as mc | |
from clip_grounding.utils.io import load_json | |
from clip_grounding.datasets.png_utils import show_image_and_caption | |
class PNG(Dataset): | |
"""Panoptic Narrative Grounding.""" | |
def __init__(self, dataset_root, split) -> None: | |
""" | |
Initializer. | |
Args: | |
dataset_root (str): path to the folder containing PNG dataset | |
split (str): MS-COCO split such as train2017/val2017 | |
""" | |
super().__init__() | |
assert isdir(dataset_root) | |
self.dataset_root = dataset_root | |
assert split in ["val2017"], f"Split {split} not supported. "\ | |
"Currently, only supports split `val2017`." | |
self.split = split | |
self.ann_dir = join(self.dataset_root, "annotations") | |
# feat_dir = join(self.dataset_root, "features") | |
panoptic = load_json(join(self.ann_dir, "panoptic_{:s}.json".format(split))) | |
images = panoptic["images"] | |
self.images_info = {i["id"]: i for i in images} | |
panoptic_anns = panoptic["annotations"] | |
self.panoptic_anns = {int(a["image_id"]): a for a in panoptic_anns} | |
# self.panoptic_pred_path = join( | |
# feat_dir, split, "panoptic_seg_predictions" | |
# ) | |
# assert isdir(self.panoptic_pred_path) | |
panoptic_narratives_path = join(self.dataset_root, "annotations", f"png_coco_{split}.json") | |
self.panoptic_narratives = load_json(panoptic_narratives_path) | |
def __len__(self): | |
return len(self.panoptic_narratives) | |
def get_image_path(self, image_id: str): | |
image_path = join(self.dataset_root, "images", self.split, f"{image_id.zfill(12)}.jpg") | |
return image_path | |
def __getitem__(self, idx: int): | |
narr = self.panoptic_narratives[idx] | |
image_id = narr["image_id"] | |
image_path = self.get_image_path(image_id) | |
assert exists(image_path) | |
image = Image.open(image_path) | |
caption = narr["caption"] | |
# show_single_image(image, title=caption, titlesize=12) | |
segments = narr["segments"] | |
image_id = int(narr["image_id"]) | |
panoptic_ann = self.panoptic_anns[image_id] | |
panoptic_ann = self.panoptic_anns[image_id] | |
segment_infos = {} | |
for s in panoptic_ann["segments_info"]: | |
idi = s["id"] | |
segment_infos[idi] = s | |
image_info = self.images_info[image_id] | |
panoptic_segm = io.imread( | |
join( | |
self.ann_dir, | |
"panoptic_segmentation", | |
self.split, | |
"{:012d}.png".format(image_id), | |
) | |
) | |
panoptic_segm = ( | |
panoptic_segm[:, :, 0] | |
+ panoptic_segm[:, :, 1] * 256 | |
+ panoptic_segm[:, :, 2] * 256 ** 2 | |
) | |
panoptic_ann = self.panoptic_anns[image_id] | |
# panoptic_pred = io.imread( | |
# join(self.panoptic_pred_path, "{:012d}.png".format(image_id)) | |
# )[:, :, 0] | |
# # select a single utterance to visualize | |
# segment = segments[7] | |
# segment_ids = segment["segment_ids"] | |
# segment_mask = np.zeros((image_info["height"], image_info["width"])) | |
# for segment_id in segment_ids: | |
# segment_id = int(segment_id) | |
# segment_mask[panoptic_segm == segment_id] = 1. | |
utterances = [s["utterance"] for s in segments] | |
outputs = [] | |
for i, segment in enumerate(segments): | |
# create segmentation mask on image | |
segment_ids = segment["segment_ids"] | |
# if no annotation for this word, skip | |
if not len(segment_ids): | |
continue | |
segment_mask = np.zeros((image_info["height"], image_info["width"])) | |
for segment_id in segment_ids: | |
segment_id = int(segment_id) | |
segment_mask[panoptic_segm == segment_id] = 1. | |
# store the outputs | |
text_mask = np.zeros(len(utterances)) | |
text_mask[i] = 1. | |
segment_data = dict( | |
image=image, | |
text=utterances, | |
image_mask=segment_mask, | |
text_mask=text_mask, | |
full_caption=caption, | |
) | |
outputs.append(segment_data) | |
# # visualize segmentation mask with associated text | |
# segment_color = "red" | |
# segmap = SegmentationMapsOnImage( | |
# segment_mask.astype(np.uint8), shape=segment_mask.shape, | |
# ) | |
# image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, COLORS[segment_color]])[0] | |
# image_with_segmap = Image.fromarray(image_with_segmap) | |
# colors = ["black" for _ in range(len(utterances))] | |
# colors[i] = segment_color | |
# show_image_and_caption(image_with_segmap, utterances, colors) | |
return outputs | |
def overlay_segmask_on_image(image, image_mask, segment_color="red"): | |
segmap = SegmentationMapsOnImage( | |
image_mask.astype(np.uint8), shape=image_mask.shape, | |
) | |
rgb_color = mc.to_rgb(segment_color) | |
rgb_color = 255 * np.array(rgb_color) | |
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0] | |
image_with_segmap = Image.fromarray(image_with_segmap) | |
return image_with_segmap | |
def get_text_colors(text, text_mask, segment_color="red"): | |
colors = ["black" for _ in range(len(text))] | |
colors[text_mask.nonzero()[0][0]] = segment_color | |
return colors | |
def overlay_relevance_map_on_image(image, heatmap): | |
width, height = image.size | |
# resize the heatmap to image size | |
heatmap = cv2.resize(heatmap, (width, height)) | |
heatmap = np.uint8(255 * heatmap) | |
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
# create overlapped super image | |
img = np.asarray(image) | |
super_img = heatmap * 0.4 + img * 0.6 | |
super_img = np.uint8(super_img) | |
super_img = Image.fromarray(super_img) | |
return super_img | |
def visualize_item(image, text, image_mask, text_mask, segment_color="red"): | |
segmap = SegmentationMapsOnImage( | |
image_mask.astype(np.uint8), shape=image_mask.shape, | |
) | |
rgb_color = mc.to_rgb(segment_color) | |
rgb_color = 255 * np.array(rgb_color) | |
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0] | |
image_with_segmap = Image.fromarray(image_with_segmap) | |
colors = ["black" for _ in range(len(text))] | |
text_idx = text_mask.argmax() | |
colors[text_idx] = segment_color | |
show_image_and_caption(image_with_segmap, text, colors) | |
if __name__ == "__main__": | |
from clip_grounding.utils.paths import REPO_PATH, DATASET_ROOTS | |
PNG_ROOT = DATASET_ROOTS["PNG"] | |
dataset = PNG(dataset_root=PNG_ROOT, split="val2017") | |
item = dataset[0] | |
sub_item = item[1] | |
visualize_item( | |
image=sub_item["image"], | |
text=sub_item["text"], | |
image_mask=sub_item["image_mask"], | |
text_mask=sub_item["text_mask"], | |
segment_color="red", | |
) | |