myownskyW7 yuhangzang commited on
Commit
40e5f64
·
verified ·
1 Parent(s): f1c3bf7

Update modeling_internlm_xcomposer2.py (#4)

Browse files

- Update modeling_internlm_xcomposer2.py (0ea5472ac715dac914f1b44c46ea6c1f464d5550)


Co-authored-by: Yuhang Zang <[email protected]>

Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +82 -45
modeling_internlm_xcomposer2.py CHANGED
@@ -286,69 +286,93 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
286
  }
287
  return inputs, wrap_im_mask, temp_len
288
 
289
- def interleav_wrap(self, img_list, text_list):
290
- wrap_embeds_list, wrap_atts_list = [], []
291
- wrap_target_list, wrap_im_mask_list = [], []
 
292
 
293
- for image, text in zip(img_list, text_list):
294
- img_embeds, atts_img, img_target = self.img2emb(image)
295
- text = text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  parts = text.split('<ImageHere>')
297
- wrap_tokens, wrap_embeds, wrap_atts, wrap_im_mask = [], [], [], []
 
298
  temp_len = 0
299
- image_nums, im_len = img_embeds.shape[:2]
300
  need_bos = True
301
  for idx, part in enumerate(parts):
302
  if len(part) > 0:
303
- part_tokens = self.tokenizer(
304
- part,
305
- return_tensors='pt',
306
- padding='longest',
307
- add_special_tokens=need_bos).to(self.device)
308
  if need_bos:
309
  need_bos = False
310
  wrap_tokens.append(part_tokens.input_ids)
311
- part_embeds = self.model.tok_embeddings(
312
- part_tokens.input_ids)
313
  wrap_embeds.append(part_embeds)
314
- wrap_atts.append(part_tokens.attention_mask)
315
- wrap_im_mask.append(
316
- torch.zeros(part_embeds.shape[:2]).to(self.device))
317
-
318
  temp_len += part_embeds.shape[1]
319
- if idx < image_nums:
320
- wrap_tokens.append(img_target[idx].unsqueeze(0))
321
- wrap_embeds.append(img_embeds[idx].unsqueeze(0))
322
- wrap_atts.append(atts_img[idx].unsqueeze(0))
323
- wrap_im_mask.append(
324
- torch.ones_like(atts_img[idx].unsqueeze(0)))
325
-
326
- temp_len += im_len
327
  if temp_len > self.max_length:
328
  break
329
-
330
  wrap_tokens = torch.cat(wrap_tokens, dim=1)
331
  wrap_embeds = torch.cat(wrap_embeds, dim=1)
332
- wrap_atts = torch.cat(wrap_atts, dim=1)
333
  wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
334
 
335
  wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
336
 
337
- wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device)
338
- wrap_atts = wrap_atts[:, :self.max_length].to(self.device)
339
- wrap_target = wrap_target[:, :self.max_length].to(self.device)
340
- wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- wrap_embeds_list.append(wrap_embeds)
343
- wrap_atts_list.append(wrap_atts)
344
- wrap_target_list.append(wrap_target)
345
- wrap_im_mask_list.append(wrap_im_mask)
346
 
347
- wrap_embeds = torch.cat(wrap_embeds_list)
348
- wrap_atts = torch.cat(wrap_atts_list)
349
- wrap_target = torch.cat(wrap_target_list)
350
- wrap_im_mask = torch.cat(wrap_im_mask_list)
351
- return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
352
 
353
  def mask_human_targets(self, input_ids, pure=False):
354
  target_batch = []
@@ -415,12 +439,25 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
415
  text = samples['text_input']
416
  # encode image
417
  if has_img:
418
- image = samples['image']
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
420
- image, text)
421
  else:
422
  to_regress_tokens, targets = self.text2emb(
423
- text, add_special=True)
424
  to_regress_embeds = self.model.tok_embeddings(
425
  to_regress_tokens.input_ids)
426
  attention_mask = to_regress_tokens.attention_mask
 
286
  }
287
  return inputs, wrap_im_mask, temp_len
288
 
289
+ def interleav_wrap(self, img_list, text_list, image_nums):
290
+ temp_embeds = []
291
+ temp_im_mask = []
292
+ temp_tars = []
293
 
294
+ # encode_image
295
+ img_embeds, img_split = self.vit(img_list, self.plora_glb_GN, self.plora_sub_GN)
296
+ img_embeds = self.vision_proj(img_embeds)
297
+
298
+ text_list = text_list[0]
299
+ for idx, text in enumerate(text_list):
300
+ image_num = image_nums[idx]
301
+ im_id = int(np.sum(image_nums[:idx]))
302
+ images = []
303
+ for i in range(image_nums[idx]):
304
+ st = int(np.sum(img_split[:im_id + i]))
305
+ sp = img_split[im_id + i]
306
+ temp_img = img_embeds[:, st:st+sp]
307
+ images.append(temp_img)
308
+ atts_img = torch.ones((len(images), images[0].shape[1]), dtype=torch.long).to(self.device)
309
+ img_target = torch.ones(
310
+ (len(images), images[0].shape[1]), dtype=torch.long).to(
311
+ self.device) * -100
312
+
313
+ if image_num == 1 and text.find('<ImageHere>') == -1:
314
+ text = '<ImageHere>' + text
315
  parts = text.split('<ImageHere>')
