x-lai commited on
Commit
620ddd7
·
1 Parent(s): 3dd44d9

Release training script

Browse files

Former-commit-id: 4fc97979a3cbc5e07342bc87370a566bbf0d9855

utils/reason_seg_dataset.py CHANGED
@@ -59,10 +59,9 @@ class ReasonSegDataset(torch.utils.data.Dataset):
59
  self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
60
 
61
  if explanatory != -1:
62
- self.img_to_why = {}
63
  for sub_data in [
64
- "20230711_2000_0_processed_masked_finished_masked.json",
65
- "20230711_2000_0_processed_masked_partial_masked.json",
66
  ]:
67
  with open(
68
  os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
@@ -70,7 +69,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
70
  items = json.load(f)
71
  for item in items:
72
  img_name = item["image_path"].split("/")[-1]
73
- self.img_to_why[img_name] = {
74
  "query": item["query"],
75
  "outputs": item["outputs"],
76
  }
@@ -136,8 +135,8 @@ class ReasonSegDataset(torch.utils.data.Dataset):
136
 
137
  image_name = image_path.split("/")[-1]
138
  if (
139
- self.explanatory != -1 and image_name in self.img_to_why
140
- ): # ds in ['20230711_2000_0_processed_masked_partial_masked', '20230711_2000_0_processed_masked_finished_masked', 'trainval_rephrased_20230730_checked_final_masked', 'rephrased_20230730_checked_final_masked']:
141
  if random.random() < self.explanatory:
142
  choice = 2
143
  else:
@@ -145,7 +144,6 @@ class ReasonSegDataset(torch.utils.data.Dataset):
145
 
146
  questions = []
147
  answers = []
148
- class_ids = []
149
  for text in sampled_sents:
150
  if is_sentence:
151
  question_template = random.choice(self.long_question_list)
@@ -155,13 +153,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
155
  questions.append(question_template.format(class_name=text.lower()))
156
 
157
  img_name = image_path.split("/")[-1]
158
- if self.explanatory != -1 and img_name in self.img_to_why:
159
  # choice = random.randint(0, 2)
160
  if choice == 0: # [SEG] token
161
  answers.append(random.choice(self.answer_list))
162
  elif choice == 1: # [SEG] token + text answer
163
  image_name = image_path.split("/")[-1]
164
- answer = self.img_to_why[image_name]["outputs"]
165
  answer = random.choice(self.answer_list) + " {}".format(answer)
166
  questions[-1] = (
167
  DEFAULT_IMAGE_TOKEN
@@ -172,7 +170,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
172
  answers.append(answer)
173
  elif choice == 2: # vanilla text answer
174
  image_name = image_path.split("/")[-1]
175
- answer = self.img_to_why[image_name]["outputs"]
176
  questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text
177
  answers.append(answer)
178
  else:
@@ -192,7 +190,6 @@ class ReasonSegDataset(torch.utils.data.Dataset):
192
  conversations.append(conv.get_prompt())
193
  i += 1
194
 
195
- # ==============================
196
  # replace <image> token
197
  for i in range(len(conversations)):
198
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
@@ -202,38 +199,18 @@ class ReasonSegDataset(torch.utils.data.Dataset):
202
  conversations[i] = conversations[i].replace(
203
  DEFAULT_IMAGE_TOKEN, replace_token
204
  )
205
- # ==============================
206
 
207
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
208
 
209
  image_name = image_path.split("/")[-1]
210
- if self.explanatory != -1 and image_name in self.img_to_why and choice == 2:
211
- # print("e1")
212
-
213
  masks = torch.rand(0, *ori_size)
214
  label = torch.ones(ori_size) * self.ignore_label
215
  else:
216
- # print("e2")
217
-
218
  masks = np.stack(sampled_masks, axis=0)
219
  masks = torch.from_numpy(masks)
220
  label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
221
 
222
- # print("reason_seg: {}".format(conversations))
223
-
224
- # # debug
225
- # if masks.shape[0] != 0:
226
- # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
227
- # os.makedirs(save_dir, exist_ok=True)
228
- # print("masks.shape: ", masks.shape)
229
- # for i in range(masks.shape[0]):
230
- # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i].numpy().astype(np.uint8)*100)
231
- # assert len(conversations) == masks.shape[0]
232
- # with open("{}/conversations.txt".format(save_dir), "w+") as f:
233
- # for i in range(len(conversations)):
234
- # f.write("{}. ".format(i) + conversations[i] + "\n")
235
- # shutil.copy(image_path, save_dir)
236
-
237
  return (
238
  image_path,
239
  images,
 
59
  self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
60
 
61
  if explanatory != -1:
62
+ self.img_to_explanation = {}
63
  for sub_data in [
64
+ "train.json",
 
65
  ]:
66
  with open(
67
  os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
 
69
  items = json.load(f)
70
  for item in items:
71
  img_name = item["image_path"].split("/")[-1]
72
+ self.img_to_explanation[img_name] = {
73
  "query": item["query"],
74
  "outputs": item["outputs"],
75
  }
 
135
 
136
  image_name = image_path.split("/")[-1]
137
  if (
138
+ self.explanatory != -1 and image_name in self.img_to_explanation
139
+ ):
140
  if random.random() < self.explanatory:
141
  choice = 2
142
  else:
 
144
 
145
  questions = []
146
  answers = []
 
147
  for text in sampled_sents:
148
  if is_sentence:
149
  question_template = random.choice(self.long_question_list)
 
153
  questions.append(question_template.format(class_name=text.lower()))
154
 
155
  img_name = image_path.split("/")[-1]
156
+ if self.explanatory != -1 and img_name in self.img_to_explanation:
157
  # choice = random.randint(0, 2)
158
  if choice == 0: # [SEG] token
159
  answers.append(random.choice(self.answer_list))
160
  elif choice == 1: # [SEG] token + text answer
161
  image_name = image_path.split("/")[-1]
162
+ answer = self.img_to_explanation[image_name]["outputs"]
163
  answer = random.choice(self.answer_list) + " {}".format(answer)
164
  questions[-1] = (
165
  DEFAULT_IMAGE_TOKEN
 
170
  answers.append(answer)
171
  elif choice == 2: # vanilla text answer
172
  image_name = image_path.split("/")[-1]
173
+ answer = self.img_to_explanation[image_name]["outputs"]
174
  questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text
175
  answers.append(answer)
176
  else:
 
190
  conversations.append(conv.get_prompt())
191
  i += 1
192
 
 
193
  # replace <image> token
194
  for i in range(len(conversations)):
195
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
 
199
  conversations[i] = conversations[i].replace(
200
  DEFAULT_IMAGE_TOKEN, replace_token
201
  )
 
202
 
203
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
204
 
205
  image_name = image_path.split("/")[-1]
206
+ if self.explanatory != -1 and image_name in self.img_to_explanation and choice == 2:
 
 
207
  masks = torch.rand(0, *ori_size)
208
  label = torch.ones(ori_size) * self.ignore_label
209
  else:
 
 
210
  masks = np.stack(sampled_masks, axis=0)
211
  masks = torch.from_numpy(masks)
212
  label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  return (
215
  image_path,
216
  images,
utils/refer_seg_dataset.py CHANGED
@@ -63,7 +63,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
63
  ref_ids_train = refer_api.getRefIds(split="train")
64
  images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
65
  refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
66
- ref_file = os.path.join(DATA_DIR, ds, "refs(" + splitBy + ").p")
67
 
68
  refer_seg_ds = {}
69
  refer_seg_ds["images"] = []
@@ -149,7 +148,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
149
  sampled_classes = sampled_sents
150
  img = cv2.imread(image_path)
151
  images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
152
- ori_size = images.shape[:2]
153
 
154
  # preprocess images for clip
155
  images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
@@ -163,7 +161,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
163
 
164
  questions = []
165
  answers = []
166
- class_ids = []
167
  for text in sampled_classes:
168
  text = text.strip()
169
  assert len(text.split("||")) == 1
@@ -183,7 +180,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
183
  conversations.append(conv.get_prompt())
184
  i += 1
185
 
186
- # ==============================
187
  # replace <image> token
188
  for i in range(len(conversations)):
189
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
@@ -193,7 +189,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
193
  conversations[i] = conversations[i].replace(
194
  DEFAULT_IMAGE_TOKEN, replace_token
195
  )
196
- # ==============================
197
 
198
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
199
 
@@ -223,42 +218,9 @@ class ReferSegDataset(torch.utils.data.Dataset):
223
  masks.append(m)
224
 
225
  masks = np.stack(masks, axis=0)
226
-
227
- # debug
228
- # print("masks.shape: ", masks.shape)
229
- # for i in range(masks.shape[0]):
230
- # cv2.imwrite("debug/{}_mask_{}.png".format(image_path.split("refer_seg/images")[-1].replace("/", "-").split(".")[0], sampled_sents[i]), masks[i]*100)
231
-
232
- # debug
233
- # if ds.endswith("masked"):
234
- # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
235
- # os.makedirs(save_dir, exist_ok=True)
236
- # print("masks.shape: ", masks.shape)
237
- # for i in range(masks.shape[0]):
238
- # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i]*100)
239
- # assert len(conversations) == masks.shape[0]
240
- # with open("{}/conversations.txt".format(save_dir), "w+") as f:
241
- # for i in range(len(conversations)):
242
- # f.write("{}. ".format(i) + conversations[i] + "\n")
243
- # shutil.copy(image_path, save_dir)
244
-
245
  masks = torch.from_numpy(masks)
246
  label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
247
 
248
- # print("refer_seg: {}".format(conversations))
249
-
250
- # # debug
251
- # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
252
- # os.makedirs(save_dir, exist_ok=True)
253
- # print("masks.shape: ", masks.shape)
254
- # for i in range(masks.shape[0]):
255
- # cv2.imwrite("{}/mask_{}_{}.jpg".format(save_dir, i, sampled_classes[i]), masks[i].numpy().astype(np.uint8)*100)
256
- # assert len(conversations) == masks.shape[0]
257
- # with open("{}/conversations.txt".format(save_dir), "w+") as f:
258
- # for i in range(len(conversations)):
259
- # f.write("{}. ".format(i) + conversations[i] + "\n")
260
- # shutil.copy(image_path, save_dir)
261
-
262
  return (
263
  image_path,
264
  images,
 
63
  ref_ids_train = refer_api.getRefIds(split="train")
64
  images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
65
  refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
 
66
 
67
  refer_seg_ds = {}
68
  refer_seg_ds["images"] = []
 
148
  sampled_classes = sampled_sents
149
  img = cv2.imread(image_path)
150
  images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
151
 
152
  # preprocess images for clip
153
  images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
 
161
 
162
  questions = []
163
  answers = []
 
164
  for text in sampled_classes:
165
  text = text.strip()
166
  assert len(text.split("||")) == 1
 
180
  conversations.append(conv.get_prompt())
181
  i += 1
182
 
 
183
  # replace <image> token
184
  for i in range(len(conversations)):
185
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
 
189
  conversations[i] = conversations[i].replace(
190
  DEFAULT_IMAGE_TOKEN, replace_token
191
  )
 
192
 
193
  images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
194
 
 
218
  masks.append(m)
219
 
220
  masks = np.stack(masks, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  masks = torch.from_numpy(masks)
222
  label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  return (
225
  image_path,
226
  images,