tt1225 commited on
Commit
284a833
1 Parent(s): 96b7bf9

Update modeling_internvl_chat.py

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +352 -344
modeling_internvl_chat.py CHANGED
@@ -1,344 +1,352 @@
1
- # --------------------------------------------------------
2
- # InternVL
3
- # Copyright (c) 2024 OpenGVLab
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # --------------------------------------------------------
6
- import warnings
7
- from typing import Any, List, Optional, Tuple, Union
8
-
9
- import torch.utils.checkpoint
10
- import transformers
11
- from torch import nn
12
- from torch.nn import CrossEntropyLoss
13
- from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
14
- Qwen2ForCausalLM)
15
- from transformers.modeling_outputs import CausalLMOutputWithPast
16
- from transformers.modeling_utils import PreTrainedModel
17
- from transformers.utils import ModelOutput, logging
18
-
19
- from .configuration_internvl_chat import InternVLChatConfig
20
- from .conversation import get_conv_template
21
- from .modeling_intern_vit import InternVisionModel
22
-
23
- logger = logging.get_logger(__name__)
24
-
25
-
26
- def version_cmp(v1, v2, op='eq'):
27
- import operator
28
-
29
- from packaging import version
30
- op_func = getattr(operator, op)
31
- return op_func(version.parse(v1), version.parse(v2))
32
-
33
-
34
- class InternVLChatModel(PreTrainedModel):
35
- config_class = InternVLChatConfig
36
- main_input_name = 'pixel_values'
37
- _supports_flash_attn_2 = True
38
- _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
39
-
40
- def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
41
- super().__init__(config)
42
-
43
- assert version_cmp(transformers.__version__, '4.37.0', 'ge')
44
- image_size = config.force_image_size or config.vision_config.image_size
45
- patch_size = config.vision_config.patch_size
46
- self.patch_size = patch_size
47
- self.select_layer = config.select_layer
48
- self.template = config.template
49
- self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
50
- self.downsample_ratio = config.downsample_ratio
51
- self.ps_version = config.ps_version
52
-
53
- logger.info(f'num_image_token: {self.num_image_token}')
54
- logger.info(f'ps_version: {self.ps_version}')
55
- if vision_model is not None:
56
- self.vision_model = vision_model
57
- else:
58
- self.vision_model = InternVisionModel(config.vision_config)
59
- if language_model is not None:
60
- self.language_model = language_model
61
- else:
62
- if config.llm_config.architectures[0] == 'LlamaForCausalLM':
63
- self.language_model = LlamaForCausalLM(config.llm_config)
64
- elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
65
- self.language_model = Qwen2ForCausalLM(config.llm_config)
66
- else:
67
- raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
68
-
69
- vit_hidden_size = config.vision_config.hidden_size
70
- llm_hidden_size = config.llm_config.hidden_size
71
-
72
- self.mlp1 = nn.Sequential(
73
- nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
74
- nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
75
- nn.GELU(),
76
- nn.Linear(llm_hidden_size, llm_hidden_size)
77
- )
78
-
79
- self.img_context_token_id = None
80
- self.conv_template = get_conv_template(self.template)
81
- self.system_message = self.conv_template.system_message
82
-
83
- def forward(
84
- self,
85
- pixel_values: torch.FloatTensor,
86
- input_ids: torch.LongTensor = None,
87
- attention_mask: Optional[torch.Tensor] = None,
88
- position_ids: Optional[torch.LongTensor] = None,
89
- image_flags: Optional[torch.LongTensor] = None,
90
- past_key_values: Optional[List[torch.FloatTensor]] = None,
91
- labels: Optional[torch.LongTensor] = None,
92
- use_cache: Optional[bool] = None,
93
- output_attentions: Optional[bool] = None,
94
- output_hidden_states: Optional[bool] = None,
95
- return_dict: Optional[bool] = None,
96
- ) -> Union[Tuple, CausalLMOutputWithPast]:
97
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
98
-
99
- image_flags = image_flags.squeeze(-1)
100
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
101
-
102
- vit_embeds = self.extract_feature(pixel_values)
103
- vit_embeds = vit_embeds[image_flags == 1]
104
- vit_batch_size = pixel_values.shape[0]
105
-
106
- B, N, C = input_embeds.shape
107
- input_embeds = input_embeds.reshape(B * N, C)
108
-
109
- if torch.distributed.get_rank() == 0:
110
- print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
111
-
112
- input_ids = input_ids.reshape(B * N)
113
- selected = (input_ids == self.img_context_token_id)
114
- try:
115
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
116
- except Exception as e:
117
- vit_embeds = vit_embeds.reshape(-1, C)
118
- print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
119
- f'vit_embeds.shape={vit_embeds.shape}')
120
- n_token = selected.sum()
121
- input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
122
-
123
- input_embeds = input_embeds.reshape(B, N, C)
124
-
125
- outputs = self.language_model(
126
- inputs_embeds=input_embeds,
127
- attention_mask=attention_mask,
128
- position_ids=position_ids,
129
- past_key_values=past_key_values,
130
- use_cache=use_cache,
131
- output_attentions=output_attentions,
132
- output_hidden_states=output_hidden_states,
133
- return_dict=return_dict,
134
- )
135
- logits = outputs.logits
136
-
137
- loss = None
138
- if labels is not None:
139
- # Shift so that tokens < n predict n
140
- shift_logits = logits[..., :-1, :].contiguous()
141
- shift_labels = labels[..., 1:].contiguous()
142
- # Flatten the tokens
143
- loss_fct = CrossEntropyLoss()
144
- shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
145
- shift_labels = shift_labels.view(-1)
146
- # Enable model parallelism
147
- shift_labels = shift_labels.to(shift_logits.device)
148
- loss = loss_fct(shift_logits, shift_labels)
149
-
150
- if not return_dict:
151
- output = (logits,) + outputs[1:]
152
- return (loss,) + output if loss is not None else output
153
-
154
- return CausalLMOutputWithPast(
155
- loss=loss,
156
- logits=logits,
157
- past_key_values=outputs.past_key_values,
158
- hidden_states=outputs.hidden_states,
159
- attentions=outputs.attentions,
160
- )
161
-
162
- def pixel_shuffle(self, x, scale_factor=0.5):
163
- n, w, h, c = x.size()
164
- # N, W, H, C --> N, W, H * scale, C // scale
165
- x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
166
- # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
167
- x = x.permute(0, 2, 1, 3).contiguous()
168
- # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
169
- x = x.view(n, int(h * scale_factor), int(w * scale_factor),
170
- int(c / (scale_factor * scale_factor)))
171
- if self.ps_version == 'v1':
172
- warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
173
- 'which results in a transposed image.')
174
- else:
175
- x = x.permute(0, 2, 1, 3).contiguous()
176
- return x
177
-
178
- def extract_feature(self, pixel_values):
179
- if self.select_layer == -1:
180
- vit_embeds = self.vision_model(
181
- pixel_values=pixel_values,
182
- output_hidden_states=False,
183
- return_dict=True).last_hidden_state
184
- else:
185
- vit_embeds = self.vision_model(
186
- pixel_values=pixel_values,
187
- output_hidden_states=True,
188
- return_dict=True).hidden_states[self.select_layer]
189
- vit_embeds = vit_embeds[:, 1:, :]
190
-
191
- h = w = int(vit_embeds.shape[1] ** 0.5)
192
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
193
- vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
194
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
195
- vit_embeds = self.mlp1(vit_embeds)
196
- return vit_embeds
197
-
198
- def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
199
- history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
200
- IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
201
- if history is not None or return_history:
202
- print('Now multi-turn chat is not supported in batch_chat.')
203
- raise NotImplementedError
204
-
205
- if image_counts is not None:
206
- num_patches_list = image_counts
207
- print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
208
-
209
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
210
- self.img_context_token_id = img_context_token_id
211
-
212
- if verbose and pixel_values is not None:
213
- image_bs = pixel_values.shape[0]
214
- print(f'dynamic ViT batch size: {image_bs}')
215
-
216
- queries = []
217
- for idx, num_patches in enumerate(num_patches_list):
218
- question = questions[idx]
219
- if pixel_values is not None and '<image>' not in question:
220
- question = '<image>\n' + question
221
- template = get_conv_template(self.template)
222
- template.append_message(template.roles[0], question)
223
- template.append_message(template.roles[1], None)
224
- query = template.get_prompt()
225
-
226
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
227
- query = query.replace('<image>', image_tokens, 1)
228
- queries.append(query)
229
-
230
- tokenizer.padding_side = 'left'
231
- model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
232
- input_ids = model_inputs['input_ids'].cuda()
233
- attention_mask = model_inputs['attention_mask'].cuda()
234
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
235
- generation_config['eos_token_id'] = eos_token_id
236
- generation_output = self.generate(
237
- pixel_values=pixel_values,
238
- input_ids=input_ids,
239
- attention_mask=attention_mask,
240
- **generation_config
241
- )
242
- responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
243
- responses = [response.split(template.sep)[0].strip() for response in responses]
244
- return responses
245
-
246
- def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
247
- num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
248
- verbose=False):
249
-
250
- if history is None and pixel_values is not None and '<image>' not in question:
251
- question = '<image>\n' + question
252
-
253
- if num_patches_list is None:
254
- num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
255
- assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
256
-
257
- img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
258
- self.img_context_token_id = img_context_token_id
259
-
260
- template = get_conv_template(self.template)
261
- template.system_message = self.system_message
262
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
263
-
264
- history = [] if history is None else history
265
- for (old_question, old_answer) in history:
266
- template.append_message(template.roles[0], old_question)
267
- template.append_message(template.roles[1], old_answer)
268
- template.append_message(template.roles[0], question)
269
- template.append_message(template.roles[1], None)
270
- query = template.get_prompt()
271
-
272
- if verbose and pixel_values is not None:
273
- image_bs = pixel_values.shape[0]
274
- print(f'dynamic ViT batch size: {image_bs}')
275
-
276
- for num_patches in num_patches_list:
277
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
278
- query = query.replace('<image>', image_tokens, 1)
279
-
280
- model_inputs = tokenizer(query, return_tensors='pt')
281
- input_ids = model_inputs['input_ids'].cuda()
282
- attention_mask = model_inputs['attention_mask'].cuda()
283
- generation_config['eos_token_id'] = eos_token_id
284
- generation_output = self.generate(
285
- pixel_values=pixel_values,
286
- input_ids=input_ids,
287
- attention_mask=attention_mask,
288
- **generation_config
289
- )
290
- response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
291
- response = response.split(template.sep)[0].strip()
292
- history.append((question, response))
293
- if return_history:
294
- return response, history
295
- else:
296
- query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
297
- query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
298
- if verbose:
299
- print(query_to_print, response)
300
- return response
301
-
302
- @torch.no_grad()
303
- def generate(
304
- self,
305
- pixel_values: Optional[torch.FloatTensor] = None,
306
- input_ids: Optional[torch.FloatTensor] = None,
307
- attention_mask: Optional[torch.LongTensor] = None,
308
- visual_features: Optional[torch.FloatTensor] = None,
309
- generation_config: Optional[GenerationConfig] = None,
310
- output_hidden_states: Optional[bool] = None,
311
- return_dict: Optional[bool] = None,
312
- **generate_kwargs,
313
- ) -> torch.LongTensor:
314
-
315
- assert self.img_context_token_id is not None
316
- if pixel_values is not None:
317
- if visual_features is not None:
318
- vit_embeds = visual_features
319
- else:
320
- vit_embeds = self.extract_feature(pixel_values)
321
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
322
- B, N, C = input_embeds.shape
323
- input_embeds = input_embeds.reshape(B * N, C)
324
-
325
- input_ids = input_ids.reshape(B * N)
326
- selected = (input_ids == self.img_context_token_id)
327
- assert selected.sum() != 0
328
- input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
329
-
330
- input_embeds = input_embeds.reshape(B, N, C)
331
- else:
332
- input_embeds = self.language_model.get_input_embeddings()(input_ids)
333
-
334
- outputs = self.language_model.generate(
335
- inputs_embeds=input_embeds,
336
- attention_mask=attention_mask,
337
- generation_config=generation_config,
338
- output_hidden_states=output_hidden_states,
339
- return_dict=return_dict,
340
- use_cache=True,
341
- **generate_kwargs,
342
- )
343
-
344
- return outputs
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import warnings
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import torch.utils.checkpoint
10
+ import transformers
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
14
+ Qwen2ForCausalLM)
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, logging
18
+
19
+ from .configuration_internvl_chat import InternVLChatConfig
20
+ from .conversation import get_conv_template
21
+ from .modeling_intern_vit import InternVisionModel
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def version_cmp(v1, v2, op='eq'):
27
+ import operator
28
+
29
+ from packaging import version
30
+ op_func = getattr(operator, op)
31
+ return op_func(version.parse(v1), version.parse(v2))
32
+
33
+
34
+ class InternVLChatModel(PreTrainedModel):
35
+ config_class = InternVLChatConfig
36
+ main_input_name = 'pixel_values'
37
+ _supports_flash_attn_2 = True
38
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
39
+
40
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
41
+ super().__init__(config)
42
+
43
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
44
+ image_size = config.force_image_size or config.vision_config.image_size
45
+ patch_size = config.vision_config.patch_size
46
+ self.patch_size = patch_size
47
+ self.select_layer = config.select_layer
48
+ self.template = config.template
49
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
50
+ self.downsample_ratio = config.downsample_ratio
51
+ self.ps_version = config.ps_version
52
+
53
+ logger.info(f'num_image_token: {self.num_image_token}')
54
+ logger.info(f'ps_version: {self.ps_version}')
55
+ if vision_model is not None:
56
+ self.vision_model = vision_model
57
+ else:
58
+ self.vision_model = InternVisionModel(config.vision_config)
59
+ if language_model is not None:
60
+ self.language_model = language_model
61
+ else:
62
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
63
+ self.language_model = LlamaForCausalLM(config.llm_config)
64
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
65
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
66
+ else:
67
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
68
+
69
+ vit_hidden_size = config.vision_config.hidden_size
70
+ llm_hidden_size = config.llm_config.hidden_size
71
+
72
+ self.mlp1 = nn.Sequential(
73
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
74
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
75
+ nn.GELU(),
76
+ nn.Linear(llm_hidden_size, llm_hidden_size)
77
+ )
78
+
79
+ self.img_context_token_id = None
80
+ self.conv_template = get_conv_template(self.template)
81
+ self.system_message = self.conv_template.system_message
82
+
83
+ def forward(
84
+ self,
85
+ pixel_values: torch.FloatTensor,
86
+ input_ids: torch.LongTensor = None,
87
+ attention_mask: Optional[torch.Tensor] = None,
88
+ position_ids: Optional[torch.LongTensor] = None,
89
+ image_flags: Optional[torch.LongTensor] = None,
90
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
91
+ labels: Optional[torch.LongTensor] = None,
92
+ use_cache: Optional[bool] = None,
93
+ output_attentions: Optional[bool] = None,
94
+ output_hidden_states: Optional[bool] = None,
95
+ return_dict: Optional[bool] = None,
96
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
97
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
98
+
99
+ image_flags = image_flags.squeeze(-1)
100
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
101
+
102
+ vit_embeds = self.extract_feature(pixel_values)
103
+ vit_embeds = vit_embeds[image_flags == 1]
104
+ vit_batch_size = pixel_values.shape[0]
105
+
106
+ B, N, C = input_embeds.shape
107
+ input_embeds = input_embeds.reshape(B * N, C)
108
+
109
+ if torch.distributed.get_rank() == 0:
110
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
111
+
112
+ input_ids = input_ids.reshape(B * N)
113
+ selected = (input_ids == self.img_context_token_id)
114
+ try:
115
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
116
+ except Exception as e:
117
+ vit_embeds = vit_embeds.reshape(-1, C)
118
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
119
+ f'vit_embeds.shape={vit_embeds.shape}')
120
+ n_token = selected.sum()
121
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
122
+
123
+ input_embeds = input_embeds.reshape(B, N, C)
124
+
125
+ outputs = self.language_model(
126
+ inputs_embeds=input_embeds,
127
+ attention_mask=attention_mask,
128
+ position_ids=position_ids,
129
+ past_key_values=past_key_values,
130
+ use_cache=use_cache,
131
+ output_attentions=output_attentions,
132
+ output_hidden_states=output_hidden_states,
133
+ return_dict=return_dict,
134
+ )
135
+ logits = outputs.logits
136
+
137
+ loss = None
138
+ if labels is not None:
139
+ # Shift so that tokens < n predict n
140
+ shift_logits = logits[..., :-1, :].contiguous()
141
+ shift_labels = labels[..., 1:].contiguous()
142
+ # Flatten the tokens
143
+ loss_fct = CrossEntropyLoss()
144
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
145
+ shift_labels = shift_labels.view(-1)
146
+ # Enable model parallelism
147
+ shift_labels = shift_labels.to(shift_logits.device)
148
+ loss = loss_fct(shift_logits, shift_labels)
149
+
150
+ if not return_dict:
151
+ output = (logits,) + outputs[1:]
152
+ return (loss,) + output if loss is not None else output
153
+
154
+ return CausalLMOutputWithPast(
155
+ loss=loss,
156
+ logits=logits,
157
+ past_key_values=outputs.past_key_values,
158
+ hidden_states=outputs.hidden_states,
159
+ attentions=outputs.attentions,
160
+ )
161
+
162
+ def pixel_shuffle(self, x, scale_factor=0.5):
163
+ n, w, h, c = x.size()
164
+ # N, W, H, C --> N, W, H * scale, C // scale
165
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
166
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
167
+ x = x.permute(0, 2, 1, 3).contiguous()
168
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
169
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
170
+ int(c / (scale_factor * scale_factor)))
171
+ if self.ps_version == 'v1':
172
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
173
+ 'which results in a transposed image.')
174
+ else:
175
+ x = x.permute(0, 2, 1, 3).contiguous()
176
+ return x
177
+
178
+ def extract_feature(self, pixel_values):
179
+ if self.select_layer == -1:
180
+ vit_embeds = self.vision_model(
181
+ pixel_values=pixel_values,
182
+ output_hidden_states=False,
183
+ return_dict=True).last_hidden_state
184
+ else:
185
+ vit_embeds = self.vision_model(
186
+ pixel_values=pixel_values,
187
+ output_hidden_states=True,
188
+ return_dict=True).hidden_states[self.select_layer]
189
+ vit_embeds = vit_embeds[:, 1:, :]
190
+
191
+ h = w = int(vit_embeds.shape[1] ** 0.5)
192
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
193
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
194
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
195
+ vit_embeds = self.mlp1(vit_embeds)
196
+ return vit_embeds
197
+
198
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
199
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
200
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
201
+ if history is not None or return_history:
202
+ print('Now multi-turn chat is not supported in batch_chat.')
203
+ raise NotImplementedError
204
+
205
+ if image_counts is not None:
206
+ num_patches_list = image_counts
207
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
208
+
209
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
210
+ self.img_context_token_id = img_context_token_id
211
+
212
+ if verbose and pixel_values is not None:
213
+ image_bs = pixel_values.shape[0]
214
+ print(f'dynamic ViT batch size: {image_bs}')
215
+
216
+ queries = []
217
+ for idx, num_patches in enumerate(num_patches_list):
218
+ question = questions[idx]
219
+ if pixel_values is not None and '<image>' not in question:
220
+ question = '<image>\n' + question
221
+ template = get_conv_template(self.template)
222
+ template.append_message(template.roles[0], question)
223
+ template.append_message(template.roles[1], None)
224
+ query = template.get_prompt()
225
+
226
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
227
+ query = query.replace('<image>', image_tokens, 1)
228
+ queries.append(query)
229
+
230
+ tokenizer.padding_side = 'left'
231
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
232
+ input_ids = model_inputs['input_ids'].cuda()
233
+ attention_mask = model_inputs['attention_mask'].cuda()
234
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
235
+ generation_config['eos_token_id'] = eos_token_id
236
+ generation_output = self.generate(
237
+ pixel_values=pixel_values,
238
+ input_ids=input_ids,
239
+ attention_mask=attention_mask,
240
+ **generation_config
241
+ )
242
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
243
+ responses = [response.split(template.sep)[0].strip() for response in responses]
244
+ return responses
245
+
246
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
247
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
248
+ verbose=False):
249
+
250
+ if history is None and pixel_values is not None and '<image>' not in question:
251
+ question = '<image>\n' + question
252
+
253
+ if num_patches_list is None:
254
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
255
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
256
+
257
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
258
+ self.img_context_token_id = img_context_token_id
259
+
260
+ template = get_conv_template(self.template)
261
+ template.system_message = self.system_message
262
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
263
+
264
+ history = [] if history is None else history
265
+ for (old_question, old_answer) in history:
266
+ template.append_message(template.roles[0], old_question)
267
+ template.append_message(template.roles[1], old_answer)
268
+ template.append_message(template.roles[0], question)
269
+ template.append_message(template.roles[1], None)
270
+ query = template.get_prompt()
271
+
272
+ if verbose and pixel_values is not None:
273
+ image_bs = pixel_values.shape[0]
274
+ print(f'dynamic ViT batch size: {image_bs}')
275
+
276
+ for num_patches in num_patches_list:
277
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
278
+ query = query.replace('<image>', image_tokens, 1)
279
+
280
+ print(self.num_image_token * num_patches)
281
+ print(query)
282
+
283
+ model_inputs = tokenizer(query, return_tensors='pt')
284
+ input_ids = model_inputs['input_ids'].cuda()
285
+ attention_mask = model_inputs['attention_mask'].cuda()
286
+
287
+ print(input_ids.shape)
288
+ print(attention_mask.shape)
289
+
290
+
291
+ generation_config['eos_token_id'] = eos_token_id
292
+ generation_output = self.generate(
293
+ pixel_values=pixel_values,
294
+ input_ids=input_ids,
295
+ attention_mask=attention_mask,
296
+ **generation_config
297
+ )
298
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
299
+ response = response.split(template.sep)[0].strip()
300
+ history.append((question, response))
301
+ if return_history:
302
+ return response, history
303
+ else:
304
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
305
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
306
+ if verbose:
307
+ print(query_to_print, response)
308
+ return response
309
+
310
+ @torch.no_grad()
311
+ def generate(
312
+ self,
313
+ pixel_values: Optional[torch.FloatTensor] = None,
314
+ input_ids: Optional[torch.FloatTensor] = None,
315
+ attention_mask: Optional[torch.LongTensor] = None,
316
+ visual_features: Optional[torch.FloatTensor] = None,
317
+ generation_config: Optional[GenerationConfig] = None,
318
+ output_hidden_states: Optional[bool] = None,
319
+ return_dict: Optional[bool] = None,
320
+ **generate_kwargs,
321
+ ) -> torch.LongTensor:
322
+
323
+ assert self.img_context_token_id is not None
324
+ if pixel_values is not None:
325
+ if visual_features is not None:
326
+ vit_embeds = visual_features
327
+ else:
328
+ vit_embeds = self.extract_feature(pixel_values)
329
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
330
+ B, N, C = input_embeds.shape
331
+ input_embeds = input_embeds.reshape(B * N, C)
332
+
333
+ input_ids = input_ids.reshape(B * N)
334
+ selected = (input_ids == self.img_context_token_id)
335
+ assert selected.sum() != 0
336
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
337
+
338
+ input_embeds = input_embeds.reshape(B, N, C)
339
+ else:
340
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
341
+
342
+ outputs = self.language_model.generate(
343
+ inputs_embeds=input_embeds,
344
+ attention_mask=attention_mask,
345
+ generation_config=generation_config,
346
+ output_hidden_states=output_hidden_states,
347
+ return_dict=return_dict,
348
+ use_cache=True,
349
+ **generate_kwargs,
350
+ )
351
+
352
+ return outputs