Spaces:
Paused
Paused
x-lai
commited on
Commit
·
620ddd7
1
Parent(s):
3dd44d9
Release training script
Browse filesFormer-commit-id: 4fc97979a3cbc5e07342bc87370a566bbf0d9855
- utils/reason_seg_dataset.py +9 -32
- utils/refer_seg_dataset.py +0 -38
utils/reason_seg_dataset.py
CHANGED
@@ -59,10 +59,9 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
59 |
self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
|
60 |
|
61 |
if explanatory != -1:
|
62 |
-
self.
|
63 |
for sub_data in [
|
64 |
-
"
|
65 |
-
"20230711_2000_0_processed_masked_partial_masked.json",
|
66 |
]:
|
67 |
with open(
|
68 |
os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
|
@@ -70,7 +69,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
70 |
items = json.load(f)
|
71 |
for item in items:
|
72 |
img_name = item["image_path"].split("/")[-1]
|
73 |
-
self.
|
74 |
"query": item["query"],
|
75 |
"outputs": item["outputs"],
|
76 |
}
|
@@ -136,8 +135,8 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
136 |
|
137 |
image_name = image_path.split("/")[-1]
|
138 |
if (
|
139 |
-
self.explanatory != -1 and image_name in self.
|
140 |
-
):
|
141 |
if random.random() < self.explanatory:
|
142 |
choice = 2
|
143 |
else:
|
@@ -145,7 +144,6 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
145 |
|
146 |
questions = []
|
147 |
answers = []
|
148 |
-
class_ids = []
|
149 |
for text in sampled_sents:
|
150 |
if is_sentence:
|
151 |
question_template = random.choice(self.long_question_list)
|
@@ -155,13 +153,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
155 |
questions.append(question_template.format(class_name=text.lower()))
|
156 |
|
157 |
img_name = image_path.split("/")[-1]
|
158 |
-
if self.explanatory != -1 and img_name in self.
|
159 |
# choice = random.randint(0, 2)
|
160 |
if choice == 0: # [SEG] token
|
161 |
answers.append(random.choice(self.answer_list))
|
162 |
elif choice == 1: # [SEG] token + text answer
|
163 |
image_name = image_path.split("/")[-1]
|
164 |
-
answer = self.
|
165 |
answer = random.choice(self.answer_list) + " {}".format(answer)
|
166 |
questions[-1] = (
|
167 |
DEFAULT_IMAGE_TOKEN
|
@@ -172,7 +170,7 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
172 |
answers.append(answer)
|
173 |
elif choice == 2: # vanilla text answer
|
174 |
image_name = image_path.split("/")[-1]
|
175 |
-
answer = self.
|
176 |
questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text
|
177 |
answers.append(answer)
|
178 |
else:
|
@@ -192,7 +190,6 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
192 |
conversations.append(conv.get_prompt())
|
193 |
i += 1
|
194 |
|
195 |
-
# ==============================
|
196 |
# replace <image> token
|
197 |
for i in range(len(conversations)):
|
198 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
@@ -202,38 +199,18 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
202 |
conversations[i] = conversations[i].replace(
|
203 |
DEFAULT_IMAGE_TOKEN, replace_token
|
204 |
)
|
205 |
-
# ==============================
|
206 |
|
207 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
208 |
|
209 |
image_name = image_path.split("/")[-1]
|
210 |
-
if self.explanatory != -1 and image_name in self.
|
211 |
-
# print("e1")
|
212 |
-
|
213 |
masks = torch.rand(0, *ori_size)
|
214 |
label = torch.ones(ori_size) * self.ignore_label
|
215 |
else:
|
216 |
-
# print("e2")
|
217 |
-
|
218 |
masks = np.stack(sampled_masks, axis=0)
|
219 |
masks = torch.from_numpy(masks)
|
220 |
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
221 |
|
222 |
-
# print("reason_seg: {}".format(conversations))
|
223 |
-
|
224 |
-
# # debug
|
225 |
-
# if masks.shape[0] != 0:
|
226 |
-
# save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
|
227 |
-
# os.makedirs(save_dir, exist_ok=True)
|
228 |
-
# print("masks.shape: ", masks.shape)
|
229 |
-
# for i in range(masks.shape[0]):
|
230 |
-
# cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i].numpy().astype(np.uint8)*100)
|
231 |
-
# assert len(conversations) == masks.shape[0]
|
232 |
-
# with open("{}/conversations.txt".format(save_dir), "w+") as f:
|
233 |
-
# for i in range(len(conversations)):
|
234 |
-
# f.write("{}. ".format(i) + conversations[i] + "\n")
|
235 |
-
# shutil.copy(image_path, save_dir)
|
236 |
-
|
237 |
return (
|
238 |
image_path,
|
239 |
images,
|
|
|
59 |
self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
|
60 |
|
61 |
if explanatory != -1:
|
62 |
+
self.img_to_explanation = {}
|
63 |
for sub_data in [
|
64 |
+
"train.json",
|
|
|
65 |
]:
|
66 |
with open(
|
67 |
os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
|
|
|
69 |
items = json.load(f)
|
70 |
for item in items:
|
71 |
img_name = item["image_path"].split("/")[-1]
|
72 |
+
self.img_to_explanation[img_name] = {
|
73 |
"query": item["query"],
|
74 |
"outputs": item["outputs"],
|
75 |
}
|
|
|
135 |
|
136 |
image_name = image_path.split("/")[-1]
|
137 |
if (
|
138 |
+
self.explanatory != -1 and image_name in self.img_to_explanation
|
139 |
+
):
|
140 |
if random.random() < self.explanatory:
|
141 |
choice = 2
|
142 |
else:
|
|
|
144 |
|
145 |
questions = []
|
146 |
answers = []
|
|
|
147 |
for text in sampled_sents:
|
148 |
if is_sentence:
|
149 |
question_template = random.choice(self.long_question_list)
|
|
|
153 |
questions.append(question_template.format(class_name=text.lower()))
|
154 |
|
155 |
img_name = image_path.split("/")[-1]
|
156 |
+
if self.explanatory != -1 and img_name in self.img_to_explanation:
|
157 |
# choice = random.randint(0, 2)
|
158 |
if choice == 0: # [SEG] token
|
159 |
answers.append(random.choice(self.answer_list))
|
160 |
elif choice == 1: # [SEG] token + text answer
|
161 |
image_name = image_path.split("/")[-1]
|
162 |
+
answer = self.img_to_explanation[image_name]["outputs"]
|
163 |
answer = random.choice(self.answer_list) + " {}".format(answer)
|
164 |
questions[-1] = (
|
165 |
DEFAULT_IMAGE_TOKEN
|
|
|
170 |
answers.append(answer)
|
171 |
elif choice == 2: # vanilla text answer
|
172 |
image_name = image_path.split("/")[-1]
|
173 |
+
answer = self.img_to_explanation[image_name]["outputs"]
|
174 |
questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text
|
175 |
answers.append(answer)
|
176 |
else:
|
|
|
190 |
conversations.append(conv.get_prompt())
|
191 |
i += 1
|
192 |
|
|
|
193 |
# replace <image> token
|
194 |
for i in range(len(conversations)):
|
195 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
|
|
199 |
conversations[i] = conversations[i].replace(
|
200 |
DEFAULT_IMAGE_TOKEN, replace_token
|
201 |
)
|
|
|
202 |
|
203 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
204 |
|
205 |
image_name = image_path.split("/")[-1]
|
206 |
+
if self.explanatory != -1 and image_name in self.img_to_explanation and choice == 2:
|
|
|
|
|
207 |
masks = torch.rand(0, *ori_size)
|
208 |
label = torch.ones(ori_size) * self.ignore_label
|
209 |
else:
|
|
|
|
|
210 |
masks = np.stack(sampled_masks, axis=0)
|
211 |
masks = torch.from_numpy(masks)
|
212 |
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
return (
|
215 |
image_path,
|
216 |
images,
|
utils/refer_seg_dataset.py
CHANGED
@@ -63,7 +63,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
63 |
ref_ids_train = refer_api.getRefIds(split="train")
|
64 |
images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
|
65 |
refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
|
66 |
-
ref_file = os.path.join(DATA_DIR, ds, "refs(" + splitBy + ").p")
|
67 |
|
68 |
refer_seg_ds = {}
|
69 |
refer_seg_ds["images"] = []
|
@@ -149,7 +148,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
149 |
sampled_classes = sampled_sents
|
150 |
img = cv2.imread(image_path)
|
151 |
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
152 |
-
ori_size = images.shape[:2]
|
153 |
|
154 |
# preprocess images for clip
|
155 |
images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
|
@@ -163,7 +161,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
163 |
|
164 |
questions = []
|
165 |
answers = []
|
166 |
-
class_ids = []
|
167 |
for text in sampled_classes:
|
168 |
text = text.strip()
|
169 |
assert len(text.split("||")) == 1
|
@@ -183,7 +180,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
183 |
conversations.append(conv.get_prompt())
|
184 |
i += 1
|
185 |
|
186 |
-
# ==============================
|
187 |
# replace <image> token
|
188 |
for i in range(len(conversations)):
|
189 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
@@ -193,7 +189,6 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
193 |
conversations[i] = conversations[i].replace(
|
194 |
DEFAULT_IMAGE_TOKEN, replace_token
|
195 |
)
|
196 |
-
# ==============================
|
197 |
|
198 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
199 |
|
@@ -223,42 +218,9 @@ class ReferSegDataset(torch.utils.data.Dataset):
|
|
223 |
masks.append(m)
|
224 |
|
225 |
masks = np.stack(masks, axis=0)
|
226 |
-
|
227 |
-
# debug
|
228 |
-
# print("masks.shape: ", masks.shape)
|
229 |
-
# for i in range(masks.shape[0]):
|
230 |
-
# cv2.imwrite("debug/{}_mask_{}.png".format(image_path.split("refer_seg/images")[-1].replace("/", "-").split(".")[0], sampled_sents[i]), masks[i]*100)
|
231 |
-
|
232 |
-
# debug
|
233 |
-
# if ds.endswith("masked"):
|
234 |
-
# save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
|
235 |
-
# os.makedirs(save_dir, exist_ok=True)
|
236 |
-
# print("masks.shape: ", masks.shape)
|
237 |
-
# for i in range(masks.shape[0]):
|
238 |
-
# cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i]*100)
|
239 |
-
# assert len(conversations) == masks.shape[0]
|
240 |
-
# with open("{}/conversations.txt".format(save_dir), "w+") as f:
|
241 |
-
# for i in range(len(conversations)):
|
242 |
-
# f.write("{}. ".format(i) + conversations[i] + "\n")
|
243 |
-
# shutil.copy(image_path, save_dir)
|
244 |
-
|
245 |
masks = torch.from_numpy(masks)
|
246 |
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
247 |
|
248 |
-
# print("refer_seg: {}".format(conversations))
|
249 |
-
|
250 |
-
# # debug
|
251 |
-
# save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
|
252 |
-
# os.makedirs(save_dir, exist_ok=True)
|
253 |
-
# print("masks.shape: ", masks.shape)
|
254 |
-
# for i in range(masks.shape[0]):
|
255 |
-
# cv2.imwrite("{}/mask_{}_{}.jpg".format(save_dir, i, sampled_classes[i]), masks[i].numpy().astype(np.uint8)*100)
|
256 |
-
# assert len(conversations) == masks.shape[0]
|
257 |
-
# with open("{}/conversations.txt".format(save_dir), "w+") as f:
|
258 |
-
# for i in range(len(conversations)):
|
259 |
-
# f.write("{}. ".format(i) + conversations[i] + "\n")
|
260 |
-
# shutil.copy(image_path, save_dir)
|
261 |
-
|
262 |
return (
|
263 |
image_path,
|
264 |
images,
|
|
|
63 |
ref_ids_train = refer_api.getRefIds(split="train")
|
64 |
images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
|
65 |
refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
|
|
|
66 |
|
67 |
refer_seg_ds = {}
|
68 |
refer_seg_ds["images"] = []
|
|
|
148 |
sampled_classes = sampled_sents
|
149 |
img = cv2.imread(image_path)
|
150 |
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
151 |
|
152 |
# preprocess images for clip
|
153 |
images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
|
|
|
161 |
|
162 |
questions = []
|
163 |
answers = []
|
|
|
164 |
for text in sampled_classes:
|
165 |
text = text.strip()
|
166 |
assert len(text.split("||")) == 1
|
|
|
180 |
conversations.append(conv.get_prompt())
|
181 |
i += 1
|
182 |
|
|
|
183 |
# replace <image> token
|
184 |
for i in range(len(conversations)):
|
185 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
|
|
189 |
conversations[i] = conversations[i].replace(
|
190 |
DEFAULT_IMAGE_TOKEN, replace_token
|
191 |
)
|
|
|
192 |
|
193 |
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
194 |
|
|
|
218 |
masks.append(m)
|
219 |
|
220 |
masks = np.stack(masks, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
masks = torch.from_numpy(masks)
|
222 |
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
return (
|
225 |
image_path,
|
226 |
images,
|