x-lai commited on
Commit
e53daa9
·
1 Parent(s): c39e06d

fix bug: add sample rate to control different sampling probability for each type of dataset

Browse files
Files changed (2) hide show
  1. train_ds.py +2 -0
  2. utils/dataset.py +4 -3
train_ds.py CHANGED
@@ -57,6 +57,7 @@ def parse_args(args):
57
  parser.add_argument(
58
  "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
59
  )
 
60
  parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
61
  parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
62
  parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
@@ -153,6 +154,7 @@ def main(args):
153
  num_classes_per_sample=args.num_classes_per_sample,
154
  exclude_val=args.exclude_val,
155
  dataset=args.dataset,
 
156
  sem_seg_data=args.sem_seg_data,
157
  refer_seg_data=args.refer_seg_data,
158
  vqa_data=args.vqa_data,
 
57
  parser.add_argument(
58
  "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
59
  )
60
+ parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
61
  parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
62
  parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
63
  parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
 
154
  num_classes_per_sample=args.num_classes_per_sample,
155
  exclude_val=args.exclude_val,
156
  dataset=args.dataset,
157
+ sample_rate=[float(x) for x in args.sample_rates.split(",")],
158
  sem_seg_data=args.sem_seg_data,
159
  refer_seg_data=args.refer_seg_data,
160
  vqa_data=args.vqa_data,
utils/dataset.py CHANGED
@@ -152,6 +152,7 @@ class HybridDataset(torch.utils.data.Dataset):
152
  num_classes_per_sample: int = 3,
153
  exclude_val=False,
154
  dataset="sem_seg||refer_seg||vqa||reason_seg",
 
155
  sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
156
  refer_seg_data="refclef||refcoco||refcoco+||refcocog",
157
  vqa_data="llava_instruct_150k",
@@ -163,6 +164,8 @@ class HybridDataset(torch.utils.data.Dataset):
163
  self.samples_per_epoch = samples_per_epoch
164
  self.explanatory = explanatory
165
  self.num_classes_per_sample = num_classes_per_sample
 
 
166
 
167
  self.base_image_dir = base_image_dir
168
  self.image_size = image_size
@@ -235,9 +238,7 @@ class HybridDataset(torch.utils.data.Dataset):
235
  return self.samples_per_epoch
236
 
237
  def __getitem__(self, idx):
238
- ind = (random.randint(0, 2023) * (idx + 1)) % len(
239
- self.datasets
240
- ) # random.randint(0, len(self.datasets)-1)
241
  data = self.all_datasets[ind]
242
  inference = False
243
  return *data[0], inference
 
152
  num_classes_per_sample: int = 3,
153
  exclude_val=False,
154
  dataset="sem_seg||refer_seg||vqa||reason_seg",
155
+ sample_rate=[9, 3, 3, 1],
156
  sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
157
  refer_seg_data="refclef||refcoco||refcoco+||refcocog",
158
  vqa_data="llava_instruct_150k",
 
164
  self.samples_per_epoch = samples_per_epoch
165
  self.explanatory = explanatory
166
  self.num_classes_per_sample = num_classes_per_sample
167
+ sample_rate = np.array(sample_rate)
168
+ self.sample_rate = sample_rate / sample_rate.sum()
169
 
170
  self.base_image_dir = base_image_dir
171
  self.image_size = image_size
 
238
  return self.samples_per_epoch
239
 
240
  def __getitem__(self, idx):
241
+ ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
 
 
242
  data = self.all_datasets[ind]
243
  inference = False
244
  return *data[0], inference