jadechoghari commited on
Commit
15fca6e
1 Parent(s): f3448bf

add initial files

Browse files
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - video-text-to-text
4
+ ---
5
+ # Citation
6
+
7
+ ```
8
+ @article{shen2024longvu,
9
+ title={LongVU: Spatiotemporal Adaptive Compression for Long Video-Language Understanding},
10
+ author={Shen, Xiaoqian and Xiong, Yunyang and Zhao, Changsheng and Wu, Lemeng and Chen, Jun and Zhu, Chenchen and Liu, Zechun and Xiao, Fanyi and Varadarajan, Balakrishnan and Bordes, Florian and Liu, Zhuang and Xu, Hu and J. Kim, Hyunwoo and Soran, Bilge and Krishnamoorthi, Raghuraman and Elhoseiny, Mohamed and Chandra, Vikas},
11
+ journal={arXiv:2410.17434},
12
+ year={2024}
13
+ }
14
+ ```
cambrian_arch.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import random
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ # define the constants
25
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
26
+ WORKER_HEART_BEAT_INTERVAL = 15
27
+
28
+ LOGDIR = "."
29
+
30
+ # Model Constants
31
+ IGNORE_INDEX = -100
32
+ IMAGE_TOKEN_INDEX = -200
33
+ DEFAULT_IMAGE_TOKEN = "<image>"
34
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
35
+ DEFAULT_IM_START_TOKEN = "<im_start>"
36
+ DEFAULT_IM_END_TOKEN = "<im_end>"
37
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
38
+
39
+ from .multimodal_encoder_builder import build_vision_tower_aux_list
40
+ from .multimodal_projector_builder import build_vision_projector
41
+ from .vision_sampler import VisionTokenSampler
42
+
43
+ IS_XLA_AVAILABLE = False
44
+
45
+
46
+ class CambrianMetaModel:
47
+
48
+ def __init__(self, config):
49
+ super(CambrianMetaModel, self).__init__(config)
50
+
51
+ if hasattr(config, "mm_vision_tower_aux_list"):
52
+
53
+ projector_type = getattr(config, "mm_projector_type", "linear")
54
+ if projector_type == "sva":
55
+
56
+ vision_hidden_size = config.vision_hidden_size
57
+ num_query_group = config.num_query_group
58
+ query_num_list = config.query_num_list
59
+ connector_only = config.connector_only
60
+ connector_depth = config.connector_depth
61
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
62
+ config, delay_load=True
63
+ )
64
+ self.mm_projector = nn.Sequential(
65
+ nn.Linear(vision_hidden_size * num_query_group, config.hidden_size),
66
+ nn.GELU(),
67
+ nn.Linear(config.hidden_size, config.hidden_size),
68
+ )
69
+
70
+ image_token_len = config.image_token_len
71
+ vision_tower_aux_token_len_list = (
72
+ self.config.mm_vision_tower_aux_token_len_list
73
+ )
74
+ cross_att_token_len_list = [
75
+ int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5)
76
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
77
+ ]
78
+
79
+ for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list):
80
+ setattr(
81
+ self,
82
+ "mm_projector_aux_{}".format(aux_i),
83
+ nn.Sequential(
84
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
85
+ nn.GELU(),
86
+ nn.Linear(vision_hidden_size, vision_hidden_size),
87
+ nn.LayerNorm(vision_hidden_size),
88
+ ),
89
+ )
90
+
91
+ for query_group_i in range(num_query_group):
92
+ cross_att_token_len_list = [
93
+ int(vision_tower_aux_token_len**0.5)
94
+ // int(query_num_list[query_group_i] ** 0.5)
95
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
96
+ ]
97
+ setattr(
98
+ self,
99
+ "vision_sampler_{}".format(query_group_i),
100
+ VisionTokenSampler(
101
+ vision_hidden_size,
102
+ vision_hidden_size,
103
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
104
+ cross_att_token_len_list,
105
+ vision_hidden_size,
106
+ connector_depth,
107
+ ),
108
+ )
109
+
110
+ if not connector_only:
111
+ num_of_vision_sampler_layers = (
112
+ config.num_of_vision_sampler_layers
113
+ ) = config.num_of_vision_sampler_layers
114
+ config.start_of_vision_sampler_layers = (
115
+ config.start_of_vision_sampler_layers
116
+ )
117
+ config.stride_of_vision_sampler_layers = (
118
+ config.stride_of_vision_sampler_layers
119
+ )
120
+ cross_att_token_len_list = [
121
+ int(vision_tower_aux_token_len**0.5)
122
+ // int(image_token_len**0.5)
123
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
124
+ ]
125
+ self.vision_sampler_layers = nn.ModuleList(
126
+ [
127
+ VisionTokenSampler(
128
+ config.hidden_size,
129
+ vision_hidden_size,
130
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
131
+ cross_att_token_len_list,
132
+ vision_hidden_size,
133
+ 1,
134
+ )
135
+ for layer_idx in range(0, num_of_vision_sampler_layers)
136
+ ]
137
+ )
138
+
139
+ self.vision_query = nn.Parameter(
140
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
141
+ )
142
+
143
+ self.image_newline = nn.Parameter(
144
+ torch.empty(config.hidden_size, dtype=self.dtype)
145
+ )
146
+
147
+ self.frame_pos = torch.stack(
148
+ [
149
+ 1
150
+ / torch.pow(
151
+ torch.tensor(10000),
152
+ torch.tensor(2 * (hid_j // 2) / config.hidden_size),
153
+ )
154
+ for hid_j in range(config.hidden_size)
155
+ ]
156
+ )
157
+
158
+ else:
159
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
160
+ config, delay_load=True
161
+ )
162
+ config.mm_hidden_size = sum(
163
+ [
164
+ vision_tower_aux.hidden_size
165
+ for vision_tower_aux in self.vision_tower_aux_list
166
+ ]
167
+ )
168
+ self.mm_projector = build_vision_projector(config)
169
+ self.image_newline = nn.Parameter(
170
+ torch.empty(config.hidden_size, dtype=self.dtype)
171
+ )
172
+
173
+ def get_frame_pos(self, time_range):
174
+ frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to(
175
+ self.frame_pos.device
176
+ )
177
+ frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2])
178
+ frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2])
179
+ frame_pos = frame_pos.unsqueeze(1)
180
+ return frame_pos
181
+
182
+ # def get_vision_tower(self):
183
+ # vision_tower = getattr(self, 'vision_tower', None)
184
+ # if type(vision_tower) is list:
185
+ # vision_tower = vision_tower[0]
186
+ # return vision_tower
187
+
188
+ def get_vision_tower_aux_list(self):
189
+ vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
190
+ return vision_tower_aux_list
191
+
192
+ def initialize_vision_modules(self, model_args, fsdp=None):
193
+ # vision_tower = model_args.vision_tower
194
+ num_query_group = model_args.num_query_group
195
+ query_num_list = model_args.query_num_list
196
+ vision_hidden_size = model_args.vision_hidden_size
197
+ vision_tower_aux_list = model_args.vision_tower_aux_list
198
+ vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
199
+ image_token_len = model_args.image_token_len
200
+ mm_vision_select_layer = model_args.mm_vision_select_layer
201
+ mm_vision_select_feature = model_args.mm_vision_select_feature
202
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
203
+ connector_only = model_args.connector_only
204
+ connector_depth = model_args.connector_depth
205
+
206
+ # self.config.mm_vision_tower = vision_tower
207
+ self.config.image_token_len = image_token_len
208
+ self.config.num_query_group = num_query_group
209
+ self.config.query_num_list = query_num_list
210
+ assert num_query_group == len(query_num_list)
211
+ self.config.connector_depth = connector_depth
212
+ self.config.mm_vision_tower_aux_list = vision_tower_aux_list
213
+ self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
214
+ self.config.connector_only = connector_only
215
+ self.config.highres_connect = model_args.highres_connect
216
+ self.config.highres = model_args.highres
217
+ self.config.frame_pos = model_args.frame_pos
218
+ self.config.lowres_token = model_args.lowres_token
219
+ self.config.connect_layer = model_args.connect_layer
220
+ self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83)
221
+ self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6)
222
+ self.config.is_image_newline = getattr(model_args, "is_image_newline", True)
223
+
224
+ if self.get_vision_tower_aux_list() is None:
225
+ vision_tower_aux_list = build_vision_tower_aux_list(model_args)
226
+ if model_args.unfreeze_mm_vision_tower:
227
+ self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
228
+ else:
229
+ self.vision_tower_aux_list = vision_tower_aux_list
230
+ else:
231
+ vision_tower_aux_list = self.vision_tower_aux_list
232
+ for vision_tower_aux in vision_tower_aux_list:
233
+ vision_tower_aux.load_model()
234
+
235
+ self.config.use_mm_proj = True
236
+ self.config.mm_projector_type = getattr(
237
+ model_args, "mm_projector_type", "linear"
238
+ )
239
+ self.config.vision_hidden_size = vision_hidden_size
240
+ self.config.mm_vision_select_layer = mm_vision_select_layer
241
+ self.config.mm_vision_select_feature = mm_vision_select_feature
242
+
243
+ if getattr(self, "mm_projector", None) is None:
244
+
245
+ if self.config.mm_projector_type == "sva":
246
+ self.mm_projector = nn.Sequential(
247
+ nn.Linear(
248
+ vision_hidden_size * num_query_group, self.config.hidden_size
249
+ ),
250
+ nn.GELU(),
251
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
252
+ )
253
+ for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list):
254
+ setattr(
255
+ self,
256
+ "mm_projector_aux_{}".format(aux_i),
257
+ nn.Sequential(
258
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
259
+ nn.GELU(),
260
+ nn.Linear(vision_hidden_size, vision_hidden_size),
261
+ nn.LayerNorm(vision_hidden_size),
262
+ ),
263
+ )
264
+
265
+ # vision sampler for each group of query as the connector before the LLM
266
+ for query_group_i in range(num_query_group):
267
+ cross_att_token_len_list = [
268
+ int(vision_tower_aux_token_len**0.5)
269
+ // int(query_num_list[query_group_i] ** 0.5)
270
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
271
+ ]
272
+ setattr(
273
+ self,
274
+ "vision_sampler_{}".format(query_group_i),
275
+ VisionTokenSampler(
276
+ vision_hidden_size,
277
+ vision_hidden_size,
278
+ [vision_hidden_size] * len(vision_tower_aux_list),
279
+ cross_att_token_len_list,
280
+ vision_hidden_size,
281
+ connector_depth,
282
+ ),
283
+ )
284
+
285
+ # sampler layers within LLM
286
+ if not connector_only:
287
+ num_of_vision_sampler_layers = (
288
+ self.config.num_of_vision_sampler_layers
289
+ ) = model_args.num_of_vision_sampler_layers
290
+ self.config.start_of_vision_sampler_layers = (
291
+ model_args.start_of_vision_sampler_layers
292
+ )
293
+ self.config.stride_of_vision_sampler_layers = (
294
+ model_args.stride_of_vision_sampler_layers
295
+ )
296
+ cross_att_token_len_list = [
297
+ int(vision_tower_aux_token_len**0.5)
298
+ // int(image_token_len**0.5)
299
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
300
+ ]
301
+ self.vision_sampler_layers = nn.ModuleList(
302
+ [
303
+ VisionTokenSampler(
304
+ self.config.hidden_size,
305
+ vision_hidden_size,
306
+ [vision_hidden_size] * len(vision_tower_aux_list),
307
+ cross_att_token_len_list,
308
+ vision_hidden_size,
309
+ 1,
310
+ )
311
+ for layer_idx in range(0, num_of_vision_sampler_layers)
312
+ ]
313
+ )
314
+ vision_embed_std = 1 / torch.sqrt(
315
+ torch.tensor(vision_hidden_size, dtype=self.dtype)
316
+ )
317
+ self.vision_query = nn.Parameter(
318
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
319
+ * vision_embed_std
320
+ )
321
+
322
+ embed_std = 1 / torch.sqrt(
323
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
324
+ )
325
+ self.image_newline = nn.Parameter(
326
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
327
+ )
328
+
329
+ else:
330
+ self.config.mm_hidden_size = sum(
331
+ [
332
+ vision_tower_aux.hidden_size
333
+ for vision_tower_aux in vision_tower_aux_list
334
+ ]
335
+ )
336
+ self.mm_projector = build_vision_projector(self.config)
337
+ embed_std = 1 / torch.sqrt(
338
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
339
+ )
340
+ self.image_newline = nn.Parameter(
341
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
342
+ )
343
+ else:
344
+ # In case it is frozen by LoRA
345
+ for p in self.mm_projector.parameters():
346
+ p.requires_grad = True
347
+
348
+ if pretrain_mm_mlp_adapter is not None:
349
+ mm_projector_weights = torch.load(
350
+ pretrain_mm_mlp_adapter, map_location="cpu"
351
+ )
352
+
353
+ def get_w(weights, keyword):
354
+ return {
355
+ k.split(keyword + ".")[1]: v
356
+ for k, v in weights.items()
357
+ if keyword + "." in k
358
+ }
359
+
360
+ self.mm_projector.load_state_dict(
361
+ get_w(mm_projector_weights, "mm_projector"), strict=True
362
+ )
363
+
364
+ if self.config.mm_projector_type == "sva":
365
+ for aux_i in range(len(vision_tower_aux_list)):
366
+ getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict(
367
+ get_w(
368
+ mm_projector_weights, "mm_projector_aux_{}".format(aux_i)
369
+ ),
370
+ strict=True,
371
+ )
372
+
373
+ for query_group_i in range(num_query_group):
374
+ getattr(
375
+ self, "vision_sampler_{}".format(query_group_i)
376
+ ).load_state_dict(
377
+ get_w(
378
+ mm_projector_weights,
379
+ "vision_sampler_{}".format(query_group_i),
380
+ ),
381
+ strict=True,
382
+ )
383
+
384
+ if not connector_only:
385
+ self.vision_sampler_layers.load_state_dict(
386
+ get_w(mm_projector_weights, "vision_sampler_layers"),
387
+ strict=True,
388
+ )
389
+ self.vision_query.data = mm_projector_weights["model.vision_query"]
390
+ self.image_newline.data = mm_projector_weights["model.image_newline"]
391
+
392
+
393
+ def unmask_attention_mask(mask, original_size):
394
+ original_w, original_h = original_size
395
+ cur_h, cur_w = mask.shape[1:3]
396
+
397
+ original_aspect_ratio = original_w / original_h
398
+ current_aspect_ratio = cur_w / cur_h
399
+
400
+ if original_aspect_ratio > current_aspect_ratio:
401
+ scale_factor = cur_w / original_w
402
+ new_height = int(original_h * scale_factor)
403
+ padding = (cur_h - new_height) // 2
404
+ if padding > 0:
405
+ mask[:, :padding, :] = 0
406
+ mask[:, -padding:, :] = 0
407
+ return mask
408
+ else:
409
+ scale_factor = cur_h / original_h
410
+ new_width = int(original_w * scale_factor)
411
+ padding = (cur_w - new_width) // 2
412
+ if padding > 0:
413
+ mask[:, :, :padding] = 0
414
+ mask[:, :, -padding:] = 0
415
+ return mask
416
+
417
+
418
+ def unpad_image(tensor, original_size):
419
+ """
420
+ Unpads a PyTorch tensor of a padded and resized image.
421
+
422
+ Args:
423
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
424
+ original_size (tuple): The original size of the image (height, width).
425
+
426
+ Returns:
427
+ torch.Tensor: The unpadded image tensor.
428
+ """
429
+ original_width, original_height = original_size
430
+ current_height, current_width = tensor.shape[1:3]
431
+
432
+ original_aspect_ratio = original_width / original_height
433
+ current_aspect_ratio = current_width / current_height
434
+
435
+ if original_aspect_ratio > current_aspect_ratio:
436
+ scale_factor = current_width / original_width
437
+ new_height = int(original_height * scale_factor)
438
+ padding = (current_height - new_height) // 2
439
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
440
+ # if 0 in unpadded_tensor.shape:
441
+ # print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
442
+ else:
443
+ scale_factor = current_height / original_height
444
+ new_width = int(original_width * scale_factor)
445
+ padding = (current_width - new_width) // 2
446
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
447
+ # if 0 in unpadded_tensor.shape:
448
+ # print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
449
+
450
+ return unpadded_tensor
451
+
452
+
453
+ class CambrianMetaForCausalLM(ABC):
454
+
455
+ @abstractmethod
456
+ def get_model(self):
457
+ pass
458
+
459
+ # def get_vision_tower(self):
460
+ # return self.get_model().get_vision_tower()
461
+
462
+ def get_vision_tower_aux_list(self):
463
+ return self.get_model().get_vision_tower_aux_list()
464
+
465
+ def rearrange_vision_tower_features_train(
466
+ self,
467
+ vision_tower_aux_feature_list,
468
+ vision_tower_aux_attention_masks_list,
469
+ query_side_len,
470
+ ):
471
+ vision_tower_aux_feature_rearranged_list = []
472
+ vision_tower_aux_attention_masks_rearranged_list = []
473
+ bs = vision_tower_aux_feature_list[0].shape[0]
474
+ for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip(
475
+ vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list
476
+ ):
477
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
478
+ assert (aux_height // query_side_len) * query_side_len == aux_height
479
+
480
+ reduce_factor = aux_height // query_side_len
481
+ vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view(
482
+ bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1
483
+ )
484
+ vision_tower_aux_feature_rearranged = (
485
+ vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5)
486
+ .contiguous()
487
+ .flatten(0, 2)
488
+ .flatten(1, 2)
489
+ )
490
+
491
+ vision_tower_aux_attention_masks_rearranged = (
492
+ vision_tower_aux_attention_masks.view(
493
+ bs * query_side_len * query_side_len, reduce_factor * reduce_factor
494
+ )
495
+ )
496
+
497
+ vision_tower_aux_feature_rearranged_list.append(
498
+ vision_tower_aux_feature_rearranged
499
+ )
500
+ vision_tower_aux_attention_masks_rearranged_list.append(
501
+ vision_tower_aux_attention_masks_rearranged
502
+ )
503
+ return (
504
+ vision_tower_aux_feature_rearranged_list,
505
+ vision_tower_aux_attention_masks_rearranged_list,
506
+ )
507
+
508
+ def rearrange_vision_tower_features_inference(
509
+ self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False
510
+ ):
511
+ vision_tower_aux_feature_rearranged_list = []
512
+ vision_tower_aux_attention_masks_rearranged_list = []
513
+ bs = vision_tower_aux_feature_list[0].shape[0]
514
+ for vision_tower_aux_feature in vision_tower_aux_feature_list:
515
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
516
+ assert (aux_height // query_side_len) * query_side_len == aux_height
517
+
518
+ reduce_factor = aux_height // query_side_len
519
+
520
+ vision_tower_aux_feature_rearranged = []
521
+ vision_tower_aux_attention_masks_rearranged = []
522
+ for batch_i in range(bs):
523
+ image_size = image_sizes[batch_i]
524
+ cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i]
525
+
526
+ cur_vision_tower_aux_attention_masks_rearranged = torch.ones(
527
+ (1, aux_height, aux_width),
528
+ dtype=torch.bool,
529
+ device=cur_vision_tower_aux_feature.device,
530
+ )
531
+ cur_vision_tower_aux_feature_rearranged = (
532
+ cur_vision_tower_aux_feature.view(
533
+ 1,
534
+ query_side_len,
535
+ reduce_factor,
536
+ query_side_len,
537
+ reduce_factor,
538
+ -1,
539
+ )
540
+ )
541
+ cur_vision_tower_aux_feature_rearranged = (
542
+ cur_vision_tower_aux_feature_rearranged.permute(
543
+ 0, 1, 3, 2, 4, 5
544
+ ).contiguous()
545
+ )
546
+ if unpad:
547
+ cur_vision_tower_aux_feature_rearranged = unpad_image(
548
+ cur_vision_tower_aux_feature_rearranged, image_size
549
+ )
550
+ cur_vision_tower_aux_feature_rearranged = (
551
+ cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2)
552
+ ) # query_side_len*query_side_len X reduce_factor*reduce_factor X C
553
+
554
+ cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask(
555
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
556
+ )
557
+ cur_vision_tower_aux_attention_masks_rearranged = (
558
+ cur_vision_tower_aux_attention_masks_rearranged.view(
559
+ 1, query_side_len, reduce_factor, query_side_len, reduce_factor
560
+ )
561
+ .permute(0, 1, 3, 2, 4)
562
+ .contiguous()
563
+ )
564
+ if unpad:
565
+ cur_vision_tower_aux_attention_masks_rearranged = unpad_image(
566
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
567
+ )
568
+ cur_vision_tower_aux_attention_masks_rearranged = (
569
+ cur_vision_tower_aux_attention_masks_rearranged.flatten(
570
+ 0, 2
571
+ ).flatten(1, 2)
572
+ )
573
+
574
+ cur_vision_tower_aux_attention_masks_rearranged[
575
+ cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0
576
+ ] = True
577
+
578
+ vision_tower_aux_feature_rearranged.append(
579
+ cur_vision_tower_aux_feature_rearranged
580
+ )
581
+ vision_tower_aux_attention_masks_rearranged.append(
582
+ cur_vision_tower_aux_attention_masks_rearranged
583
+ )
584
+
585
+ vision_tower_aux_feature_rearranged = torch.cat(
586
+ vision_tower_aux_feature_rearranged, 0
587
+ )
588
+ vision_tower_aux_attention_masks_rearranged = torch.cat(
589
+ vision_tower_aux_attention_masks_rearranged, 0
590
+ )
591
+
592
+ vision_tower_aux_feature_rearranged_list.append(
593
+ vision_tower_aux_feature_rearranged
594
+ )
595
+ vision_tower_aux_attention_masks_rearranged_list.append(
596
+ vision_tower_aux_attention_masks_rearranged
597
+ )
598
+
599
+ return (
600
+ vision_tower_aux_feature_rearranged_list,
601
+ vision_tower_aux_attention_masks_rearranged_list,
602
+ )
603
+
604
+ def encode_images(self, image_aux_list, encode_type=None):
605
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
606
+ image_aux_features_list = []
607
+ chunk_size = 64
608
+ if encode_type == "dino":
609
+ image_aux = image_aux_list[-1]
610
+ vision_tower_aux = vision_tower_aux_list[-1]
611
+ if image_aux.shape[0] > chunk_size:
612
+ image_aux_features_chunks = []
613
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
614
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
615
+ chunk = image_aux[start_idx:end_idx]
616
+ image_aux_features_chunk = vision_tower_aux(chunk)
617
+ image_aux_features_chunks.append(image_aux_features_chunk)
618
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
619
+ else:
620
+ image_aux_features = vision_tower_aux(image_aux)
621
+ return image_aux_features
622
+ elif encode_type == "siglip":
623
+ image_aux = image_aux_list[0]
624
+ vision_tower_aux = vision_tower_aux_list[0]
625
+ if image_aux.shape[0] > chunk_size:
626
+ image_aux_features_chunks = []
627
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
628
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
629
+ chunk = image_aux[start_idx:end_idx]
630
+ image_aux_features_chunk = vision_tower_aux(chunk)
631
+ image_aux_features_chunks.append(image_aux_features_chunk)
632
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
633
+ else:
634
+ image_aux_features = vision_tower_aux(image_aux)
635
+ return image_aux_features
636
+ else:
637
+ for image_aux, vision_tower_aux in zip(
638
+ image_aux_list, vision_tower_aux_list
639
+ ):
640
+ if image_aux.shape[0] > chunk_size:
641
+ image_aux_features_chunks = []
642
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
643
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
644
+ chunk = image_aux[start_idx:end_idx]
645
+ image_aux_features_chunk = vision_tower_aux(chunk)
646
+ image_aux_features_chunks.append(image_aux_features_chunk)
647
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
648
+ else:
649
+ image_aux_features = vision_tower_aux(image_aux)
650
+ image_aux_features_list.append(image_aux_features)
651
+ return image_aux_features_list
652
+
653
+ def select_frame(
654
+ self,
655
+ feature_list,
656
+ split_sizes,
657
+ input_ids,
658
+ new_image_aux_list,
659
+ image_sizes,
660
+ window_size=16,
661
+ threshold=0.83,
662
+ ):
663
+ dino_features_batch = torch.split(feature_list, split_sizes, dim=0)
664
+ new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0)
665
+ new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0)
666
+ new_split_sizes = []
667
+ selected_frames_all_0 = []
668
+ selected_frames_all_1 = []
669
+ selected_frames_feature_all = []
670
+ selected_frame_indices_all = []
671
+ for i_batch, frame_features in enumerate(dino_features_batch):
672
+ try:
673
+ if "llama" in self.get_model().config.model_type:
674
+ text_len = torch.where(input_ids[i_batch] == 128002)[-1][0]
675
+ else:
676
+ text_len = torch.where(input_ids[i_batch] == 151643)[-1][0]
677
+ except:
678
+ text_len = len(input_ids[i_batch])
679
+ original_width, original_height = image_sizes[i_batch]
680
+ if getattr(self.get_model().config, "highres", False):
681
+ token_per_frame = self.get_model().config.lowres_token ** 2
682
+ else:
683
+ token_per_frame = self.get_model().config.image_token_len
684
+ # current_height, current_width = token_per_side, token_per_side
685
+ # original_aspect_ratio = original_width / original_height
686
+ # current_aspect_ratio = current_width / current_height
687
+ # if original_aspect_ratio > current_aspect_ratio:
688
+ # scale_factor = current_width / original_width
689
+ # new_height = int(original_height * scale_factor)
690
+ # padding = math.ceil((current_height - new_height) / 2.0)
691
+ # token_per_frame = (
692
+ # current_height - padding * 2
693
+ # ) * token_per_side + token_per_side
694
+ # else:
695
+ # scale_factor = current_height / original_height
696
+ # new_width = int(original_width * scale_factor)
697
+ # padding = math.ceil((current_width - new_width) / 2.0)
698
+ # token_per_frame = (current_width - padding * 2) * token_per_side + (
699
+ # current_width - padding * 2
700
+ # )
701
+ # token_per_frame = (
702
+ # token_per_side**2 if token_per_frame < 1 else token_per_frame
703
+ # )
704
+ max_num_frames = max(
705
+ 1,
706
+ (
707
+ self.get_model().config.tokenizer_model_max_length
708
+ - text_len
709
+ - getattr(self.get_model().config, "inference_max_length", 16)
710
+ )
711
+ // token_per_frame,
712
+ )
713
+ if len(frame_features) < max_num_frames:
714
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch])
715
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch])
716
+ selected_frames_feature_all.append(frame_features)
717
+ new_split_sizes.append(len(frame_features))
718
+ selected_frame_indices_all.append(torch.arange(len(frame_features)))
719
+ continue
720
+
721
+ num_segments = len(frame_features) // window_size
722
+ if num_segments == 0:
723
+ query_feature = frame_features.flatten(1, 2)
724
+ query_feature = query_feature / torch.norm(
725
+ (query_feature), dim=1, keepdim=True
726
+ )
727
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
728
+ similarities[len(frame_features) // 2] = 0
729
+ indices = torch.where(similarities < threshold)[0]
730
+ selected_frame_indices_all.append(indices)
731
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices])
732
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices])
733
+ selected_frames_feature_all.append(frame_features[indices])
734
+ new_split_sizes.append(len(indices))
735
+ continue
736
+ segments_frames_0 = []
737
+ segments_frames_1 = []
738
+ segments_features = []
739
+ for start_idx in range(0, len(frame_features), window_size):
740
+ end_idx = min(start_idx + window_size, len(frame_features))
741
+ segments_frames_0.append(
742
+ new_image_aux_batch_0[i_batch][start_idx:end_idx]
743
+ )
744
+ segments_frames_1.append(
745
+ new_image_aux_batch_1[i_batch][start_idx:end_idx]
746
+ )
747
+ segments_features.append(frame_features[start_idx:end_idx])
748
+ selected_frames_0 = []
749
+ selected_frames_1 = []
750
+ selected_features = []
751
+ selected_frame_indices = []
752
+ for i, segment in enumerate(segments_features):
753
+ query_feature = segment.flatten(1, 2)
754
+ query_feature = query_feature / torch.norm(
755
+ (query_feature), dim=1, keepdim=True
756
+ )
757
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
758
+ similarities[len(segment) // 2] = 0
759
+ indices = torch.where(similarities < threshold)[0]
760
+ selected_frames_0.append(segments_frames_0[i][indices])
761
+ selected_frames_1.append(segments_frames_1[i][indices])
762
+ selected_features.append(segment[indices])
763
+ selected_frame_indices.extend(indices + i * window_size)
764
+ selected_frames_0 = torch.cat(selected_frames_0, dim=0)
765
+ selected_frames_1 = torch.cat(selected_frames_1, dim=0)
766
+ selected_features = torch.cat(selected_features, dim=0)
767
+ selected_frame_indices = torch.tensor(selected_frame_indices)
768
+ # ablation
769
+ max_num_frames = 400 # in case of OOM
770
+ if len(selected_frames_0) > max_num_frames:
771
+ interval = len(selected_frames_0) / float(max_num_frames)
772
+ indices = [int(interval * i) for i in range(max_num_frames)]
773
+ new_split_sizes.append(len(indices))
774
+ selected_frames_all_0.append(selected_frames_0[indices])
775
+ selected_frames_all_1.append(selected_frames_1[indices])
776
+ selected_frames_feature_all.append(selected_features[indices])
777
+ selected_frame_indices = selected_frame_indices[indices]
778
+ else:
779
+ new_split_sizes.append(len(selected_frames_0))
780
+ selected_frames_all_0.append(selected_frames_0)
781
+ selected_frames_all_1.append(selected_frames_1)
782
+ selected_frames_feature_all.append(selected_features)
783
+ selected_frame_indices_all.append(selected_frame_indices)
784
+ selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0)
785
+ selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0)
786
+ selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0)
787
+ return (
788
+ selected_frames_feature_all,
789
+ new_split_sizes,
790
+ [selected_frames_all_0, selected_frames_all_1],
791
+ selected_frame_indices_all,
792
+ )
793
+
794
+ def prepare_inputs_labels_for_multimodal(
795
+ self,
796
+ input_ids,
797
+ position_ids,
798
+ attention_mask,
799
+ past_key_values,
800
+ labels,
801
+ images,
802
+ image_aux_attention_masks_list=None,
803
+ image_sizes=None,
804
+ ):
805
+ # vision_tower = self.get_vision_tower()
806
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
807
+ if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1:
808
+ return (
809
+ input_ids,
810
+ position_ids,
811
+ attention_mask,
812
+ past_key_values,
813
+ None,
814
+ labels,
815
+ None,
816
+ None,
817
+ None,
818
+ None,
819
+ )
820
+
821
+ image_aux_list = images
822
+
823
+ split_sizes = None
824
+
825
+ if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5:
826
+ split_sizes_ori = [
827
+ 1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0]
828
+ ]
829
+ new_image_aux_list = []
830
+ for image_aux in image_aux_list:
831
+ if type(image_aux) is list:
832
+ image_aux = [
833
+ x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux
834
+ ]
835
+ concat_image_aux = torch.cat([image for image in image_aux], dim=0)
836
+ new_image_aux_list.append(concat_image_aux)
837
+ image_aux_features_dino = self.encode_images(
838
+ new_image_aux_list, encode_type="dino"
839
+ )
840
+
841
+ (
842
+ image_aux_features_dino,
843
+ split_sizes,
844
+ new_image_aux_list,
845
+ selected_frame_indices_all,
846
+ ) = self.select_frame(
847
+ image_aux_features_dino,
848
+ split_sizes_ori,
849
+ input_ids,
850
+ new_image_aux_list,
851
+ image_sizes,
852
+ threshold=getattr(self.get_model().config, "dino_threshold", 0.83),
853
+ )
854
+
855
+ image_aux_features_siglip = self.encode_images(
856
+ new_image_aux_list, encode_type="siglip"
857
+ )
858
+ image_aux_features_list = [
859
+ image_aux_features_siglip,
860
+ image_aux_features_dino,
861
+ ]
862
+
863
+ bs = image_aux_features_list[0].shape[0]
864
+ dtype = new_image_aux_list[0].dtype
865
+
866
+ frame_sizes = []
867
+ for i in range(len(image_sizes)):
868
+ for j in range(split_sizes[i]):
869
+ frame_sizes.append(image_sizes[i])
870
+ image_sizes = frame_sizes
871
+ else:
872
+ image_aux_features_list = self.encode_images(image_aux_list)
873
+ bs = image_aux_list[0].shape[0]
874
+ dtype = image_aux_list[0].dtype
875
+
876
+ image_token_len = self.get_model().config.image_token_len
877
+ query_num_list = self.get_model().config.query_num_list
878
+
879
+ final_height = final_width = int(image_token_len**0.5)
880
+
881
+ final_image_features_list = []
882
+ final_image_features_down_list = []
883
+
884
+ # only needed for sva
885
+ vision_tower_aux_feature_list_final = None
886
+ vision_tower_aux_attention_masks_list_final = None
887
+ global_context_feature_final = None
888
+
889
+ if self.get_model().config.mm_projector_type == "sva":
890
+ vision_tower_aux_feature_list = []
891
+ vision_tower_aux_attention_masks_list = []
892
+ # get vision tokens from each vision tower
893
+ for aux_i in range(len(vision_tower_aux_list)):
894
+ image_aux_features = image_aux_features_list[aux_i]
895
+
896
+ image_aux_features = getattr(
897
+ self.get_model(), "mm_projector_aux_{}".format(aux_i)
898
+ )(image_aux_features).to(dtype)
899
+ if aux_i == 0:
900
+ global_context_feature = image_aux_features.mean(1).view(
901
+ bs, 1, 1, -1
902
+ )
903
+
904
+ vision_tower_aux_feature_list.append(image_aux_features)
905
+ input_mix_res = True
906
+ input_high_res = True
907
+ # perform vision sampling for each query group
908
+ for query_group_i, query_num in enumerate(query_num_list):
909
+ query_features_i = (
910
+ self.get_model()
911
+ .vision_query[query_group_i, :]
912
+ .view(1, 1, 1, -1)
913
+ .expand(bs, query_num, -1, -1)
914
+ )
915
+ global_context_feature_i = global_context_feature.expand(
916
+ -1, query_num, 1, -1
917
+ ).flatten(0, 1)
918
+ query_side_len = int(query_num**0.5)
919
+ if IS_XLA_AVAILABLE:
920
+ (
921
+ vision_tower_aux_feature_list_i,
922
+ vision_tower_aux_attention_masks_list_i,
923
+ ) = self.rearrange_vision_tower_features_train(
924
+ vision_tower_aux_feature_list,
925
+ image_aux_attention_masks_list,
926
+ query_side_len,
927
+ )
928
+ else:
929
+ (
930
+ vision_tower_aux_feature_list_i,
931
+ vision_tower_aux_attention_masks_list_i,
932
+ ) = self.rearrange_vision_tower_features_inference(
933
+ vision_tower_aux_feature_list, query_side_len, image_sizes
934
+ )
935
+
936
+ query_features_i = getattr(
937
+ self.get_model(), "vision_sampler_{}".format(query_group_i)
938
+ )(
939
+ query_features_i.flatten(0, 1),
940
+ global_context_feature_i,
941
+ *vision_tower_aux_feature_list_i,
942
+ *vision_tower_aux_attention_masks_list_i,
943
+ )
944
+ query_features_i = query_features_i.view(bs, query_num, -1)
945
+
946
+ if split_sizes is not None:
947
+ try:
948
+ if "llama" in self.get_model().config.model_type:
949
+ text_len = torch.where(input_ids[0] == 128002)[-1][0]
950
+ else:
951
+ text_len = torch.where(input_ids[0] == 151643)[-1][0]
952
+ except:
953
+ text_len = len(input_ids[0])
954
+ max_visual_len = (
955
+ self.get_model().config.tokenizer_model_max_length
956
+ - text_len
957
+ - getattr(self.get_model().config, "inference_max_length", 16)
958
+ )
959
+ max_num_frames = max(
960
+ 1,
961
+ math.floor(max_visual_len // (final_height * final_width)),
962
+ )
963
+ max_num_frames_low = max(
964
+ 1,
965
+ math.floor(
966
+ max_visual_len
967
+ // (self.get_model().config.lowres_token ** 2)
968
+ ),
969
+ )
970
+ if split_sizes[0] < max_num_frames:
971
+ input_mix_res = False
972
+ elif split_sizes[0] > max_num_frames_low:
973
+ input_mix_res = False
974
+ input_high_res = False
975
+
976
+ # input_mix_res = False # ablation
977
+
978
+ if (getattr(self.config, "highres", False)) and input_mix_res:
979
+ _query_features_i = (
980
+ query_features_i.permute(0, 2, 1)
981
+ .contiguous()
982
+ .view(bs, -1, query_side_len, query_side_len)
983
+ )
984
+ _query_features_i = F.interpolate(
985
+ _query_features_i.float(),
986
+ size=(
987
+ self.get_model().config.lowres_token,
988
+ self.get_model().config.lowres_token,
989
+ ),
990
+ mode="bilinear",
991
+ align_corners=False,
992
+ ).to(dtype=query_features_i.dtype)
993
+ _query_features_i = (
994
+ _query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
995
+ )
996
+ final_image_features_down_list.append(_query_features_i)
997
+
998
+ # interpolate to the final target size
999
+ if query_side_len != final_height:
1000
+ query_features_i = (
1001
+ query_features_i.permute(0, 2, 1)
1002
+ .contiguous()
1003
+ .view(bs, -1, query_side_len, query_side_len)
1004
+ )
1005
+ if input_high_res:
1006
+ query_features_i = F.interpolate(
1007
+ query_features_i.float(),
1008
+ size=(final_height, final_width),
1009
+ mode="bilinear",
1010
+ align_corners=False,
1011
+ ).to(dtype=query_features_i.dtype)
1012
+ else:
1013
+ query_features_i = F.interpolate(
1014
+ query_features_i.float(),
1015
+ size=(8, 8),
1016
+ mode="bilinear",
1017
+ align_corners=False,
1018
+ ).to(dtype=query_features_i.dtype)
1019
+ query_features_i = (
1020
+ query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
1021
+ )
1022
+ final_image_features_list.append(query_features_i)
1023
+
1024
+ if IS_XLA_AVAILABLE:
1025
+ (
1026
+ vision_tower_aux_feature_list_final,
1027
+ vision_tower_aux_attention_masks_list_final,
1028
+ ) = self.rearrange_vision_tower_features_train(
1029
+ vision_tower_aux_feature_list,
1030
+ image_aux_attention_masks_list,
1031
+ final_height,
1032
+ )
1033
+ global_context_feature_final = global_context_feature.expand(
1034
+ -1, final_height * final_width, 1, -1
1035
+ ).flatten(0, 1)
1036
+ else:
1037
+ final_image_features_list = image_aux_features_list
1038
+
1039
+ image_features = torch.cat(final_image_features_list, -1)
1040
+ image_features = self.get_model().mm_projector(image_features).to(dtype)
1041
+
1042
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1043
+ image_features_down = torch.cat(final_image_features_down_list, -1)
1044
+ image_features_down = (
1045
+ self.get_model().mm_projector(image_features_down).to(dtype)
1046
+ )
1047
+
1048
+ if IS_XLA_AVAILABLE:
1049
+ image_features = image_features.view(
1050
+ image_features.shape[0], final_height, final_width, -1
1051
+ )
1052
+ image_features = torch.cat(
1053
+ (
1054
+ image_features,
1055
+ self.model.image_newline[None, None, None, :].expand(
1056
+ image_features.shape[0], final_height, 1, -1
1057
+ ),
1058
+ ),
1059
+ dim=2,
1060
+ )
1061
+ image_features = image_features.flatten(1, 2)
1062
+ final_size = [(final_height, final_width)] * bs
1063
+
1064
+ else:
1065
+ image_features = image_features.view(bs, final_height, final_width, -1)
1066
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1067
+ image_features_down = image_features_down.view(
1068
+ bs,
1069
+ self.get_model().config.lowres_token,
1070
+ self.get_model().config.lowres_token,
1071
+ -1,
1072
+ )
1073
+ image_features_unpadded = []
1074
+ image_features_downsample = []
1075
+ final_size = []
1076
+ if self.get_model().config.mm_projector_type == "sva":
1077
+ (
1078
+ vision_tower_aux_feature_list_final,
1079
+ vision_tower_aux_attention_masks_list_final,
1080
+ ) = self.rearrange_vision_tower_features_inference(
1081
+ vision_tower_aux_feature_list, final_height, image_sizes, unpad=True
1082
+ )
1083
+ global_context_feature_final = []
1084
+ for batch_i in range(bs):
1085
+ cur_image_feature = image_features[batch_i]
1086
+ image_size = image_sizes[batch_i]
1087
+
1088
+ cur_image_feature = unpad_image(
1089
+ cur_image_feature.unsqueeze(0), image_size
1090
+ )
1091
+
1092
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1093
+ try: # fix bug for some invalid image
1094
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1095
+ final_size.append((cur_h, cur_w))
1096
+ except:
1097
+ # print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True)
1098
+ cur_image_feature = image_features[batch_i].unsqueeze(0)
1099
+ image_size = image_sizes[batch_i]
1100
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1101
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1102
+ final_size.append((cur_h, cur_w))
1103
+
1104
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1105
+ cur_image_feature_down = unpad_image(
1106
+ image_features_down[batch_i].unsqueeze(0),
1107
+ (
1108
+ int(
1109
+ image_size[0]
1110
+ / (
1111
+ image_token_len**0.5
1112
+ / self.get_model().config.lowres_token
1113
+ )
1114
+ ),
1115
+ int(
1116
+ image_size[1]
1117
+ / (
1118
+ image_token_len**0.5
1119
+ / self.get_model().config.lowres_token
1120
+ )
1121
+ ),
1122
+ ),
1123
+ )
1124
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1125
+
1126
+ try: # fix bug for some invalid image
1127
+ cur_image_feature_down = cur_image_feature_down.view(
1128
+ 1, _cur_h, _cur_w, -1
1129
+ )
1130
+ except:
1131
+ print("invalid after unpad", flush=True)
1132
+ cur_image_feature_down = image_features_down[batch_i].unsqueeze(
1133
+ 0
1134
+ )
1135
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1136
+ cur_image_feature_down = cur_image_feature_down.view(
1137
+ 1, _cur_h, _cur_w, -1
1138
+ )
1139
+
1140
+ cur_image_feature_down = torch.cat(
1141
+ (
1142
+ cur_image_feature_down,
1143
+ self.model.image_newline.view(1, 1, 1, -1)
1144
+ .expand(1, _cur_h, 1, -1)
1145
+ .to(cur_image_feature_down.device),
1146
+ ),
1147
+ dim=2,
1148
+ ).flatten(1, 2)
1149
+
1150
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1151
+ frame_pos = (
1152
+ self.get_model()
1153
+ .get_frame_pos(torch.arange(1))
1154
+ .to(cur_image_feature_down.device)
1155
+ .to(cur_image_feature_down.dtype)
1156
+ )
1157
+ cur_image_feature_down += frame_pos
1158
+
1159
+ image_features_downsample.append(cur_image_feature_down.squeeze(0))
1160
+
1161
+ cur_image_feature = torch.cat(
1162
+ (
1163
+ cur_image_feature,
1164
+ self.model.image_newline.view(1, 1, 1, -1)
1165
+ .expand(1, cur_h, 1, -1)
1166
+ .to(cur_image_feature.device),
1167
+ ),
1168
+ dim=2,
1169
+ )
1170
+
1171
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1172
+ frame_pos = (
1173
+ self.get_model()
1174
+ .get_frame_pos(torch.arange(1))
1175
+ .to(cur_image_feature.device)
1176
+ .to(cur_image_feature.dtype)
1177
+ )
1178
+ cur_image_feature += frame_pos
1179
+
1180
+ cur_image_feature = cur_image_feature.flatten(1, 2)
1181
+ image_features_unpadded.append(cur_image_feature.squeeze(0))
1182
+
1183
+ if self.get_model().config.mm_projector_type == "sva":
1184
+ cur_global_context_feature = global_context_feature[batch_i].expand(
1185
+ cur_h * cur_w, 1, -1
1186
+ )
1187
+ global_context_feature_final.append(cur_global_context_feature)
1188
+ if self.get_model().config.mm_projector_type == "sva":
1189
+ global_context_feature_final = torch.cat(
1190
+ global_context_feature_final, 0
1191
+ )
1192
+
1193
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1194
+ image_features = image_features_downsample
1195
+ else:
1196
+ image_features = image_features_unpadded
1197
+
1198
+ # TODO: image start / end is not implemented here to support pretraining.
1199
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
1200
+ self.config, "mm_use_im_start_end", False
1201
+ ):
1202
+ raise NotImplementedError
1203
+
1204
+ split_image_features_unpadded = None
1205
+ frame_split_sizes = None
1206
+
1207
+ if split_sizes is not None:
1208
+ split_image_features = []
1209
+ split_image_features_unpadded = (
1210
+ []
1211
+ if (getattr(self.config, "highres", False)) and input_mix_res
1212
+ else None
1213
+ )
1214
+ start_idx = 0
1215
+ for split_batch_idx, split_size in enumerate(split_sizes):
1216
+ if isinstance(image_features[start_idx : start_idx + split_size], list):
1217
+ if getattr(self.config, "frame_pos", False):
1218
+ frame_feature = torch.cat(
1219
+ image_features[start_idx : start_idx + split_size], dim=0
1220
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1221
+ frame_pos = (
1222
+ self.get_model()
1223
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1224
+ .to(frame_feature.device)
1225
+ .to(frame_feature.dtype)
1226
+ )
1227
+ frame_feature += frame_pos
1228
+ split_image_features.append(
1229
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1230
+ )
1231
+ else:
1232
+ split_image_features.append(
1233
+ torch.cat(
1234
+ image_features[start_idx : start_idx + split_size],
1235
+ dim=0,
1236
+ )
1237
+ )
1238
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1239
+ if getattr(self.config, "frame_pos", False):
1240
+ frame_feature = torch.cat(
1241
+ image_features_unpadded[
1242
+ start_idx : start_idx + split_size
1243
+ ],
1244
+ dim=0,
1245
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1246
+ frame_pos = (
1247
+ self.get_model()
1248
+ .get_frame_pos(
1249
+ selected_frame_indices_all[split_batch_idx]
1250
+ )
1251
+ .to(frame_feature.device)
1252
+ .to(frame_feature.dtype)
1253
+ )
1254
+ frame_feature += frame_pos
1255
+ split_image_features_unpadded.append(
1256
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1257
+ )
1258
+ else:
1259
+ split_image_features_unpadded.append(
1260
+ torch.cat(
1261
+ image_features_unpadded[
1262
+ start_idx : start_idx + split_size
1263
+ ],
1264
+ dim=0,
1265
+ )
1266
+ )
1267
+ else:
1268
+ if getattr(self.config, "frame_pos", False):
1269
+ frame_feature = image_features[
1270
+ start_idx : start_idx + split_size
1271
+ ].reshape(split_size, -1, image_features[0].shape[-1])
1272
+ frame_pos = (
1273
+ self.get_model()
1274
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1275
+ .to(frame_feature.device)
1276
+ .to(frame_feature.dtype)
1277
+ )
1278
+ frame_feature += frame_pos
1279
+ split_image_features.append(
1280
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1281
+ )
1282
+ else:
1283
+ split_image_features.append(
1284
+ image_features[start_idx : start_idx + split_size]
1285
+ )
1286
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1287
+ if getattr(self.config, "frame_pos", False):
1288
+ frame_feature = image_features_unpadded[
1289
+ start_idx : start_idx + split_size
1290
+ ]
1291
+ frame_pos = (
1292
+ self.get_model()
1293
+ .get_frame_pos(
1294
+ selected_frame_indices_all[split_batch_idx]
1295
+ )
1296
+ .to(frame_feature.device)
1297
+ .to(frame_feature.dtype)
1298
+ )
1299
+ frame_feature += frame_pos
1300
+ split_image_features_unpadded.append(
1301
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1302
+ )
1303
+ else:
1304
+ split_image_features_unpadded.append(
1305
+ image_features_unpadded[
1306
+ start_idx : start_idx + split_size
1307
+ ]
1308
+ )
1309
+ start_idx += split_size
1310
+ image_features = split_image_features
1311
+ frame_split_sizes = split_sizes
1312
+
1313
+ _labels = labels
1314
+ _position_ids = position_ids
1315
+ _attention_mask = attention_mask
1316
+ if attention_mask is None:
1317
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1318
+ else:
1319
+ attention_mask = attention_mask.bool()
1320
+ if position_ids is None:
1321
+ position_ids = torch.arange(
1322
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
1323
+ )
1324
+ if labels is None:
1325
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1326
+
1327
+ # remove the padding using attention_mask -- FIXME
1328
+ _input_ids = input_ids
1329
+
1330
+ attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
1331
+
1332
+ input_ids = [
1333
+ cur_input_ids[cur_attention_mask]
1334
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
1335
+ ]
1336
+ labels = [
1337
+ cur_labels[cur_attention_mask]
1338
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
1339
+ ]
1340
+
1341
+ new_input_embeds = []
1342
+ new_labels = []
1343
+ image_token_indices_batch = []
1344
+ cur_image_idx = 0
1345
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1346
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1347
+ if num_images == 0:
1348
+ cur_image_features = image_features[cur_image_idx]
1349
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1350
+ cur_input_embeds = torch.cat(
1351
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
1352
+ )
1353
+ new_input_embeds.append(cur_input_embeds)
1354
+ new_labels.append(labels[batch_idx])
1355
+ cur_image_idx += 1
1356
+ continue
1357
+
1358
+ image_token_indices = (
1359
+ [-1]
1360
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
1361
+ + [cur_input_ids.shape[0]]
1362
+ )
1363
+ image_token_indices_batch.append(
1364
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0]
1365
+ )
1366
+ cur_input_ids_noim = []
1367
+ cur_labels = labels[batch_idx]
1368
+ cur_labels_noim = []
1369
+ for i in range(len(image_token_indices) - 1):
1370
+ cur_input_ids_noim.append(
1371
+ cur_input_ids[
1372
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
1373
+ ]
1374
+ )
1375
+ cur_labels_noim.append(
1376
+ cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
1377
+ )
1378
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1379
+ cur_input_embeds = self.get_model().embed_tokens(
1380
+ torch.cat(cur_input_ids_noim)
1381
+ )
1382
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1383
+ cur_new_input_embeds = []
1384
+ cur_new_labels = []
1385
+
1386
+ text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
1387
+ visual_len = len(image_features[cur_image_idx])
1388
+ max_visual_len = (
1389
+ self.get_model().config.tokenizer_model_max_length
1390
+ - getattr(self.get_model().config, "inference_max_length", 16)
1391
+ - text_len
1392
+ )
1393
+ mix_token = False
1394
+
1395
+ # ablation mix
1396
+ if (
1397
+ input_mix_res
1398
+ and (
1399
+ self.get_model().config.image_token_len
1400
+ > getattr(self.get_model().config, "lowres_token", 8) ** 2
1401
+ )
1402
+ and frame_split_sizes is not None
1403
+ and getattr(self.config, "highres", False)
1404
+ ):
1405
+ if max_visual_len > visual_len:
1406
+ visual_emb = image_features[cur_image_idx]
1407
+ text_emb = cur_input_embeds_no_im[-1]
1408
+ highres_num = math.floor(
1409
+ (max_visual_len - visual_len)
1410
+ / (
1411
+ split_image_features_unpadded[cur_image_idx].shape[0]
1412
+ // frame_split_sizes[cur_image_idx]
1413
+ - visual_emb.shape[0] // frame_split_sizes[cur_image_idx]
1414
+ )
1415
+ )
1416
+ if highres_num >= 1:
1417
+ mix_token = True
1418
+ sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean(
1419
+ dim=-1
1420
+ )
1421
+ sim_frame = sim.reshape(
1422
+ frame_split_sizes[cur_image_idx], -1
1423
+ ).mean(dim=-1)
1424
+ highres_num = min(highres_num, sim_frame.shape[0])
1425
+ top_values, top_indices = torch.topk(sim_frame, highres_num)
1426
+ if len(top_indices) > 0:
1427
+ sorted_indices = torch.sort(top_indices)[1]
1428
+ top_indices = top_indices[sorted_indices]
1429
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1430
+ frame_split_sizes[cur_image_idx],
1431
+ -1,
1432
+ image_features[cur_image_idx].shape[-1],
1433
+ )
1434
+ visual_emb_frame_highres = split_image_features_unpadded[
1435
+ cur_image_idx
1436
+ ].reshape(
1437
+ frame_split_sizes[cur_image_idx],
1438
+ -1,
1439
+ split_image_features_unpadded[cur_image_idx].shape[-1],
1440
+ )
1441
+ current_point = 0
1442
+ mix_visual_emb_frame = []
1443
+ for frame_i in range(len(visual_emb_frame)):
1444
+ if current_point > len(top_indices) - 1:
1445
+ mix_visual_emb_frame.append(
1446
+ visual_emb_frame[frame_i]
1447
+ )
1448
+ continue
1449
+ if frame_i == top_indices[current_point]:
1450
+ mix_visual_emb_frame.append(
1451
+ visual_emb_frame_highres[frame_i]
1452
+ )
1453
+ current_point += 1
1454
+ else:
1455
+ mix_visual_emb_frame.append(
1456
+ visual_emb_frame[frame_i]
1457
+ )
1458
+ image_features[cur_image_idx] = torch.cat(
1459
+ mix_visual_emb_frame, dim=0
1460
+ )
1461
+ # ablation drop
1462
+
1463
+ if (
1464
+ max_visual_len < visual_len
1465
+ and frame_split_sizes is not None
1466
+ and not mix_token
1467
+ ):
1468
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1469
+ frame_split_sizes[cur_image_idx],
1470
+ -1,
1471
+ image_features[cur_image_idx].shape[-1],
1472
+ )
1473
+
1474
+ sim = F.cosine_similarity(
1475
+ visual_emb_frame[:-1],
1476
+ visual_emb_frame[1:],
1477
+ dim=-1,
1478
+ )
1479
+
1480
+ new_visual_emb_frames = []
1481
+ for start_idx in range(0, len(visual_emb_frame), 8):
1482
+ end_idx = min(start_idx + 8, len(visual_emb_frame))
1483
+ chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C
1484
+ if len(chunk_feature) == 1:
1485
+ new_visual_emb_frames.append(chunk_feature[0])
1486
+ continue
1487
+ sim = F.cosine_similarity(
1488
+ chunk_feature[0]
1489
+ .unsqueeze(0)
1490
+ .repeat_interleave(len(chunk_feature[1:]), dim=0),
1491
+ chunk_feature[1:],
1492
+ dim=-1,
1493
+ )
1494
+ new_visual_emb_frame = torch.cat(
1495
+ [
1496
+ chunk_feature[0],
1497
+ chunk_feature[1:].flatten(0, 1)[
1498
+ sim.flatten(0, 1)
1499
+ < getattr(
1500
+ self.get_model().config, "drop_threshold", 0.7
1501
+ )
1502
+ ],
1503
+ ],
1504
+ dim=0,
1505
+ )
1506
+ new_visual_emb_frames.append(new_visual_emb_frame)
1507
+
1508
+ reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames])
1509
+
1510
+ if reduced_visual_len > max_visual_len:
1511
+ force_remove = math.ceil(
1512
+ (reduced_visual_len - max_visual_len)
1513
+ / len(new_visual_emb_frames)
1514
+ )
1515
+ for chunk_i in range(len(new_visual_emb_frames)):
1516
+ new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][
1517
+ :-force_remove
1518
+ ]
1519
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1520
+ else:
1521
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1522
+
1523
+ image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len]
1524
+
1525
+ for i in range(num_images + 1):
1526
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1527
+ cur_new_labels.append(cur_labels_noim[i])
1528
+ if i < num_images:
1529
+ cur_image_features = image_features[cur_image_idx]
1530
+ cur_image_idx += 1
1531
+ cur_new_input_embeds.append(cur_image_features)
1532
+ cur_new_labels.append(
1533
+ torch.full(
1534
+ (cur_image_features.shape[0],),
1535
+ IGNORE_INDEX,
1536
+ device=cur_labels.device,
1537
+ dtype=cur_labels.dtype,
1538
+ )
1539
+ )
1540
+
1541
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1542
+
1543
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1544
+ cur_new_labels = torch.cat(cur_new_labels)
1545
+
1546
+ new_input_embeds.append(cur_new_input_embeds)
1547
+ new_labels.append(cur_new_labels)
1548
+
1549
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1550
+ tokenizer_model_max_length = getattr(
1551
+ self.config, "tokenizer_model_max_length", None
1552
+ )
1553
+ if tokenizer_model_max_length is not None:
1554
+ new_input_embeds = [
1555
+ x[:tokenizer_model_max_length] for x in new_input_embeds
1556
+ ]
1557
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1558
+
1559
+ # Combine them
1560
+ max_len = max(x.shape[0] for x in new_input_embeds)
1561
+ batch_size = len(new_input_embeds)
1562
+
1563
+ new_input_embeds_padded = []
1564
+ new_labels_padded = torch.full(
1565
+ (batch_size, max_len),
1566
+ IGNORE_INDEX,
1567
+ dtype=new_labels[0].dtype,
1568
+ device=new_labels[0].device,
1569
+ )
1570
+ attention_mask = torch.zeros(
1571
+ (batch_size, max_len),
1572
+ dtype=attention_mask.dtype,
1573
+ device=attention_mask.device,
1574
+ )
1575
+ position_ids = torch.zeros(
1576
+ (batch_size, max_len),
1577
+ dtype=position_ids.dtype,
1578
+ device=position_ids.device,
1579
+ )
1580
+
1581
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
1582
+ zip(new_input_embeds, new_labels)
1583
+ ):
1584
+ cur_len = cur_new_embed.shape[0]
1585
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
1586
+ new_input_embeds_padded.append(
1587
+ torch.cat(
1588
+ (
1589
+ torch.zeros(
1590
+ (max_len - cur_len, cur_new_embed.shape[1]),
1591
+ dtype=cur_new_embed.dtype,
1592
+ device=cur_new_embed.device,
1593
+ ),
1594
+ cur_new_embed,
1595
+ ),
1596
+ dim=0,
1597
+ )
1598
+ )
1599
+ if cur_len > 0:
1600
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1601
+ attention_mask[i, -cur_len:] = True
1602
+ position_ids[i, -cur_len:] = torch.arange(
1603
+ 0,
1604
+ cur_len,
1605
+ dtype=position_ids.dtype,
1606
+ device=position_ids.device,
1607
+ )
1608
+ else:
1609
+ new_input_embeds_padded.append(
1610
+ torch.cat(
1611
+ (
1612
+ cur_new_embed,
1613
+ torch.zeros(
1614
+ (max_len - cur_len, cur_new_embed.shape[1]),
1615
+ dtype=cur_new_embed.dtype,
1616
+ device=cur_new_embed.device,
1617
+ ),
1618
+ ),
1619
+ dim=0,
1620
+ )
1621
+ )
1622
+ if cur_len > 0:
1623
+ new_labels_padded[i, :cur_len] = cur_new_labels
1624
+ attention_mask[i, :cur_len] = True
1625
+ position_ids[i, :cur_len] = torch.arange(
1626
+ 0,
1627
+ cur_len,
1628
+ dtype=position_ids.dtype,
1629
+ device=position_ids.device,
1630
+ )
1631
+
1632
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1633
+
1634
+ if _labels is None:
1635
+ new_labels = None
1636
+ else:
1637
+ new_labels = new_labels_padded
1638
+
1639
+ if _attention_mask is None:
1640
+ attention_mask = None
1641
+ else:
1642
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1643
+
1644
+ if _position_ids is None:
1645
+ position_ids = None
1646
+
1647
+ return (
1648
+ None,
1649
+ position_ids,
1650
+ attention_mask,
1651
+ past_key_values,
1652
+ new_input_embeds,
1653
+ new_labels,
1654
+ vision_tower_aux_feature_list_final,
1655
+ vision_tower_aux_attention_masks_list_final,
1656
+ final_size,
1657
+ global_context_feature_final,
1658
+ )
1659
+
1660
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
1661
+ if model_args.mm_use_im_patch_token:
1662
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
1663
+ self.resize_token_embeddings(len(tokenizer))
1664
+
1665
+ if model_args.mm_use_im_start_end:
1666
+ num_new_tokens = tokenizer.add_tokens(
1667
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
1668
+ )
1669
+ self.resize_token_embeddings(len(tokenizer))
1670
+
1671
+ if num_new_tokens > 0:
1672
+ input_embeddings = self.get_input_embeddings().weight.data
1673
+ output_embeddings = self.get_output_embeddings().weight.data
1674
+
1675
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
1676
+ dim=0, keepdim=True
1677
+ )
1678
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
1679
+ dim=0, keepdim=True
1680
+ )
1681
+
1682
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
1683
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
1684
+
1685
+ if model_args.tune_mm_mlp_adapter:
1686
+ for p in self.get_input_embeddings().parameters():
1687
+ p.requires_grad = True
1688
+ for p in self.get_output_embeddings().parameters():
1689
+ p.requires_grad = False
1690
+
1691
+ if model_args.pretrain_mm_mlp_adapter:
1692
+ mm_projector_weights = torch.load(
1693
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
1694
+ )
1695
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
1696
+ assert num_new_tokens == 2
1697
+ if input_embeddings.shape == embed_tokens_weight.shape:
1698
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
1699
+ -num_new_tokens:
1700
+ ]
1701
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
1702
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
1703
+ else:
1704
+ raise ValueError(
1705
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
1706
+ )
1707
+ elif model_args.mm_use_im_patch_token:
1708
+ if model_args.tune_mm_mlp_adapter:
1709
+ for p in self.get_input_embeddings().parameters():
1710
+ p.requires_grad = False
1711
+ for p in self.get_output_embeddings().parameters():
1712
+ p.requires_grad = False
config.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jadechoghari/LongVU_Llama3_2_3B",
3
+ "architectures": [
4
+ "CambrianLlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 128000,
9
+ "connect_layer": 2,
10
+ "connector_depth": 3,
11
+ "connector_only": true,
12
+ "dino_threshold": 0.83,
13
+ "drop_threshold": 0.8,
14
+ "eos_token_id": [
15
+ 128001,
16
+ 128008,
17
+ 128009
18
+ ],
19
+ "frame_pos": false,
20
+ "freeze_mm_mlp_adapter": false,
21
+ "hidden_act": "silu",
22
+ "hidden_size": 3072,
23
+ "highres": true,
24
+ "highres_connect": false,
25
+ "image_aspect_ratio": "pad",
26
+ "image_position": 91,
27
+ "image_token_len": 144,
28
+ "initializer_range": 0.02,
29
+ "intermediate_size": 8192,
30
+ "is_image_newline": true,
31
+ "is_st_sampler": false,
32
+ "lowres_token": 8,
33
+ "max_position_embeddings": 131072,
34
+ "mlp_bias": false,
35
+ "mm_patch_merge_type": "flat",
36
+ "mm_projector_lr": null,
37
+ "mm_projector_type": "sva",
38
+ "mm_use_im_patch_token": false,
39
+ "mm_use_im_start_end": false,
40
+ "mm_vision_sampler_lr": null,
41
+ "mm_vision_select_feature": "patch",
42
+ "mm_vision_select_layer": -2,
43
+ "mm_vision_tower_aux_list": [
44
+ "siglip/CLIP-ViT-SO400M-14-384",
45
+ "facebook/dinov2-giant-res378"
46
+ ],
47
+ "mm_vision_tower_aux_token_len_list": [
48
+ 576,
49
+ 576
50
+ ],
51
+ "mm_vision_tower_lr": null,
52
+ "model_type": "cambrian_llama",
53
+ "num_attention_heads": 24,
54
+ "num_hidden_layers": 28,
55
+ "num_key_value_heads": 8,
56
+ "num_of_vision_sampler_layers": 10,
57
+ "num_query_group": 1,
58
+ "pretraining_tp": 1,
59
+ "query_num_list": [
60
+ 144
61
+ ],
62
+ "rms_norm_eps": 1e-05,
63
+ "rope_scaling": {
64
+ "factor": 32.0,
65
+ "high_freq_factor": 4.0,
66
+ "low_freq_factor": 1.0,
67
+ "original_max_position_embeddings": 8192,
68
+ "rope_type": "llama3"
69
+ },
70
+ "rope_theta": 500000.0,
71
+ "spmd_debug": null,
72
+ "spmd_fsdp_sharding": null,
73
+ "spmd_mesh": null,
74
+ "start_of_vision_sampler_layers": 0,
75
+ "stride_of_vision_sampler_layers": 3,
76
+ "tie_word_embeddings": false,
77
+ "tokenizer_model_max_length": 8192,
78
+ "tokenizer_padding_side": "right",
79
+ "torch_dtype": "float32",
80
+ "transformers_version": "4.44.2",
81
+ "tune_mm_mlp_adapter": false,
82
+ "unfreeze_mm_vision_tower": false,
83
+ "use_cache": false,
84
+ "use_mm_proj": true,
85
+ "vision_hidden_size": 1024,
86
+ "vision_tower_aux_token_len_list": [
87
+ 576,
88
+ 576
89
+ ],
90
+ "vocab_size": 128256
91
+ }
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 128000,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 128001,
6
+ 128008,
7
+ 128009
8
+ ],
9
+ "temperature": 0.6,
10
+ "top_p": 0.9,
11
+ "transformers_version": "4.45.0.dev0"
12
+ }
modeling.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import CrossEntropyLoss
22
+
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModelForCausalLM,
26
+ LlamaConfig,
27
+ LlamaForCausalLM,
28
+ LlamaModel,
29
+ )
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation.utils import GenerateOutput
32
+
33
+ from transformers.modeling_attn_mask_utils import (
34
+ _prepare_4d_causal_attention_mask,
35
+ _prepare_4d_causal_attention_mask_for_sdpa,
36
+ )
37
+
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ CausalLMOutputWithPast,
41
+ )
42
+ from transformers.utils import logging
43
+
44
+ from cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
45
+
46
+ IS_XLA_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class CambrianConfig(LlamaConfig):
52
+ model_type = "cambrian_llama"
53
+
54
+ debug = "debug"
55
+
56
+
57
+ class CambrianLlamaModel(CambrianMetaModel, LlamaModel):
58
+ config_class = CambrianConfig
59
+
60
+ def __init__(self, config: LlamaConfig):
61
+ super(CambrianLlamaModel, self).__init__(config)
62
+
63
+ def forward(
64
+ self,
65
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
66
+ input_ids: torch.LongTensor = None,
67
+ attention_mask: Optional[torch.Tensor] = None,
68
+ position_ids: Optional[torch.LongTensor] = None,
69
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
70
+ inputs_embeds: Optional[torch.FloatTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ return_dict: Optional[bool] = None,
75
+ vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
76
+ vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
77
+ final_vision_feature_size: Optional[List[tuple]] = None,
78
+ global_context_feature: Optional[torch.Tensor] = None,
79
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
80
+
81
+ output_attentions = (
82
+ output_attentions
83
+ if output_attentions is not None
84
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `config`.
85
+ else self.config.output_attentions
86
+ )
87
+
88
+ output_hidden_states = (
89
+ output_hidden_states
90
+ if output_hidden_states is not None
91
+ else self.config.output_hidden_states
92
+ )
93
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
94
+
95
+ return_dict = (
96
+ return_dict if return_dict is not None else self.config.use_return_dict
97
+ )
98
+
99
+ # retrieve input_ids and inputs_embeds
100
+ if input_ids is not None and inputs_embeds is not None:
101
+ raise ValueError(
102
+ "You cannot specify both input_ids and inputs_embeds at the same time"
103
+ )
104
+ elif input_ids is not None:
105
+ batch_size, seq_length = input_ids.shape[:2]
106
+ elif inputs_embeds is not None:
107
+ batch_size, seq_length = inputs_embeds.shape[:2]
108
+ else:
109
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
110
+
111
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
112
+ # `gradient_checkpointing`.
113
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `training`.
114
+ if self.gradient_checkpointing and self.training:
115
+ if use_cache:
116
+ logger.warning_once(
117
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
118
+ )
119
+ use_cache = False
120
+
121
+ past_key_values_length = 0
122
+ if use_cache:
123
+ use_legacy_cache = not isinstance(past_key_values, Cache)
124
+ if use_legacy_cache:
125
+ # pyre-fixme[9]: past_key_values has type
126
+ # `Optional[List[FloatTensor]]`; used as `DynamicCache`.
127
+ # pyre-fixme[6]: For 1st argument expected
128
+ # `Optional[Tuple[Tuple[FloatTensor]]]` but got
129
+ # `Optional[List[FloatTensor]]`.
130
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
131
+ # pyre-fixme[16]: `Optional` has no attribute `get_usable_length`.
132
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
133
+
134
+ if position_ids is None:
135
+ # pyre-fixme[16]: `Optional` has no attribute `device`.
136
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
137
+ position_ids = torch.arange(
138
+ past_key_values_length,
139
+ seq_length + past_key_values_length,
140
+ dtype=torch.long,
141
+ device=device,
142
+ )
143
+ position_ids = position_ids.unsqueeze(0)
144
+
145
+ if inputs_embeds is None:
146
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `embed_tokens`.
147
+ inputs_embeds = self.embed_tokens(input_ids)
148
+
149
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
150
+ # `_use_flash_attention_2`.
151
+ self._use_flash_attention_2 = getattr(self, "_use_flash_attention_2", False)
152
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `_use_sdpa`.
153
+ self._use_sdpa = getattr(self, "_use_sdpa", True)
154
+ if self._use_flash_attention_2:
155
+ # 2d mask is passed through the layers
156
+ attention_mask = (
157
+ attention_mask
158
+ if (attention_mask is not None and 0 in attention_mask)
159
+ else None
160
+ )
161
+ elif self._use_sdpa and not output_attentions:
162
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
163
+ # the manual implementation that requires a 4D causal mask in all cases.
164
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
165
+ attention_mask,
166
+ (batch_size, seq_length),
167
+ inputs_embeds,
168
+ past_key_values_length,
169
+ )
170
+ else:
171
+ # 4d mask is passed through the layers
172
+ attention_mask = _prepare_4d_causal_attention_mask(
173
+ attention_mask,
174
+ (batch_size, seq_length),
175
+ inputs_embeds,
176
+ past_key_values_length,
177
+ )
178
+
179
+ # embed positions
180
+ hidden_states = inputs_embeds
181
+ # decoder layers
182
+ all_hidden_states = () if output_hidden_states else None
183
+ all_self_attns = () if output_attentions else None
184
+ next_decoder_cache = None
185
+
186
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `layers`.
187
+ for i, decoder_layer in enumerate(self.layers):
188
+ if output_hidden_states:
189
+ all_hidden_states += (hidden_states,)
190
+
191
+ if self.gradient_checkpointing and self.training:
192
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
193
+ # `_gradient_checkpointing_func`.
194
+ layer_outputs = self._gradient_checkpointing_func(
195
+ decoder_layer.__call__,
196
+ hidden_states,
197
+ attention_mask,
198
+ position_ids,
199
+ past_key_values,
200
+ output_attentions,
201
+ use_cache,
202
+ )
203
+ else:
204
+ layer_outputs = decoder_layer(
205
+ hidden_states,
206
+ attention_mask=attention_mask,
207
+ position_ids=position_ids,
208
+ past_key_value=past_key_values,
209
+ output_attentions=output_attentions,
210
+ use_cache=use_cache,
211
+ )
212
+
213
+ hidden_states = layer_outputs[0]
214
+
215
+ if use_cache:
216
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
217
+
218
+ if output_attentions:
219
+ all_self_attns += (layer_outputs[1],)
220
+
221
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute `norm`.
222
+ hidden_states = self.norm(hidden_states)
223
+
224
+ # add hidden states from the last decoder layer
225
+ if output_hidden_states:
226
+ all_hidden_states += (hidden_states,)
227
+
228
+ next_cache = None
229
+ if use_cache:
230
+ next_cache = (
231
+ next_decoder_cache.to_legacy_cache()
232
+ # pyre-fixme[61]: `use_legacy_cache` is undefined, or not always
233
+ # defined.
234
+ if use_legacy_cache
235
+ else next_decoder_cache
236
+ )
237
+ if not return_dict:
238
+ return tuple(
239
+ v
240
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
241
+ if v is not None
242
+ )
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=next_cache,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_self_attns,
248
+ )
249
+
250
+
251
+ class CambrianLlamaForCausalLM(LlamaForCausalLM, CambrianMetaForCausalLM):
252
+ config_class = CambrianConfig
253
+
254
+ def __init__(self, config):
255
+ super(LlamaForCausalLM, self).__init__(config)
256
+
257
+ self.model = CambrianLlamaModel(config)
258
+ self.pretraining_tp = config.pretraining_tp
259
+ self.vocab_size = config.vocab_size
260
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_model(self):
266
+ return self.model
267
+
268
+ def forward(
269
+ self,
270
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
271
+ input_ids: torch.LongTensor = None,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ position_ids: Optional[torch.LongTensor] = None,
274
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
276
+ labels: Optional[torch.LongTensor] = None,
277
+ use_cache: Optional[bool] = None,
278
+ output_attentions: Optional[bool] = None,
279
+ output_hidden_states: Optional[bool] = None,
280
+ images: Optional[torch.FloatTensor] = None,
281
+ image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
282
+ image_sizes: Optional[List[List[int]]] = None,
283
+ return_dict: Optional[bool] = None,
284
+ cache_position=None,
285
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
286
+
287
+ final_vision_feature_size = None
288
+
289
+ if inputs_embeds is None:
290
+ (
291
+ input_ids,
292
+ position_ids,
293
+ attention_mask,
294
+ past_key_values,
295
+ inputs_embeds,
296
+ labels,
297
+ vision_tower_aux_feature_list,
298
+ vision_tower_aux_attention_masks_list,
299
+ final_vision_feature_size,
300
+ global_context_feature,
301
+ ) = self.prepare_inputs_labels_for_multimodal(
302
+ input_ids,
303
+ position_ids,
304
+ attention_mask,
305
+ past_key_values,
306
+ labels,
307
+ images,
308
+ image_aux_attention_masks_list,
309
+ image_sizes,
310
+ )
311
+ if IS_XLA_AVAILABLE:
312
+ # Very Important for TorchXLA
313
+ # self.model.gradient_checkpointing = False
314
+
315
+ # pyre-fixme[21]: Could not find module `torch_xla.utils.checkpoint`.
316
+ from torch_xla.utils.checkpoint import checkpoint
317
+
318
+ # self.model.gradient_checkpointing = True
319
+ # pyre-fixme[16]: `CambrianLlamaModel` has no attribute
320
+ # `_gradient_checkpointing_func`.
321
+ self.model._gradient_checkpointing_func = checkpoint
322
+
323
+ output_attentions = (
324
+ output_attentions
325
+ if output_attentions is not None
326
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute `config`.
327
+ else self.config.output_attentions
328
+ )
329
+ output_hidden_states = (
330
+ output_hidden_states
331
+ if output_hidden_states is not None
332
+ else self.config.output_hidden_states
333
+ )
334
+ return_dict = (
335
+ return_dict if return_dict is not None else self.config.use_return_dict
336
+ )
337
+
338
+ # training
339
+ if IS_XLA_AVAILABLE:
340
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
341
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
342
+ outputs = self.model(
343
+ input_ids=input_ids,
344
+ attention_mask=attention_mask,
345
+ position_ids=position_ids,
346
+ past_key_values=past_key_values,
347
+ inputs_embeds=inputs_embeds,
348
+ use_cache=use_cache,
349
+ output_attentions=output_attentions,
350
+ output_hidden_states=output_hidden_states,
351
+ return_dict=return_dict,
352
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is undefined, or
353
+ # not always defined.
354
+ vision_tower_aux_feature_list=vision_tower_aux_feature_list,
355
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
356
+ # undefined, or not always defined.
357
+ vision_tower_aux_attention_masks_list=vision_tower_aux_attention_masks_list,
358
+ final_vision_feature_size=final_vision_feature_size,
359
+ # pyre-fixme[61]: `global_context_feature` is undefined, or not
360
+ # always defined.
361
+ global_context_feature=global_context_feature,
362
+ )
363
+
364
+ # inference
365
+ else:
366
+ if hasattr(self, "vision_tower_aux_feature_list"):
367
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
368
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
369
+ outputs = self.model(
370
+ input_ids=input_ids,
371
+ attention_mask=attention_mask,
372
+ position_ids=position_ids,
373
+ past_key_values=past_key_values,
374
+ inputs_embeds=inputs_embeds,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ vision_tower_aux_feature_list=(
380
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is
381
+ # undefined, or not always defined.
382
+ vision_tower_aux_feature_list
383
+ if inputs_embeds is None
384
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
385
+ # attribute `vision_tower_aux_feature_list`.
386
+ else self.vision_tower_aux_feature_list
387
+ ),
388
+ vision_tower_aux_attention_masks_list=(
389
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
390
+ # undefined, or not always defined.
391
+ vision_tower_aux_attention_masks_list
392
+ if inputs_embeds is None
393
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
394
+ # attribute `vision_tower_aux_attention_masks_list`.
395
+ else self.vision_tower_aux_attention_masks_list
396
+ ),
397
+ final_vision_feature_size=(
398
+ final_vision_feature_size
399
+ if inputs_embeds is None
400
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
401
+ # attribute `final_vision_feature_size`.
402
+ else self.final_vision_feature_size
403
+ ),
404
+ global_context_feature=(
405
+ # pyre-fixme[61]: `global_context_feature` is undefined, or
406
+ # not always defined.
407
+ global_context_feature
408
+ if inputs_embeds is None
409
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no
410
+ # attribute `global_context_feature`.
411
+ else self.global_context_feature
412
+ ),
413
+ )
414
+ else:
415
+ # pyre-fixme[29]: `CambrianLlamaModel` is not a function.
416
+ outputs = self.model(
417
+ input_ids=input_ids,
418
+ attention_mask=attention_mask,
419
+ position_ids=position_ids,
420
+ past_key_values=past_key_values,
421
+ inputs_embeds=inputs_embeds,
422
+ use_cache=use_cache,
423
+ output_attentions=output_attentions,
424
+ output_hidden_states=output_hidden_states,
425
+ return_dict=return_dict,
426
+ # final_vision_feature_size=final_vision_feature_size,
427
+ )
428
+
429
+ hidden_states = outputs[0]
430
+ if self.config.pretraining_tp > 1:
431
+ lm_head_slices = self.lm_head.weight.split(
432
+ self.vocab_size // self.config.pretraining_tp, dim=0
433
+ )
434
+ logits = [
435
+ F.linear(hidden_states, lm_head_slices[i])
436
+ for i in range(self.config.pretraining_tp)
437
+ ]
438
+ logits = torch.cat(logits, dim=-1)
439
+ else:
440
+ logits = self.lm_head(hidden_states)
441
+ logits = logits.float()
442
+
443
+ loss = None
444
+ if labels is not None:
445
+ # Shift so that tokens < n predict n
446
+ shift_logits = logits[..., :-1, :].contiguous()
447
+ shift_labels = labels[..., 1:].contiguous()
448
+ # Flatten the tokens
449
+ loss_fct = CrossEntropyLoss()
450
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
451
+ shift_labels = shift_labels.view(-1)
452
+ # Enable model parallelism
453
+ shift_labels = shift_labels.to(shift_logits.device)
454
+ loss = loss_fct(shift_logits, shift_labels)
455
+
456
+ if not return_dict:
457
+ output = (logits,) + outputs[1:]
458
+ return (loss,) + output if loss is not None else output
459
+
460
+ return CausalLMOutputWithPast(
461
+ loss=loss,
462
+ logits=logits,
463
+ past_key_values=outputs.past_key_values,
464
+ hidden_states=outputs.hidden_states,
465
+ attentions=outputs.attentions,
466
+ )
467
+
468
+ @torch.no_grad()
469
+ def generate(
470
+ self,
471
+ inputs: Optional[torch.Tensor] = None,
472
+ images: Optional[torch.Tensor] = None,
473
+ image_sizes: Optional[torch.Tensor] = None,
474
+ **kwargs,
475
+ ) -> Union[GenerateOutput, torch.LongTensor]:
476
+ position_ids = kwargs.pop("position_ids", None)
477
+ attention_mask = kwargs.pop("attention_mask", None)
478
+ if "inputs_embeds" in kwargs:
479
+ raise NotImplementedError("`inputs_embeds` is not supported")
480
+
481
+ if images is not None:
482
+ (
483
+ inputs,
484
+ position_ids,
485
+ attention_mask,
486
+ _,
487
+ inputs_embeds,
488
+ _,
489
+ vision_tower_aux_feature_list,
490
+ vision_tower_aux_attention_masks_list,
491
+ final_vision_feature_size,
492
+ global_context_feature,
493
+ ) = self.prepare_inputs_labels_for_multimodal(
494
+ inputs,
495
+ position_ids,
496
+ attention_mask,
497
+ None,
498
+ None,
499
+ images,
500
+ image_sizes=image_sizes,
501
+ )
502
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
503
+ # `vision_tower_aux_feature_list`.
504
+ self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
505
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
506
+ # `vision_tower_aux_attention_masks_list`.
507
+ self.vision_tower_aux_attention_masks_list = (
508
+ vision_tower_aux_attention_masks_list
509
+ )
510
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
511
+ # `final_vision_feature_size`.
512
+ self.final_vision_feature_size = final_vision_feature_size
513
+ # pyre-fixme[16]: `CambrianLlamaForCausalLM` has no attribute
514
+ # `global_context_feature`.
515
+ self.global_context_feature = global_context_feature
516
+ else:
517
+ inputs_embeds = self.get_model().embed_tokens(inputs)
518
+
519
+ # pyre-fixme[16]: `LlamaForCausalLM` has no attribute `generate`.
520
+ return super().generate(
521
+ position_ids=position_ids,
522
+ attention_mask=attention_mask,
523
+ inputs_embeds=inputs_embeds,
524
+ **kwargs,
525
+ )
526
+
527
+ def prepare_inputs_for_generation(
528
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
529
+ ):
530
+ images = kwargs.pop("images", None)
531
+ image_sizes = kwargs.pop("image_sizes", None)
532
+ inputs = super().prepare_inputs_for_generation(
533
+ input_ids,
534
+ past_key_values=past_key_values,
535
+ inputs_embeds=inputs_embeds,
536
+ **kwargs,
537
+ )
538
+ if images is not None:
539
+ inputs["images"] = images
540
+ if image_sizes is not None:
541
+ inputs["image_sizes"] = image_sizes
542
+ return inputs
543
+
544
+
545
+ AutoConfig.register("cambrian_llama", CambrianConfig)
546
+ AutoModelForCausalLM.register(CambrianConfig, CambrianLlamaForCausalLM)
multimodal_encoder_builder.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
6
+ from abc import ABC, abstractmethod
7
+ import torch.nn as nn
8
+
9
+
10
+ class ProcessorWrapper:
11
+ def __init__(
12
+ self,
13
+ transform,
14
+ height=378,
15
+ width=378,
16
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
17
+ ):
18
+ self._crop_size = {
19
+ "height": height,
20
+ "width": width,
21
+ }
22
+ self._transforms = transform
23
+ # print(transform)
24
+ self.image_mean = image_mean
25
+
26
+ @property
27
+ def crop_size(self):
28
+ return self._crop_size
29
+
30
+ def preprocess(self, image, return_tensors="pt"):
31
+ # Ensure image is a PIL Image
32
+ output = {}
33
+ output["pixel_values"] = [self._transforms(image)]
34
+ return output
35
+
36
+
37
+ class BaseVisionTower(nn.Module):
38
+ def __init__(self, vision_tower_name, args, delay_load=False):
39
+ super().__init__()
40
+
41
+ self.is_loaded = False
42
+ self.args = args
43
+
44
+ self.vision_tower_name = vision_tower_name
45
+ self.select_layer = args.mm_vision_select_layer
46
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
47
+ self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
48
+ self.delay_load = delay_load
49
+
50
+ @abstractmethod
51
+ def load_model(self, device_map=None):
52
+ raise NotImplementedError("Subclasses must implement load_model")
53
+
54
+ @abstractmethod
55
+ def _forward(self, images):
56
+ raise NotImplementedError("Subclasses must implement forward")
57
+
58
+ def forward(self, images):
59
+ if type(images) is list:
60
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
61
+ else:
62
+ image_features = self._forward(images)
63
+
64
+ return image_features
65
+
66
+ @property
67
+ def dummy_feature(self):
68
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
69
+
70
+ @property
71
+ def dtype(self):
72
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
73
+ if hasattr(self.vision_tower, "dtype"):
74
+ return self.vision_tower.dtype
75
+ else:
76
+ params = list(self.vision_tower.parameters())
77
+ return (
78
+ params[0].dtype if len(params) > 0 else torch.float32
79
+ ) # Default to torch.float32 if no parameters
80
+
81
+ @property
82
+ def device(self):
83
+ # Dynamically infer the device from the first parameter, if not explicitly specified
84
+ if hasattr(self.vision_tower, "device"):
85
+ return self.vision_tower.device
86
+ else:
87
+ params = list(self.vision_tower.parameters())
88
+ return (
89
+ params[0].device if len(params) > 0 else torch.device("cpu")
90
+ ) # Default to CPU if no parameters
91
+
92
+ @property
93
+ def config(self):
94
+ if self.is_loaded:
95
+ return self.vision_tower.config
96
+ else:
97
+ return self.cfg_only
98
+
99
+ @property
100
+ def hidden_size(self):
101
+ try:
102
+ return self.config.hidden_size
103
+ except:
104
+ return self._hidden_size
105
+
106
+ @property
107
+ def image_size(self): # resolution
108
+ # return self.config.image_size
109
+ try:
110
+ return self.config.image_size
111
+ except:
112
+ return self._image_size
113
+
114
+ @property
115
+ def patch_size(self):
116
+ # return self.config.patch_size
117
+ try:
118
+ return self.config.patch_size
119
+ except:
120
+ return self._patch_size
121
+
122
+ @property
123
+ def num_patches_per_side(self):
124
+ if self._interp_size is not None:
125
+ return int(self._interp_size**0.5)
126
+ try:
127
+ return self.image_size // self.patch_size
128
+ except:
129
+ return self._num_patches_per_side
130
+
131
+ @property
132
+ def num_patches(self):
133
+ if self._interp_size is not None:
134
+ return self._interp_size
135
+ try:
136
+ return self.num_patches_per_side**2
137
+ except:
138
+ return self._num_patches
139
+
140
+
141
+ class DinoVisionTower(BaseVisionTower):
142
+ def __init__(self, vision_tower, args, delay_load=False):
143
+ super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
144
+
145
+ model_path = "facebook/dinov2-giant"
146
+ base_model_name, res, interp = model_path, 378, 576
147
+ self._vision_tower_name = vision_tower
148
+ self.vision_tower_name = base_model_name
149
+ self._image_size = res
150
+ self._interp_size = interp
151
+ self._patch_size = 14 # default patch size
152
+
153
+ if not self.delay_load:
154
+ self.load_model()
155
+ else:
156
+ self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
157
+
158
+ def load_model(self, device_map=None):
159
+
160
+ self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
161
+ """ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
162
+ self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
163
+
164
+ _image_size = self.vision_tower.config.image_size
165
+ if self._image_size is None:
166
+ self._image_size = _image_size
167
+
168
+ # increase shortest edge to prevent edge case crops
169
+ default_shortest_ratio = 8 / 7 # 224/256
170
+ # shortest_edge = int(default_shortest_ratio * self._image_size)
171
+ shortest_edge = self._image_size
172
+
173
+ processor = AutoImageProcessor.from_pretrained(
174
+ self.vision_tower_name,
175
+ crop_size=dict(height=self._image_size, width=self._image_size),
176
+ size=dict(shortest_edge=shortest_edge),
177
+ )
178
+ self.image_processor = processor
179
+
180
+ # Assign the output channels of the projection convolution as the hidden size
181
+ self._hidden_size = (
182
+ self.vision_tower.embeddings.patch_embeddings.projection.out_channels
183
+ )
184
+ # Assign the first value of the stride of the projection convolution as the patch size
185
+ self._patch_size = (
186
+ self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
187
+ )
188
+
189
+ # print(self._hidden_size, self._patch_size)
190
+
191
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
192
+ self.is_loaded = True
193
+
194
+ @property
195
+ def image_size(self):
196
+ return self._image_size
197
+
198
+ def feature_select(self, outputs):
199
+ sequence_output = outputs[
200
+ "last_hidden_state"
201
+ ] # batch_size, sequence_length, hidden_size
202
+
203
+ if self.select_feature == "cls_patch":
204
+ image_features = sequence_output
205
+ elif self.select_feature == "patch":
206
+ image_features = sequence_output[:, 1:]
207
+ elif self.select_feature == "cls":
208
+ image_features = sequence_output[:, 0]
209
+ else:
210
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
211
+ return image_features
212
+
213
+ def interpolate(self, image_features):
214
+ if self._interp_size is None:
215
+ return image_features
216
+
217
+ b, num_tokens, dim = image_features.shape
218
+
219
+ if num_tokens != self.num_patches:
220
+ target_h = target_w = int(self._interp_size**0.5)
221
+ h = w = int(num_tokens**0.5)
222
+
223
+ image_features = image_features.view(b, h, w, dim)
224
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
225
+
226
+ image_features = F.interpolate(
227
+ image_features.to(torch.float32),
228
+ size=(target_h, target_w),
229
+ mode="bilinear",
230
+ align_corners=False,
231
+ ).to(image_features.dtype)
232
+
233
+ # Permute the dimensions back to (b, target_h, target_w, dim)
234
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
235
+
236
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
237
+ image_features = image_features.flatten(1, 2)
238
+
239
+ return image_features
240
+
241
+ def _forward(self, images):
242
+ # logger.warning(f"images shape: {images.shape}")
243
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
244
+ image_forward_outs = self.vision_tower.forward(
245
+ images.to(device=self.device, dtype=self.dtype)
246
+ )
247
+ # logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
248
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
249
+ # logger.warning(f"image_features shape: {image_features.shape}")
250
+ interp_features = self.interpolate(image_features)
251
+ # logger.warning(f"interp_features shape: {interp_features.shape}")
252
+ return interp_features
253
+
254
+ @property
255
+ def num_patches_per_side(self):
256
+ return int(self.num_patches**0.5)
257
+
258
+ @property
259
+ def num_patches(self):
260
+ if self._interp_size is None:
261
+ return (self._image_size // self._patch_size) ** 2
262
+ else:
263
+ return self._interp_size
264
+
265
+
266
+ # from .siglip_encoder import SiglipVisionTower
267
+ class SiglipVisionTower(BaseVisionTower):
268
+ def __init__(self, vision_tower_name, args, delay_load=False):
269
+ super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
270
+
271
+ model_path = "google/siglip-so400m-patch14-384"
272
+ base_model_name, res, interp = model_path, 384, 576
273
+ self.vision_tower_name = base_model_name
274
+ self._image_size = res if res is not None else 512
275
+ self._interp_size = interp
276
+ if not self.delay_load:
277
+ self.load_model()
278
+ elif self.unfreeze_mm_vision_tower:
279
+ self.load_model()
280
+ else:
281
+ self._hidden_size = 1152
282
+
283
+ def load_model(self, device_map=None):
284
+ self.vision_model = "siglip"
285
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
286
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
287
+
288
+ # self.vision_tower = clip_model.visual.trunk
289
+ self.vision_tower.output_tokens = True
290
+
291
+ self._hidden_size = self.vision_tower.config.hidden_size
292
+ self._image_size = self.vision_tower.config.image_size
293
+ self._patch_size = self.vision_tower.config.patch_size
294
+ self.image_processor = SiglipImageProcessor.from_pretrained(
295
+ self.vision_tower_name
296
+ )
297
+
298
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
299
+ self.is_loaded = True
300
+
301
+ def interpolate(self, image_features):
302
+ if self._interp_size is None:
303
+ return image_features
304
+
305
+ b, num_tokens, dim = image_features.shape
306
+
307
+ if num_tokens != self.num_patches:
308
+ target_h = target_w = int(self._interp_size**0.5)
309
+ h = w = int(num_tokens**0.5)
310
+
311
+ image_features = image_features.view(b, h, w, dim)
312
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
313
+
314
+ image_features = F.interpolate(
315
+ image_features.to(torch.float32),
316
+ size=(target_h, target_w),
317
+ mode="bilinear",
318
+ align_corners=False,
319
+ ).to(image_features.dtype)
320
+
321
+ # Permute the dimensions back to (b, target_h, target_w, dim)
322
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
323
+
324
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
325
+ image_features = image_features.flatten(1, 2)
326
+
327
+ return image_features
328
+
329
+ def _forward(self, images, interpolate_token=576):
330
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
331
+ image_features = self.vision_tower.forward(
332
+ images.to(device=self.device, dtype=self.dtype),
333
+ output_hidden_states=True,
334
+ ).hidden_states[-1]
335
+ interp_features = self.interpolate(image_features)
336
+ return interp_features
337
+
338
+
339
+ def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
340
+ vision_tower_aux_name_list = getattr(
341
+ vision_tower_cfg,
342
+ "mm_vision_tower_aux_list",
343
+ getattr(vision_tower_cfg, "vision_tower_aux_list", None),
344
+ )
345
+ vision_tower_aux_token_len_list = getattr(
346
+ vision_tower_cfg,
347
+ "mm_vision_tower_aux_token_len_list",
348
+ getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
349
+ )
350
+ vision_tower_aux_list = []
351
+ for vision_tower_aux_name, vision_tower_aux_token_len in zip(
352
+ vision_tower_aux_name_list, vision_tower_aux_token_len_list
353
+ ):
354
+ config = copy.deepcopy(vision_tower_cfg)
355
+ vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
356
+ if "siglip" in vision_tower_aux_name.lower():
357
+ vision_tower_aux_list.append(
358
+ SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
359
+ )
360
+
361
+ # SSL-based Vision Towers
362
+ elif "dinov2" in vision_tower_aux_name.lower():
363
+ vision_tower_aux_list.append(
364
+ DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
365
+ )
366
+ else:
367
+ raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
368
+ return vision_tower_aux_list
multimodal_projector_builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import re
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": "identity"}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+ config.mm_hidden_size = 256
36
+
37
+ if projector_type == "linear":
38
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
39
+
40
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == "identity":
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f"Unknown projector type: {projector_type}")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b65748460ffda13e94d55cd97b7ea248faf4f515e56ceb59eba2126315dca3c
3
+ size 7317897562
special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|eot_id|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,2062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "128000": {
4
+ "content": "<|begin_of_text|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "128001": {
12
+ "content": "<|end_of_text|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "128002": {
20
+ "content": "<|reserved_special_token_0|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "128003": {
28
+ "content": "<|reserved_special_token_1|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "128004": {
36
+ "content": "<|finetune_right_pad_id|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "128005": {
44
+ "content": "<|reserved_special_token_2|>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "128006": {
52
+ "content": "<|start_header_id|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "128007": {
60
+ "content": "<|end_header_id|>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "128008": {
68
+ "content": "<|eom_id|>",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "128009": {
76
+ "content": "<|eot_id|>",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ },
83
+ "128010": {
84
+ "content": "<|python_tag|>",
85
+ "lstrip": false,
86
+ "normalized": false,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": true
90
+ },
91
+ "128011": {
92
+ "content": "<|reserved_special_token_3|>",
93
+ "lstrip": false,
94
+ "normalized": false,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": true
98
+ },
99
+ "128012": {
100
+ "content": "<|reserved_special_token_4|>",
101
+ "lstrip": false,
102
+ "normalized": false,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": true
106
+ },
107
+ "128013": {
108
+ "content": "<|reserved_special_token_5|>",
109
+ "lstrip": false,
110
+ "normalized": false,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": true
114
+ },
115
+ "128014": {
116
+ "content": "<|reserved_special_token_6|>",
117
+ "lstrip": false,
118
+ "normalized": false,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": true
122
+ },
123
+ "128015": {
124
+ "content": "<|reserved_special_token_7|>",
125
+ "lstrip": false,
126
+ "normalized": false,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": true
130
+ },
131
+ "128016": {
132
+ "content": "<|reserved_special_token_8|>",
133
+ "lstrip": false,
134
+ "normalized": false,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": true
138
+ },
139
+ "128017": {
140
+ "content": "<|reserved_special_token_9|>",
141
+ "lstrip": false,
142
+ "normalized": false,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": true
146
+ },
147
+ "128018": {
148
+ "content": "<|reserved_special_token_10|>",
149
+ "lstrip": false,
150
+ "normalized": false,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": true
154
+ },
155
+ "128019": {
156
+ "content": "<|reserved_special_token_11|>",
157
+ "lstrip": false,
158
+ "normalized": false,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": true
162
+ },
163
+ "128020": {
164
+ "content": "<|reserved_special_token_12|>",
165
+ "lstrip": false,
166
+ "normalized": false,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": true
170
+ },
171
+ "128021": {
172
+ "content": "<|reserved_special_token_13|>",
173
+ "lstrip": false,
174
+ "normalized": false,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": true
178
+ },
179
+ "128022": {
180
+ "content": "<|reserved_special_token_14|>",
181
+ "lstrip": false,
182
+ "normalized": false,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": true
186
+ },
187
+ "128023": {
188
+ "content": "<|reserved_special_token_15|>",
189
+ "lstrip": false,
190
+ "normalized": false,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": true
194
+ },
195
+ "128024": {
196
+ "content": "<|reserved_special_token_16|>",
197
+ "lstrip": false,
198
+ "normalized": false,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": true
202
+ },
203
+ "128025": {
204
+ "content": "<|reserved_special_token_17|>",
205
+ "lstrip": false,
206
+ "normalized": false,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": true
210
+ },
211
+ "128026": {
212
+ "content": "<|reserved_special_token_18|>",
213
+ "lstrip": false,
214
+ "normalized": false,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": true
218
+ },
219
+ "128027": {
220
+ "content": "<|reserved_special_token_19|>",
221
+ "lstrip": false,
222
+ "normalized": false,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": true
226
+ },
227
+ "128028": {
228
+ "content": "<|reserved_special_token_20|>",
229
+ "lstrip": false,
230
+ "normalized": false,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": true
234
+ },
235
+ "128029": {
236
+ "content": "<|reserved_special_token_21|>",
237
+ "lstrip": false,
238
+ "normalized": false,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": true
242
+ },
243
+ "128030": {
244
+ "content": "<|reserved_special_token_22|>",
245
+ "lstrip": false,
246
+ "normalized": false,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": true
250
+ },
251
+ "128031": {
252
+ "content": "<|reserved_special_token_23|>",
253
+ "lstrip": false,
254
+ "normalized": false,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": true
258
+ },
259
+ "128032": {
260
+ "content": "<|reserved_special_token_24|>",
261
+ "lstrip": false,
262
+ "normalized": false,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": true
266
+ },
267
+ "128033": {
268
+ "content": "<|reserved_special_token_25|>",
269
+ "lstrip": false,
270
+ "normalized": false,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": true
274
+ },
275
+ "128034": {
276
+ "content": "<|reserved_special_token_26|>",
277
+ "lstrip": false,
278
+ "normalized": false,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": true
282
+ },
283
+ "128035": {
284
+ "content": "<|reserved_special_token_27|>",
285
+ "lstrip": false,
286
+ "normalized": false,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": true
290
+ },
291
+ "128036": {
292
+ "content": "<|reserved_special_token_28|>",
293
+ "lstrip": false,
294
+ "normalized": false,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": true
298
+ },
299
+ "128037": {
300
+ "content": "<|reserved_special_token_29|>",
301
+ "lstrip": false,
302
+ "normalized": false,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": true
306
+ },
307
+ "128038": {
308
+ "content": "<|reserved_special_token_30|>",
309
+ "lstrip": false,
310
+ "normalized": false,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": true
314
+ },
315
+ "128039": {
316
+ "content": "<|reserved_special_token_31|>",
317
+ "lstrip": false,
318
+ "normalized": false,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": true
322
+ },
323
+ "128040": {
324
+ "content": "<|reserved_special_token_32|>",
325
+ "lstrip": false,
326
+ "normalized": false,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": true
330
+ },
331
+ "128041": {
332
+ "content": "<|reserved_special_token_33|>",
333
+ "lstrip": false,
334
+ "normalized": false,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": true
338
+ },
339
+ "128042": {
340
+ "content": "<|reserved_special_token_34|>",
341
+ "lstrip": false,
342
+ "normalized": false,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": true
346
+ },
347
+ "128043": {
348
+ "content": "<|reserved_special_token_35|>",
349
+ "lstrip": false,
350
+ "normalized": false,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": true
354
+ },
355
+ "128044": {
356
+ "content": "<|reserved_special_token_36|>",
357
+ "lstrip": false,
358
+ "normalized": false,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": true
362
+ },
363
+ "128045": {
364
+ "content": "<|reserved_special_token_37|>",
365
+ "lstrip": false,
366
+ "normalized": false,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": true
370
+ },
371
+ "128046": {
372
+ "content": "<|reserved_special_token_38|>",
373
+ "lstrip": false,
374
+ "normalized": false,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": true
378
+ },
379
+ "128047": {
380
+ "content": "<|reserved_special_token_39|>",
381
+ "lstrip": false,
382
+ "normalized": false,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": true
386
+ },
387
+ "128048": {
388
+ "content": "<|reserved_special_token_40|>",
389
+ "lstrip": false,
390
+ "normalized": false,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": true
394
+ },
395
+ "128049": {
396
+ "content": "<|reserved_special_token_41|>",
397
+ "lstrip": false,
398
+ "normalized": false,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": true
402
+ },
403
+ "128050": {
404
+ "content": "<|reserved_special_token_42|>",
405
+ "lstrip": false,
406
+ "normalized": false,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": true
410
+ },
411
+ "128051": {
412
+ "content": "<|reserved_special_token_43|>",
413
+ "lstrip": false,
414
+ "normalized": false,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": true
418
+ },
419
+ "128052": {
420
+ "content": "<|reserved_special_token_44|>",
421
+ "lstrip": false,
422
+ "normalized": false,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": true
426
+ },
427
+ "128053": {
428
+ "content": "<|reserved_special_token_45|>",
429
+ "lstrip": false,
430
+ "normalized": false,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": true
434
+ },
435
+ "128054": {
436
+ "content": "<|reserved_special_token_46|>",
437
+ "lstrip": false,
438
+ "normalized": false,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": true
442
+ },
443
+ "128055": {
444
+ "content": "<|reserved_special_token_47|>",
445
+ "lstrip": false,
446
+ "normalized": false,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": true
450
+ },
451
+ "128056": {
452
+ "content": "<|reserved_special_token_48|>",
453
+ "lstrip": false,
454
+ "normalized": false,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": true
458
+ },
459
+ "128057": {
460
+ "content": "<|reserved_special_token_49|>",
461
+ "lstrip": false,
462
+ "normalized": false,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": true
466
+ },
467
+ "128058": {
468
+ "content": "<|reserved_special_token_50|>",
469
+ "lstrip": false,
470
+ "normalized": false,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": true
474
+ },
475
+ "128059": {
476
+ "content": "<|reserved_special_token_51|>",
477
+ "lstrip": false,
478
+ "normalized": false,
479
+ "rstrip": false,
480
+ "single_word": false,
481
+ "special": true
482
+ },
483
+ "128060": {
484
+ "content": "<|reserved_special_token_52|>",
485
+ "lstrip": false,
486
+ "normalized": false,
487
+ "rstrip": false,
488
+ "single_word": false,
489
+ "special": true
490
+ },
491
+ "128061": {
492
+ "content": "<|reserved_special_token_53|>",
493
+ "lstrip": false,
494
+ "normalized": false,
495
+ "rstrip": false,
496
+ "single_word": false,
497
+ "special": true
498
+ },
499
+ "128062": {
500
+ "content": "<|reserved_special_token_54|>",
501
+ "lstrip": false,
502
+ "normalized": false,
503
+ "rstrip": false,
504
+ "single_word": false,
505
+ "special": true
506
+ },
507
+ "128063": {
508
+ "content": "<|reserved_special_token_55|>",
509
+ "lstrip": false,
510
+ "normalized": false,
511
+ "rstrip": false,
512
+ "single_word": false,
513
+ "special": true
514
+ },
515
+ "128064": {
516
+ "content": "<|reserved_special_token_56|>",
517
+ "lstrip": false,
518
+ "normalized": false,
519
+ "rstrip": false,
520
+ "single_word": false,
521
+ "special": true
522
+ },
523
+ "128065": {
524
+ "content": "<|reserved_special_token_57|>",
525
+ "lstrip": false,
526
+ "normalized": false,
527
+ "rstrip": false,
528
+ "single_word": false,
529
+ "special": true
530
+ },
531
+ "128066": {
532
+ "content": "<|reserved_special_token_58|>",
533
+ "lstrip": false,
534
+ "normalized": false,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": true
538
+ },
539
+ "128067": {
540
+ "content": "<|reserved_special_token_59|>",
541
+ "lstrip": false,
542
+ "normalized": false,
543
+ "rstrip": false,
544
+ "single_word": false,
545
+ "special": true
546
+ },
547
+ "128068": {
548
+ "content": "<|reserved_special_token_60|>",
549
+ "lstrip": false,
550
+ "normalized": false,
551
+ "rstrip": false,
552
+ "single_word": false,
553
+ "special": true
554
+ },
555
+ "128069": {
556
+ "content": "<|reserved_special_token_61|>",
557
+ "lstrip": false,
558
+ "normalized": false,
559
+ "rstrip": false,
560
+ "single_word": false,
561
+ "special": true
562
+ },
563
+ "128070": {
564
+ "content": "<|reserved_special_token_62|>",
565
+ "lstrip": false,
566
+ "normalized": false,
567
+ "rstrip": false,
568
+ "single_word": false,
569
+ "special": true
570
+ },
571
+ "128071": {
572
+ "content": "<|reserved_special_token_63|>",
573
+ "lstrip": false,
574
+ "normalized": false,
575
+ "rstrip": false,
576
+ "single_word": false,
577
+ "special": true
578
+ },
579
+ "128072": {
580
+ "content": "<|reserved_special_token_64|>",
581
+ "lstrip": false,
582
+ "normalized": false,
583
+ "rstrip": false,
584
+ "single_word": false,
585
+ "special": true
586
+ },
587
+ "128073": {
588
+ "content": "<|reserved_special_token_65|>",
589
+ "lstrip": false,
590
+ "normalized": false,
591
+ "rstrip": false,
592
+ "single_word": false,
593
+ "special": true
594
+ },
595
+ "128074": {
596
+ "content": "<|reserved_special_token_66|>",
597
+ "lstrip": false,
598
+ "normalized": false,
599
+ "rstrip": false,
600
+ "single_word": false,
601
+ "special": true
602
+ },
603
+ "128075": {
604
+ "content": "<|reserved_special_token_67|>",
605
+ "lstrip": false,
606
+ "normalized": false,
607
+ "rstrip": false,
608
+ "single_word": false,
609
+ "special": true
610
+ },
611
+ "128076": {
612
+ "content": "<|reserved_special_token_68|>",
613
+ "lstrip": false,
614
+ "normalized": false,
615
+ "rstrip": false,
616
+ "single_word": false,
617
+ "special": true
618
+ },
619
+ "128077": {
620
+ "content": "<|reserved_special_token_69|>",
621
+ "lstrip": false,
622
+ "normalized": false,
623
+ "rstrip": false,
624
+ "single_word": false,
625
+ "special": true
626
+ },
627
+ "128078": {
628
+ "content": "<|reserved_special_token_70|>",
629
+ "lstrip": false,
630
+ "normalized": false,
631
+ "rstrip": false,
632
+ "single_word": false,
633
+ "special": true
634
+ },
635
+ "128079": {
636
+ "content": "<|reserved_special_token_71|>",
637
+ "lstrip": false,
638
+ "normalized": false,
639
+ "rstrip": false,
640
+ "single_word": false,
641
+ "special": true
642
+ },
643
+ "128080": {
644
+ "content": "<|reserved_special_token_72|>",
645
+ "lstrip": false,
646
+ "normalized": false,
647
+ "rstrip": false,
648
+ "single_word": false,
649
+ "special": true
650
+ },
651
+ "128081": {
652
+ "content": "<|reserved_special_token_73|>",
653
+ "lstrip": false,
654
+ "normalized": false,
655
+ "rstrip": false,
656
+ "single_word": false,
657
+ "special": true
658
+ },
659
+ "128082": {
660
+ "content": "<|reserved_special_token_74|>",
661
+ "lstrip": false,
662
+ "normalized": false,
663
+ "rstrip": false,
664
+ "single_word": false,
665
+ "special": true
666
+ },
667
+ "128083": {
668
+ "content": "<|reserved_special_token_75|>",
669
+ "lstrip": false,
670
+ "normalized": false,
671
+ "rstrip": false,
672
+ "single_word": false,
673
+ "special": true
674
+ },
675
+ "128084": {
676
+ "content": "<|reserved_special_token_76|>",
677
+ "lstrip": false,
678
+ "normalized": false,
679
+ "rstrip": false,
680
+ "single_word": false,
681
+ "special": true
682
+ },
683
+ "128085": {
684
+ "content": "<|reserved_special_token_77|>",
685
+ "lstrip": false,
686
+ "normalized": false,
687
+ "rstrip": false,
688
+ "single_word": false,
689
+ "special": true
690
+ },
691
+ "128086": {
692
+ "content": "<|reserved_special_token_78|>",
693
+ "lstrip": false,
694
+ "normalized": false,
695
+ "rstrip": false,
696
+ "single_word": false,
697
+ "special": true
698
+ },
699
+ "128087": {
700
+ "content": "<|reserved_special_token_79|>",
701
+ "lstrip": false,
702
+ "normalized": false,
703
+ "rstrip": false,
704
+ "single_word": false,
705
+ "special": true
706
+ },
707
+ "128088": {
708
+ "content": "<|reserved_special_token_80|>",
709
+ "lstrip": false,
710
+ "normalized": false,
711
+ "rstrip": false,
712
+ "single_word": false,
713
+ "special": true
714
+ },
715
+ "128089": {
716
+ "content": "<|reserved_special_token_81|>",
717
+ "lstrip": false,
718
+ "normalized": false,
719
+ "rstrip": false,
720
+ "single_word": false,
721
+ "special": true
722
+ },
723
+ "128090": {
724
+ "content": "<|reserved_special_token_82|>",
725
+ "lstrip": false,
726
+ "normalized": false,
727
+ "rstrip": false,
728
+ "single_word": false,
729
+ "special": true
730
+ },
731
+ "128091": {
732
+ "content": "<|reserved_special_token_83|>",
733
+ "lstrip": false,
734
+ "normalized": false,
735
+ "rstrip": false,
736
+ "single_word": false,
737
+ "special": true
738
+ },
739
+ "128092": {
740
+ "content": "<|reserved_special_token_84|>",
741
+ "lstrip": false,
742
+ "normalized": false,
743
+ "rstrip": false,
744
+ "single_word": false,
745
+ "special": true
746
+ },
747
+ "128093": {
748
+ "content": "<|reserved_special_token_85|>",
749
+ "lstrip": false,
750
+ "normalized": false,
751
+ "rstrip": false,
752
+ "single_word": false,
753
+ "special": true
754
+ },
755
+ "128094": {
756
+ "content": "<|reserved_special_token_86|>",
757
+ "lstrip": false,
758
+ "normalized": false,
759
+ "rstrip": false,
760
+ "single_word": false,
761
+ "special": true
762
+ },
763
+ "128095": {
764
+ "content": "<|reserved_special_token_87|>",
765
+ "lstrip": false,
766
+ "normalized": false,
767
+ "rstrip": false,
768
+ "single_word": false,
769
+ "special": true
770
+ },
771
+ "128096": {
772
+ "content": "<|reserved_special_token_88|>",
773
+ "lstrip": false,
774
+ "normalized": false,
775
+ "rstrip": false,
776
+ "single_word": false,
777
+ "special": true
778
+ },
779
+ "128097": {
780
+ "content": "<|reserved_special_token_89|>",
781
+ "lstrip": false,
782
+ "normalized": false,
783
+ "rstrip": false,
784
+ "single_word": false,
785
+ "special": true
786
+ },
787
+ "128098": {
788
+ "content": "<|reserved_special_token_90|>",
789
+ "lstrip": false,
790
+ "normalized": false,
791
+ "rstrip": false,
792
+ "single_word": false,
793
+ "special": true
794
+ },
795
+ "128099": {
796
+ "content": "<|reserved_special_token_91|>",
797
+ "lstrip": false,
798
+ "normalized": false,
799
+ "rstrip": false,
800
+ "single_word": false,
801
+ "special": true
802
+ },
803
+ "128100": {
804
+ "content": "<|reserved_special_token_92|>",
805
+ "lstrip": false,
806
+ "normalized": false,
807
+ "rstrip": false,
808
+ "single_word": false,
809
+ "special": true
810
+ },
811
+ "128101": {
812
+ "content": "<|reserved_special_token_93|>",
813
+ "lstrip": false,
814
+ "normalized": false,
815
+ "rstrip": false,
816
+ "single_word": false,
817
+ "special": true
818
+ },
819
+ "128102": {
820
+ "content": "<|reserved_special_token_94|>",
821
+ "lstrip": false,
822
+ "normalized": false,
823
+ "rstrip": false,
824
+ "single_word": false,
825
+ "special": true
826
+ },
827
+ "128103": {
828
+ "content": "<|reserved_special_token_95|>",
829
+ "lstrip": false,
830
+ "normalized": false,
831
+ "rstrip": false,
832
+ "single_word": false,
833
+ "special": true
834
+ },
835
+ "128104": {
836
+ "content": "<|reserved_special_token_96|>",
837
+ "lstrip": false,
838
+ "normalized": false,
839
+ "rstrip": false,
840
+ "single_word": false,
841
+ "special": true
842
+ },
843
+ "128105": {
844
+ "content": "<|reserved_special_token_97|>",
845
+ "lstrip": false,
846
+ "normalized": false,
847
+ "rstrip": false,
848
+ "single_word": false,
849
+ "special": true
850
+ },
851
+ "128106": {
852
+ "content": "<|reserved_special_token_98|>",
853
+ "lstrip": false,
854
+ "normalized": false,
855
+ "rstrip": false,
856
+ "single_word": false,
857
+ "special": true
858
+ },
859
+ "128107": {
860
+ "content": "<|reserved_special_token_99|>",
861
+ "lstrip": false,
862
+ "normalized": false,
863
+ "rstrip": false,
864
+ "single_word": false,
865
+ "special": true
866
+ },
867
+ "128108": {
868
+ "content": "<|reserved_special_token_100|>",
869
+ "lstrip": false,
870
+ "normalized": false,
871
+ "rstrip": false,
872
+ "single_word": false,
873
+ "special": true
874
+ },
875
+ "128109": {
876
+ "content": "<|reserved_special_token_101|>",
877
+ "lstrip": false,
878
+ "normalized": false,
879
+ "rstrip": false,
880
+ "single_word": false,
881
+ "special": true
882
+ },
883
+ "128110": {
884
+ "content": "<|reserved_special_token_102|>",
885
+ "lstrip": false,
886
+ "normalized": false,
887
+ "rstrip": false,
888
+ "single_word": false,
889
+ "special": true
890
+ },
891
+ "128111": {
892
+ "content": "<|reserved_special_token_103|>",
893
+ "lstrip": false,
894
+ "normalized": false,
895
+ "rstrip": false,
896
+ "single_word": false,
897
+ "special": true
898
+ },
899
+ "128112": {
900
+ "content": "<|reserved_special_token_104|>",
901
+ "lstrip": false,
902
+ "normalized": false,
903
+ "rstrip": false,
904
+ "single_word": false,
905
+ "special": true
906
+ },
907
+ "128113": {
908
+ "content": "<|reserved_special_token_105|>",
909
+ "lstrip": false,
910
+ "normalized": false,
911
+ "rstrip": false,
912
+ "single_word": false,
913
+ "special": true
914
+ },
915
+ "128114": {
916
+ "content": "<|reserved_special_token_106|>",
917
+ "lstrip": false,
918
+ "normalized": false,
919
+ "rstrip": false,
920
+ "single_word": false,
921
+ "special": true
922
+ },
923
+ "128115": {
924
+ "content": "<|reserved_special_token_107|>",
925
+ "lstrip": false,
926
+ "normalized": false,
927
+ "rstrip": false,
928
+ "single_word": false,
929
+ "special": true
930
+ },
931
+ "128116": {
932
+ "content": "<|reserved_special_token_108|>",
933
+ "lstrip": false,
934
+ "normalized": false,
935
+ "rstrip": false,
936
+ "single_word": false,
937
+ "special": true
938
+ },
939
+ "128117": {
940
+ "content": "<|reserved_special_token_109|>",
941
+ "lstrip": false,
942
+ "normalized": false,
943
+ "rstrip": false,
944
+ "single_word": false,
945
+ "special": true
946
+ },
947
+ "128118": {
948
+ "content": "<|reserved_special_token_110|>",
949
+ "lstrip": false,
950
+ "normalized": false,
951
+ "rstrip": false,
952
+ "single_word": false,
953
+ "special": true
954
+ },
955
+ "128119": {
956
+ "content": "<|reserved_special_token_111|>",
957
+ "lstrip": false,
958
+ "normalized": false,
959
+ "rstrip": false,
960
+ "single_word": false,
961
+ "special": true
962
+ },
963
+ "128120": {
964
+ "content": "<|reserved_special_token_112|>",
965
+ "lstrip": false,
966
+ "normalized": false,
967
+ "rstrip": false,
968
+ "single_word": false,
969
+ "special": true
970
+ },
971
+ "128121": {
972
+ "content": "<|reserved_special_token_113|>",
973
+ "lstrip": false,
974
+ "normalized": false,
975
+ "rstrip": false,
976
+ "single_word": false,
977
+ "special": true
978
+ },
979
+ "128122": {
980
+ "content": "<|reserved_special_token_114|>",
981
+ "lstrip": false,
982
+ "normalized": false,
983
+ "rstrip": false,
984
+ "single_word": false,
985
+ "special": true
986
+ },
987
+ "128123": {
988
+ "content": "<|reserved_special_token_115|>",
989
+ "lstrip": false,
990
+ "normalized": false,
991
+ "rstrip": false,
992
+ "single_word": false,
993
+ "special": true
994
+ },
995
+ "128124": {
996
+ "content": "<|reserved_special_token_116|>",
997
+ "lstrip": false,
998
+ "normalized": false,
999
+ "rstrip": false,
1000
+ "single_word": false,
1001
+ "special": true
1002
+ },
1003
+ "128125": {
1004
+ "content": "<|reserved_special_token_117|>",
1005
+ "lstrip": false,
1006
+ "normalized": false,
1007
+ "rstrip": false,
1008
+ "single_word": false,
1009
+ "special": true
1010
+ },
1011
+ "128126": {
1012
+ "content": "<|reserved_special_token_118|>",
1013
+ "lstrip": false,
1014
+ "normalized": false,
1015
+ "rstrip": false,
1016
+ "single_word": false,
1017
+ "special": true
1018
+ },
1019
+ "128127": {
1020
+ "content": "<|reserved_special_token_119|>",
1021
+ "lstrip": false,
1022
+ "normalized": false,
1023
+ "rstrip": false,
1024
+ "single_word": false,
1025
+ "special": true
1026
+ },
1027
+ "128128": {
1028
+ "content": "<|reserved_special_token_120|>",
1029
+ "lstrip": false,
1030
+ "normalized": false,
1031
+ "rstrip": false,
1032
+ "single_word": false,
1033
+ "special": true
1034
+ },
1035
+ "128129": {
1036
+ "content": "<|reserved_special_token_121|>",
1037
+ "lstrip": false,
1038
+ "normalized": false,
1039
+ "rstrip": false,
1040
+ "single_word": false,
1041
+ "special": true
1042
+ },
1043
+ "128130": {
1044
+ "content": "<|reserved_special_token_122|>",
1045
+ "lstrip": false,
1046
+ "normalized": false,
1047
+ "rstrip": false,
1048
+ "single_word": false,
1049
+ "special": true
1050
+ },
1051
+ "128131": {
1052
+ "content": "<|reserved_special_token_123|>",
1053
+ "lstrip": false,
1054
+ "normalized": false,
1055
+ "rstrip": false,
1056
+ "single_word": false,
1057
+ "special": true
1058
+ },
1059
+ "128132": {
1060
+ "content": "<|reserved_special_token_124|>",
1061
+ "lstrip": false,
1062
+ "normalized": false,
1063
+ "rstrip": false,
1064
+ "single_word": false,
1065
+ "special": true
1066
+ },
1067
+ "128133": {
1068
+ "content": "<|reserved_special_token_125|>",
1069
+ "lstrip": false,
1070
+ "normalized": false,
1071
+ "rstrip": false,
1072
+ "single_word": false,
1073
+ "special": true
1074
+ },
1075
+ "128134": {
1076
+ "content": "<|reserved_special_token_126|>",
1077
+ "lstrip": false,
1078
+ "normalized": false,
1079
+ "rstrip": false,
1080
+ "single_word": false,
1081
+ "special": true
1082
+ },
1083
+ "128135": {
1084
+ "content": "<|reserved_special_token_127|>",
1085
+ "lstrip": false,
1086
+ "normalized": false,
1087
+ "rstrip": false,
1088
+ "single_word": false,
1089
+ "special": true
1090
+ },
1091
+ "128136": {
1092
+ "content": "<|reserved_special_token_128|>",
1093
+ "lstrip": false,
1094
+ "normalized": false,
1095
+ "rstrip": false,
1096
+ "single_word": false,
1097
+ "special": true
1098
+ },
1099
+ "128137": {
1100
+ "content": "<|reserved_special_token_129|>",
1101
+ "lstrip": false,
1102
+ "normalized": false,
1103
+ "rstrip": false,
1104
+ "single_word": false,
1105
+ "special": true
1106
+ },
1107
+ "128138": {
1108
+ "content": "<|reserved_special_token_130|>",
1109
+ "lstrip": false,
1110
+ "normalized": false,
1111
+ "rstrip": false,
1112
+ "single_word": false,
1113
+ "special": true
1114
+ },
1115
+ "128139": {
1116
+ "content": "<|reserved_special_token_131|>",
1117
+ "lstrip": false,
1118
+ "normalized": false,
1119
+ "rstrip": false,
1120
+ "single_word": false,
1121
+ "special": true
1122
+ },
1123
+ "128140": {
1124
+ "content": "<|reserved_special_token_132|>",
1125
+ "lstrip": false,
1126
+ "normalized": false,
1127
+ "rstrip": false,
1128
+ "single_word": false,
1129
+ "special": true
1130
+ },
1131
+ "128141": {
1132
+ "content": "<|reserved_special_token_133|>",
1133
+ "lstrip": false,
1134
+ "normalized": false,
1135
+ "rstrip": false,
1136
+ "single_word": false,
1137
+ "special": true
1138
+ },
1139
+ "128142": {
1140
+ "content": "<|reserved_special_token_134|>",
1141
+ "lstrip": false,
1142
+ "normalized": false,
1143
+ "rstrip": false,
1144
+ "single_word": false,
1145
+ "special": true
1146
+ },
1147
+ "128143": {
1148
+ "content": "<|reserved_special_token_135|>",
1149
+ "lstrip": false,
1150
+ "normalized": false,
1151
+ "rstrip": false,
1152
+ "single_word": false,
1153
+ "special": true
1154
+ },
1155
+ "128144": {
1156
+ "content": "<|reserved_special_token_136|>",
1157
+ "lstrip": false,
1158
+ "normalized": false,
1159
+ "rstrip": false,
1160
+ "single_word": false,
1161
+ "special": true
1162
+ },
1163
+ "128145": {
1164
+ "content": "<|reserved_special_token_137|>",
1165
+ "lstrip": false,
1166
+ "normalized": false,
1167
+ "rstrip": false,
1168
+ "single_word": false,
1169
+ "special": true
1170
+ },
1171
+ "128146": {
1172
+ "content": "<|reserved_special_token_138|>",
1173
+ "lstrip": false,
1174
+ "normalized": false,
1175
+ "rstrip": false,
1176
+ "single_word": false,
1177
+ "special": true
1178
+ },
1179
+ "128147": {
1180
+ "content": "<|reserved_special_token_139|>",
1181
+ "lstrip": false,
1182
+ "normalized": false,
1183
+ "rstrip": false,
1184
+ "single_word": false,
1185
+ "special": true
1186
+ },
1187
+ "128148": {
1188
+ "content": "<|reserved_special_token_140|>",
1189
+ "lstrip": false,
1190
+ "normalized": false,
1191
+ "rstrip": false,
1192
+ "single_word": false,
1193
+ "special": true
1194
+ },
1195
+ "128149": {
1196
+ "content": "<|reserved_special_token_141|>",
1197
+ "lstrip": false,
1198
+ "normalized": false,
1199
+ "rstrip": false,
1200
+ "single_word": false,
1201
+ "special": true
1202
+ },
1203
+ "128150": {
1204
+ "content": "<|reserved_special_token_142|>",
1205
+ "lstrip": false,
1206
+ "normalized": false,
1207
+ "rstrip": false,
1208
+ "single_word": false,
1209
+ "special": true
1210
+ },
1211
+ "128151": {
1212
+ "content": "<|reserved_special_token_143|>",
1213
+ "lstrip": false,
1214
+ "normalized": false,
1215
+ "rstrip": false,
1216
+ "single_word": false,
1217
+ "special": true
1218
+ },
1219
+ "128152": {
1220
+ "content": "<|reserved_special_token_144|>",
1221
+ "lstrip": false,
1222
+ "normalized": false,
1223
+ "rstrip": false,
1224
+ "single_word": false,
1225
+ "special": true
1226
+ },
1227
+ "128153": {
1228
+ "content": "<|reserved_special_token_145|>",
1229
+ "lstrip": false,
1230
+ "normalized": false,
1231
+ "rstrip": false,
1232
+ "single_word": false,
1233
+ "special": true
1234
+ },
1235
+ "128154": {
1236
+ "content": "<|reserved_special_token_146|>",
1237
+ "lstrip": false,
1238
+ "normalized": false,
1239
+ "rstrip": false,
1240
+ "single_word": false,
1241
+ "special": true
1242
+ },
1243
+ "128155": {
1244
+ "content": "<|reserved_special_token_147|>",
1245
+ "lstrip": false,
1246
+ "normalized": false,
1247
+ "rstrip": false,
1248
+ "single_word": false,
1249
+ "special": true
1250
+ },
1251
+ "128156": {
1252
+ "content": "<|reserved_special_token_148|>",
1253
+ "lstrip": false,
1254
+ "normalized": false,
1255
+ "rstrip": false,
1256
+ "single_word": false,
1257
+ "special": true
1258
+ },
1259
+ "128157": {
1260
+ "content": "<|reserved_special_token_149|>",
1261
+ "lstrip": false,
1262
+ "normalized": false,
1263
+ "rstrip": false,
1264
+ "single_word": false,
1265
+ "special": true
1266
+ },
1267
+ "128158": {
1268
+ "content": "<|reserved_special_token_150|>",
1269
+ "lstrip": false,
1270
+ "normalized": false,
1271
+ "rstrip": false,
1272
+ "single_word": false,
1273
+ "special": true
1274
+ },
1275
+ "128159": {
1276
+ "content": "<|reserved_special_token_151|>",
1277
+ "lstrip": false,
1278
+ "normalized": false,
1279
+ "rstrip": false,
1280
+ "single_word": false,
1281
+ "special": true
1282
+ },
1283
+ "128160": {
1284
+ "content": "<|reserved_special_token_152|>",
1285
+ "lstrip": false,
1286
+ "normalized": false,
1287
+ "rstrip": false,
1288
+ "single_word": false,
1289
+ "special": true
1290
+ },
1291
+ "128161": {
1292
+ "content": "<|reserved_special_token_153|>",
1293
+ "lstrip": false,
1294
+ "normalized": false,
1295
+ "rstrip": false,
1296
+ "single_word": false,
1297
+ "special": true
1298
+ },
1299
+ "128162": {
1300
+ "content": "<|reserved_special_token_154|>",
1301
+ "lstrip": false,
1302
+ "normalized": false,
1303
+ "rstrip": false,
1304
+ "single_word": false,
1305
+ "special": true
1306
+ },
1307
+ "128163": {
1308
+ "content": "<|reserved_special_token_155|>",
1309
+ "lstrip": false,
1310
+ "normalized": false,
1311
+ "rstrip": false,
1312
+ "single_word": false,
1313
+ "special": true
1314
+ },
1315
+ "128164": {
1316
+ "content": "<|reserved_special_token_156|>",
1317
+ "lstrip": false,
1318
+ "normalized": false,
1319
+ "rstrip": false,
1320
+ "single_word": false,
1321
+ "special": true
1322
+ },
1323
+ "128165": {
1324
+ "content": "<|reserved_special_token_157|>",
1325
+ "lstrip": false,
1326
+ "normalized": false,
1327
+ "rstrip": false,
1328
+ "single_word": false,
1329
+ "special": true
1330
+ },
1331
+ "128166": {
1332
+ "content": "<|reserved_special_token_158|>",
1333
+ "lstrip": false,
1334
+ "normalized": false,
1335
+ "rstrip": false,
1336
+ "single_word": false,
1337
+ "special": true
1338
+ },
1339
+ "128167": {
1340
+ "content": "<|reserved_special_token_159|>",
1341
+ "lstrip": false,
1342
+ "normalized": false,
1343
+ "rstrip": false,
1344
+ "single_word": false,
1345
+ "special": true
1346
+ },
1347
+ "128168": {
1348
+ "content": "<|reserved_special_token_160|>",
1349
+ "lstrip": false,
1350
+ "normalized": false,
1351
+ "rstrip": false,
1352
+ "single_word": false,
1353
+ "special": true
1354
+ },
1355
+ "128169": {
1356
+ "content": "<|reserved_special_token_161|>",
1357
+ "lstrip": false,
1358
+ "normalized": false,
1359
+ "rstrip": false,
1360
+ "single_word": false,
1361
+ "special": true
1362
+ },
1363
+ "128170": {
1364
+ "content": "<|reserved_special_token_162|>",
1365
+ "lstrip": false,
1366
+ "normalized": false,
1367
+ "rstrip": false,
1368
+ "single_word": false,
1369
+ "special": true
1370
+ },
1371
+ "128171": {
1372
+ "content": "<|reserved_special_token_163|>",
1373
+ "lstrip": false,
1374
+ "normalized": false,
1375
+ "rstrip": false,
1376
+ "single_word": false,
1377
+ "special": true
1378
+ },
1379
+ "128172": {
1380
+ "content": "<|reserved_special_token_164|>",
1381
+ "lstrip": false,
1382
+ "normalized": false,
1383
+ "rstrip": false,
1384
+ "single_word": false,
1385
+ "special": true
1386
+ },
1387
+ "128173": {
1388
+ "content": "<|reserved_special_token_165|>",
1389
+ "lstrip": false,
1390
+ "normalized": false,
1391
+ "rstrip": false,
1392
+ "single_word": false,
1393
+ "special": true
1394
+ },
1395
+ "128174": {
1396
+ "content": "<|reserved_special_token_166|>",
1397
+ "lstrip": false,
1398
+ "normalized": false,
1399
+ "rstrip": false,
1400
+ "single_word": false,
1401
+ "special": true
1402
+ },
1403
+ "128175": {
1404
+ "content": "<|reserved_special_token_167|>",
1405
+ "lstrip": false,
1406
+ "normalized": false,
1407
+ "rstrip": false,
1408
+ "single_word": false,
1409
+ "special": true
1410
+ },
1411
+ "128176": {
1412
+ "content": "<|reserved_special_token_168|>",
1413
+ "lstrip": false,
1414
+ "normalized": false,
1415
+ "rstrip": false,
1416
+ "single_word": false,
1417
+ "special": true
1418
+ },
1419
+ "128177": {
1420
+ "content": "<|reserved_special_token_169|>",
1421
+ "lstrip": false,
1422
+ "normalized": false,
1423
+ "rstrip": false,
1424
+ "single_word": false,
1425
+ "special": true
1426
+ },
1427
+ "128178": {
1428
+ "content": "<|reserved_special_token_170|>",
1429
+ "lstrip": false,
1430
+ "normalized": false,
1431
+ "rstrip": false,
1432
+ "single_word": false,
1433
+ "special": true
1434
+ },
1435
+ "128179": {
1436
+ "content": "<|reserved_special_token_171|>",
1437
+ "lstrip": false,
1438
+ "normalized": false,
1439
+ "rstrip": false,
1440
+ "single_word": false,
1441
+ "special": true
1442
+ },
1443
+ "128180": {
1444
+ "content": "<|reserved_special_token_172|>",
1445
+ "lstrip": false,
1446
+ "normalized": false,
1447
+ "rstrip": false,
1448
+ "single_word": false,
1449
+ "special": true
1450
+ },
1451
+ "128181": {
1452
+ "content": "<|reserved_special_token_173|>",
1453
+ "lstrip": false,
1454
+ "normalized": false,
1455
+ "rstrip": false,
1456
+ "single_word": false,
1457
+ "special": true
1458
+ },
1459
+ "128182": {
1460
+ "content": "<|reserved_special_token_174|>",
1461
+ "lstrip": false,
1462
+ "normalized": false,
1463
+ "rstrip": false,
1464
+ "single_word": false,
1465
+ "special": true
1466
+ },
1467
+ "128183": {
1468
+ "content": "<|reserved_special_token_175|>",
1469
+ "lstrip": false,
1470
+ "normalized": false,
1471
+ "rstrip": false,
1472
+ "single_word": false,
1473
+ "special": true
1474
+ },
1475
+ "128184": {
1476
+ "content": "<|reserved_special_token_176|>",
1477
+ "lstrip": false,
1478
+ "normalized": false,
1479
+ "rstrip": false,
1480
+ "single_word": false,
1481
+ "special": true
1482
+ },
1483
+ "128185": {
1484
+ "content": "<|reserved_special_token_177|>",
1485
+ "lstrip": false,
1486
+ "normalized": false,
1487
+ "rstrip": false,
1488
+ "single_word": false,
1489
+ "special": true
1490
+ },
1491
+ "128186": {
1492
+ "content": "<|reserved_special_token_178|>",
1493
+ "lstrip": false,
1494
+ "normalized": false,
1495
+ "rstrip": false,
1496
+ "single_word": false,
1497
+ "special": true
1498
+ },
1499
+ "128187": {
1500
+ "content": "<|reserved_special_token_179|>",
1501
+ "lstrip": false,
1502
+ "normalized": false,
1503
+ "rstrip": false,
1504
+ "single_word": false,
1505
+ "special": true
1506
+ },
1507
+ "128188": {
1508
+ "content": "<|reserved_special_token_180|>",
1509
+ "lstrip": false,
1510
+ "normalized": false,
1511
+ "rstrip": false,
1512
+ "single_word": false,
1513
+ "special": true
1514
+ },
1515
+ "128189": {
1516
+ "content": "<|reserved_special_token_181|>",
1517
+ "lstrip": false,
1518
+ "normalized": false,
1519
+ "rstrip": false,
1520
+ "single_word": false,
1521
+ "special": true
1522
+ },
1523
+ "128190": {
1524
+ "content": "<|reserved_special_token_182|>",
1525
+ "lstrip": false,
1526
+ "normalized": false,
1527
+ "rstrip": false,
1528
+ "single_word": false,
1529
+ "special": true
1530
+ },
1531
+ "128191": {
1532
+ "content": "<|reserved_special_token_183|>",
1533
+ "lstrip": false,
1534
+ "normalized": false,
1535
+ "rstrip": false,
1536
+ "single_word": false,
1537
+ "special": true
1538
+ },
1539
+ "128192": {
1540
+ "content": "<|reserved_special_token_184|>",
1541
+ "lstrip": false,
1542
+ "normalized": false,
1543
+ "rstrip": false,
1544
+ "single_word": false,
1545
+ "special": true
1546
+ },
1547
+ "128193": {
1548
+ "content": "<|reserved_special_token_185|>",
1549
+ "lstrip": false,
1550
+ "normalized": false,
1551
+ "rstrip": false,
1552
+ "single_word": false,
1553
+ "special": true
1554
+ },
1555
+ "128194": {
1556
+ "content": "<|reserved_special_token_186|>",
1557
+ "lstrip": false,
1558
+ "normalized": false,
1559
+ "rstrip": false,
1560
+ "single_word": false,
1561
+ "special": true
1562
+ },
1563
+ "128195": {
1564
+ "content": "<|reserved_special_token_187|>",
1565
+ "lstrip": false,
1566
+ "normalized": false,
1567
+ "rstrip": false,
1568
+ "single_word": false,
1569
+ "special": true
1570
+ },
1571
+ "128196": {
1572
+ "content": "<|reserved_special_token_188|>",
1573
+ "lstrip": false,
1574
+ "normalized": false,
1575
+ "rstrip": false,
1576
+ "single_word": false,
1577
+ "special": true
1578
+ },
1579
+ "128197": {
1580
+ "content": "<|reserved_special_token_189|>",
1581
+ "lstrip": false,
1582
+ "normalized": false,
1583
+ "rstrip": false,
1584
+ "single_word": false,
1585
+ "special": true
1586
+ },
1587
+ "128198": {
1588
+ "content": "<|reserved_special_token_190|>",
1589
+ "lstrip": false,
1590
+ "normalized": false,
1591
+ "rstrip": false,
1592
+ "single_word": false,
1593
+ "special": true
1594
+ },
1595
+ "128199": {
1596
+ "content": "<|reserved_special_token_191|>",
1597
+ "lstrip": false,
1598
+ "normalized": false,
1599
+ "rstrip": false,
1600
+ "single_word": false,
1601
+ "special": true
1602
+ },
1603
+ "128200": {
1604
+ "content": "<|reserved_special_token_192|>",
1605
+ "lstrip": false,
1606
+ "normalized": false,
1607
+ "rstrip": false,
1608
+ "single_word": false,
1609
+ "special": true
1610
+ },
1611
+ "128201": {
1612
+ "content": "<|reserved_special_token_193|>",
1613
+ "lstrip": false,
1614
+ "normalized": false,
1615
+ "rstrip": false,
1616
+ "single_word": false,
1617
+ "special": true
1618
+ },
1619
+ "128202": {
1620
+ "content": "<|reserved_special_token_194|>",
1621
+ "lstrip": false,
1622
+ "normalized": false,
1623
+ "rstrip": false,
1624
+ "single_word": false,
1625
+ "special": true
1626
+ },
1627
+ "128203": {
1628
+ "content": "<|reserved_special_token_195|>",
1629
+ "lstrip": false,
1630
+ "normalized": false,
1631
+ "rstrip": false,
1632
+ "single_word": false,
1633
+ "special": true
1634
+ },
1635
+ "128204": {
1636
+ "content": "<|reserved_special_token_196|>",
1637
+ "lstrip": false,
1638
+ "normalized": false,
1639
+ "rstrip": false,
1640
+ "single_word": false,
1641
+ "special": true
1642
+ },
1643
+ "128205": {
1644
+ "content": "<|reserved_special_token_197|>",
1645
+ "lstrip": false,
1646
+ "normalized": false,
1647
+ "rstrip": false,
1648
+ "single_word": false,
1649
+ "special": true
1650
+ },
1651
+ "128206": {
1652
+ "content": "<|reserved_special_token_198|>",
1653
+ "lstrip": false,
1654
+ "normalized": false,
1655
+ "rstrip": false,
1656
+ "single_word": false,
1657
+ "special": true
1658
+ },
1659
+ "128207": {
1660
+ "content": "<|reserved_special_token_199|>",
1661
+ "lstrip": false,
1662
+ "normalized": false,
1663
+ "rstrip": false,
1664
+ "single_word": false,
1665
+ "special": true
1666
+ },
1667
+ "128208": {
1668
+ "content": "<|reserved_special_token_200|>",
1669
+ "lstrip": false,
1670
+ "normalized": false,
1671
+ "rstrip": false,
1672
+ "single_word": false,
1673
+ "special": true
1674
+ },
1675
+ "128209": {
1676
+ "content": "<|reserved_special_token_201|>",
1677
+ "lstrip": false,
1678
+ "normalized": false,
1679
+ "rstrip": false,
1680
+ "single_word": false,
1681
+ "special": true
1682
+ },
1683
+ "128210": {
1684
+ "content": "<|reserved_special_token_202|>",
1685
+ "lstrip": false,
1686
+ "normalized": false,
1687
+ "rstrip": false,
1688
+ "single_word": false,
1689
+ "special": true
1690
+ },
1691
+ "128211": {
1692
+ "content": "<|reserved_special_token_203|>",
1693
+ "lstrip": false,
1694
+ "normalized": false,
1695
+ "rstrip": false,
1696
+ "single_word": false,
1697
+ "special": true
1698
+ },
1699
+ "128212": {
1700
+ "content": "<|reserved_special_token_204|>",
1701
+ "lstrip": false,
1702
+ "normalized": false,
1703
+ "rstrip": false,
1704
+ "single_word": false,
1705
+ "special": true
1706
+ },
1707
+ "128213": {
1708
+ "content": "<|reserved_special_token_205|>",
1709
+ "lstrip": false,
1710
+ "normalized": false,
1711
+ "rstrip": false,
1712
+ "single_word": false,
1713
+ "special": true
1714
+ },
1715
+ "128214": {
1716
+ "content": "<|reserved_special_token_206|>",
1717
+ "lstrip": false,
1718
+ "normalized": false,
1719
+ "rstrip": false,
1720
+ "single_word": false,
1721
+ "special": true
1722
+ },
1723
+ "128215": {
1724
+ "content": "<|reserved_special_token_207|>",
1725
+ "lstrip": false,
1726
+ "normalized": false,
1727
+ "rstrip": false,
1728
+ "single_word": false,
1729
+ "special": true
1730
+ },
1731
+ "128216": {
1732
+ "content": "<|reserved_special_token_208|>",
1733
+ "lstrip": false,
1734
+ "normalized": false,
1735
+ "rstrip": false,
1736
+ "single_word": false,
1737
+ "special": true
1738
+ },
1739
+ "128217": {
1740
+ "content": "<|reserved_special_token_209|>",
1741
+ "lstrip": false,
1742
+ "normalized": false,
1743
+ "rstrip": false,
1744
+ "single_word": false,
1745
+ "special": true
1746
+ },
1747
+ "128218": {
1748
+ "content": "<|reserved_special_token_210|>",
1749
+ "lstrip": false,
1750
+ "normalized": false,
1751
+ "rstrip": false,
1752
+ "single_word": false,
1753
+ "special": true
1754
+ },
1755
+ "128219": {
1756
+ "content": "<|reserved_special_token_211|>",
1757
+ "lstrip": false,
1758
+ "normalized": false,
1759
+ "rstrip": false,
1760
+ "single_word": false,
1761
+ "special": true
1762
+ },
1763
+ "128220": {
1764
+ "content": "<|reserved_special_token_212|>",
1765
+ "lstrip": false,
1766
+ "normalized": false,
1767
+ "rstrip": false,
1768
+ "single_word": false,
1769
+ "special": true
1770
+ },
1771
+ "128221": {
1772
+ "content": "<|reserved_special_token_213|>",
1773
+ "lstrip": false,
1774
+ "normalized": false,
1775
+ "rstrip": false,
1776
+ "single_word": false,
1777
+ "special": true
1778
+ },
1779
+ "128222": {
1780
+ "content": "<|reserved_special_token_214|>",
1781
+ "lstrip": false,
1782
+ "normalized": false,
1783
+ "rstrip": false,
1784
+ "single_word": false,
1785
+ "special": true
1786
+ },
1787
+ "128223": {
1788
+ "content": "<|reserved_special_token_215|>",
1789
+ "lstrip": false,
1790
+ "normalized": false,
1791
+ "rstrip": false,
1792
+ "single_word": false,
1793
+ "special": true
1794
+ },
1795
+ "128224": {
1796
+ "content": "<|reserved_special_token_216|>",
1797
+ "lstrip": false,
1798
+ "normalized": false,
1799
+ "rstrip": false,
1800
+ "single_word": false,
1801
+ "special": true
1802
+ },
1803
+ "128225": {
1804
+ "content": "<|reserved_special_token_217|>",
1805
+ "lstrip": false,
1806
+ "normalized": false,
1807
+ "rstrip": false,
1808
+ "single_word": false,
1809
+ "special": true
1810
+ },
1811
+ "128226": {
1812
+ "content": "<|reserved_special_token_218|>",
1813
+ "lstrip": false,
1814
+ "normalized": false,
1815
+ "rstrip": false,
1816
+ "single_word": false,
1817
+ "special": true
1818
+ },
1819
+ "128227": {
1820
+ "content": "<|reserved_special_token_219|>",
1821
+ "lstrip": false,
1822
+ "normalized": false,
1823
+ "rstrip": false,
1824
+ "single_word": false,
1825
+ "special": true
1826
+ },
1827
+ "128228": {
1828
+ "content": "<|reserved_special_token_220|>",
1829
+ "lstrip": false,
1830
+ "normalized": false,
1831
+ "rstrip": false,
1832
+ "single_word": false,
1833
+ "special": true
1834
+ },
1835
+ "128229": {
1836
+ "content": "<|reserved_special_token_221|>",
1837
+ "lstrip": false,
1838
+ "normalized": false,
1839
+ "rstrip": false,
1840
+ "single_word": false,
1841
+ "special": true
1842
+ },
1843
+ "128230": {
1844
+ "content": "<|reserved_special_token_222|>",
1845
+ "lstrip": false,
1846
+ "normalized": false,
1847
+ "rstrip": false,
1848
+ "single_word": false,
1849
+ "special": true
1850
+ },
1851
+ "128231": {
1852
+ "content": "<|reserved_special_token_223|>",
1853
+ "lstrip": false,
1854
+ "normalized": false,
1855
+ "rstrip": false,
1856
+ "single_word": false,
1857
+ "special": true
1858
+ },
1859
+ "128232": {
1860
+ "content": "<|reserved_special_token_224|>",
1861
+ "lstrip": false,
1862
+ "normalized": false,
1863
+ "rstrip": false,
1864
+ "single_word": false,
1865
+ "special": true
1866
+ },
1867
+ "128233": {
1868
+ "content": "<|reserved_special_token_225|>",
1869
+ "lstrip": false,
1870
+ "normalized": false,
1871
+ "rstrip": false,
1872
+ "single_word": false,
1873
+ "special": true
1874
+ },
1875
+ "128234": {
1876
+ "content": "<|reserved_special_token_226|>",
1877
+ "lstrip": false,
1878
+ "normalized": false,
1879
+ "rstrip": false,
1880
+ "single_word": false,
1881
+ "special": true
1882
+ },
1883
+ "128235": {
1884
+ "content": "<|reserved_special_token_227|>",
1885
+ "lstrip": false,
1886
+ "normalized": false,
1887
+ "rstrip": false,
1888
+ "single_word": false,
1889
+ "special": true
1890
+ },
1891
+ "128236": {
1892
+ "content": "<|reserved_special_token_228|>",
1893
+ "lstrip": false,
1894
+ "normalized": false,
1895
+ "rstrip": false,
1896
+ "single_word": false,
1897
+ "special": true
1898
+ },
1899
+ "128237": {
1900
+ "content": "<|reserved_special_token_229|>",
1901
+ "lstrip": false,
1902
+ "normalized": false,
1903
+ "rstrip": false,
1904
+ "single_word": false,
1905
+ "special": true
1906
+ },
1907
+ "128238": {
1908
+ "content": "<|reserved_special_token_230|>",
1909
+ "lstrip": false,
1910
+ "normalized": false,
1911
+ "rstrip": false,
1912
+ "single_word": false,
1913
+ "special": true
1914
+ },
1915
+ "128239": {
1916
+ "content": "<|reserved_special_token_231|>",
1917
+ "lstrip": false,
1918
+ "normalized": false,
1919
+ "rstrip": false,
1920
+ "single_word": false,
1921
+ "special": true
1922
+ },
1923
+ "128240": {
1924
+ "content": "<|reserved_special_token_232|>",
1925
+ "lstrip": false,
1926
+ "normalized": false,
1927
+ "rstrip": false,
1928
+ "single_word": false,
1929
+ "special": true
1930
+ },
1931
+ "128241": {
1932
+ "content": "<|reserved_special_token_233|>",
1933
+ "lstrip": false,
1934
+ "normalized": false,
1935
+ "rstrip": false,
1936
+ "single_word": false,
1937
+ "special": true
1938
+ },
1939
+ "128242": {
1940
+ "content": "<|reserved_special_token_234|>",
1941
+ "lstrip": false,
1942
+ "normalized": false,
1943
+ "rstrip": false,
1944
+ "single_word": false,
1945
+ "special": true
1946
+ },
1947
+ "128243": {
1948
+ "content": "<|reserved_special_token_235|>",
1949
+ "lstrip": false,
1950
+ "normalized": false,
1951
+ "rstrip": false,
1952
+ "single_word": false,
1953
+ "special": true
1954
+ },
1955
+ "128244": {
1956
+ "content": "<|reserved_special_token_236|>",
1957
+ "lstrip": false,
1958
+ "normalized": false,
1959
+ "rstrip": false,
1960
+ "single_word": false,
1961
+ "special": true
1962
+ },
1963
+ "128245": {
1964
+ "content": "<|reserved_special_token_237|>",
1965
+ "lstrip": false,
1966
+ "normalized": false,
1967
+ "rstrip": false,
1968
+ "single_word": false,
1969
+ "special": true
1970
+ },
1971
+ "128246": {
1972
+ "content": "<|reserved_special_token_238|>",
1973
+ "lstrip": false,
1974
+ "normalized": false,
1975
+ "rstrip": false,
1976
+ "single_word": false,
1977
+ "special": true
1978
+ },
1979
+ "128247": {
1980
+ "content": "<|reserved_special_token_239|>",
1981
+ "lstrip": false,
1982
+ "normalized": false,
1983
+ "rstrip": false,
1984
+ "single_word": false,
1985
+ "special": true
1986
+ },
1987
+ "128248": {
1988
+ "content": "<|reserved_special_token_240|>",
1989
+ "lstrip": false,
1990
+ "normalized": false,
1991
+ "rstrip": false,
1992
+ "single_word": false,
1993
+ "special": true
1994
+ },
1995
+ "128249": {
1996
+ "content": "<|reserved_special_token_241|>",
1997
+ "lstrip": false,
1998
+ "normalized": false,
1999
+ "rstrip": false,
2000
+ "single_word": false,
2001
+ "special": true
2002
+ },
2003
+ "128250": {
2004
+ "content": "<|reserved_special_token_242|>",
2005
+ "lstrip": false,
2006
+ "normalized": false,
2007
+ "rstrip": false,
2008
+ "single_word": false,
2009
+ "special": true
2010
+ },
2011
+ "128251": {
2012
+ "content": "<|reserved_special_token_243|>",
2013
+ "lstrip": false,
2014
+ "normalized": false,
2015
+ "rstrip": false,
2016
+ "single_word": false,
2017
+ "special": true
2018
+ },
2019
+ "128252": {
2020
+ "content": "<|reserved_special_token_244|>",
2021
+ "lstrip": false,
2022
+ "normalized": false,
2023
+ "rstrip": false,
2024
+ "single_word": false,
2025
+ "special": true
2026
+ },
2027
+ "128253": {
2028
+ "content": "<|reserved_special_token_245|>",
2029
+ "lstrip": false,
2030
+ "normalized": false,
2031
+ "rstrip": false,
2032
+ "single_word": false,
2033
+ "special": true
2034
+ },
2035
+ "128254": {
2036
+ "content": "<|reserved_special_token_246|>",
2037
+ "lstrip": false,
2038
+ "normalized": false,
2039
+ "rstrip": false,
2040
+ "single_word": false,
2041
+ "special": true
2042
+ },
2043
+ "128255": {
2044
+ "content": "<|reserved_special_token_247|>",
2045
+ "lstrip": false,
2046
+ "normalized": false,
2047
+ "rstrip": false,
2048
+ "single_word": false,
2049
+ "special": true
2050
+ }
2051
+ },
2052
+ "bos_token": "<|begin_of_text|>",
2053
+ "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
2054
+ "clean_up_tokenization_spaces": true,
2055
+ "eos_token": "<|eot_id|>",
2056
+ "model_input_names": [
2057
+ "input_ids",
2058
+ "attention_mask"
2059
+ ],
2060
+ "model_max_length": 131072,
2061
+ "tokenizer_class": "PreTrainedTokenizerFast"
2062
+ }
vision_sampler.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+
9
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
10
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ grid_h = np.arange(grid_size, dtype=np.float32)
17
+ grid_w = np.arange(grid_size, dtype=np.float32)
18
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
19
+ grid = np.stack(grid, axis=0)
20
+
21
+ grid = grid.reshape([2, 1, grid_size, grid_size])
22
+
23
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24
+ if cls_token:
25
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26
+ return pos_embed
27
+
28
+
29
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
30
+ assert embed_dim % 2 == 0
31
+
32
+ # use half of dimensions to encode grid_h
33
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
34
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
35
+
36
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
37
+ return emb
38
+
39
+
40
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
41
+ """
42
+ embed_dim: output dimension for each position
43
+ pos: a list of positions to be encoded: size (M,)
44
+ out: (M, D)
45
+ """
46
+ assert embed_dim % 2 == 0
47
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
48
+ omega /= embed_dim / 2.0
49
+ omega = 1.0 / 10000**omega # (D/2,)
50
+
51
+ pos = pos.reshape(-1) # (M,)
52
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
53
+
54
+ emb_sin = np.sin(out) # (M, D/2)
55
+ emb_cos = np.cos(out) # (M, D/2)
56
+
57
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
58
+ return emb
59
+
60
+
61
+ class CrossAttention(nn.Module):
62
+
63
+ def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
64
+ super().__init__()
65
+ self.hidden_dim = hidden_dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = self.hidden_dim // self.num_heads
68
+
69
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
70
+ raise ValueError(
71
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
72
+ f" and `num_heads`: {self.num_heads})."
73
+ )
74
+
75
+ self.q_proj = nn.Sequential(
76
+ nn.LayerNorm(q_dim),
77
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
78
+ )
79
+ self.k_proj = nn.Sequential(
80
+ nn.LayerNorm(kv_dim),
81
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
82
+ )
83
+ self.v_proj = nn.Sequential(
84
+ nn.LayerNorm(kv_dim),
85
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
86
+ )
87
+ self.o_proj = nn.Linear(
88
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
89
+ )
90
+
91
+ def forward(self, vision_latents, queries, attention_mask):
92
+
93
+ bsz, q_len, _ = queries.size()
94
+ bsz, v_len, _ = vision_latents.size()
95
+
96
+ query_states = self.q_proj(queries)
97
+ key_states = self.k_proj(vision_latents)
98
+ value_states = self.v_proj(vision_latents)
99
+
100
+ query_states = query_states.view(
101
+ bsz, q_len, self.num_heads, self.head_dim
102
+ ).transpose(1, 2)
103
+ key_states = key_states.view(
104
+ bsz, v_len, self.num_heads, self.head_dim
105
+ ).transpose(1, 2)
106
+ value_states = value_states.view(
107
+ bsz, v_len, self.num_heads, self.head_dim
108
+ ).transpose(1, 2)
109
+
110
+ if attention_mask is not None:
111
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
112
+ raise ValueError(
113
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
114
+ )
115
+
116
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
117
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
118
+ if query_states.device.type == "cuda" and attention_mask is not None:
119
+ query_states = query_states.contiguous()
120
+ key_states = key_states.contiguous()
121
+ value_states = value_states.contiguous()
122
+
123
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
124
+ query_states,
125
+ key_states,
126
+ value_states,
127
+ attn_mask=attention_mask,
128
+ )
129
+
130
+ attn_output = attn_output.transpose(1, 2).contiguous()
131
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
132
+
133
+ attn_output = self.o_proj(attn_output)
134
+
135
+ return attn_output
136
+
137
+
138
+ class AggregationBlock(nn.Module):
139
+ def __init__(
140
+ self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
141
+ ):
142
+ super().__init__()
143
+ self.hidden_dim = hidden_dim
144
+ self.num_heads = num_heads
145
+ self.head_dim = self.hidden_dim // self.num_heads
146
+
147
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
148
+ raise ValueError(
149
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
150
+ f" and `num_heads`: {self.num_heads})."
151
+ )
152
+
153
+ self.attention = attention
154
+ if attention:
155
+ self.attention_layer = CrossAttention(
156
+ q_dim, kv_dim, hidden_dim, num_heads, attention_bias
157
+ )
158
+ else:
159
+ self.attention_layer = MLP(kv_dim, q_dim, q_dim)
160
+
161
+ def forward(self, vision_latents, queries, attention_mask):
162
+ if self.attention:
163
+ queries = self.attention_layer(vision_latents, queries, attention_mask)
164
+ else:
165
+ queries = self.attention_layer(vision_latents)
166
+
167
+ return queries
168
+
169
+
170
+ class MultiKVCrossAttention(nn.Module):
171
+
172
+ def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
173
+ super().__init__()
174
+
175
+ self.hidden_dim = hidden_dim
176
+ self.num_heads = num_heads
177
+ self.head_dim = self.hidden_dim // self.num_heads
178
+
179
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
180
+ raise ValueError(
181
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
182
+ f" and `num_heads`: {self.num_heads})."
183
+ )
184
+
185
+ self.q_proj = nn.Sequential(
186
+ nn.LayerNorm(q_dim),
187
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
188
+ )
189
+ self.num_of_kvs = len(kv_dim_list)
190
+ for i, kv_dim in enumerate(kv_dim_list):
191
+ setattr(
192
+ self,
193
+ "k_proj_{}".format(i),
194
+ nn.Sequential(
195
+ nn.LayerNorm(kv_dim),
196
+ nn.Linear(
197
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
198
+ ),
199
+ ),
200
+ )
201
+ setattr(
202
+ self,
203
+ "v_proj_{}".format(i),
204
+ nn.Sequential(
205
+ nn.LayerNorm(kv_dim),
206
+ nn.Linear(
207
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
208
+ ),
209
+ ),
210
+ )
211
+ self.o_proj = nn.Linear(
212
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
213
+ )
214
+
215
+ def forward(
216
+ self,
217
+ queries,
218
+ *vision_latents_attention_mask_list,
219
+ ):
220
+
221
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
222
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
223
+
224
+ bsz, q_len, _ = queries.size()
225
+
226
+ query_states = self.q_proj(queries)
227
+ key_states = torch.cat(
228
+ [
229
+ getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
230
+ for i in range(self.num_of_kvs)
231
+ ],
232
+ dim=1,
233
+ )
234
+ value_states = torch.cat(
235
+ [
236
+ getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
237
+ for i in range(self.num_of_kvs)
238
+ ],
239
+ dim=1,
240
+ )
241
+
242
+ v_len = key_states.shape[1]
243
+
244
+ query_states = query_states.view(
245
+ bsz, q_len, self.num_heads, self.head_dim
246
+ ).transpose(1, 2)
247
+ key_states = key_states.view(
248
+ bsz, v_len, self.num_heads, self.head_dim
249
+ ).transpose(1, 2)
250
+ value_states = value_states.view(
251
+ bsz, v_len, self.num_heads, self.head_dim
252
+ ).transpose(1, 2)
253
+
254
+ # if kv_weight is not None:
255
+ # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
256
+
257
+ attention_mask = torch.cat(attention_mask_list, dim=-1)
258
+
259
+ if attention_mask is not None:
260
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
261
+ raise ValueError(
262
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
263
+ )
264
+
265
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
266
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
267
+ if query_states.device.type == "cuda" and attention_mask is not None:
268
+ query_states = query_states.contiguous()
269
+ key_states = key_states.contiguous()
270
+ value_states = value_states.contiguous()
271
+
272
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
273
+ query_states,
274
+ key_states,
275
+ value_states,
276
+ attn_mask=attention_mask,
277
+ )
278
+ # attn_output = spda(
279
+ # query_states,
280
+ # key_states,
281
+ # value_states,
282
+ # attn_mask=attention_mask,
283
+ # additional_score=kv_weight
284
+ # )
285
+
286
+ attn_output = attn_output.transpose(1, 2).contiguous()
287
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
288
+
289
+ attn_output = self.o_proj(attn_output)
290
+
291
+ return attn_output
292
+
293
+
294
+ class MLP(nn.Module):
295
+ def __init__(self, d_in, d_hidden, d_out):
296
+ super().__init__()
297
+ self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
298
+ self.act = nn.GELU()
299
+ self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
300
+
301
+ def forward(self, x):
302
+ return self.linear_2(self.act(self.linear_1(x)))
303
+
304
+
305
+ class VisionCrossAttentionLayer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ q_dim,
309
+ context_dim,
310
+ kv_dim_list,
311
+ kv_size_list,
312
+ hidden_dim=1024,
313
+ layer_idx=0,
314
+ ):
315
+ super().__init__()
316
+ num_heads = 16
317
+ self.num_of_kvs = len(kv_dim_list)
318
+
319
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
320
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
321
+ # if self.num_of_kvs > 1:
322
+ # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
323
+ # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
324
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
325
+
326
+ self.norm = nn.LayerNorm(hidden_dim)
327
+
328
+ self.cross_attn = MultiKVCrossAttention(
329
+ hidden_dim, kv_dim_list, hidden_dim, num_heads
330
+ )
331
+ self.kv_size_list = kv_size_list
332
+ for i, kv_size in enumerate(kv_size_list):
333
+ if kv_size > 1:
334
+ setattr(
335
+ self,
336
+ "pos_embed_{}".format(i),
337
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
338
+ )
339
+ # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
340
+
341
+ def forward(
342
+ self,
343
+ queries,
344
+ context_feature,
345
+ *vision_latents_attention_mask_list,
346
+ ) -> torch.FloatTensor:
347
+
348
+ residual = queries
349
+ # queries = self.proj_in(queries)
350
+ context_feature = self.proj_context(context_feature)
351
+ # queries = queries + context_feature
352
+ queries = torch.cat([queries, context_feature], -1)
353
+
354
+ # if self.num_of_kvs > 1:
355
+ # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
356
+ # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
357
+ # kv_weight = kv_weight.softmax(-1)
358
+ # kv_number_list = [size**2 for size in self.kv_size_list]
359
+ # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
360
+ # else:
361
+ # kv_weight = None
362
+
363
+ queries = self.proj_in(queries)
364
+
365
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
366
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
367
+
368
+ attention_mask_list_reshaped = []
369
+ if attention_mask_list is not None:
370
+ for attention_mask in attention_mask_list:
371
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
372
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
373
+ attention_mask_list_reshaped.append(attention_mask)
374
+
375
+ vision_latents_pos_list = []
376
+ for i, vision_latents in enumerate(vision_latents_list):
377
+ if vision_latents.shape[1] > 1:
378
+ vision_latents_pos_list.append(
379
+ vision_latents
380
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
381
+ vision_latents.dtype
382
+ )
383
+ )
384
+ else:
385
+ vision_latents_pos_list.append(vision_latents)
386
+
387
+ # Cross Attention
388
+ attention_output = self.cross_attn(
389
+ queries, *vision_latents_pos_list, *attention_mask_list_reshaped
390
+ )
391
+
392
+ # attention_output = (attention_output * combination_weight).sum(2)
393
+ queries = queries + attention_output
394
+
395
+ queries = self.norm(queries)
396
+
397
+ queries = self.proj_out(queries)
398
+
399
+ queries = queries + residual
400
+
401
+ return queries
402
+
403
+
404
+ class VisionAggregationLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ q_dim,
408
+ context_dim,
409
+ kv_dim_list,
410
+ kv_size_list,
411
+ hidden_dim=1024,
412
+ layer_idx=0,
413
+ ):
414
+ super().__init__()
415
+ num_heads = 16
416
+ self.num_of_kvs = len(kv_dim_list)
417
+
418
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
419
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
420
+
421
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
422
+
423
+ self.norm = nn.LayerNorm(hidden_dim)
424
+
425
+ if self.num_of_kvs > 1:
426
+ self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
427
+
428
+ for i, kv_size in enumerate(kv_size_list):
429
+ if kv_size > 1:
430
+ setattr(
431
+ self,
432
+ "pos_embed_{}".format(i),
433
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
434
+ )
435
+ setattr(
436
+ self,
437
+ "aggregate_{}".format(i),
438
+ AggregationBlock(
439
+ True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
440
+ ),
441
+ )
442
+ else:
443
+ setattr(
444
+ self,
445
+ "aggregate_{}".format(i),
446
+ AggregationBlock(
447
+ False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
448
+ ),
449
+ )
450
+
451
+ def forward(
452
+ self,
453
+ queries,
454
+ context_feature,
455
+ *vision_latents_attention_mask_list,
456
+ ) -> torch.FloatTensor:
457
+
458
+ residual = queries
459
+ # queries = self.proj_in(queries)
460
+ context_feature = self.proj_context(context_feature)
461
+ # queries = queries + context_feature
462
+ queries = torch.cat([queries, context_feature], -1)
463
+
464
+ if self.num_of_kvs > 1:
465
+ combination_weight = self.weight_mlp(queries).softmax(
466
+ -1
467
+ ) # B * 1 * num_tower
468
+ combination_weight = combination_weight.unsqueeze(-1)
469
+ else:
470
+ combination_weight = 1
471
+
472
+ queries = self.proj_in(queries)
473
+
474
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
475
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
476
+
477
+ attention_mask_list_reshaped = []
478
+ if attention_mask_list is not None:
479
+ for attention_mask in attention_mask_list:
480
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
481
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
482
+ attention_mask_list_reshaped.append(attention_mask)
483
+
484
+ vision_latents_pos_list = []
485
+ for i, vision_latents in enumerate(vision_latents_list):
486
+ if vision_latents.shape[1] > 1:
487
+ vision_latents_pos_list.append(
488
+ vision_latents
489
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
490
+ vision_latents.dtype
491
+ )
492
+ )
493
+ else:
494
+ vision_latents_pos_list.append(vision_latents)
495
+
496
+ aggregated_vision_latents_list = []
497
+ for i, (vision_latents, attention_mask) in enumerate(
498
+ zip(vision_latents_pos_list, attention_mask_list_reshaped)
499
+ ):
500
+ aggregated_vision_latents_list.append(
501
+ getattr(self, "aggregate_{}".format(i))(
502
+ vision_latents, queries, attention_mask
503
+ )
504
+ )
505
+
506
+ aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
507
+
508
+ queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
509
+
510
+ queries = self.norm(queries)
511
+
512
+ queries = self.proj_out(queries)
513
+
514
+ queries = queries + residual
515
+
516
+ return queries
517
+
518
+
519
+ class VisionTokenSampler(nn.Module):
520
+ def __init__(
521
+ self,
522
+ q_dim,
523
+ context_dim,
524
+ kv_dim_list,
525
+ kv_size_list,
526
+ vision_hidden_size,
527
+ num_of_layers=1,
528
+ layer_type="joint",
529
+ ):
530
+ super().__init__()
531
+ assert layer_type in ["joint", "sep"]
532
+ if layer_type == "joint":
533
+ self.layers = nn.ModuleList(
534
+ [
535
+ VisionCrossAttentionLayer(
536
+ q_dim,
537
+ context_dim,
538
+ kv_dim_list,
539
+ kv_size_list,
540
+ vision_hidden_size,
541
+ idx,
542
+ )
543
+ for idx in range(num_of_layers)
544
+ ]
545
+ )
546
+ else:
547
+ self.layers = nn.ModuleList(
548
+ [
549
+ VisionAggregationLayer(
550
+ q_dim,
551
+ context_dim,
552
+ kv_dim_list,
553
+ kv_size_list,
554
+ vision_hidden_size,
555
+ idx,
556
+ )
557
+ for idx in range(num_of_layers)
558
+ ]
559
+ )
560
+
561
+ def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
562
+ for layer in self.layers:
563
+ queries = layer(
564
+ queries, context_feature, *vision_latents_attention_mask_list
565
+ )
566
+ return queries