CLIP_as_RNN / data /preprocess.py
Kevin Sun
init commit
6cd90b7
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocess for referring datasets.
Adapted from
https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
"""
# pylint: disable=all
from refer.refer import REFER
from torch.utils import data
class ReferDataset(data.Dataset):
"""Refer dataset."""
def __init__(
self,
root,
dataset='refcoco',
splitBy='unc',
image_transforms=None,
target_transforms=None,
split='train',
eval_mode=False,
):
self.classes = []
self.image_transforms = image_transforms
self.target_transforms = target_transforms
self.split = split
self.refer = REFER(root, dataset=dataset, splitBy=splitBy)
ref_ids = self.refer.getRefIds(split=self.split)
img_ids = self.refer.getImgIds(ref_ids)
all_imgs = self.refer.Imgs
self.imgs = list(all_imgs[i] for i in img_ids)
self.ref_ids = ref_ids
print(len(ref_ids))
print(len(self.imgs))
# print(self.imgs)
self.sentence_raw = []
self.eval_mode = eval_mode
# if we are testing on a dataset, test all sentences of an object;
# o/w, we are validating during training, randomly sample one sentence for
# efficiency
for r in ref_ids:
ref = self.refer.Refs[r]
ref_sentences = []
for el, _ in zip(ref['sentences'], ref['sent_ids']):
sentence_raw = el['raw']
ref_sentences.append(sentence_raw)
self.sentence_raw.append(ref_sentences)
# print(len(self.sentence_raw))
def get_classes(self):
return self.classes
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
this_img_id = self.imgs[index]['id']
this_ref_ids = self.refer.getRefIds(this_img_id)
this_img = self.refer.Imgs[this_img_id]
refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids]
batch_sentences = {}
# batch_targets = {}
for ref in refs:
# Get sentence
sentence_lis = []
for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']):
sentence_raw = el['raw']
sentence_lis.append(sentence_raw)
batch_sentences.update({ref[0]['ref_id']: sentence_lis})
return [this_img['file_name']], batch_sentences
def get_ref(self):
name_lis = []
for i in range(len(self.ref_ids)):
rid = self.ref_ids[i]
# print(rid)
ref = self.refer.loadRefs(rid)
if ref[0]['file_name'] == '':
print(1)
# print(ref[0]['file_name'])
# if ref[0]['file_name'] in name_lis:
# print("md")
name_lis.append(ref[0]['file_name'])
print(ref[0]['file_name'])
# print(name_lis)
print(len(name_lis))
print(len(list(set(name_lis))))