File size: 3,271 Bytes
6cd90b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Preprocess for referring datasets.

Adapted from
https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
"""
# pylint: disable=all
from refer.refer import REFER
from torch.utils import data


class ReferDataset(data.Dataset):
  """Refer dataset."""

  def __init__(
      self,
      root,
      dataset='refcoco',
      splitBy='unc',
      image_transforms=None,
      target_transforms=None,
      split='train',
      eval_mode=False,
  ):

    self.classes = []
    self.image_transforms = image_transforms
    self.target_transforms = target_transforms
    self.split = split
    self.refer = REFER(root, dataset=dataset, splitBy=splitBy)

    ref_ids = self.refer.getRefIds(split=self.split)
    img_ids = self.refer.getImgIds(ref_ids)

    all_imgs = self.refer.Imgs
    self.imgs = list(all_imgs[i] for i in img_ids)
    self.ref_ids = ref_ids
    print(len(ref_ids))
    print(len(self.imgs))
    # print(self.imgs)
    self.sentence_raw = []

    self.eval_mode = eval_mode
    # if we are testing on a dataset, test all sentences of an object;
    # o/w, we are validating during training, randomly sample one sentence for
    # efficiency
    for r in ref_ids:
      ref = self.refer.Refs[r]
      ref_sentences = []
      for el, _ in zip(ref['sentences'], ref['sent_ids']):
        sentence_raw = el['raw']
        ref_sentences.append(sentence_raw)

      self.sentence_raw.append(ref_sentences)
    # print(len(self.sentence_raw))

  def get_classes(self):
    return self.classes

  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, index):
    this_img_id = self.imgs[index]['id']
    this_ref_ids = self.refer.getRefIds(this_img_id)
    this_img = self.refer.Imgs[this_img_id]
    refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids]

    batch_sentences = {}
    # batch_targets = {}
    for ref in refs:
      # Get sentence
      sentence_lis = []
      for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']):
        sentence_raw = el['raw']
        sentence_lis.append(sentence_raw)
      batch_sentences.update({ref[0]['ref_id']: sentence_lis})

    return [this_img['file_name']], batch_sentences

  def get_ref(self):
    name_lis = []
    for i in range(len(self.ref_ids)):
      rid = self.ref_ids[i]
      # print(rid)
      ref = self.refer.loadRefs(rid)
      if ref[0]['file_name'] == '':
        print(1)
      # print(ref[0]['file_name'])
      # if ref[0]['file_name'] in name_lis:
      #     print("md")
      name_lis.append(ref[0]['file_name'])
      print(ref[0]['file_name'])
    # print(name_lis)
    print(len(name_lis))
    print(len(list(set(name_lis))))