Spaces:
Paused
Paused
x-lai
commited on
Commit
·
e53daa9
1
Parent(s):
c39e06d
fix bug: add sample rate to control different sampling probability for each type of dataset
Browse files- train_ds.py +2 -0
- utils/dataset.py +4 -3
train_ds.py
CHANGED
@@ -57,6 +57,7 @@ def parse_args(args):
|
|
57 |
parser.add_argument(
|
58 |
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
|
59 |
)
|
|
|
60 |
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
61 |
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
62 |
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
|
@@ -153,6 +154,7 @@ def main(args):
|
|
153 |
num_classes_per_sample=args.num_classes_per_sample,
|
154 |
exclude_val=args.exclude_val,
|
155 |
dataset=args.dataset,
|
|
|
156 |
sem_seg_data=args.sem_seg_data,
|
157 |
refer_seg_data=args.refer_seg_data,
|
158 |
vqa_data=args.vqa_data,
|
|
|
57 |
parser.add_argument(
|
58 |
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
|
59 |
)
|
60 |
+
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
|
61 |
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
62 |
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
63 |
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
|
|
|
154 |
num_classes_per_sample=args.num_classes_per_sample,
|
155 |
exclude_val=args.exclude_val,
|
156 |
dataset=args.dataset,
|
157 |
+
sample_rate=[float(x) for x in args.sample_rates.split(",")],
|
158 |
sem_seg_data=args.sem_seg_data,
|
159 |
refer_seg_data=args.refer_seg_data,
|
160 |
vqa_data=args.vqa_data,
|
utils/dataset.py
CHANGED
@@ -152,6 +152,7 @@ class HybridDataset(torch.utils.data.Dataset):
|
|
152 |
num_classes_per_sample: int = 3,
|
153 |
exclude_val=False,
|
154 |
dataset="sem_seg||refer_seg||vqa||reason_seg",
|
|
|
155 |
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
|
156 |
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
|
157 |
vqa_data="llava_instruct_150k",
|
@@ -163,6 +164,8 @@ class HybridDataset(torch.utils.data.Dataset):
|
|
163 |
self.samples_per_epoch = samples_per_epoch
|
164 |
self.explanatory = explanatory
|
165 |
self.num_classes_per_sample = num_classes_per_sample
|
|
|
|
|
166 |
|
167 |
self.base_image_dir = base_image_dir
|
168 |
self.image_size = image_size
|
@@ -235,9 +238,7 @@ class HybridDataset(torch.utils.data.Dataset):
|
|
235 |
return self.samples_per_epoch
|
236 |
|
237 |
def __getitem__(self, idx):
|
238 |
-
ind =
|
239 |
-
self.datasets
|
240 |
-
) # random.randint(0, len(self.datasets)-1)
|
241 |
data = self.all_datasets[ind]
|
242 |
inference = False
|
243 |
return *data[0], inference
|
|
|
152 |
num_classes_per_sample: int = 3,
|
153 |
exclude_val=False,
|
154 |
dataset="sem_seg||refer_seg||vqa||reason_seg",
|
155 |
+
sample_rate=[9, 3, 3, 1],
|
156 |
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
|
157 |
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
|
158 |
vqa_data="llava_instruct_150k",
|
|
|
164 |
self.samples_per_epoch = samples_per_epoch
|
165 |
self.explanatory = explanatory
|
166 |
self.num_classes_per_sample = num_classes_per_sample
|
167 |
+
sample_rate = np.array(sample_rate)
|
168 |
+
self.sample_rate = sample_rate / sample_rate.sum()
|
169 |
|
170 |
self.base_image_dir = base_image_dir
|
171 |
self.image_size = image_size
|
|
|
238 |
return self.samples_per_epoch
|
239 |
|
240 |
def __getitem__(self, idx):
|
241 |
+
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
|
|
|
|
|
242 |
data = self.all_datasets[ind]
|
243 |
inference = False
|
244 |
return *data[0], inference
|