Update seagull/model/seagull_arch.py
Browse files- seagull/model/seagull_arch.py +11 -11
seagull/model/seagull_arch.py
CHANGED
@@ -92,7 +92,7 @@ class SeagullMetaForCausalLM(ABC):
|
|
92 |
image_features, image_features_dict = self.encode_images(images)
|
93 |
|
94 |
|
95 |
-
|
96 |
|
97 |
new_input_embeds = []
|
98 |
new_labels = [] if labels is not None else None
|
@@ -151,10 +151,10 @@ class SeagullMetaForCausalLM(ABC):
|
|
151 |
_l = 0
|
152 |
for i, idx in enumerate(mask_idx):
|
153 |
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
|
154 |
-
##
|
155 |
-
cur_new_input_embeds.append(
|
156 |
-
##
|
157 |
-
cur_new_input_embeds.append(
|
158 |
if labels is not None:
|
159 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
160 |
_l = idx[0]+2
|
@@ -164,16 +164,16 @@ class SeagullMetaForCausalLM(ABC):
|
|
164 |
else:
|
165 |
|
166 |
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
167 |
-
assert len(mask_idx) == len(
|
168 |
|
169 |
_l = 0
|
170 |
for i, idx in enumerate(mask_idx):
|
171 |
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
|
172 |
cur_new_input_embeds.append(cur_raw_new_input_embeds)
|
173 |
-
##
|
174 |
-
cur_new_input_embeds.append(
|
175 |
-
##
|
176 |
-
cur_new_input_embeds.append(
|
177 |
|
178 |
if labels is not None:
|
179 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
@@ -235,7 +235,7 @@ class SeagullMetaForCausalLM(ABC):
|
|
235 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
236 |
self.resize_token_embeddings(len(tokenizer))
|
237 |
|
238 |
-
mask_tokens = ['<global>', '<
|
239 |
num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
|
240 |
|
241 |
if model_args.mm_use_im_start_end:
|
|
|
92 |
image_features, image_features_dict = self.encode_images(images)
|
93 |
|
94 |
|
95 |
+
global_features_, local_features_ = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
|
96 |
|
97 |
new_input_embeds = []
|
98 |
new_labels = [] if labels is not None else None
|
|
|
151 |
_l = 0
|
152 |
for i, idx in enumerate(mask_idx):
|
153 |
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
|
154 |
+
## global
|
155 |
+
cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].detach())
|
156 |
+
## local
|
157 |
+
cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].detach())
|
158 |
if labels is not None:
|
159 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
160 |
_l = idx[0]+2
|
|
|
164 |
else:
|
165 |
|
166 |
mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
|
167 |
+
assert len(mask_idx) == len(global_features_[batch_idx]), "mask num not equal to mask feats"
|
168 |
|
169 |
_l = 0
|
170 |
for i, idx in enumerate(mask_idx):
|
171 |
cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
|
172 |
cur_new_input_embeds.append(cur_raw_new_input_embeds)
|
173 |
+
## global
|
174 |
+
cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
175 |
+
## local
|
176 |
+
cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
|
177 |
|
178 |
if labels is not None:
|
179 |
cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
|
|
|
235 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
236 |
self.resize_token_embeddings(len(tokenizer))
|
237 |
|
238 |
+
mask_tokens = ['<global>', '<local>']
|
239 |
num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
|
240 |
|
241 |
if model_args.mm_use_im_start_end:
|