Update modeling_internlm_xcomposer2.py (#4)
Browse files- Update modeling_internlm_xcomposer2.py (0ea5472ac715dac914f1b44c46ea6c1f464d5550)
Co-authored-by: Yuhang Zang <[email protected]>
- modeling_internlm_xcomposer2.py +82 -45
modeling_internlm_xcomposer2.py
CHANGED
@@ -286,69 +286,93 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
286 |
}
|
287 |
return inputs, wrap_im_mask, temp_len
|
288 |
|
289 |
-
def interleav_wrap(self, img_list, text_list):
|
290 |
-
|
291 |
-
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
parts = text.split('<ImageHere>')
|
297 |
-
|
|
|
298 |
temp_len = 0
|
299 |
-
image_nums, im_len = img_embeds.shape[:2]
|
300 |
need_bos = True
|
301 |
for idx, part in enumerate(parts):
|
302 |
if len(part) > 0:
|
303 |
-
part_tokens = self.tokenizer(
|
304 |
-
|
305 |
-
return_tensors='pt',
|
306 |
-
padding='longest',
|
307 |
-
add_special_tokens=need_bos).to(self.device)
|
308 |
if need_bos:
|
309 |
need_bos = False
|
310 |
wrap_tokens.append(part_tokens.input_ids)
|
311 |
-
part_embeds = self.model.tok_embeddings(
|
312 |
-
part_tokens.input_ids)
|
313 |
wrap_embeds.append(part_embeds)
|
314 |
-
|
315 |
-
wrap_im_mask.append(
|
316 |
-
torch.zeros(part_embeds.shape[:2]).to(self.device))
|
317 |
-
|
318 |
temp_len += part_embeds.shape[1]
|
319 |
-
if idx <
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
wrap_im_mask.append(
|
324 |
-
|
325 |
-
|
326 |
-
temp_len += im_len
|
327 |
if temp_len > self.max_length:
|
328 |
break
|
329 |
-
|
330 |
wrap_tokens = torch.cat(wrap_tokens, dim=1)
|
331 |
wrap_embeds = torch.cat(wrap_embeds, dim=1)
|
332 |
-
wrap_atts = torch.cat(wrap_atts, dim=1)
|
333 |
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
|
334 |
|
335 |
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
|
347 |
-
|
348 |
-
wrap_atts = torch.cat(wrap_atts_list)
|
349 |
-
wrap_target = torch.cat(wrap_target_list)
|
350 |
-
wrap_im_mask = torch.cat(wrap_im_mask_list)
|
351 |
-
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
352 |
|
353 |
def mask_human_targets(self, input_ids, pure=False):
|
354 |
target_batch = []
|
@@ -415,12 +439,25 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
415 |
text = samples['text_input']
|
416 |
# encode image
|
417 |
if has_img:
|
418 |
-
image = samples['image']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
|
420 |
-
image, text)
|
421 |
else:
|
422 |
to_regress_tokens, targets = self.text2emb(
|
423 |
-
text,
|
424 |
to_regress_embeds = self.model.tok_embeddings(
|
425 |
to_regress_tokens.input_ids)
|
426 |
attention_mask = to_regress_tokens.attention_mask
|
|
|
286 |
}
|
287 |
return inputs, wrap_im_mask, temp_len
|
288 |
|
289 |
+
def interleav_wrap(self, img_list, text_list, image_nums):
|
290 |
+
temp_embeds = []
|
291 |
+
temp_im_mask = []
|
292 |
+
temp_tars = []
|
293 |
|
294 |
+
# encode_image
|
295 |
+
img_embeds, img_split = self.vit(img_list, self.plora_glb_GN, self.plora_sub_GN)
|
296 |
+
img_embeds = self.vision_proj(img_embeds)
|
297 |
+
|
298 |
+
text_list = text_list[0]
|
299 |
+
for idx, text in enumerate(text_list):
|
300 |
+
image_num = image_nums[idx]
|
301 |
+
im_id = int(np.sum(image_nums[:idx]))
|
302 |
+
images = []
|
303 |
+
for i in range(image_nums[idx]):
|
304 |
+
st = int(np.sum(img_split[:im_id + i]))
|
305 |
+
sp = img_split[im_id + i]
|
306 |
+
temp_img = img_embeds[:, st:st+sp]
|
307 |
+
images.append(temp_img)
|
308 |
+
atts_img = torch.ones((len(images), images[0].shape[1]), dtype=torch.long).to(self.device)
|
309 |
+
img_target = torch.ones(
|
310 |
+
(len(images), images[0].shape[1]), dtype=torch.long).to(
|
311 |
+
self.device) * -100
|
312 |
+
|
313 |
+
if image_num == 1 and text.find('<ImageHere>') == -1:
|
314 |
+
text = '<ImageHere>' + text
|
315 |
parts = text.split('<ImageHere>')
|
316 |
+
|
317 |
+
wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
|
318 |
temp_len = 0
|
|
|
319 |
need_bos = True
|
320 |
for idx, part in enumerate(parts):
|
321 |
if len(part) > 0:
|
322 |
+
part_tokens = self.tokenizer(part, return_tensors='pt', padding='longest',
|
323 |
+
add_special_tokens=need_bos).to(self.device)
|
|
|
|
|
|
|
324 |
if need_bos:
|
325 |
need_bos = False
|
326 |
wrap_tokens.append(part_tokens.input_ids)
|
327 |
+
part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
|
|
|
328 |
wrap_embeds.append(part_embeds)
|
329 |
+
wrap_im_mask.append(torch.zeros(part_embeds.shape[:2]).to(self.device))
|
|
|
|
|
|
|
330 |
temp_len += part_embeds.shape[1]
|
331 |
+
if idx < image_num:
|
332 |
+
wrap_embeds.append(images[idx])
|
333 |
+
wrap_token = torch.ones(images[idx].shape[:2], dtype=torch.long).to(self.device) * -100
|
334 |
+
wrap_tokens.append(wrap_token)
|
335 |
+
wrap_im_mask.append(torch.ones(images[idx].shape[:2]).to(self.device))
|
336 |
+
temp_len += images[idx].shape[1]
|
|
|
|
|
337 |
if temp_len > self.max_length:
|
338 |
break
|
|
|
339 |
wrap_tokens = torch.cat(wrap_tokens, dim=1)
|
340 |
wrap_embeds = torch.cat(wrap_embeds, dim=1)
|
|
|
341 |
wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
|
342 |
|
343 |
wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
|
344 |
|
345 |
+
temp_embeds.append(wrap_embeds)
|
346 |
+
temp_im_mask.append(wrap_im_mask)
|
347 |
+
temp_tars.append(wrap_target)
|
348 |
+
|
349 |
+
temp_max_len = np.max([i.shape[1] for i in temp_embeds])
|
350 |
+
temp_max_len = min(temp_max_len, self.max_length)
|
351 |
+
|
352 |
+
final_input, final_atts, final_tars, final_mask = [], [], [], []
|
353 |
+
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
|
354 |
+
pad = pad.long().to(self.device)
|
355 |
+
pad_emb = self.model.tok_embeddings(pad)
|
356 |
+
|
357 |
+
for idx in range(len(temp_embeds)):
|
358 |
+
temp_len = temp_embeds[idx].shape[1]
|
359 |
+
if temp_len >= temp_max_len:
|
360 |
+
final_input.append(temp_embeds[idx][:, :temp_max_len])
|
361 |
+
final_atts.append(torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device))
|
362 |
+
final_tars.append(temp_tars[idx][:, :temp_max_len])
|
363 |
+
final_mask.append(temp_im_mask[idx][:, :temp_max_len])
|
364 |
+
else:
|
365 |
+
final_input.append(torch.cat([temp_embeds[idx], pad_emb.repeat(1, temp_max_len-temp_len, 1)], dim=1))
|
366 |
+
final_atts.append(torch.cat([torch.ones(1, temp_len), torch.zeros(1, temp_max_len-temp_len)], dim=1).to(wrap_target.dtype).to(self.device))
|
367 |
+
final_tars.append(torch.cat([temp_tars[idx], (torch.ones(1, temp_max_len-temp_len)*-100).to(wrap_target.dtype).to(self.device)], dim=1))
|
368 |
+
final_mask.append(torch.cat([temp_im_mask[idx], (torch.zeros(1, temp_max_len-temp_len)).to(wrap_target.dtype).to(self.device)], dim=1))
|
369 |
|
370 |
+
inputs_embeds = torch.cat(final_input, dim=0)
|
371 |
+
attention_mask = torch.cat(final_atts, dim=0)
|
372 |
+
targets = torch.cat(final_tars, dim=0)
|
373 |
+
im_mask = torch.cat(final_mask, dim=0)
|
374 |
|
375 |
+
return inputs_embeds, attention_mask, targets, im_mask
|
|
|
|
|
|
|
|
|
376 |
|
377 |
def mask_human_targets(self, input_ids, pure=False):
|
378 |
target_batch = []
|
|
|
439 |
text = samples['text_input']
|
440 |
# encode image
|
441 |
if has_img:
|
442 |
+
image = samples['image'][0]
|
443 |
+
bs = len(samples['text_input'][0])
|
444 |
+
image_nums = []
|
445 |
+
temp_image = []
|
446 |
+
for im in image:
|
447 |
+
if type(im) is list:
|
448 |
+
image_nums.append(len(im))
|
449 |
+
temp_image.extend(im)
|
450 |
+
else:
|
451 |
+
image_nums.append(1)
|
452 |
+
temp_image.append(im)
|
453 |
+
image = temp_image
|
454 |
+
assert type(image) is list and len(image_nums) == bs
|
455 |
+
|
456 |
to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
|
457 |
+
image, text, image_nums)
|
458 |
else:
|
459 |
to_regress_tokens, targets = self.text2emb(
|
460 |
+
text, add_special_tokens=True)
|
461 |
to_regress_embeds = self.model.tok_embeddings(
|
462 |
to_regress_tokens.input_ids)
|
463 |
attention_mask = to_regress_tokens.attention_mask
|