ChengyouJia commited on
Commit
b6fba5d
1 Parent(s): 9a25537
Files changed (1) hide show
  1. modeling_internvl_chat.py +154 -16
modeling_internvl_chat.py CHANGED
@@ -6,31 +6,46 @@
6
  import warnings
7
  from typing import Any, List, Optional, Tuple, Union
8
 
 
9
  import torch.utils.checkpoint
 
 
 
 
 
10
  from torch import nn
11
  from torch.nn import CrossEntropyLoss
12
  from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
13
- LlamaTokenizer)
14
  from transformers.modeling_outputs import CausalLMOutputWithPast
15
  from transformers.modeling_utils import PreTrainedModel
16
  from transformers.utils import ModelOutput, logging
17
 
18
  from .configuration_internvl_chat import InternVLChatConfig
19
- from .conversation import get_conv_template
20
  from .modeling_intern_vit import InternVisionModel
21
- from .modeling_internlm2 import InternLM2ForCausalLM
22
 
23
  logger = logging.get_logger(__name__)
24
 
25
 
 
 
 
 
 
 
 
 
26
  class InternVLChatModel(PreTrainedModel):
27
  config_class = InternVLChatConfig
28
  main_input_name = 'pixel_values'
29
- _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer']
 
 
30
 
31
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
32
  super().__init__(config)
33
 
 
34
  image_size = config.force_image_size or config.vision_config.image_size
35
  patch_size = config.vision_config.patch_size
36
  self.patch_size = patch_size
@@ -39,6 +54,7 @@ class InternVLChatModel(PreTrainedModel):
39
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
40
  self.downsample_ratio = config.downsample_ratio
41
  self.ps_version = config.ps_version
 
42
 
43
  logger.info(f'num_image_token: {self.num_image_token}')
44
  logger.info(f'ps_version: {self.ps_version}')
@@ -53,6 +69,10 @@ class InternVLChatModel(PreTrainedModel):
53
  self.language_model = LlamaForCausalLM(config.llm_config)
54
  elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
55
  self.language_model = InternLM2ForCausalLM(config.llm_config)
 
 
 
 
56
  else:
57
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
58
 
@@ -67,6 +87,50 @@ class InternVLChatModel(PreTrainedModel):
67
  )
68
 
69
  self.img_context_token_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def forward(
72
  self,
@@ -85,7 +149,7 @@ class InternVLChatModel(PreTrainedModel):
85
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
 
87
  image_flags = image_flags.squeeze(-1)
88
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
89
 
90
  vit_embeds = self.extract_feature(pixel_values)
91
  vit_embeds = vit_embeds[image_flags == 1]
@@ -94,19 +158,21 @@ class InternVLChatModel(PreTrainedModel):
94
  B, N, C = input_embeds.shape
95
  input_embeds = input_embeds.reshape(B * N, C)
96
 
97
- if torch.distributed.get_rank() == 0:
98
  print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
99
 
100
  input_ids = input_ids.reshape(B * N)
101
  selected = (input_ids == self.img_context_token_id)
102
  try:
103
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
 
104
  except Exception as e:
105
  vit_embeds = vit_embeds.reshape(-1, C)
106
  print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
107
  f'vit_embeds.shape={vit_embeds.shape}')
108
  n_token = selected.sum()
109
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
 
110
 
111
  input_embeds = input_embeds.reshape(B, N, C)
112
 
@@ -134,6 +200,8 @@ class InternVLChatModel(PreTrainedModel):
134
  # Enable model parallelism
135
  shift_labels = shift_labels.to(shift_logits.device)
136
  loss = loss_fct(shift_logits, shift_labels)
 
 
137
 
138
  if not return_dict:
139
  output = (logits,) + outputs[1:]
@@ -183,36 +251,44 @@ class InternVLChatModel(PreTrainedModel):
183
  vit_embeds = self.mlp1(vit_embeds)
184
  return vit_embeds
185
 
186
- def batch_chat(self, tokenizer, pixel_values, num_patches_list, questions, generation_config, history=None,
187
- return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
188
- IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
189
  if history is not None or return_history:
190
  print('Now multi-turn chat is not supported in batch_chat.')
191
  raise NotImplementedError
 
 
 
 
 
192
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
193
  self.img_context_token_id = img_context_token_id
194
 
195
- from .conversation import get_conv_template
 
 
196
 
197
  queries = []
198
- if verbose:
199
- image_bs = pixel_values.shape[0]
200
- print(f'dynamic ViT batch size: {image_bs}, num_patches_list: {num_patches_list}')
201
  for idx, num_patches in enumerate(num_patches_list):
202
- image_token = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
203
- question = image_token + '\n' + questions[idx]
 
204
  template = get_conv_template(self.template)
205
  template.append_message(template.roles[0], question)
206
  template.append_message(template.roles[1], None)
207
  query = template.get_prompt()
 
 
 
208
  queries.append(query)
 
209
  tokenizer.padding_side = 'left'
210
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
211
  input_ids = model_inputs['input_ids'].cuda()
212
  attention_mask = model_inputs['attention_mask'].cuda()
213
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
214
  generation_config['eos_token_id'] = eos_token_id
215
-
216
  generation_output = self.generate(
217
  pixel_values=pixel_values,
218
  input_ids=input_ids,
@@ -226,6 +302,13 @@ class InternVLChatModel(PreTrainedModel):
226
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
227
  num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
228
  verbose=False):
 
 
 
 
 
 
 
229
 
230
  if history is None and pixel_values is not None and '<image>' not in question:
231
  question = '<image>\n' + question
@@ -238,6 +321,7 @@ class InternVLChatModel(PreTrainedModel):
238
  self.img_context_token_id = img_context_token_id
239
 
240
  template = get_conv_template(self.template)
 
241
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
242
 
243
  history = [] if history is None else history
@@ -278,6 +362,60 @@ class InternVLChatModel(PreTrainedModel):
278
  print(query_to_print, response)
279
  return response
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  @torch.no_grad()
282
  def generate(
283
  self,
 
6
  import warnings
7
  from typing import Any, List, Optional, Tuple, Union
8
 
9
+ import torch.distributed as dist
10
  import torch.utils.checkpoint
11
+ import transformers
12
+ from internvl.conversation import get_conv_template
13
+ from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
14
+ from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
15
+ from peft import LoraConfig, get_peft_model
16
  from torch import nn
17
  from torch.nn import CrossEntropyLoss
18
  from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
19
+ LlamaTokenizer, Qwen2ForCausalLM)
20
  from transformers.modeling_outputs import CausalLMOutputWithPast
21
  from transformers.modeling_utils import PreTrainedModel
22
  from transformers.utils import ModelOutput, logging
23
 
24
  from .configuration_internvl_chat import InternVLChatConfig
 
25
  from .modeling_intern_vit import InternVisionModel
 
26
 
27
  logger = logging.get_logger(__name__)
28
 
29
 
30
+ def version_cmp(v1, v2, op='eq'):
31
+ import operator
32
+
33
+ from packaging import version
34
+ op_func = getattr(operator, op)
35
+ return op_func(version.parse(v1), version.parse(v2))
36
+
37
+
38
  class InternVLChatModel(PreTrainedModel):
39
  config_class = InternVLChatConfig
40
  main_input_name = 'pixel_values'
41
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
42
+ 'Phi3DecoderLayer', 'Qwen2DecoderLayer']
43
+ _supports_flash_attn_2 = True
44
 
45
  def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
46
  super().__init__(config)
47
 
48
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
49
  image_size = config.force_image_size or config.vision_config.image_size
50
  patch_size = config.vision_config.patch_size
51
  self.patch_size = patch_size
 
54
  self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
55
  self.downsample_ratio = config.downsample_ratio
56
  self.ps_version = config.ps_version
57
+ self.llm_arch_name = config.llm_config.architectures[0]
58
 
59
  logger.info(f'num_image_token: {self.num_image_token}')
60
  logger.info(f'ps_version: {self.ps_version}')
 
69
  self.language_model = LlamaForCausalLM(config.llm_config)
70
  elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
71
  self.language_model = InternLM2ForCausalLM(config.llm_config)
72
+ elif config.llm_config.architectures[0] == 'Phi3ForCausalLM':
73
+ self.language_model = Phi3ForCausalLM(config.llm_config)
74
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
75
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
76
  else:
77
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
78
 
 
87
  )
88
 
89
  self.img_context_token_id = None
90
+ self.conv_template = get_conv_template(self.template)
91
+ if hasattr(config, 'system_message'):
92
+ self.system_message = config.system_message
93
+ else:
94
+ self.system_message = self.conv_template.system_message
95
+ self.num_samples = 0
96
+
97
+ if config.use_backbone_lora:
98
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
99
+
100
+ if config.use_llm_lora:
101
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
102
+
103
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
104
+ lora_config = LoraConfig(
105
+ r=r,
106
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
107
+ lora_alpha=lora_alpha,
108
+ lora_dropout=lora_dropout,
109
+ )
110
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
111
+ self.vision_model.print_trainable_parameters()
112
+
113
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
114
+ # Determine the target modules based on the architecture of the language model
115
+ if self.llm_arch_name == 'InternLM2ForCausalLM':
116
+ target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
117
+ elif self.llm_arch_name == 'Phi3ForCausalLM':
118
+ target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
119
+ elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
120
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
121
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
122
+ else:
123
+ raise NotImplemented
124
+ lora_config = LoraConfig(
125
+ r=r,
126
+ target_modules=target_modules,
127
+ lora_alpha=lora_alpha,
128
+ lora_dropout=lora_dropout,
129
+ task_type='CAUSAL_LM'
130
+ )
131
+ self.language_model = get_peft_model(self.language_model, lora_config)
132
+ self.language_model.enable_input_require_grads()
133
+ self.language_model.print_trainable_parameters()
134
 
