jadechoghari commited on
Commit
be195f0
·
verified ·
1 Parent(s): 54b2327

add files - PR (with the other config PR)- check description

Browse files

This PR allows anyone to easily load the model wiht transformers.
Instead of requiring users to manually clone the model and place it in the folder
Now anyone could easily use the model with transofmrers as such:

```python
from transformers import AutoModel

model = AutoModel.from_pretrained("Vision-CAIR/LongVU_Qwen2_7B", trust_remote_code=True)

```

try it out ! with "jadechoghari/LongVU_Qwen2_7B") instead :)

also linked in this issue: https://github.com/Vision-CAIR/LongVU/issues/5


will be updating the other model, however they're ready here: https://huggingface.co/models?sort=trending&search=LongVU+jad

cambrian_arch.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import random
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ # define the constants
25
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
26
+ WORKER_HEART_BEAT_INTERVAL = 15
27
+
28
+ LOGDIR = "."
29
+
30
+ # Model Constants
31
+ IGNORE_INDEX = -100
32
+ IMAGE_TOKEN_INDEX = -200
33
+ DEFAULT_IMAGE_TOKEN = "<image>"
34
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
35
+ DEFAULT_IM_START_TOKEN = "<im_start>"
36
+ DEFAULT_IM_END_TOKEN = "<im_end>"
37
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
38
+
39
+ from .multimodal_encoder_builder import build_vision_tower_aux_list
40
+ from .multimodal_projector_builder import build_vision_projector
41
+ from .vision_sampler import VisionTokenSampler
42
+
43
+ IS_XLA_AVAILABLE = False
44
+
45
+
46
+ class CambrianMetaModel:
47
+
48
+ def __init__(self, config):
49
+ super(CambrianMetaModel, self).__init__(config)
50
+
51
+ if hasattr(config, "mm_vision_tower_aux_list"):
52
+
53
+ projector_type = getattr(config, "mm_projector_type", "linear")
54
+ if projector_type == "sva":
55
+
56
+ vision_hidden_size = config.vision_hidden_size
57
+ num_query_group = config.num_query_group
58
+ query_num_list = config.query_num_list
59
+ connector_only = config.connector_only
60
+ connector_depth = config.connector_depth
61
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
62
+ config, delay_load=True
63
+ )
64
+ self.mm_projector = nn.Sequential(
65
+ nn.Linear(vision_hidden_size * num_query_group, config.hidden_size),
66
+ nn.GELU(),
67
+ nn.Linear(config.hidden_size, config.hidden_size),
68
+ )
69
+
70
+ image_token_len = config.image_token_len
71
+ vision_tower_aux_token_len_list = (
72
+ self.config.mm_vision_tower_aux_token_len_list
73
+ )
74
+ cross_att_token_len_list = [
75
+ int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5)
76
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
77
+ ]
78
+
79
+ for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list):
80
+ setattr(
81
+ self,
82
+ "mm_projector_aux_{}".format(aux_i),
83
+ nn.Sequential(
84
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
85
+ nn.GELU(),
86
+ nn.Linear(vision_hidden_size, vision_hidden_size),
87
+ nn.LayerNorm(vision_hidden_size),
88
+ ),
89
+ )
90
+
91
+ for query_group_i in range(num_query_group):
92
+ cross_att_token_len_list = [
93
+ int(vision_tower_aux_token_len**0.5)
94
+ // int(query_num_list[query_group_i] ** 0.5)
95
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
96
+ ]
97
+ setattr(
98
+ self,
99
+ "vision_sampler_{}".format(query_group_i),
100
+ VisionTokenSampler(
101
+ vision_hidden_size,
102
+ vision_hidden_size,
103
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
104
+ cross_att_token_len_list,
105
+ vision_hidden_size,
106
+ connector_depth,
107
+ ),
108
+ )
109
+
110
+ if not connector_only:
111
+ num_of_vision_sampler_layers = (
112
+ config.num_of_vision_sampler_layers
113
+ ) = config.num_of_vision_sampler_layers
114
+ config.start_of_vision_sampler_layers = (
115
+ config.start_of_vision_sampler_layers
116
+ )
117
+ config.stride_of_vision_sampler_layers = (
118
+ config.stride_of_vision_sampler_layers
119
+ )
120
+ cross_att_token_len_list = [
121
+ int(vision_tower_aux_token_len**0.5)
122
+ // int(image_token_len**0.5)
123
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
124
+ ]
125
+ self.vision_sampler_layers = nn.ModuleList(
126
+ [
127
+ VisionTokenSampler(
128
+ config.hidden_size,
129
+ vision_hidden_size,
130
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
131
+ cross_att_token_len_list,
132
+ vision_hidden_size,
133
+ 1,
134
+ )
135
+ for layer_idx in range(0, num_of_vision_sampler_layers)
136
+ ]
137
+ )
138
+
139
+ self.vision_query = nn.Parameter(
140
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
141
+ )
142
+
143
+ self.image_newline = nn.Parameter(
144
+ torch.empty(config.hidden_size, dtype=self.dtype)
145
+ )
146
+
147
+ self.frame_pos = torch.stack(
148
+ [
149
+ 1
150
+ / torch.pow(
151
+ torch.tensor(10000),
152
+ torch.tensor(2 * (hid_j // 2) / config.hidden_size),
153
+ )
154
+ for hid_j in range(config.hidden_size)
155
+ ]
156
+ )
157
+
158
+ else:
159
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
160
+ config, delay_load=True
161
+ )
162
+ config.mm_hidden_size = sum(
163
+ [
164
+ vision_tower_aux.hidden_size
165
+ for vision_tower_aux in self.vision_tower_aux_list
166
+ ]
167
+ )
168
+ self.mm_projector = build_vision_projector(config)
169
+ self.image_newline = nn.Parameter(
170
+ torch.empty(config.hidden_size, dtype=self.dtype)
171
+ )
172
+
173
+ def get_frame_pos(self, time_range):
174
+ frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to(
175
+ self.frame_pos.device
176
+ )
177
+ frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2])
178
+ frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2])
179
+ frame_pos = frame_pos.unsqueeze(1)
180
+ return frame_pos
181
+
182
+ # def get_vision_tower(self):
183
+ # vision_tower = getattr(self, 'vision_tower', None)
184
+ # if type(vision_tower) is list:
185
+ # vision_tower = vision_tower[0]
186
+ # return vision_tower
187
+
188
+ def get_vision_tower_aux_list(self):
189
+ vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
190
+ return vision_tower_aux_list
191
+
192
+ def initialize_vision_modules(self, model_args, fsdp=None):
193
+ # vision_tower = model_args.vision_tower
194
+ num_query_group = model_args.num_query_group
195
+ query_num_list = model_args.query_num_list
196
+ vision_hidden_size = model_args.vision_hidden_size
197
+ vision_tower_aux_list = model_args.vision_tower_aux_list
198
+ vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
199
+ image_token_len = model_args.image_token_len
200
+ mm_vision_select_layer = model_args.mm_vision_select_layer
201
+ mm_vision_select_feature = model_args.mm_vision_select_feature
202
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
203
+ connector_only = model_args.connector_only
204
+ connector_depth = model_args.connector_depth
205
+
206
+ # self.config.mm_vision_tower = vision_tower
207
+ self.config.image_token_len = image_token_len
208
+ self.config.num_query_group = num_query_group
209
+ self.config.query_num_list = query_num_list
210
+ assert num_query_group == len(query_num_list)
211
+ self.config.connector_depth = connector_depth
212
+ self.config.mm_vision_tower_aux_list = vision_tower_aux_list
213
+ self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
214
+ self.config.connector_only = connector_only
215
+ self.config.highres_connect = model_args.highres_connect
216
+ self.config.highres = model_args.highres
217
+ self.config.frame_pos = model_args.frame_pos
218
+ self.config.lowres_token = model_args.lowres_token
219
+ self.config.connect_layer = model_args.connect_layer
220
+ self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83)
221
+ self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6)
222
+ self.config.is_image_newline = getattr(model_args, "is_image_newline", True)
223
+
224
+ if self.get_vision_tower_aux_list() is None:
225
+ vision_tower_aux_list = build_vision_tower_aux_list(model_args)
226
+ if model_args.unfreeze_mm_vision_tower:
227
+ self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
228
+ else:
229
+ self.vision_tower_aux_list = vision_tower_aux_list
230
+ else:
231
+ vision_tower_aux_list = self.vision_tower_aux_list
232
+ for vision_tower_aux in vision_tower_aux_list:
233
+ vision_tower_aux.load_model()
234
+
235
+ self.config.use_mm_proj = True
236
+ self.config.mm_projector_type = getattr(
237
+ model_args, "mm_projector_type", "linear"
238
+ )
239
+ self.config.vision_hidden_size = vision_hidden_size
240
+ self.config.mm_vision_select_layer = mm_vision_select_layer
241
+ self.config.mm_vision_select_feature = mm_vision_select_feature
242
+
243
+ if getattr(self, "mm_projector", None) is None:
244
+
245
+ if self.config.mm_projector_type == "sva":
246
+ self.mm_projector = nn.Sequential(
247
+ nn.Linear(
248
+ vision_hidden_size * num_query_group, self.config.hidden_size
249
+ ),
250
+ nn.GELU(),
251
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
252
+ )
253
+ for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list):
254
+ setattr(
255
+ self,
256
+ "mm_projector_aux_{}".format(aux_i),
257
+ nn.Sequential(
258
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
259
+ nn.GELU(),
260
+ nn.Linear(vision_hidden_size, vision_hidden_size),
261
+ nn.LayerNorm(vision_hidden_size),
262
+ ),
263
+ )
264
+
265
+ # vision sampler for each group of query as the connector before the LLM
266
+ for query_group_i in range(num_query_group):
267
+ cross_att_token_len_list = [
268
+ int(vision_tower_aux_token_len**0.5)
269
+ // int(query_num_list[query_group_i] ** 0.5)
270
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
271
+ ]
272
+ setattr(
273
+ self,
274
+ "vision_sampler_{}".format(query_group_i),
275
+ VisionTokenSampler(
276
+ vision_hidden_size,
277
+ vision_hidden_size,
278
+ [vision_hidden_size] * len(vision_tower_aux_list),
279
+ cross_att_token_len_list,
280
+ vision_hidden_size,
281
+ connector_depth,
282
+ ),
283
+ )
284
+
285
+ # sampler layers within LLM
286
+ if not connector_only:
287
+ num_of_vision_sampler_layers = (
288
+ self.config.num_of_vision_sampler_layers
289
+ ) = model_args.num_of_vision_sampler_layers
290
+ self.config.start_of_vision_sampler_layers = (
291
+ model_args.start_of_vision_sampler_layers
292
+ )
293
+ self.config.stride_of_vision_sampler_layers = (
294
+ model_args.stride_of_vision_sampler_layers
295
+ )
296
+ cross_att_token_len_list = [
297
+ int(vision_tower_aux_token_len**0.5)
298
+ // int(image_token_len**0.5)
299
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
300
+ ]
301
+ self.vision_sampler_layers = nn.ModuleList(
302
+ [
303
+ VisionTokenSampler(
304
+ self.config.hidden_size,
305
+ vision_hidden_size,
306
+ [vision_hidden_size] * len(vision_tower_aux_list),
307
+ cross_att_token_len_list,
308
+ vision_hidden_size,
309
+ 1,
310
+ )
311
+ for layer_idx in range(0, num_of_vision_sampler_layers)
312
+ ]
313
+ )
314
+ vision_embed_std = 1 / torch.sqrt(
315
+ torch.tensor(vision_hidden_size, dtype=self.dtype)
316
+ )
317
+ self.vision_query = nn.Parameter(
318
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
319
+ * vision_embed_std
320
+ )
321
+
322
+ embed_std = 1 / torch.sqrt(
323
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
324
+ )
325
+ self.image_newline = nn.Parameter(
326
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
327
+ )
328
+
329
+ else:
330
+ self.config.mm_hidden_size = sum(
331
+ [
332
+ vision_tower_aux.hidden_size
333
+ for vision_tower_aux in vision_tower_aux_list
334
+ ]
335
+ )
336
+ self.mm_projector = build_vision_projector(self.config)
337
+ embed_std = 1 / torch.sqrt(
338
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
339
+ )
340
+ self.image_newline = nn.Parameter(
341
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
342
+ )
343
+ else:
344
+ # In case it is frozen by LoRA
345
+ for p in self.mm_projector.parameters():
346
+ p.requires_grad = True
347
+
348
+ if pretrain_mm_mlp_adapter is not None:
349
+ mm_projector_weights = torch.load(
350
+ pretrain_mm_mlp_adapter, map_location="cpu"
351
+ )
352
+
353
+ def get_w(weights, keyword):
354
+ return {
355
+ k.split(keyword + ".")[1]: v
356
+ for k, v in weights.items()
357
+ if keyword + "." in k
358
+ }
359
+
360
+ self.mm_projector.load_state_dict(
361
+ get_w(mm_projector_weights, "mm_projector"), strict=True
362
+ )
363
+
364
+ if self.config.mm_projector_type == "sva":
365
+ for aux_i in range(len(vision_tower_aux_list)):
366
+ getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict(
367
+ get_w(
368
+ mm_projector_weights, "mm_projector_aux_{}".format(aux_i)
369
+ ),
370
+ strict=True,
371
+ )
372
+
373
+ for query_group_i in range(num_query_group):
374
+ getattr(
375
+ self, "vision_sampler_{}".format(query_group_i)
376
+ ).load_state_dict(
377
+ get_w(
378
+ mm_projector_weights,
379
+ "vision_sampler_{}".format(query_group_i),
380
+ ),
381
+ strict=True,
382
+ )
383
+
384
+ if not connector_only:
385
+ self.vision_sampler_layers.load_state_dict(
386
+ get_w(mm_projector_weights, "vision_sampler_layers"),
387
+ strict=True,
388
+ )
389
+ self.vision_query.data = mm_projector_weights["model.vision_query"]
390
+ self.image_newline.data = mm_projector_weights["model.image_newline"]
391
+
392
+
393
+ def unmask_attention_mask(mask, original_size):
394
+ original_w, original_h = original_size
395
+ cur_h, cur_w = mask.shape[1:3]
396
+
397
+ original_aspect_ratio = original_w / original_h
398
+ current_aspect_ratio = cur_w / cur_h
399
+
400
+ if original_aspect_ratio > current_aspect_ratio:
401
+ scale_factor = cur_w / original_w
402
+ new_height = int(original_h * scale_factor)
403
+ padding = (cur_h - new_height) // 2
404
+ if padding > 0:
405
+ mask[:, :padding, :] = 0
406
+ mask[:, -padding:, :] = 0
407
+ return mask
408
+ else:
409
+ scale_factor = cur_h / original_h
410
+ new_width = int(original_w * scale_factor)
411
+ padding = (cur_w - new_width) // 2
412
+ if padding > 0:
413
+ mask[:, :, :padding] = 0
414
+ mask[:, :, -padding:] = 0
415
+ return mask
416
+
417
+
418
+ def unpad_image(tensor, original_size):
419
+ """
420
+ Unpads a PyTorch tensor of a padded and resized image.
421
+
422
+ Args:
423
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
424
+ original_size (tuple): The original size of the image (height, width).
425
+
426
+ Returns:
427
+ torch.Tensor: The unpadded image tensor.
428
+ """
429
+ original_width, original_height = original_size
430
+ current_height, current_width = tensor.shape[1:3]
431
+
432
+ original_aspect_ratio = original_width / original_height
433
+ current_aspect_ratio = current_width / current_height
434
+
435
+ if original_aspect_ratio > current_aspect_ratio:
436
+ scale_factor = current_width / original_width
437
+ new_height = int(original_height * scale_factor)
438
+ padding = (current_height - new_height) // 2
439
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
440
+ # if 0 in unpadded_tensor.shape:
441
+ # print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
442
+ else:
443
+ scale_factor = current_height / original_height
444
+ new_width = int(original_width * scale_factor)
445
+ padding = (current_width - new_width) // 2
446
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
447
+ # if 0 in unpadded_tensor.shape:
448
+ # print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
449
+
450
+ return unpadded_tensor
451
+
452
+
453
+ class CambrianMetaForCausalLM(ABC):
454
+
455
+ @abstractmethod
456
+ def get_model(self):
457
+ pass
458
+
459
+ # def get_vision_tower(self):
460
+ # return self.get_model().get_vision_tower()
461
+
462
+ def get_vision_tower_aux_list(self):
463
+ return self.get_model().get_vision_tower_aux_list()
464
+
465
+ def rearrange_vision_tower_features_train(
466
+ self,
467
+ vision_tower_aux_feature_list,
468
+ vision_tower_aux_attention_masks_list,
469
+ query_side_len,
470
+ ):
471
+ vision_tower_aux_feature_rearranged_list = []
472
+ vision_tower_aux_attention_masks_rearranged_list = []
473
+ bs = vision_tower_aux_feature_list[0].shape[0]
474
+ for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip(
475
+ vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list
476
+ ):
477
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
478
+ assert (aux_height // query_side_len) * query_side_len == aux_height
479
+
480
+ reduce_factor = aux_height // query_side_len
481
+ vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view(
482
+ bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1
483
+ )
484
+ vision_tower_aux_feature_rearranged = (
485
+ vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5)
486
+ .contiguous()
487
+ .flatten(0, 2)
488
+ .flatten(1, 2)
489
+ )
490
+
491
+ vision_tower_aux_attention_masks_rearranged = (
492
+ vision_tower_aux_attention_masks.view(
493
+ bs * query_side_len * query_side_len, reduce_factor * reduce_factor
494
+ )
495
+ )
496
+
497
+ vision_tower_aux_feature_rearranged_list.append(
498
+ vision_tower_aux_feature_rearranged
499
+ )
500
+ vision_tower_aux_attention_masks_rearranged_list.append(
501
+ vision_tower_aux_attention_masks_rearranged
502
+ )
503
+ return (
504
+ vision_tower_aux_feature_rearranged_list,
505
+ vision_tower_aux_attention_masks_rearranged_list,
506
+ )
507
+
508
+ def rearrange_vision_tower_features_inference(
509
+ self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False
510
+ ):
511
+ vision_tower_aux_feature_rearranged_list = []
512
+ vision_tower_aux_attention_masks_rearranged_list = []
513
+ bs = vision_tower_aux_feature_list[0].shape[0]
514
+ for vision_tower_aux_feature in vision_tower_aux_feature_list:
515
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
516
+ assert (aux_height // query_side_len) * query_side_len == aux_height
517
+
518
+ reduce_factor = aux_height // query_side_len
519
+
520
+ vision_tower_aux_feature_rearranged = []
521
+ vision_tower_aux_attention_masks_rearranged = []
522
+ for batch_i in range(bs):
523
+ image_size = image_sizes[batch_i]
524
+ cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i]
525
+
526
+ cur_vision_tower_aux_attention_masks_rearranged = torch.ones(
527
+ (1, aux_height, aux_width),
528
+ dtype=torch.bool,
529
+ device=cur_vision_tower_aux_feature.device,
530
+ )
531
+ cur_vision_tower_aux_feature_rearranged = (
532
+ cur_vision_tower_aux_feature.view(
533
+ 1,
534
+ query_side_len,
535
+ reduce_factor,
536
+ query_side_len,
537
+ reduce_factor,
538
+ -1,
539
+ )
540
+ )
541
+ cur_vision_tower_aux_feature_rearranged = (
542
+ cur_vision_tower_aux_feature_rearranged.permute(
543
+ 0, 1, 3, 2, 4, 5
544
+ ).contiguous()
545
+ )
546
+ if unpad:
547
+ cur_vision_tower_aux_feature_rearranged = unpad_image(
548
+ cur_vision_tower_aux_feature_rearranged, image_size
549
+ )
550
+ cur_vision_tower_aux_feature_rearranged = (
551
+ cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2)
552
+ ) # query_side_len*query_side_len X reduce_factor*reduce_factor X C
553
+
554
+ cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask(
555
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
556
+ )
557
+ cur_vision_tower_aux_attention_masks_rearranged = (
558
+ cur_vision_tower_aux_attention_masks_rearranged.view(
559
+ 1, query_side_len, reduce_factor, query_side_len, reduce_factor
560
+ )
561
+ .permute(0, 1, 3, 2, 4)
562
+ .contiguous()
563
+ )
564
+ if unpad:
565
+ cur_vision_tower_aux_attention_masks_rearranged = unpad_image(
566
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
567
+ )
568
+ cur_vision_tower_aux_attention_masks_rearranged = (
569
+ cur_vision_tower_aux_attention_masks_rearranged.flatten(
570
+ 0, 2
571
+ ).flatten(1, 2)
572
+ )
573
+
574
+ cur_vision_tower_aux_attention_masks_rearranged[
575
+ cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0
576
+ ] = True
577
+
578
+ vision_tower_aux_feature_rearranged.append(
579
+ cur_vision_tower_aux_feature_rearranged
580
+ )
581
+ vision_tower_aux_attention_masks_rearranged.append(
582
+ cur_vision_tower_aux_attention_masks_rearranged
583
+ )
584
+
585
+ vision_tower_aux_feature_rearranged = torch.cat(
586
+ vision_tower_aux_feature_rearranged, 0
587
+ )
588
+ vision_tower_aux_attention_masks_rearranged = torch.cat(
589
+ vision_tower_aux_attention_masks_rearranged, 0
590
+ )
591
+
592
+ vision_tower_aux_feature_rearranged_list.append(
593
+ vision_tower_aux_feature_rearranged
594
+ )
595
+ vision_tower_aux_attention_masks_rearranged_list.append(
596
+ vision_tower_aux_attention_masks_rearranged
597
+ )
598
+
599
+ return (
600
+ vision_tower_aux_feature_rearranged_list,
601
+ vision_tower_aux_attention_masks_rearranged_list,
602
+ )
603
+
604
+ def encode_images(self, image_aux_list, encode_type=None):
605
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
606
+ image_aux_features_list = []
607
+ chunk_size = 64
608
+ if encode_type == "dino":
609
+ image_aux = image_aux_list[-1]
610
+ vision_tower_aux = vision_tower_aux_list[-1]
611
+ if image_aux.shape[0] > chunk_size:
612
+ image_aux_features_chunks = []
613
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
614
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
615
+ chunk = image_aux[start_idx:end_idx]
616
+ image_aux_features_chunk = vision_tower_aux(chunk)
617
+ image_aux_features_chunks.append(image_aux_features_chunk)
618
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
619
+ else:
620
+ image_aux_features = vision_tower_aux(image_aux)
621
+ return image_aux_features
622
+ elif encode_type == "siglip":
623
+ image_aux = image_aux_list[0]
624
+ vision_tower_aux = vision_tower_aux_list[0]
625
+ if image_aux.shape[0] > chunk_size:
626
+ image_aux_features_chunks = []
627
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
628
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
629
+ chunk = image_aux[start_idx:end_idx]
630
+ image_aux_features_chunk = vision_tower_aux(chunk)
631
+ image_aux_features_chunks.append(image_aux_features_chunk)
632
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
633
+ else:
634
+ image_aux_features = vision_tower_aux(image_aux)
635
+ return image_aux_features
636
+ else:
637
+ for image_aux, vision_tower_aux in zip(
638
+ image_aux_list, vision_tower_aux_list
639
+ ):
640
+ if image_aux.shape[0] > chunk_size:
641
+ image_aux_features_chunks = []
642
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
643
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
644
+ chunk = image_aux[start_idx:end_idx]
645
+ image_aux_features_chunk = vision_tower_aux(chunk)
646
+ image_aux_features_chunks.append(image_aux_features_chunk)
647
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
648
+ else:
649
+ image_aux_features = vision_tower_aux(image_aux)
650
+ image_aux_features_list.append(image_aux_features)
651
+ return image_aux_features_list
652
+
653
+ def select_frame(
654
+ self,
655
+ feature_list,
656
+ split_sizes,
657
+ input_ids,
658
+ new_image_aux_list,
659
+ image_sizes,
660
+ window_size=16,
661
+ threshold=0.83,
662
+ ):
663
+ dino_features_batch = torch.split(feature_list, split_sizes, dim=0)
664
+ new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0)
665
+ new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0)
666
+ new_split_sizes = []
667
+ selected_frames_all_0 = []
668
+ selected_frames_all_1 = []
669
+ selected_frames_feature_all = []
670
+ selected_frame_indices_all = []
671
+ for i_batch, frame_features in enumerate(dino_features_batch):
672
+ try:
673
+ if "llama" in self.get_model().config.model_type:
674
+ text_len = torch.where(input_ids[i_batch] == 128002)[-1][0]
675
+ else:
676
+ text_len = torch.where(input_ids[i_batch] == 151643)[-1][0]
677
+ except:
678
+ text_len = len(input_ids[i_batch])
679
+ original_width, original_height = image_sizes[i_batch]
680
+ if getattr(self.get_model().config, "highres", False):
681
+ token_per_frame = self.get_model().config.lowres_token ** 2
682
+ else:
683
+ token_per_frame = self.get_model().config.image_token_len
684
+ # current_height, current_width = token_per_side, token_per_side
685
+ # original_aspect_ratio = original_width / original_height
686
+ # current_aspect_ratio = current_width / current_height
687
+ # if original_aspect_ratio > current_aspect_ratio:
688
+ # scale_factor = current_width / original_width
689
+ # new_height = int(original_height * scale_factor)
690
+ # padding = math.ceil((current_height - new_height) / 2.0)
691
+ # token_per_frame = (
692
+ # current_height - padding * 2
693
+ # ) * token_per_side + token_per_side
694
+ # else:
695
+ # scale_factor = current_height / original_height
696
+ # new_width = int(original_width * scale_factor)
697
+ # padding = math.ceil((current_width - new_width) / 2.0)
698
+ # token_per_frame = (current_width - padding * 2) * token_per_side + (
699
+ # current_width - padding * 2
700
+ # )
701
+ # token_per_frame = (
702
+ # token_per_side**2 if token_per_frame < 1 else token_per_frame
703
+ # )
704
+ max_num_frames = max(
705
+ 1,
706
+ (
707
+ self.get_model().config.tokenizer_model_max_length
708
+ - text_len
709
+ - getattr(self.get_model().config, "inference_max_length", 16)
710
+ )
711
+ // token_per_frame,
712
+ )
713
+ if len(frame_features) < max_num_frames:
714
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch])
715
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch])
716
+ selected_frames_feature_all.append(frame_features)
717
+ new_split_sizes.append(len(frame_features))
718
+ selected_frame_indices_all.append(torch.arange(len(frame_features)))
719
+ continue
720
+
721
+ num_segments = len(frame_features) // window_size
722
+ if num_segments == 0:
723
+ query_feature = frame_features.flatten(1, 2)
724
+ query_feature = query_feature / torch.norm(
725
+ (query_feature), dim=1, keepdim=True
726
+ )
727
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
728
+ similarities[len(frame_features) // 2] = 0
729
+ indices = torch.where(similarities < threshold)[0]
730
+ selected_frame_indices_all.append(indices)
731
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices])
732
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices])
733
+ selected_frames_feature_all.append(frame_features[indices])
734
+ new_split_sizes.append(len(indices))
735
+ continue
736
+ segments_frames_0 = []
737
+ segments_frames_1 = []
738
+ segments_features = []
739
+ for start_idx in range(0, len(frame_features), window_size):
740
+ end_idx = min(start_idx + window_size, len(frame_features))
741
+ segments_frames_0.append(
742
+ new_image_aux_batch_0[i_batch][start_idx:end_idx]
743
+ )
744
+ segments_frames_1.append(
745
+ new_image_aux_batch_1[i_batch][start_idx:end_idx]
746
+ )
747
+ segments_features.append(frame_features[start_idx:end_idx])
748
+ selected_frames_0 = []
749
+ selected_frames_1 = []
750
+ selected_features = []
751
+ selected_frame_indices = []
752
+ for i, segment in enumerate(segments_features):
753
+ query_feature = segment.flatten(1, 2)
754
+ query_feature = query_feature / torch.norm(
755
+ (query_feature), dim=1, keepdim=True
756
+ )
757
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
758
+ similarities[len(segment) // 2] = 0
759
+ indices = torch.where(similarities < threshold)[0]
760
+ selected_frames_0.append(segments_frames_0[i][indices])
761
+ selected_frames_1.append(segments_frames_1[i][indices])
762
+ selected_features.append(segment[indices])
763
+ selected_frame_indices.extend(indices + i * window_size)
764
+ selected_frames_0 = torch.cat(selected_frames_0, dim=0)
765
+ selected_frames_1 = torch.cat(selected_frames_1, dim=0)
766
+ selected_features = torch.cat(selected_features, dim=0)
767
+ selected_frame_indices = torch.tensor(selected_frame_indices)
768
+ # ablation
769
+ max_num_frames = 400 # in case of OOM
770
+ if len(selected_frames_0) > max_num_frames:
771
+ interval = len(selected_frames_0) / float(max_num_frames)
772
+ indices = [int(interval * i) for i in range(max_num_frames)]
773
+ new_split_sizes.append(len(indices))
774
+ selected_frames_all_0.append(selected_frames_0[indices])
775
+ selected_frames_all_1.append(selected_frames_1[indices])
776
+ selected_frames_feature_all.append(selected_features[indices])
777
+ selected_frame_indices = selected_frame_indices[indices]
778
+ else:
779
+ new_split_sizes.append(len(selected_frames_0))
780
+ selected_frames_all_0.append(selected_frames_0)
781
+ selected_frames_all_1.append(selected_frames_1)
782
+ selected_frames_feature_all.append(selected_features)
783
+ selected_frame_indices_all.append(selected_frame_indices)
784
+ selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0)
785
+ selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0)
786
+ selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0)
787
+ return (
788
+ selected_frames_feature_all,
789
+ new_split_sizes,
790
+ [selected_frames_all_0, selected_frames_all_1],
791
+ selected_frame_indices_all,
792
+ )
793
+
794
+ def prepare_inputs_labels_for_multimodal(
795
+ self,
796
+ input_ids,
797
+ position_ids,
798
+ attention_mask,
799
+ past_key_values,
800
+ labels,
801
+ images,
802
+ image_aux_attention_masks_list=None,
803
+ image_sizes=None,
804
+ ):
805
+ # vision_tower = self.get_vision_tower()
806
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
807
+ if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1:
808
+ return (
809
+ input_ids,
810
+ position_ids,
811
+ attention_mask,
812
+ past_key_values,
813
+ None,
814
+ labels,
815
+ None,
816
+ None,
817
+ None,
818
+ None,
819
+ )
820
+
821
+ image_aux_list = images
822
+
823
+ split_sizes = None
824
+
825
+ if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5:
826
+ split_sizes_ori = [
827
+ 1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0]
828
+ ]
829
+ new_image_aux_list = []
830
+ for image_aux in image_aux_list:
831
+ if type(image_aux) is list:
832
+ image_aux = [
833
+ x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux
834
+ ]
835
+ concat_image_aux = torch.cat([image for image in image_aux], dim=0)
836
+ new_image_aux_list.append(concat_image_aux)
837
+ image_aux_features_dino = self.encode_images(
838
+ new_image_aux_list, encode_type="dino"
839
+ )
840
+
841
+ (
842
+ image_aux_features_dino,
843
+ split_sizes,
844
+ new_image_aux_list,
845
+ selected_frame_indices_all,
846
+ ) = self.select_frame(
847
+ image_aux_features_dino,
848
+ split_sizes_ori,
849
+ input_ids,
850
+ new_image_aux_list,
851
+ image_sizes,
852
+ threshold=getattr(self.get_model().config, "dino_threshold", 0.83),
853
+ )
854
+
855
+ image_aux_features_siglip = self.encode_images(
856
+ new_image_aux_list, encode_type="siglip"
857
+ )
858
+ image_aux_features_list = [
859
+ image_aux_features_siglip,
860
+ image_aux_features_dino,
861
+ ]
862
+
863
+ bs = image_aux_features_list[0].shape[0]
864
+ dtype = new_image_aux_list[0].dtype
865
+
866
+ frame_sizes = []
867
+ for i in range(len(image_sizes)):
868
+ for j in range(split_sizes[i]):
869
+ frame_sizes.append(image_sizes[i])
870
+ image_sizes = frame_sizes
871
+ else:
872
+ image_aux_features_list = self.encode_images(image_aux_list)
873
+ bs = image_aux_list[0].shape[0]
874
+ dtype = image_aux_list[0].dtype
875
+
876
+ image_token_len = self.get_model().config.image_token_len
877
+ query_num_list = self.get_model().config.query_num_list
878
+
879
+ final_height = final_width = int(image_token_len**0.5)
880
+
881
+ final_image_features_list = []
882
+ final_image_features_down_list = []
883
+
884
+ # only needed for sva
885
+ vision_tower_aux_feature_list_final = None
886
+ vision_tower_aux_attention_masks_list_final = None
887
+ global_context_feature_final = None
888
+
889
+ if self.get_model().config.mm_projector_type == "sva":
890
+ vision_tower_aux_feature_list = []
891
+ vision_tower_aux_attention_masks_list = []
892
+ # get vision tokens from each vision tower
893
+ for aux_i in range(len(vision_tower_aux_list)):
894
+ image_aux_features = image_aux_features_list[aux_i]
895
+
896
+ image_aux_features = getattr(
897
+ self.get_model(), "mm_projector_aux_{}".format(aux_i)
898
+ )(image_aux_features).to(dtype)
899
+ if aux_i == 0:
900
+ global_context_feature = image_aux_features.mean(1).view(
901
+ bs, 1, 1, -1
902
+ )
903
+
904
+ vision_tower_aux_feature_list.append(image_aux_features)
905
+ input_mix_res = True
906
+ input_high_res = True
907
+ # perform vision sampling for each query group
908
+ for query_group_i, query_num in enumerate(query_num_list):
909
+ query_features_i = (
910
+ self.get_model()
911
+ .vision_query[query_group_i, :]
912
+ .view(1, 1, 1, -1)
913
+ .expand(bs, query_num, -1, -1)
914
+ )
915
+ global_context_feature_i = global_context_feature.expand(
916
+ -1, query_num, 1, -1
917
+ ).flatten(0, 1)
918
+ query_side_len = int(query_num**0.5)
919
+ if IS_XLA_AVAILABLE:
920
+ (
921
+ vision_tower_aux_feature_list_i,
922
+ vision_tower_aux_attention_masks_list_i,
923
+ ) = self.rearrange_vision_tower_features_train(
924
+ vision_tower_aux_feature_list,
925
+ image_aux_attention_masks_list,
926
+ query_side_len,
927
+ )
928
+ else:
929
+ (
930
+ vision_tower_aux_feature_list_i,
931
+ vision_tower_aux_attention_masks_list_i,
932
+ ) = self.rearrange_vision_tower_features_inference(
933
+ vision_tower_aux_feature_list, query_side_len, image_sizes
934
+ )
935
+
936
+ query_features_i = getattr(
937
+ self.get_model(), "vision_sampler_{}".format(query_group_i)
938
+ )(
939
+ query_features_i.flatten(0, 1),
940
+ global_context_feature_i,
941
+ *vision_tower_aux_feature_list_i,
942
+ *vision_tower_aux_attention_masks_list_i,
943
+ )
944
+ query_features_i = query_features_i.view(bs, query_num, -1)
945
+
946
+ if split_sizes is not None:
947
+ try:
948
+ if "llama" in self.get_model().config.model_type:
949
+ text_len = torch.where(input_ids[0] == 128002)[-1][0]
950
+ else:
951
+ text_len = torch.where(input_ids[0] == 151643)[-1][0]
952
+ except:
953
+ text_len = len(input_ids[0])
954
+ max_visual_len = (
955
+ self.get_model().config.tokenizer_model_max_length
956
+ - text_len
957
+ - getattr(self.get_model().config, "inference_max_length", 16)
958
+ )
959
+ max_num_frames = max(
960
+ 1,
961
+ math.floor(max_visual_len // (final_height * final_width)),
962
+ )
963
+ max_num_frames_low = max(
964
+ 1,
965
+ math.floor(
966
+ max_visual_len
967
+ // (self.get_model().config.lowres_token ** 2)
968
+ ),
969
+ )
970
+ if split_sizes[0] < max_num_frames:
971
+ input_mix_res = False
972
+ elif split_sizes[0] > max_num_frames_low:
973
+ input_mix_res = False
974
+ input_high_res = False
975
+
976
+ # input_mix_res = False # ablation
977
+
978
+ if (getattr(self.config, "highres", False)) and input_mix_res:
979
+ _query_features_i = (
980
+ query_features_i.permute(0, 2, 1)
981
+ .contiguous()
982
+ .view(bs, -1, query_side_len, query_side_len)
983
+ )
984
+ _query_features_i = F.interpolate(
985
+ _query_features_i.float(),
986
+ size=(
987
+ self.get_model().config.lowres_token,
988
+ self.get_model().config.lowres_token,
989
+ ),
990
+ mode="bilinear",
991
+ align_corners=False,
992
+ ).to(dtype=query_features_i.dtype)
993
+ _query_features_i = (
994
+ _query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
995
+ )
996
+ final_image_features_down_list.append(_query_features_i)
997
+
998
+ # interpolate to the final target size
999
+ if query_side_len != final_height:
1000
+ query_features_i = (
1001
+ query_features_i.permute(0, 2, 1)
1002
+ .contiguous()
1003
+ .view(bs, -1, query_side_len, query_side_len)
1004
+ )
1005
+ if input_high_res:
1006
+ query_features_i = F.interpolate(
1007
+ query_features_i.float(),
1008
+ size=(final_height, final_width),
1009
+ mode="bilinear",
1010
+ align_corners=False,
1011
+ ).to(dtype=query_features_i.dtype)
1012
+ else:
1013
+ query_features_i = F.interpolate(
1014
+ query_features_i.float(),
1015
+ size=(8, 8),
1016
+ mode="bilinear",
1017
+ align_corners=False,
1018
+ ).to(dtype=query_features_i.dtype)
1019
+ query_features_i = (
1020
+ query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
1021
+ )
1022
+ final_image_features_list.append(query_features_i)
1023
+
1024
+ if IS_XLA_AVAILABLE:
1025
+ (
1026
+ vision_tower_aux_feature_list_final,
1027
+ vision_tower_aux_attention_masks_list_final,
1028
+ ) = self.rearrange_vision_tower_features_train(
1029
+ vision_tower_aux_feature_list,
1030
+ image_aux_attention_masks_list,
1031
+ final_height,
1032
+ )
1033
+ global_context_feature_final = global_context_feature.expand(
1034
+ -1, final_height * final_width, 1, -1
1035
+ ).flatten(0, 1)
1036
+ else:
1037
+ final_image_features_list = image_aux_features_list
1038
+
1039
+ image_features = torch.cat(final_image_features_list, -1)
1040
+ image_features = self.get_model().mm_projector(image_features).to(dtype)
1041
+
1042
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1043
+ image_features_down = torch.cat(final_image_features_down_list, -1)
1044
+ image_features_down = (
1045
+ self.get_model().mm_projector(image_features_down).to(dtype)
1046
+ )
1047
+
1048
+ if IS_XLA_AVAILABLE:
1049
+ image_features = image_features.view(
1050
+ image_features.shape[0], final_height, final_width, -1
1051
+ )
1052
+ image_features = torch.cat(
1053
+ (
1054
+ image_features,
1055
+ self.model.image_newline[None, None, None, :].expand(
1056
+ image_features.shape[0], final_height, 1, -1
1057
+ ),
1058
+ ),
1059
+ dim=2,
1060
+ )
1061
+ image_features = image_features.flatten(1, 2)
1062
+ final_size = [(final_height, final_width)] * bs
1063
+
1064
+ else:
1065
+ image_features = image_features.view(bs, final_height, final_width, -1)
1066
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1067
+ image_features_down = image_features_down.view(
1068
+ bs,
1069
+ self.get_model().config.lowres_token,
1070
+ self.get_model().config.lowres_token,
1071
+ -1,
1072
+ )
1073
+ image_features_unpadded = []
1074
+ image_features_downsample = []
1075
+ final_size = []
1076
+ if self.get_model().config.mm_projector_type == "sva":
1077
+ (
1078
+ vision_tower_aux_feature_list_final,
1079
+ vision_tower_aux_attention_masks_list_final,
1080
+ ) = self.rearrange_vision_tower_features_inference(
1081
+ vision_tower_aux_feature_list, final_height, image_sizes, unpad=True
1082
+ )
1083
+ global_context_feature_final = []
1084
+ for batch_i in range(bs):
1085
+ cur_image_feature = image_features[batch_i]
1086
+ image_size = image_sizes[batch_i]
1087
+
1088
+ cur_image_feature = unpad_image(
1089
+ cur_image_feature.unsqueeze(0), image_size
1090
+ )
1091
+
1092
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1093
+ try: # fix bug for some invalid image
1094
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1095
+ final_size.append((cur_h, cur_w))
1096
+ except:
1097
+ # print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True)
1098
+ cur_image_feature = image_features[batch_i].unsqueeze(0)
1099
+ image_size = image_sizes[batch_i]
1100
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1101
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1102
+ final_size.append((cur_h, cur_w))
1103
+
1104
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1105
+ cur_image_feature_down = unpad_image(
1106
+ image_features_down[batch_i].unsqueeze(0),
1107
+ (
1108
+ int(
1109
+ image_size[0]
1110
+ / (
1111
+ image_token_len**0.5
1112
+ / self.get_model().config.lowres_token
1113
+ )
1114
+ ),
1115
+ int(
1116
+ image_size[1]
1117
+ / (
1118
+ image_token_len**0.5
1119
+ / self.get_model().config.lowres_token
1120
+ )
1121
+ ),
1122
+ ),
1123
+ )
1124
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1125
+
1126
+ try: # fix bug for some invalid image
1127
+ cur_image_feature_down = cur_image_feature_down.view(
1128
+ 1, _cur_h, _cur_w, -1
1129
+ )
1130
+ except:
1131
+ print("invalid after unpad", flush=True)
1132
+ cur_image_feature_down = image_features_down[batch_i].unsqueeze(
1133
+ 0
1134
+ )
1135
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1136
+ cur_image_feature_down = cur_image_feature_down.view(
1137
+ 1, _cur_h, _cur_w, -1
1138
+ )
1139
+
1140
+ cur_image_feature_down = torch.cat(
1141
+ (
1142
+ cur_image_feature_down,
1143
+ self.model.image_newline.view(1, 1, 1, -1)
1144
+ .expand(1, _cur_h, 1, -1)
1145
+ .to(cur_image_feature_down.device),
1146
+ ),
1147
+ dim=2,
1148
+ ).flatten(1, 2)
1149
+
1150
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1151
+ frame_pos = (
1152
+ self.get_model()
1153
+ .get_frame_pos(torch.arange(1))
1154
+ .to(cur_image_feature_down.device)
1155
+ .to(cur_image_feature_down.dtype)
1156
+ )
1157
+ cur_image_feature_down += frame_pos
1158
+
1159
+ image_features_downsample.append(cur_image_feature_down.squeeze(0))
1160
+
1161
+ cur_image_feature = torch.cat(
1162
+ (
1163
+ cur_image_feature,
1164
+ self.model.image_newline.view(1, 1, 1, -1)
1165
+ .expand(1, cur_h, 1, -1)
1166
+ .to(cur_image_feature.device),
1167
+ ),
1168
+ dim=2,
1169
+ )
1170
+
1171
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1172
+ frame_pos = (
1173
+ self.get_model()
1174
+ .get_frame_pos(torch.arange(1))
1175
+ .to(cur_image_feature.device)
1176
+ .to(cur_image_feature.dtype)
1177
+ )
1178
+ cur_image_feature += frame_pos
1179
+
1180
+ cur_image_feature = cur_image_feature.flatten(1, 2)
1181
+ image_features_unpadded.append(cur_image_feature.squeeze(0))
1182
+
1183
+ if self.get_model().config.mm_projector_type == "sva":
1184
+ cur_global_context_feature = global_context_feature[batch_i].expand(
1185
+ cur_h * cur_w, 1, -1
1186
+ )
1187
+ global_context_feature_final.append(cur_global_context_feature)
1188
+ if self.get_model().config.mm_projector_type == "sva":
1189
+ global_context_feature_final = torch.cat(
1190
+ global_context_feature_final, 0
1191
+ )
1192
+
1193
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1194
+ image_features = image_features_downsample
1195
+ else:
1196
+ image_features = image_features_unpadded
1197
+
1198
+ # TODO: image start / end is not implemented here to support pretraining.
1199
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
1200
+ self.config, "mm_use_im_start_end", False
1201
+ ):
1202
+ raise NotImplementedError
1203
+
1204
+ split_image_features_unpadded = None
1205
+ frame_split_sizes = None
1206
+
1207
+ if split_sizes is not None:
1208
+ split_image_features = []
1209
+ split_image_features_unpadded = (
1210
+ []
1211
+ if (getattr(self.config, "highres", False)) and input_mix_res
1212
+ else None
1213
+ )
1214
+ start_idx = 0
1215
+ for split_batch_idx, split_size in enumerate(split_sizes):
1216
+ if isinstance(image_features[start_idx : start_idx + split_size], list):
1217
+ if getattr(self.config, "frame_pos", False):
1218
+ frame_feature = torch.cat(
1219
+ image_features[start_idx : start_idx + split_size], dim=0
1220
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1221
+ frame_pos = (
1222
+ self.get_model()
1223
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1224
+ .to(frame_feature.device)
1225
+ .to(frame_feature.dtype)
1226
+ )
1227
+ frame_feature += frame_pos
1228
+ split_image_features.append(
1229
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1230
+ )
1231
+ else:
1232
+ split_image_features.append(
1233
+ torch.cat(
1234
+ image_features[start_idx : start_idx + split_size],
1235
+ dim=0,
1236
+ )
1237
+ )
1238
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1239
+ if getattr(self.config, "frame_pos", False):
1240
+ frame_feature = torch.cat(
1241
+ image_features_unpadded[
1242
+ start_idx : start_idx + split_size
1243
+ ],
1244
+ dim=0,
1245
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1246
+ frame_pos = (
1247
+ self.get_model()
1248
+ .get_frame_pos(
1249
+ selected_frame_indices_all[split_batch_idx]
1250
+ )
1251
+ .to(frame_feature.device)
1252
+ .to(frame_feature.dtype)
1253
+ )
1254
+ frame_feature += frame_pos
1255
+ split_image_features_unpadded.append(
1256
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1257
+ )
1258
+ else:
1259
+ split_image_features_unpadded.append(
1260
+ torch.cat(
1261
+ image_features_unpadded[
1262
+ start_idx : start_idx + split_size
1263
+ ],
1264
+ dim=0,
1265
+ )
1266
+ )
1267
+ else:
1268
+ if getattr(self.config, "frame_pos", False):
1269
+ frame_feature = image_features[
1270
+ start_idx : start_idx + split_size
1271
+ ].reshape(split_size, -1, image_features[0].shape[-1])
1272
+ frame_pos = (
1273
+ self.get_model()
1274
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1275
+ .to(frame_feature.device)
1276
+ .to(frame_feature.dtype)
1277
+ )
1278
+ frame_feature += frame_pos
1279
+ split_image_features.append(
1280
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1281
+ )
1282
+ else:
1283
+ split_image_features.append(
1284
+ image_features[start_idx : start_idx + split_size]
1285
+ )
1286
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1287
+ if getattr(self.config, "frame_pos", False):
1288
+ frame_feature = image_features_unpadded[
1289
+ start_idx : start_idx + split_size
1290
+ ]
1291
+ frame_pos = (
1292
+ self.get_model()
1293
+ .get_frame_pos(
1294
+ selected_frame_indices_all[split_batch_idx]
1295
+ )
1296
+ .to(frame_feature.device)
1297
+ .to(frame_feature.dtype)
1298
+ )
1299
+ frame_feature += frame_pos
1300
+ split_image_features_unpadded.append(
1301
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1302
+ )
1303
+ else:
1304
+ split_image_features_unpadded.append(
1305
+ image_features_unpadded[
1306
+ start_idx : start_idx + split_size
1307
+ ]
1308
+ )
1309
+ start_idx += split_size
1310
+ image_features = split_image_features
1311
+ frame_split_sizes = split_sizes
1312
+
1313
+ _labels = labels
1314
+ _position_ids = position_ids
1315
+ _attention_mask = attention_mask
1316
+ if attention_mask is None:
1317
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1318
+ else:
1319
+ attention_mask = attention_mask.bool()
1320
+ if position_ids is None:
1321
+ position_ids = torch.arange(
1322
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
1323
+ )
1324
+ if labels is None:
1325
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1326
+
1327
+ # remove the padding using attention_mask -- FIXME
1328
+ _input_ids = input_ids
1329
+
1330
+ attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
1331
+
1332
+ input_ids = [
1333
+ cur_input_ids[cur_attention_mask]
1334
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
1335
+ ]
1336
+ labels = [
1337
+ cur_labels[cur_attention_mask]
1338
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
1339
+ ]
1340
+
1341
+ new_input_embeds = []
1342
+ new_labels = []
1343
+ image_token_indices_batch = []
1344
+ cur_image_idx = 0
1345
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1346
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1347
+ if num_images == 0:
1348
+ cur_image_features = image_features[cur_image_idx]
1349
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1350
+ cur_input_embeds = torch.cat(
1351
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
1352
+ )
1353
+ new_input_embeds.append(cur_input_embeds)
1354
+ new_labels.append(labels[batch_idx])
1355
+ cur_image_idx += 1
1356
+ continue
1357
+
1358
+ image_token_indices = (
1359
+ [-1]
1360
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
1361
+ + [cur_input_ids.shape[0]]
1362
+ )
1363
+ image_token_indices_batch.append(
1364
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0]
1365
+ )
1366
+ cur_input_ids_noim = []
1367
+ cur_labels = labels[batch_idx]
1368
+ cur_labels_noim = []
1369
+ for i in range(len(image_token_indices) - 1):
1370
+ cur_input_ids_noim.append(
1371
+ cur_input_ids[
1372
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
1373
+ ]
1374
+ )
1375
+ cur_labels_noim.append(
1376
+ cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
1377
+ )
1378
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1379
+ cur_input_embeds = self.get_model().embed_tokens(
1380
+ torch.cat(cur_input_ids_noim)
1381
+ )
1382
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1383
+ cur_new_input_embeds = []
1384
+ cur_new_labels = []
1385
+
1386
+ text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
1387
+ visual_len = len(image_features[cur_image_idx])
1388
+ max_visual_len = (
1389
+ self.get_model().config.tokenizer_model_max_length
1390
+ - getattr(self.get_model().config, "inference_max_length", 16)
1391
+ - text_len
1392
+ )
1393
+ mix_token = False
1394
+
1395
+ # ablation mix
1396
+ if (
1397
+ input_mix_res
1398
+ and (
1399
+ self.get_model().config.image_token_len
1400
+ > getattr(self.get_model().config, "lowres_token", 8) ** 2
1401
+ )
1402
+ and frame_split_sizes is not None
1403
+ and getattr(self.config, "highres", False)
1404
+ ):
1405
+ if max_visual_len > visual_len:
1406
+ visual_emb = image_features[cur_image_idx]
1407
+ text_emb = cur_input_embeds_no_im[-1]
1408
+ highres_num = math.floor(
1409
+ (max_visual_len - visual_len)
1410
+ / (
1411
+ split_image_features_unpadded[cur_image_idx].shape[0]
1412
+ // frame_split_sizes[cur_image_idx]
1413
+ - visual_emb.shape[0] // frame_split_sizes[cur_image_idx]
1414
+ )
1415
+ )
1416
+ if highres_num >= 1:
1417
+ mix_token = True
1418
+ sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean(
1419
+ dim=-1
1420
+ )
1421
+ sim_frame = sim.reshape(
1422
+ frame_split_sizes[cur_image_idx], -1
1423
+ ).mean(dim=-1)
1424
+ highres_num = min(highres_num, sim_frame.shape[0])
1425
+ top_values, top_indices = torch.topk(sim_frame, highres_num)
1426
+ if len(top_indices) > 0:
1427
+ sorted_indices = torch.sort(top_indices)[1]
1428
+ top_indices = top_indices[sorted_indices]
1429
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1430
+ frame_split_sizes[cur_image_idx],
1431
+ -1,
1432
+ image_features[cur_image_idx].shape[-1],
1433
+ )
1434
+ visual_emb_frame_highres = split_image_features_unpadded[
1435
+ cur_image_idx
1436
+ ].reshape(
1437
+ frame_split_sizes[cur_image_idx],
1438
+ -1,
1439
+ split_image_features_unpadded[cur_image_idx].shape[-1],
1440
+ )
1441
+ current_point = 0
1442
+ mix_visual_emb_frame = []
1443
+ for frame_i in range(len(visual_emb_frame)):
1444
+ if current_point > len(top_indices) - 1:
1445
+ mix_visual_emb_frame.append(
1446
+ visual_emb_frame[frame_i]
1447
+ )
1448
+ continue
1449
+ if frame_i == top_indices[current_point]:
1450
+ mix_visual_emb_frame.append(
1451
+ visual_emb_frame_highres[frame_i]
1452
+ )
1453
+ current_point += 1
1454
+ else:
1455
+ mix_visual_emb_frame.append(
1456
+ visual_emb_frame[frame_i]
1457
+ )
1458
+ image_features[cur_image_idx] = torch.cat(
1459
+ mix_visual_emb_frame, dim=0
1460
+ )
1461
+ # ablation drop
1462
+
1463
+ if (
1464
+ max_visual_len < visual_len
1465
+ and frame_split_sizes is not None
1466
+ and not mix_token
1467
+ ):
1468
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1469
+ frame_split_sizes[cur_image_idx],
1470
+ -1,
1471
+ image_features[cur_image_idx].shape[-1],
1472
+ )
1473
+
1474
+ sim = F.cosine_similarity(
1475
+ visual_emb_frame[:-1],
1476
+ visual_emb_frame[1:],
1477
+ dim=-1,
1478
+ )
1479
+
1480
+ new_visual_emb_frames = []
1481
+ for start_idx in range(0, len(visual_emb_frame), 8):
1482
+ end_idx = min(start_idx + 8, len(visual_emb_frame))
1483
+ chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C
1484
+ if len(chunk_feature) == 1:
1485
+ new_visual_emb_frames.append(chunk_feature[0])
1486
+ continue
1487
+ sim = F.cosine_similarity(
1488
+ chunk_feature[0]
1489
+ .unsqueeze(0)
1490
+ .repeat_interleave(len(chunk_feature[1:]), dim=0),
1491
+ chunk_feature[1:],
1492
+ dim=-1,
1493
+ )
1494
+ new_visual_emb_frame = torch.cat(
1495
+ [
1496
+ chunk_feature[0],
1497
+ chunk_feature[1:].flatten(0, 1)[
1498
+ sim.flatten(0, 1)
1499
+ < getattr(
1500
+ self.get_model().config, "drop_threshold", 0.7
1501
+ )
1502
+ ],
1503
+ ],
1504
+ dim=0,
1505
+ )
1506
+ new_visual_emb_frames.append(new_visual_emb_frame)
1507
+
1508
+ reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames])
1509
+
1510
+ if reduced_visual_len > max_visual_len:
1511
+ force_remove = math.ceil(
1512
+ (reduced_visual_len - max_visual_len)
1513
+ / len(new_visual_emb_frames)
1514
+ )
1515
+ for chunk_i in range(len(new_visual_emb_frames)):
1516
+ new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][
1517
+ :-force_remove
1518
+ ]
1519
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1520
+ else:
1521
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1522
+
1523
+ image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len]
1524
+
1525
+ for i in range(num_images + 1):
1526
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1527
+ cur_new_labels.append(cur_labels_noim[i])
1528
+ if i < num_images:
1529
+ cur_image_features = image_features[cur_image_idx]
1530
+ cur_image_idx += 1
1531
+ cur_new_input_embeds.append(cur_image_features)
1532
+ cur_new_labels.append(
1533
+ torch.full(
1534
+ (cur_image_features.shape[0],),
1535
+ IGNORE_INDEX,
1536
+ device=cur_labels.device,
1537
+ dtype=cur_labels.dtype,
1538
+ )
1539
+ )
1540
+
1541
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1542
+
1543
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1544
+ cur_new_labels = torch.cat(cur_new_labels)
1545
+
1546
+ new_input_embeds.append(cur_new_input_embeds)
1547
+ new_labels.append(cur_new_labels)
1548
+
1549
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1550
+ tokenizer_model_max_length = getattr(
1551
+ self.config, "tokenizer_model_max_length", None
1552
+ )
1553
+ if tokenizer_model_max_length is not None:
1554
+ new_input_embeds = [
1555
+ x[:tokenizer_model_max_length] for x in new_input_embeds
1556
+ ]
1557
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1558
+
1559
+ # Combine them
1560
+ max_len = max(x.shape[0] for x in new_input_embeds)
1561
+ batch_size = len(new_input_embeds)
1562
+
1563
+ new_input_embeds_padded = []
1564
+ new_labels_padded = torch.full(
1565
+ (batch_size, max_len),
1566
+ IGNORE_INDEX,
1567
+ dtype=new_labels[0].dtype,
1568
+ device=new_labels[0].device,
1569
+ )
1570
+ attention_mask = torch.zeros(
1571
+ (batch_size, max_len),
1572
+ dtype=attention_mask.dtype,
1573
+ device=attention_mask.device,
1574
+ )
1575
+ position_ids = torch.zeros(
1576
+ (batch_size, max_len),
1577
+ dtype=position_ids.dtype,
1578
+ device=position_ids.device,
1579
+ )
1580
+
1581
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
1582
+ zip(new_input_embeds, new_labels)
1583
+ ):
1584
+ cur_len = cur_new_embed.shape[0]
1585
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
1586
+ new_input_embeds_padded.append(
1587
+ torch.cat(
1588
+ (
1589
+ torch.zeros(
1590
+ (max_len - cur_len, cur_new_embed.shape[1]),
1591
+ dtype=cur_new_embed.dtype,
1592
+ device=cur_new_embed.device,
1593
+ ),
1594
+ cur_new_embed,
1595
+ ),
1596
+ dim=0,
1597
+ )
1598
+ )
1599
+ if cur_len > 0:
1600
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1601
+ attention_mask[i, -cur_len:] = True
1602
+ position_ids[i, -cur_len:] = torch.arange(
1603
+ 0,
1604
+ cur_len,
1605
+ dtype=position_ids.dtype,
1606
+ device=position_ids.device,
1607
+ )
1608
+ else:
1609
+ new_input_embeds_padded.append(
1610
+ torch.cat(
1611
+ (
1612
+ cur_new_embed,
1613
+ torch.zeros(
1614
+ (max_len - cur_len, cur_new_embed.shape[1]),
1615
+ dtype=cur_new_embed.dtype,
1616
+ device=cur_new_embed.device,
1617
+ ),
1618
+ ),
1619
+ dim=0,
1620
+ )
1621
+ )
1622
+ if cur_len > 0:
1623
+ new_labels_padded[i, :cur_len] = cur_new_labels
1624
+ attention_mask[i, :cur_len] = True
1625
+ position_ids[i, :cur_len] = torch.arange(
1626
+ 0,
1627
+ cur_len,
1628
+ dtype=position_ids.dtype,
1629
+ device=position_ids.device,
1630
+ )
1631
+
1632
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1633
+
1634
+ if _labels is None:
1635
+ new_labels = None
1636
+ else:
1637
+ new_labels = new_labels_padded
1638
+
1639
+ if _attention_mask is None:
1640
+ attention_mask = None
1641
+ else:
1642
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1643
+
1644
+ if _position_ids is None:
1645
+ position_ids = None
1646
+
1647
+ return (
1648
+ None,
1649
+ position_ids,
1650
+ attention_mask,
1651
+ past_key_values,
1652
+ new_input_embeds,
1653
+ new_labels,
1654
+ vision_tower_aux_feature_list_final,
1655
+ vision_tower_aux_attention_masks_list_final,
1656
+ final_size,
1657
+ global_context_feature_final,
1658
+ )
1659
+
1660
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
1661
+ if model_args.mm_use_im_patch_token:
1662
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
1663
+ self.resize_token_embeddings(len(tokenizer))
1664
+
1665
+ if model_args.mm_use_im_start_end:
1666
+ num_new_tokens = tokenizer.add_tokens(
1667
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
1668
+ )
1669
+ self.resize_token_embeddings(len(tokenizer))
1670
+
1671
+ if num_new_tokens > 0:
1672
+ input_embeddings = self.get_input_embeddings().weight.data
1673
+ output_embeddings = self.get_output_embeddings().weight.data
1674
+
1675
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
1676
+ dim=0, keepdim=True
1677
+ )
1678
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
1679
+ dim=0, keepdim=True
1680
+ )
1681
+
1682
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
1683
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
1684
+
1685
+ if model_args.tune_mm_mlp_adapter:
1686
+ for p in self.get_input_embeddings().parameters():
1687
+ p.requires_grad = True
1688
+ for p in self.get_output_embeddings().parameters():
1689
+ p.requires_grad = False
1690
+
1691
+ if model_args.pretrain_mm_mlp_adapter:
1692
+ mm_projector_weights = torch.load(
1693
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
1694
+ )
1695
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
1696
+ assert num_new_tokens == 2
1697
+ if input_embeddings.shape == embed_tokens_weight.shape:
1698
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
1699
+ -num_new_tokens:
1700
+ ]
1701
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
1702
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
1703
+ else:
1704
+ raise ValueError(
1705
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
1706
+ )
1707
+ elif model_args.mm_use_im_patch_token:
1708
+ if model_args.tune_mm_mlp_adapter:
1709
+ for p in self.get_input_embeddings().parameters():
1710
+ p.requires_grad = False
1711
+ for p in self.get_output_embeddings().parameters():
1712
+ p.requires_grad = False
multimodal_encoder_builder.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
6
+ from abc import ABC, abstractmethod
7
+ import torch.nn as nn
8
+
9
+
10
+ class ProcessorWrapper:
11
+ def __init__(
12
+ self,
13
+ transform,
14
+ height=378,
15
+ width=378,
16
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
17
+ ):
18
+ self._crop_size = {
19
+ "height": height,
20
+ "width": width,
21
+ }
22
+ self._transforms = transform
23
+ # print(transform)
24
+ self.image_mean = image_mean
25
+
26
+ @property
27
+ def crop_size(self):
28
+ return self._crop_size
29
+
30
+ def preprocess(self, image, return_tensors="pt"):
31
+ # Ensure image is a PIL Image
32
+ output = {}
33
+ output["pixel_values"] = [self._transforms(image)]
34
+ return output
35
+
36
+
37
+ class BaseVisionTower(nn.Module):
38
+ def __init__(self, vision_tower_name, args, delay_load=False):
39
+ super().__init__()
40
+
41
+ self.is_loaded = False
42
+ self.args = args
43
+
44
+ self.vision_tower_name = vision_tower_name
45
+ self.select_layer = args.mm_vision_select_layer
46
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
47
+ self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
48
+ self.delay_load = delay_load
49
+
50
+ @abstractmethod
51
+ def load_model(self, device_map=None):
52
+ raise NotImplementedError("Subclasses must implement load_model")
53
+
54
+ @abstractmethod
55
+ def _forward(self, images):
56
+ raise NotImplementedError("Subclasses must implement forward")
57
+
58
+ def forward(self, images):
59
+ if type(images) is list:
60
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
61
+ else:
62
+ image_features = self._forward(images)
63
+
64
+ return image_features
65
+
66
+ @property
67
+ def dummy_feature(self):
68
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
69
+
70
+ @property
71
+ def dtype(self):
72
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
73
+ if hasattr(self.vision_tower, "dtype"):
74
+ return self.vision_tower.dtype
75
+ else:
76
+ params = list(self.vision_tower.parameters())
77
+ return (
78
+ params[0].dtype if len(params) > 0 else torch.float32
79
+ ) # Default to torch.float32 if no parameters
80
+
81
+ @property
82
+ def device(self):
83
+ # Dynamically infer the device from the first parameter, if not explicitly specified
84
+ if hasattr(self.vision_tower, "device"):
85
+ return self.vision_tower.device
86
+ else:
87
+ params = list(self.vision_tower.parameters())
88
+ return (
89
+ params[0].device if len(params) > 0 else torch.device("cpu")
90
+ ) # Default to CPU if no parameters
91
+
92
+ @property
93
+ def config(self):
94
+ if self.is_loaded:
95
+ return self.vision_tower.config
96
+ else:
97
+ return self.cfg_only
98
+
99
+ @property
100
+ def hidden_size(self):
101
+ try:
102
+ return self.config.hidden_size
103
+ except:
104
+ return self._hidden_size
105
+
106
+ @property
107
+ def image_size(self): # resolution
108
+ # return self.config.image_size
109
+ try:
110
+ return self.config.image_size
111
+ except:
112
+ return self._image_size
113
+
114
+ @property
115
+ def patch_size(self):
116
+ # return self.config.patch_size
117
+ try:
118
+ return self.config.patch_size
119
+ except:
120
+ return self._patch_size
121
+
122
+ @property
123
+ def num_patches_per_side(self):
124
+ if self._interp_size is not None:
125
+ return int(self._interp_size**0.5)
126
+ try:
127
+ return self.image_size // self.patch_size
128
+ except:
129
+ return self._num_patches_per_side
130
+
131
+ @property
132
+ def num_patches(self):
133
+ if self._interp_size is not None:
134
+ return self._interp_size
135
+ try:
136
+ return self.num_patches_per_side**2
137
+ except:
138
+ return self._num_patches
139
+
140
+
141
+ class DinoVisionTower(BaseVisionTower):
142
+ def __init__(self, vision_tower, args, delay_load=False):
143
+ super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
144
+
145
+ model_path = "facebook/dinov2-giant"
146
+ base_model_name, res, interp = model_path, 378, 576
147
+ self._vision_tower_name = vision_tower
148
+ self.vision_tower_name = base_model_name
149
+ self._image_size = res
150
+ self._interp_size = interp
151
+ self._patch_size = 14 # default patch size
152
+
153
+ if not self.delay_load:
154
+ self.load_model()
155
+ else:
156
+ self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
157
+
158
+ def load_model(self, device_map=None):
159
+
160
+ self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
161
+ """ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
162
+ self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
163
+
164
+ _image_size = self.vision_tower.config.image_size
165
+ if self._image_size is None:
166
+ self._image_size = _image_size
167
+
168
+ # increase shortest edge to prevent edge case crops
169
+ default_shortest_ratio = 8 / 7 # 224/256
170
+ # shortest_edge = int(default_shortest_ratio * self._image_size)
171
+ shortest_edge = self._image_size
172
+
173
+ processor = AutoImageProcessor.from_pretrained(
174
+ self.vision_tower_name,
175
+ crop_size=dict(height=self._image_size, width=self._image_size),
176
+ size=dict(shortest_edge=shortest_edge),
177
+ )
178
+ self.image_processor = processor
179
+
180
+ # Assign the output channels of the projection convolution as the hidden size
181
+ self._hidden_size = (
182
+ self.vision_tower.embeddings.patch_embeddings.projection.out_channels
183
+ )
184
+ # Assign the first value of the stride of the projection convolution as the patch size
185
+ self._patch_size = (
186
+ self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
187
+ )
188
+
189
+ # print(self._hidden_size, self._patch_size)
190
+
191
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
192
+ self.is_loaded = True
193
+
194
+ @property
195
+ def image_size(self):
196
+ return self._image_size
197
+
198
+ def feature_select(self, outputs):
199
+ sequence_output = outputs[
200
+ "last_hidden_state"
201
+ ] # batch_size, sequence_length, hidden_size
202
+
203
+ if self.select_feature == "cls_patch":
204
+ image_features = sequence_output
205
+ elif self.select_feature == "patch":
206
+ image_features = sequence_output[:, 1:]
207
+ elif self.select_feature == "cls":
208
+ image_features = sequence_output[:, 0]
209
+ else:
210
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
211
+ return image_features
212
+
213
+ def interpolate(self, image_features):
214
+ if self._interp_size is None:
215
+ return image_features
216
+
217
+ b, num_tokens, dim = image_features.shape
218
+
219
+ if num_tokens != self.num_patches:
220
+ target_h = target_w = int(self._interp_size**0.5)
221
+ h = w = int(num_tokens**0.5)
222
+
223
+ image_features = image_features.view(b, h, w, dim)
224
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
225
+
226
+ image_features = F.interpolate(
227
+ image_features.to(torch.float32),
228
+ size=(target_h, target_w),
229
+ mode="bilinear",
230
+ align_corners=False,
231
+ ).to(image_features.dtype)
232
+
233
+ # Permute the dimensions back to (b, target_h, target_w, dim)
234
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
235
+
236
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
237
+ image_features = image_features.flatten(1, 2)
238
+
239
+ return image_features
240
+
241
+ def _forward(self, images):
242
+ # logger.warning(f"images shape: {images.shape}")
243
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
244
+ image_forward_outs = self.vision_tower.forward(
245
+ images.to(device=self.device, dtype=self.dtype)
246
+ )
247
+ # logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
248
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
249
+ # logger.warning(f"image_features shape: {image_features.shape}")
250
+ interp_features = self.interpolate(image_features)
251
+ # logger.warning(f"interp_features shape: {interp_features.shape}")
252
+ return interp_features
253
+
254
+ @property
255
+ def num_patches_per_side(self):
256
+ return int(self.num_patches**0.5)
257
+
258
+ @property
259
+ def num_patches(self):
260
+ if self._interp_size is None:
261
+ return (self._image_size // self._patch_size) ** 2
262
+ else:
263
+ return self._interp_size
264
+
265
+
266
+ # from .siglip_encoder import SiglipVisionTower
267
+ class SiglipVisionTower(BaseVisionTower):
268
+ def __init__(self, vision_tower_name, args, delay_load=False):
269
+ super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
270
+
271
+ model_path = "google/siglip-so400m-patch14-384"
272
+ base_model_name, res, interp = model_path, 384, 576
273
+ self.vision_tower_name = base_model_name
274
+ self._image_size = res if res is not None else 512
275
+ self._interp_size = interp
276
+ if not self.delay_load:
277
+ self.load_model()
278
+ elif self.unfreeze_mm_vision_tower:
279
+ self.load_model()
280
+ else:
281
+ self._hidden_size = 1152
282
+
283
+ def load_model(self, device_map=None):
284
+ self.vision_model = "siglip"
285
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
286
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
287
+
288
+ # self.vision_tower = clip_model.visual.trunk
289
+ self.vision_tower.output_tokens = True
290
+
291
+ self._hidden_size = self.vision_tower.config.hidden_size
292
+ self._image_size = self.vision_tower.config.image_size
293
+ self._patch_size = self.vision_tower.config.patch_size
294
+ self.image_processor = SiglipImageProcessor.from_pretrained(
295
+ self.vision_tower_name
296
+ )
297
+
298
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
299
+ self.is_loaded = True
300
+
301
+ def interpolate(self, image_features):
302
+ if self._interp_size is None:
303
+ return image_features
304
+
305
+ b, num_tokens, dim = image_features.shape
306
+
307
+ if num_tokens != self.num_patches:
308
+ target_h = target_w = int(self._interp_size**0.5)
309
+ h = w = int(num_tokens**0.5)
310
+
311
+ image_features = image_features.view(b, h, w, dim)
312
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
313
+
314
+ image_features = F.interpolate(
315
+ image_features.to(torch.float32),
316
+ size=(target_h, target_w),
317
+ mode="bilinear",
318
+ align_corners=False,
319
+ ).to(image_features.dtype)
320
+
321
+ # Permute the dimensions back to (b, target_h, target_w, dim)
322
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
323
+
324
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
325
+ image_features = image_features.flatten(1, 2)
326
+
327
+ return image_features
328
+
329
+ def _forward(self, images, interpolate_token=576):
330
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
331
+ image_features = self.vision_tower.forward(
332
+ images.to(device=self.device, dtype=self.dtype),
333
+ output_hidden_states=True,
334
+ ).hidden_states[-1]
335
+ interp_features = self.interpolate(image_features)
336
+ return interp_features
337
+
338
+
339
+ def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
340
+ vision_tower_aux_name_list = getattr(
341
+ vision_tower_cfg,
342
+ "mm_vision_tower_aux_list",
343
+ getattr(vision_tower_cfg, "vision_tower_aux_list", None),
344
+ )
345
+ vision_tower_aux_token_len_list = getattr(
346
+ vision_tower_cfg,
347
+ "mm_vision_tower_aux_token_len_list",
348
+ getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
349
+ )
350
+ vision_tower_aux_list = []
351
+ for vision_tower_aux_name, vision_tower_aux_token_len in zip(
352
+ vision_tower_aux_name_list, vision_tower_aux_token_len_list
353
+ ):
354
+ config = copy.deepcopy(vision_tower_cfg)
355
+ vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
356
+ if "siglip" in vision_tower_aux_name.lower():
357
+ vision_tower_aux_list.append(
358
+ SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
359
+ )
360
+
361
+ # SSL-based Vision Towers
362
+ elif "dinov2" in vision_tower_aux_name.lower():
363
+ vision_tower_aux_list.append(
364
+ DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
365
+ )
366
+ else:
367
+ raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
368
+ return vision_tower_aux_list
multimodal_projector_builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import re
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": "identity"}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+ config.mm_hidden_size = 256
36
+
37
+ if projector_type == "linear":
38
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
39
+
40
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == "identity":
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f"Unknown projector type: {projector_type}")
vision_sampler.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+
9
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
10
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ grid_h = np.arange(grid_size, dtype=np.float32)
17
+ grid_w = np.arange(grid_size, dtype=np.float32)
18
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
19
+ grid = np.stack(grid, axis=0)
20
+
21
+ grid = grid.reshape([2, 1, grid_size, grid_size])
22
+
23
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24
+ if cls_token:
25
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26
+ return pos_embed
27
+
28
+
29
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
30
+ assert embed_dim % 2 == 0
31
+
32
+ # use half of dimensions to encode grid_h
33
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
34
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
35
+
36
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
37
+ return emb
38
+
39
+
40
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
41
+ """
42
+ embed_dim: output dimension for each position
43
+ pos: a list of positions to be encoded: size (M,)
44
+ out: (M, D)
45
+ """
46
+ assert embed_dim % 2 == 0
47
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
48
+ omega /= embed_dim / 2.0
49
+ omega = 1.0 / 10000**omega # (D/2,)
50
+
51
+ pos = pos.reshape(-1) # (M,)
52
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
53
+
54
+ emb_sin = np.sin(out) # (M, D/2)
55
+ emb_cos = np.cos(out) # (M, D/2)
56
+
57
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
58
+ return emb
59
+
60
+
61
+ class CrossAttention(nn.Module):
62
+
63
+ def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
64
+ super().__init__()
65
+ self.hidden_dim = hidden_dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = self.hidden_dim // self.num_heads
68
+
69
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
70
+ raise ValueError(
71
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
72
+ f" and `num_heads`: {self.num_heads})."
73
+ )
74
+
75
+ self.q_proj = nn.Sequential(
76
+ nn.LayerNorm(q_dim),
77
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
78
+ )
79
+ self.k_proj = nn.Sequential(
80
+ nn.LayerNorm(kv_dim),
81
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
82
+ )
83
+ self.v_proj = nn.Sequential(
84
+ nn.LayerNorm(kv_dim),
85
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
86
+ )
87
+ self.o_proj = nn.Linear(
88
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
89
+ )
90
+
91
+ def forward(self, vision_latents, queries, attention_mask):
92
+
93
+ bsz, q_len, _ = queries.size()
94
+ bsz, v_len, _ = vision_latents.size()
95
+
96
+ query_states = self.q_proj(queries)
97
+ key_states = self.k_proj(vision_latents)
98
+ value_states = self.v_proj(vision_latents)
99
+
100
+ query_states = query_states.view(
101
+ bsz, q_len, self.num_heads, self.head_dim
102
+ ).transpose(1, 2)
103
+ key_states = key_states.view(
104
+ bsz, v_len, self.num_heads, self.head_dim
105
+ ).transpose(1, 2)
106
+ value_states = value_states.view(
107
+ bsz, v_len, self.num_heads, self.head_dim
108
+ ).transpose(1, 2)
109
+
110
+ if attention_mask is not None:
111
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
112
+ raise ValueError(
113
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
114
+ )
115
+
116
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
117
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
118
+ if query_states.device.type == "cuda" and attention_mask is not None:
119
+ query_states = query_states.contiguous()
120
+ key_states = key_states.contiguous()
121
+ value_states = value_states.contiguous()
122
+
123
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
124
+ query_states,
125
+ key_states,
126
+ value_states,
127
+ attn_mask=attention_mask,
128
+ )
129
+
130
+ attn_output = attn_output.transpose(1, 2).contiguous()
131
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
132
+
133
+ attn_output = self.o_proj(attn_output)
134
+
135
+ return attn_output
136
+
137
+
138
+ class AggregationBlock(nn.Module):
139
+ def __init__(
140
+ self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
141
+ ):
142
+ super().__init__()
143
+ self.hidden_dim = hidden_dim
144
+ self.num_heads = num_heads
145
+ self.head_dim = self.hidden_dim // self.num_heads
146
+
147
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
148
+ raise ValueError(
149
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
150
+ f" and `num_heads`: {self.num_heads})."
151
+ )
152
+
153
+ self.attention = attention
154
+ if attention:
155
+ self.attention_layer = CrossAttention(
156
+ q_dim, kv_dim, hidden_dim, num_heads, attention_bias
157
+ )
158
+ else:
159
+ self.attention_layer = MLP(kv_dim, q_dim, q_dim)
160
+
161
+ def forward(self, vision_latents, queries, attention_mask):
162
+ if self.attention:
163
+ queries = self.attention_layer(vision_latents, queries, attention_mask)
164
+ else:
165
+ queries = self.attention_layer(vision_latents)
166
+
167
+ return queries
168
+
169
+
170
+ class MultiKVCrossAttention(nn.Module):
171
+
172
+ def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
173
+ super().__init__()
174
+
175
+ self.hidden_dim = hidden_dim
176
+ self.num_heads = num_heads
177
+ self.head_dim = self.hidden_dim // self.num_heads
178
+
179
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
180
+ raise ValueError(
181
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
182
+ f" and `num_heads`: {self.num_heads})."
183
+ )
184
+
185
+ self.q_proj = nn.Sequential(
186
+ nn.LayerNorm(q_dim),
187
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
188
+ )
189
+ self.num_of_kvs = len(kv_dim_list)
190
+ for i, kv_dim in enumerate(kv_dim_list):
191
+ setattr(
192
+ self,
193
+ "k_proj_{}".format(i),
194
+ nn.Sequential(
195
+ nn.LayerNorm(kv_dim),
196
+ nn.Linear(
197
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
198
+ ),
199
+ ),
200
+ )
201
+ setattr(
202
+ self,
203
+ "v_proj_{}".format(i),
204
+ nn.Sequential(
205
+ nn.LayerNorm(kv_dim),
206
+ nn.Linear(
207
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
208
+ ),
209
+ ),
210
+ )
211
+ self.o_proj = nn.Linear(
212
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
213
+ )
214
+
215
+ def forward(
216
+ self,
217
+ queries,
218
+ *vision_latents_attention_mask_list,
219
+ ):
220
+
221
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
222
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
223
+
224
+ bsz, q_len, _ = queries.size()
225
+
226
+ query_states = self.q_proj(queries)
227
+ key_states = torch.cat(
228
+ [
229
+ getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
230
+ for i in range(self.num_of_kvs)
231
+ ],
232
+ dim=1,
233
+ )
234
+ value_states = torch.cat(
235
+ [
236
+ getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
237
+ for i in range(self.num_of_kvs)
238
+ ],
239
+ dim=1,
240
+ )
241
+
242
+ v_len = key_states.shape[1]
243
+
244
+ query_states = query_states.view(
245
+ bsz, q_len, self.num_heads, self.head_dim
246
+ ).transpose(1, 2)
247
+ key_states = key_states.view(
248
+ bsz, v_len, self.num_heads, self.head_dim
249
+ ).transpose(1, 2)
250
+ value_states = value_states.view(
251
+ bsz, v_len, self.num_heads, self.head_dim
252
+ ).transpose(1, 2)
253
+
254
+ # if kv_weight is not None:
255
+ # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
256
+
257
+ attention_mask = torch.cat(attention_mask_list, dim=-1)
258
+
259
+ if attention_mask is not None:
260
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
261
+ raise ValueError(
262
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
263
+ )
264
+
265
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
266
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
267
+ if query_states.device.type == "cuda" and attention_mask is not None:
268
+ query_states = query_states.contiguous()
269
+ key_states = key_states.contiguous()
270
+ value_states = value_states.contiguous()
271
+
272
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
273
+ query_states,
274
+ key_states,
275
+ value_states,
276
+ attn_mask=attention_mask,
277
+ )
278
+ # attn_output = spda(
279
+ # query_states,
280
+ # key_states,
281
+ # value_states,
282
+ # attn_mask=attention_mask,
283
+ # additional_score=kv_weight
284
+ # )
285
+
286
+ attn_output = attn_output.transpose(1, 2).contiguous()
287
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
288
+
289
+ attn_output = self.o_proj(attn_output)
290
+
291
+ return attn_output
292
+
293
+
294
+ class MLP(nn.Module):
295
+ def __init__(self, d_in, d_hidden, d_out):
296
+ super().__init__()
297
+ self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
298
+ self.act = nn.GELU()
299
+ self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
300
+
301
+ def forward(self, x):
302
+ return self.linear_2(self.act(self.linear_1(x)))
303
+
304
+
305
+ class VisionCrossAttentionLayer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ q_dim,
309
+ context_dim,
310
+ kv_dim_list,
311
+ kv_size_list,
312
+ hidden_dim=1024,
313
+ layer_idx=0,
314
+ ):
315
+ super().__init__()
316
+ num_heads = 16
317
+ self.num_of_kvs = len(kv_dim_list)
318
+
319
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
320
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
321
+ # if self.num_of_kvs > 1:
322
+ # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
323
+ # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
324
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
325
+
326
+ self.norm = nn.LayerNorm(hidden_dim)
327
+
328
+ self.cross_attn = MultiKVCrossAttention(
329
+ hidden_dim, kv_dim_list, hidden_dim, num_heads
330
+ )
331
+ self.kv_size_list = kv_size_list
332
+ for i, kv_size in enumerate(kv_size_list):
333
+ if kv_size > 1:
334
+ setattr(
335
+ self,
336
+ "pos_embed_{}".format(i),
337
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
338
+ )
339
+ # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
340
+
341
+ def forward(
342
+ self,
343
+ queries,
344
+ context_feature,
345
+ *vision_latents_attention_mask_list,
346
+ ) -> torch.FloatTensor:
347
+
348
+ residual = queries
349
+ # queries = self.proj_in(queries)
350
+ context_feature = self.proj_context(context_feature)
351
+ # queries = queries + context_feature
352
+ queries = torch.cat([queries, context_feature], -1)
353
+
354
+ # if self.num_of_kvs > 1:
355
+ # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
356
+ # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
357
+ # kv_weight = kv_weight.softmax(-1)
358
+ # kv_number_list = [size**2 for size in self.kv_size_list]
359
+ # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
360
+ # else:
361
+ # kv_weight = None
362
+
363
+ queries = self.proj_in(queries)
364
+
365
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
366
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
367
+
368
+ attention_mask_list_reshaped = []
369
+ if attention_mask_list is not None:
370
+ for attention_mask in attention_mask_list:
371
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
372
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
373
+ attention_mask_list_reshaped.append(attention_mask)
374
+
375
+ vision_latents_pos_list = []
376
+ for i, vision_latents in enumerate(vision_latents_list):
377
+ if vision_latents.shape[1] > 1:
378
+ vision_latents_pos_list.append(
379
+ vision_latents
380
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
381
+ vision_latents.dtype
382
+ )
383
+ )
384
+ else:
385
+ vision_latents_pos_list.append(vision_latents)
386
+
387
+ # Cross Attention
388
+ attention_output = self.cross_attn(
389
+ queries, *vision_latents_pos_list, *attention_mask_list_reshaped
390
+ )
391
+
392
+ # attention_output = (attention_output * combination_weight).sum(2)
393
+ queries = queries + attention_output
394
+
395
+ queries = self.norm(queries)
396
+
397
+ queries = self.proj_out(queries)
398
+
399
+ queries = queries + residual
400
+
401
+ return queries
402
+
403
+
404
+ class VisionAggregationLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ q_dim,
408
+ context_dim,
409
+ kv_dim_list,
410
+ kv_size_list,
411
+ hidden_dim=1024,
412
+ layer_idx=0,
413
+ ):
414
+ super().__init__()
415
+ num_heads = 16
416
+ self.num_of_kvs = len(kv_dim_list)
417
+
418
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
419
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
420
+
421
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
422
+
423
+ self.norm = nn.LayerNorm(hidden_dim)
424
+
425
+ if self.num_of_kvs > 1:
426
+ self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
427
+
428
+ for i, kv_size in enumerate(kv_size_list):
429
+ if kv_size > 1:
430
+ setattr(
431
+ self,
432
+ "pos_embed_{}".format(i),
433
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
434
+ )
435
+ setattr(
436
+ self,
437
+ "aggregate_{}".format(i),
438
+ AggregationBlock(
439
+ True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
440
+ ),
441
+ )
442
+ else:
443
+ setattr(
444
+ self,
445
+ "aggregate_{}".format(i),
446
+ AggregationBlock(
447
+ False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
448
+ ),
449
+ )
450
+
451
+ def forward(
452
+ self,
453
+ queries,
454
+ context_feature,
455
+ *vision_latents_attention_mask_list,
456
+ ) -> torch.FloatTensor:
457
+
458
+ residual = queries
459
+ # queries = self.proj_in(queries)
460
+ context_feature = self.proj_context(context_feature)
461
+ # queries = queries + context_feature
462
+ queries = torch.cat([queries, context_feature], -1)
463
+
464
+ if self.num_of_kvs > 1:
465
+ combination_weight = self.weight_mlp(queries).softmax(
466
+ -1
467
+ ) # B * 1 * num_tower
468
+ combination_weight = combination_weight.unsqueeze(-1)
469
+ else:
470
+ combination_weight = 1
471
+
472
+ queries = self.proj_in(queries)
473
+
474
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
475
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
476
+
477
+ attention_mask_list_reshaped = []
478
+ if attention_mask_list is not None:
479
+ for attention_mask in attention_mask_list:
480
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
481
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
482
+ attention_mask_list_reshaped.append(attention_mask)
483
+
484
+ vision_latents_pos_list = []
485
+ for i, vision_latents in enumerate(vision_latents_list):
486
+ if vision_latents.shape[1] > 1:
487
+ vision_latents_pos_list.append(
488
+ vision_latents
489
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
490
+ vision_latents.dtype
491
+ )
492
+ )
493
+ else:
494
+ vision_latents_pos_list.append(vision_latents)
495
+
496
+ aggregated_vision_latents_list = []
497
+ for i, (vision_latents, attention_mask) in enumerate(
498
+ zip(vision_latents_pos_list, attention_mask_list_reshaped)
499
+ ):
500
+ aggregated_vision_latents_list.append(
501
+ getattr(self, "aggregate_{}".format(i))(
502
+ vision_latents, queries, attention_mask
503
+ )
504
+ )
505
+
506
+ aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
507
+
508
+ queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
509
+
510
+ queries = self.norm(queries)
511
+
512
+ queries = self.proj_out(queries)
513
+
514
+ queries = queries + residual
515
+
516
+ return queries
517
+
518
+
519
+ class VisionTokenSampler(nn.Module):
520
+ def __init__(
521
+ self,
522
+ q_dim,
523
+ context_dim,
524
+ kv_dim_list,
525
+ kv_size_list,
526
+ vision_hidden_size,
527
+ num_of_layers=1,
528
+ layer_type="joint",
529
+ ):
530
+ super().__init__()
531
+ assert layer_type in ["joint", "sep"]
532
+ if layer_type == "joint":
533
+ self.layers = nn.ModuleList(
534
+ [
535
+ VisionCrossAttentionLayer(
536
+ q_dim,
537
+ context_dim,
538
+ kv_dim_list,
539
+ kv_size_list,
540
+ vision_hidden_size,
541
+ idx,
542
+ )
543
+ for idx in range(num_of_layers)
544
+ ]
545
+ )
546
+ else:
547
+ self.layers = nn.ModuleList(
548
+ [
549
+ VisionAggregationLayer(
550
+ q_dim,
551
+ context_dim,
552
+ kv_dim_list,
553
+ kv_size_list,
554
+ vision_hidden_size,
555
+ idx,
556
+ )
557
+ for idx in range(num_of_layers)
558
+ ]
559
+ )
560
+
561
+ def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
562
+ for layer in self.layers:
563
+ queries = layer(
564
+ queries, context_feature, *vision_latents_attention_mask_list
565
+ )
566
+ return queries