|
__author__ = "licheng" |
|
|
|
""" |
|
This interface provides access to four datasets: |
|
1) refclef |
|
2) refcoco |
|
3) refcoco+ |
|
4) refcocog |
|
split by unc and google |
|
|
|
The following API functions are defined: |
|
REFER - REFER api class |
|
getRefIds - get ref ids that satisfy given filter conditions. |
|
getAnnIds - get ann ids that satisfy given filter conditions. |
|
getImgIds - get image ids that satisfy given filter conditions. |
|
getCatIds - get category ids that satisfy given filter conditions. |
|
loadRefs - load refs with the specified ref ids. |
|
loadAnns - load anns with the specified ann ids. |
|
loadImgs - load images with the specified image ids. |
|
loadCats - load category names with the specified category ids. |
|
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id |
|
showRef - show image, segmentation or box of the referred object with the ref |
|
getMask - get mask and area of the referred object given ref |
|
showMask - show mask of the referred object given ref |
|
""" |
|
|
|
import itertools |
|
import json |
|
import os.path as osp |
|
import pickle |
|
import sys |
|
import time |
|
from pprint import pprint |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import skimage.io as io |
|
from matplotlib.collections import PatchCollection |
|
from matplotlib.patches import Polygon, Rectangle |
|
from pycocotools import mask |
|
|
|
|
|
class REFER: |
|
def __init__(self, data_root, dataset="refcoco", splitBy="unc"): |
|
|
|
|
|
|
|
print("loading dataset %s into memory..." % dataset) |
|
self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) |
|
self.DATA_DIR = osp.join(data_root, dataset) |
|
if dataset in ["refcoco", "refcoco+", "refcocog"]: |
|
self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014") |
|
elif dataset == "refclef": |
|
self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12") |
|
else: |
|
print("No refer dataset is called [%s]" % dataset) |
|
sys.exit() |
|
|
|
self.dataset = dataset |
|
|
|
|
|
tic = time.time() |
|
|
|
ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p") |
|
print("ref_file: ", ref_file) |
|
self.data = {} |
|
self.data["dataset"] = dataset |
|
self.data["refs"] = pickle.load(open(ref_file, "rb")) |
|
|
|
|
|
instances_file = osp.join(self.DATA_DIR, "instances.json") |
|
instances = json.load(open(instances_file, "rb")) |
|
self.data["images"] = instances["images"] |
|
self.data["annotations"] = instances["annotations"] |
|
self.data["categories"] = instances["categories"] |
|
|
|
|
|
self.createIndex() |
|
print("DONE (t=%.2fs)" % (time.time() - tic)) |
|
|
|
def createIndex(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("creating index...") |
|
|
|
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} |
|
for ann in self.data["annotations"]: |
|
Anns[ann["id"]] = ann |
|
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] |
|
for img in self.data["images"]: |
|
Imgs[img["id"]] = img |
|
for cat in self.data["categories"]: |
|
Cats[cat["id"]] = cat["name"] |
|
|
|
|
|
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} |
|
Sents, sentToRef, sentToTokens = {}, {}, {} |
|
for ref in self.data["refs"]: |
|
|
|
ref_id = ref["ref_id"] |
|
ann_id = ref["ann_id"] |
|
category_id = ref["category_id"] |
|
image_id = ref["image_id"] |
|
|
|
|
|
Refs[ref_id] = ref |
|
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] |
|
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] |
|
refToAnn[ref_id] = Anns[ann_id] |
|
annToRef[ann_id] = ref |
|
|
|
|
|
for sent in ref["sentences"]: |
|
Sents[sent["sent_id"]] = sent |
|
sentToRef[sent["sent_id"]] = ref |
|
sentToTokens[sent["sent_id"]] = sent["tokens"] |
|
|
|
|
|
self.Refs = Refs |
|
self.Anns = Anns |
|
self.Imgs = Imgs |
|
self.Cats = Cats |
|
self.Sents = Sents |
|
self.imgToRefs = imgToRefs |
|
self.imgToAnns = imgToAnns |
|
self.refToAnn = refToAnn |
|
self.annToRef = annToRef |
|
self.catToRefs = catToRefs |
|
self.sentToRef = sentToRef |
|
self.sentToTokens = sentToTokens |
|
print("index created.") |
|
|
|
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""): |
|
image_ids = image_ids if type(image_ids) == list else [image_ids] |
|
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] |
|
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
|
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: |
|
refs = self.data["refs"] |
|
else: |
|
if not len(image_ids) == 0: |
|
refs = [self.imgToRefs[image_id] for image_id in image_ids] |
|
else: |
|
refs = self.data["refs"] |
|
if not len(cat_ids) == 0: |
|
refs = [ref for ref in refs if ref["category_id"] in cat_ids] |
|
if not len(ref_ids) == 0: |
|
refs = [ref for ref in refs if ref["ref_id"] in ref_ids] |
|
if not len(split) == 0: |
|
if split in ["testA", "testB", "testC"]: |
|
refs = [ |
|
ref for ref in refs if split[-1] in ref["split"] |
|
] |
|
elif split in ["testAB", "testBC", "testAC"]: |
|
refs = [ |
|
ref for ref in refs if ref["split"] == split |
|
] |
|
elif split == "test": |
|
refs = [ref for ref in refs if "test" in ref["split"]] |
|
elif split == "train" or split == "val": |
|
refs = [ref for ref in refs if ref["split"] == split] |
|
else: |
|
print("No such split [%s]" % split) |
|
sys.exit() |
|
ref_ids = [ref["ref_id"] for ref in refs] |
|
return ref_ids |
|
|
|
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): |
|
image_ids = image_ids if type(image_ids) == list else [image_ids] |
|
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] |
|
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
|
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: |
|
ann_ids = [ann["id"] for ann in self.data["annotations"]] |
|
else: |
|
if not len(image_ids) == 0: |
|
lists = [ |
|
self.imgToAnns[image_id] |
|
for image_id in image_ids |
|
if image_id in self.imgToAnns |
|
] |
|
anns = list(itertools.chain.from_iterable(lists)) |
|
else: |
|
anns = self.data["annotations"] |
|
if not len(cat_ids) == 0: |
|
anns = [ann for ann in anns if ann["category_id"] in cat_ids] |
|
ann_ids = [ann["id"] for ann in anns] |
|
if not len(ref_ids) == 0: |
|
ids = set(ann_ids).intersection( |
|
set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]) |
|
) |
|
return ann_ids |
|
|
|
def getImgIds(self, ref_ids=[]): |
|
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
|
if not len(ref_ids) == 0: |
|
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) |
|
else: |
|
image_ids = self.Imgs.keys() |
|
return image_ids |
|
|
|
def getCatIds(self): |
|
return self.Cats.keys() |
|
|
|
def loadRefs(self, ref_ids=[]): |
|
if type(ref_ids) == list: |
|
return [self.Refs[ref_id] for ref_id in ref_ids] |
|
elif type(ref_ids) == int: |
|
return [self.Refs[ref_ids]] |
|
|
|
def loadAnns(self, ann_ids=[]): |
|
if type(ann_ids) == list: |
|
return [self.Anns[ann_id] for ann_id in ann_ids] |
|
elif type(ann_ids) == int or type(ann_ids) == unicode: |
|
return [self.Anns[ann_ids]] |
|
|
|
def loadImgs(self, image_ids=[]): |
|
if type(image_ids) == list: |
|
return [self.Imgs[image_id] for image_id in image_ids] |
|
elif type(image_ids) == int: |
|
return [self.Imgs[image_ids]] |
|
|
|
def loadCats(self, cat_ids=[]): |
|
if type(cat_ids) == list: |
|
return [self.Cats[cat_id] for cat_id in cat_ids] |
|
elif type(cat_ids) == int: |
|
return [self.Cats[cat_ids]] |
|
|
|
def getRefBox(self, ref_id): |
|
ref = self.Refs[ref_id] |
|
ann = self.refToAnn[ref_id] |
|
return ann["bbox"] |
|
|
|
def showRef(self, ref, seg_box="seg"): |
|
ax = plt.gca() |
|
|
|
image = self.Imgs[ref["image_id"]] |
|
I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) |
|
ax.imshow(I) |
|
|
|
for sid, sent in enumerate(ref["sentences"]): |
|
print("%s. %s" % (sid + 1, sent["sent"])) |
|
|
|
if seg_box == "seg": |
|
ann_id = ref["ann_id"] |
|
ann = self.Anns[ann_id] |
|
polygons = [] |
|
color = [] |
|
c = "none" |
|
if type(ann["segmentation"][0]) == list: |
|
|
|
for seg in ann["segmentation"]: |
|
poly = np.array(seg).reshape((len(seg) / 2, 2)) |
|
polygons.append(Polygon(poly, True, alpha=0.4)) |
|
color.append(c) |
|
p = PatchCollection( |
|
polygons, |
|
facecolors=color, |
|
edgecolors=(1, 1, 0, 0), |
|
linewidths=3, |
|
alpha=1, |
|
) |
|
ax.add_collection(p) |
|
p = PatchCollection( |
|
polygons, |
|
facecolors=color, |
|
edgecolors=(1, 0, 0, 0), |
|
linewidths=1, |
|
alpha=1, |
|
) |
|
ax.add_collection(p) |
|
else: |
|
|
|
rle = ann["segmentation"] |
|
m = mask.decode(rle) |
|
img = np.ones((m.shape[0], m.shape[1], 3)) |
|
color_mask = np.array([2.0, 166.0, 101.0]) / 255 |
|
for i in range(3): |
|
img[:, :, i] = color_mask[i] |
|
ax.imshow(np.dstack((img, m * 0.5))) |
|
|
|
elif seg_box == "box": |
|
ann_id = ref["ann_id"] |
|
ann = self.Anns[ann_id] |
|
bbox = self.getRefBox(ref["ref_id"]) |
|
box_plot = Rectangle( |
|
(bbox[0], bbox[1]), |
|
bbox[2], |
|
bbox[3], |
|
fill=False, |
|
edgecolor="green", |
|
linewidth=3, |
|
) |
|
ax.add_patch(box_plot) |
|
|
|
def getMask(self, ref): |
|
|
|
ann = self.refToAnn[ref["ref_id"]] |
|
image = self.Imgs[ref["image_id"]] |
|
if type(ann["segmentation"][0]) == list: |
|
rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) |
|
else: |
|
rle = ann["segmentation"] |
|
m = mask.decode(rle) |
|
m = np.sum( |
|
m, axis=2 |
|
) |
|
m = m.astype(np.uint8) |
|
|
|
area = sum(mask.area(rle)) |
|
return {"mask": m, "area": area} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def showMask(self, ref): |
|
M = self.getMask(ref) |
|
msk = M["mask"] |
|
ax = plt.gca() |
|
ax.imshow(msk) |
|
|
|
|
|
if __name__ == "__main__": |
|
refer = REFER(dataset="refcocog", splitBy="google") |
|
ref_ids = refer.getRefIds() |
|
print(len(ref_ids)) |
|
|
|
print(len(refer.Imgs)) |
|
print(len(refer.imgToRefs)) |
|
|
|
ref_ids = refer.getRefIds(split="train") |
|
print("There are %s training referred objects." % len(ref_ids)) |
|
|
|
for ref_id in ref_ids: |
|
ref = refer.loadRefs(ref_id)[0] |
|
if len(ref["sentences"]) < 2: |
|
continue |
|
|
|
pprint(ref) |
|
print("The label is %s." % refer.Cats[ref["category_id"]]) |
|
plt.figure() |
|
refer.showRef(ref, seg_box="box") |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|