Zevin2023 commited on
Commit
8834f72
1 Parent(s): eb97d8c

Update seagull/model/seagull_arch.py

Browse files
Files changed (1) hide show
  1. seagull/model/seagull_arch.py +11 -11
seagull/model/seagull_arch.py CHANGED
@@ -92,7 +92,7 @@ class SeagullMetaForCausalLM(ABC):
92
  image_features, image_features_dict = self.encode_images(images)
93
 
94
 
95
- mask_feats, pos_feats = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
96
 
97
  new_input_embeds = []
98
  new_labels = [] if labels is not None else None
@@ -151,10 +151,10 @@ class SeagullMetaForCausalLM(ABC):
151
  _l = 0
152
  for i, idx in enumerate(mask_idx):
153
  cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
154
- ## mask
155
- cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].detach())
156
- ## pos
157
- cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].detach())
158
  if labels is not None:
159
  cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
160
  _l = idx[0]+2
@@ -164,16 +164,16 @@ class SeagullMetaForCausalLM(ABC):
164
  else:
165
 
166
  mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
167
- assert len(mask_idx) == len(mask_feats[batch_idx]), "mask num not equal to mask feats"
168
 
169
  _l = 0
170
  for i, idx in enumerate(mask_idx):
171
  cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
172
  cur_new_input_embeds.append(cur_raw_new_input_embeds)
173
- ## mask
174
- cur_new_input_embeds.append(mask_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
175
- ## pos
176
- cur_new_input_embeds.append(pos_feats[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
177
 
178
  if labels is not None:
179
  cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
@@ -235,7 +235,7 @@ class SeagullMetaForCausalLM(ABC):
235
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
236
  self.resize_token_embeddings(len(tokenizer))
237
 
238
- mask_tokens = ['<global>', '<pos>']
239
  num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
240
 
241
  if model_args.mm_use_im_start_end:
 
92
  image_features, image_features_dict = self.encode_images(images)
93
 
94
 
95
+ global_features_, local_features_ = self.mask_extractor(image_features_dict, masks, cropped_img=cropped_img)
96
 
97
  new_input_embeds = []
98
  new_labels = [] if labels is not None else None
 
151
  _l = 0
152
  for i, idx in enumerate(mask_idx):
153
  cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[_l:idx[0]]).detach())
154
+ ## global
155
+ cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].detach())
156
+ ## local
157
+ cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].detach())
158
  if labels is not None:
159
  cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
160
  _l = idx[0]+2
 
164
  else:
165
 
166
  mask_idx = torch.nonzero(cur_input_ids==self.tokenizer.convert_tokens_to_ids(['<global>'])[0])
167
+ assert len(mask_idx) == len(global_features_[batch_idx]), "mask num not equal to mask feats"
168
 
169
  _l = 0
170
  for i, idx in enumerate(mask_idx):
171
  cur_raw_new_input_embeds = self.get_model().embed_tokens(cur_input_ids[_l:idx[0]])
172
  cur_new_input_embeds.append(cur_raw_new_input_embeds)
173
+ ## global
174
+ cur_new_input_embeds.append(global_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
175
+ ## local
176
+ cur_new_input_embeds.append(local_features_[batch_idx][i:i+1].to(cur_raw_new_input_embeds.dtype))
177
 
178
  if labels is not None:
179
  cur_labels[idx[0]:idx[0]+2] = torch.full((2,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)
 
235
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
236
  self.resize_token_embeddings(len(tokenizer))
237
 
238
+ mask_tokens = ['<global>', '<local>']
239
  num_new_tokens = tokenizer.add_tokens(mask_tokens, special_tokens=True)
240
 
241
  if model_args.mm_use_im_start_end: