|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Preprocess for referring datasets. |
|
|
|
Adapted from |
|
https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py |
|
""" |
|
|
|
from refer.refer import REFER |
|
from torch.utils import data |
|
|
|
|
|
class ReferDataset(data.Dataset): |
|
"""Refer dataset.""" |
|
|
|
def __init__( |
|
self, |
|
root, |
|
dataset='refcoco', |
|
splitBy='unc', |
|
image_transforms=None, |
|
target_transforms=None, |
|
split='train', |
|
eval_mode=False, |
|
): |
|
|
|
self.classes = [] |
|
self.image_transforms = image_transforms |
|
self.target_transforms = target_transforms |
|
self.split = split |
|
self.refer = REFER(root, dataset=dataset, splitBy=splitBy) |
|
|
|
ref_ids = self.refer.getRefIds(split=self.split) |
|
img_ids = self.refer.getImgIds(ref_ids) |
|
|
|
all_imgs = self.refer.Imgs |
|
self.imgs = list(all_imgs[i] for i in img_ids) |
|
self.ref_ids = ref_ids |
|
print(len(ref_ids)) |
|
print(len(self.imgs)) |
|
|
|
self.sentence_raw = [] |
|
|
|
self.eval_mode = eval_mode |
|
|
|
|
|
|
|
for r in ref_ids: |
|
ref = self.refer.Refs[r] |
|
ref_sentences = [] |
|
for el, _ in zip(ref['sentences'], ref['sent_ids']): |
|
sentence_raw = el['raw'] |
|
ref_sentences.append(sentence_raw) |
|
|
|
self.sentence_raw.append(ref_sentences) |
|
|
|
|
|
def get_classes(self): |
|
return self.classes |
|
|
|
def __len__(self): |
|
return len(self.imgs) |
|
|
|
def __getitem__(self, index): |
|
this_img_id = self.imgs[index]['id'] |
|
this_ref_ids = self.refer.getRefIds(this_img_id) |
|
this_img = self.refer.Imgs[this_img_id] |
|
refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids] |
|
|
|
batch_sentences = {} |
|
|
|
for ref in refs: |
|
|
|
sentence_lis = [] |
|
for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']): |
|
sentence_raw = el['raw'] |
|
sentence_lis.append(sentence_raw) |
|
batch_sentences.update({ref[0]['ref_id']: sentence_lis}) |
|
|
|
return [this_img['file_name']], batch_sentences |
|
|
|
def get_ref(self): |
|
name_lis = [] |
|
for i in range(len(self.ref_ids)): |
|
rid = self.ref_ids[i] |
|
|
|
ref = self.refer.loadRefs(rid) |
|
if ref[0]['file_name'] == '': |
|
print(1) |
|
|
|
|
|
|
|
name_lis.append(ref[0]['file_name']) |
|
print(ref[0]['file_name']) |
|
|
|
print(len(name_lis)) |
|
print(len(list(set(name_lis)))) |
|
|