135
  def forward(
136
  self,
 
149
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
 
151
  image_flags = image_flags.squeeze(-1)
152
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
153
 
154
  vit_embeds = self.extract_feature(pixel_values)
155
  vit_embeds = vit_embeds[image_flags == 1]
 
158
  B, N, C = input_embeds.shape
159
  input_embeds = input_embeds.reshape(B * N, C)
160
 
161
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
162
  print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
163
 
164
  input_ids = input_ids.reshape(B * N)
165
  selected = (input_ids == self.img_context_token_id)
166
  try:
167
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
168
+ ignore_flag = False
169
  except Exception as e:
170
  vit_embeds = vit_embeds.reshape(-1, C)
171
  print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
172
  f'vit_embeds.shape={vit_embeds.shape}')
173
  n_token = selected.sum()
174
  input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
175
+ ignore_flag = True
176
 
177
  input_embeds = input_embeds.reshape(B, N, C)
178
 
 
200
  # Enable model parallelism
201
  shift_labels = shift_labels.to(shift_logits.device)
202
  loss = loss_fct(shift_logits, shift_labels)
203
+ if ignore_flag:
204
+ loss = loss * 0.0
205
 
206
  if not return_dict:
207
  output = (logits,) + outputs[1:]
 
251
  vit_embeds = self.mlp1(vit_embeds)
252
  return vit_embeds
253
 
254
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
255
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
256
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
257
  if history is not None or return_history:
258
  print('Now multi-turn chat is not supported in batch_chat.')
259
  raise NotImplementedError
260
+
261
+ if image_counts is not None:
262
+ num_patches_list = image_counts
263
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
264
+
265
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
266
  self.img_context_token_id = img_context_token_id
267
 
268
+ if verbose and pixel_values is not None:
269
+ image_bs = pixel_values.shape[0]
270
+ print(f'dynamic ViT batch size: {image_bs}')
271
 
272
  queries = []
 
 
 
273
  for idx, num_patches in enumerate(num_patches_list):
274
+ question = questions[idx]
275
+ if pixel_values is not None and '<image>' not in question:
276
+ question = '<image>\n' + question
277
  template = get_conv_template(self.template)
278
  template.append_message(template.roles[0], question)
279
  template.append_message(template.roles[1], None)
280
  query = template.get_prompt()
281
+
282
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
283
+ query = query.replace('<image>', image_tokens, 1)
284
  queries.append(query)
285
+
286
  tokenizer.padding_side = 'left'
287
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
288
  input_ids = model_inputs['input_ids'].cuda()
289
  attention_mask = model_inputs['attention_mask'].cuda()
290
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
291
  generation_config['eos_token_id'] = eos_token_id
 
292
  generation_output = self.generate(
293
  pixel_values=pixel_values,
294
  input_ids=input_ids,
 
302
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
303
  num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
304
  verbose=False):
305
+
306
+ # add
307
+ if generation_config.get('token_enable', False):
308
+ self.language_model.token_enable = True
309
+ else:
310
+ self.language_model.token_enable = False
311
+ generation_config.pop('token_enable', None)
312
 
313
  if history is None and pixel_values is not None and '<image>' not in question:
314
  question = '<image>\n' + question
 
321
  self.img_context_token_id = img_context_token_id
322
 
323
  template = get_conv_template(self.template)
324
+ template.system_message = self.system_message
325
  eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
326
 
327
  history = [] if history is None else history
 
362
  print(query_to_print, response)
363
  return response
364
 
365
+ def return_top(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
366
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
367
+ verbose=False):
368
+
369
+ # add
370
+ if generation_config.get('token_enable', False):
371
+ self.language_model.token_enable = True
372
+ else:
373
+ self.language_model.token_enable = False
374
+ generation_config.pop('token_enable', None)
375
+
376
+ if history is None and pixel_values is not None and '<image>' not in question:
377
+ question = '<image>\n' + question
378
+
379
+ if num_patches_list is None:
380
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
381
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
382
+
383
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
384
+ self.img_context_token_id = img_context_token_id
385
+
386
+ template = get_conv_template(self.template)
387
+ template.system_message = self.system_message
388
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
389
+
390
+ history = [] if history is None else history
391
+ for (old_question, old_answer) in history:
392
+ template.append_message(template.roles[0], old_question)
393
+ template.append_message(template.roles[1], old_answer)
394
+ template.append_message(template.roles[0], question)
395
+ template.append_message(template.roles[1], None)
396
+ query = template.get_prompt()
397
+
398
+ if verbose and pixel_values is not None:
399
+ image_bs = pixel_values.shape[0]
400
+ print(f'dynamic ViT batch size: {image_bs}')
401
+
402
+ for num_patches in num_patches_list:
403
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
404
+ query = query.replace('<image>', image_tokens, 1)
405
+
406
+ model_inputs = tokenizer(query, return_tensors='pt')
407
+ input_ids = model_inputs['input_ids'].cuda()
408
+ attention_mask = model_inputs['attention_mask'].cuda()
409
+ generation_config['eos_token_id'] = eos_token_id
410
+ generation_output = self.generate(
411
+ pixel_values=pixel_values,
412
+ input_ids=input_ids,
413
+ attention_mask=attention_mask,
414
+ **generation_config
415
+ )
416
+
417
+ return generation_output
418
+
419
  @torch.no_grad()
420
  def generate(
421
  self,