316
+
317
+ wrap_tokens, wrap_embeds, wrap_im_mask = [], [], []
318
  temp_len = 0
 
319
  need_bos = True
320
  for idx, part in enumerate(parts):
321
  if len(part) > 0:
322
+ part_tokens = self.tokenizer(part, return_tensors='pt', padding='longest',
323
+ add_special_tokens=need_bos).to(self.device)
 
 
 
324
  if need_bos:
325
  need_bos = False
326
  wrap_tokens.append(part_tokens.input_ids)
327
+ part_embeds = self.model.tok_embeddings(part_tokens.input_ids)
 
328
  wrap_embeds.append(part_embeds)
329
+ wrap_im_mask.append(torch.zeros(part_embeds.shape[:2]).to(self.device))
 
 
 
330
  temp_len += part_embeds.shape[1]
331
+ if idx < image_num:
332
+ wrap_embeds.append(images[idx])
333
+ wrap_token = torch.ones(images[idx].shape[:2], dtype=torch.long).to(self.device) * -100
334
+ wrap_tokens.append(wrap_token)
335
+ wrap_im_mask.append(torch.ones(images[idx].shape[:2]).to(self.device))
336
+ temp_len += images[idx].shape[1]
 
 
337
  if temp_len > self.max_length:
338
  break
 
339
  wrap_tokens = torch.cat(wrap_tokens, dim=1)
340
  wrap_embeds = torch.cat(wrap_embeds, dim=1)
 
341
  wrap_im_mask = torch.cat(wrap_im_mask, dim=1)
342
 
343
  wrap_target = self.mask_human_targets(wrap_tokens).to(self.device)
344
 
345
+ temp_embeds.append(wrap_embeds)
346
+ temp_im_mask.append(wrap_im_mask)
347
+ temp_tars.append(wrap_target)
348
+
349
+ temp_max_len = np.max([i.shape[1] for i in temp_embeds])
350
+ temp_max_len = min(temp_max_len, self.max_length)
351
+
352
+ final_input, final_atts, final_tars, final_mask = [], [], [], []
353
+ pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
354
+ pad = pad.long().to(self.device)
355
+ pad_emb = self.model.tok_embeddings(pad)
356
+
357
+ for idx in range(len(temp_embeds)):
358
+ temp_len = temp_embeds[idx].shape[1]
359
+ if temp_len >= temp_max_len:
360
+ final_input.append(temp_embeds[idx][:, :temp_max_len])
361
+ final_atts.append(torch.ones(1, temp_max_len).to(wrap_target.dtype).to(self.device))
362
+ final_tars.append(temp_tars[idx][:, :temp_max_len])
363
+ final_mask.append(temp_im_mask[idx][:, :temp_max_len])
364
+ else:
365
+ final_input.append(torch.cat([temp_embeds[idx], pad_emb.repeat(1, temp_max_len-temp_len, 1)], dim=1))
366
+ final_atts.append(torch.cat([torch.ones(1, temp_len), torch.zeros(1, temp_max_len-temp_len)], dim=1).to(wrap_target.dtype).to(self.device))
367
+ final_tars.append(torch.cat([temp_tars[idx], (torch.ones(1, temp_max_len-temp_len)*-100).to(wrap_target.dtype).to(self.device)], dim=1))
368
+ final_mask.append(torch.cat([temp_im_mask[idx], (torch.zeros(1, temp_max_len-temp_len)).to(wrap_target.dtype).to(self.device)], dim=1))
369
 
370
+ inputs_embeds = torch.cat(final_input, dim=0)
371
+ attention_mask = torch.cat(final_atts, dim=0)
372
+ targets = torch.cat(final_tars, dim=0)
373
+ im_mask = torch.cat(final_mask, dim=0)
374
 
375
+ return inputs_embeds, attention_mask, targets, im_mask
 
 
 
 
376
 
377
  def mask_human_targets(self, input_ids, pure=False):
378
  target_batch = []
 
439
  text = samples['text_input']
440
  # encode image
441
  if has_img:
442
+ image = samples['image'][0]
443
+ bs = len(samples['text_input'][0])
444
+ image_nums = []
445
+ temp_image = []
446
+ for im in image:
447
+ if type(im) is list:
448
+ image_nums.append(len(im))
449
+ temp_image.extend(im)
450
+ else:
451
+ image_nums.append(1)
452
+ temp_image.append(im)
453
+ image = temp_image
454
+ assert type(image) is list and len(image_nums) == bs
455
+
456
  to_regress_embeds, attention_mask, targets, im_mask = self.interleav_wrap(
457
+ image, text, image_nums)
458
  else:
459
  to_regress_tokens, targets = self.text2emb(
460
+ text, add_special_tokens=True)
461
  to_regress_embeds = self.model.tok_embeddings(
462
  to_regress_tokens.input_ids)
463
  attention_mask = to_regress_tokens.attention_mask