jadechoghari commited on
Commit
56f618d
1 Parent(s): 3fab70c

Create ferret_arch.py

Browse files
Files changed (1) hide show
  1. ferret_arch.py +926 -0
ferret_arch.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+ import math
18
+
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.distributed as dist
24
+
25
+ from .multimodal_encoder.builder import build_vision_tower
26
+ from .multimodal_projector.builder import build_vision_projector
27
+
28
+ from .constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX,
29
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
30
+ DEFAULT_IM_END_TOKEN, DEFAULT_REGION_FEA_TOKEN)
31
+
32
+ from .mm_utils import get_anyres_image_grid_shape
33
+
34
+ import os
35
+
36
+ def rand_sample(x, max_len):
37
+ if x.shape[0] <= max_len:
38
+ return x
39
+ else:
40
+ rand_idx = torch.randperm(x.shape[0])[:max_len]
41
+ return x[rand_idx, :]
42
+
43
+
44
+ def rand_sample_repeat(x, max_len):
45
+ if x.shape[0] < max_len:
46
+ indices = torch.randint(0, x.shape[0], (max_len-x.shape[0],))
47
+ # pdb.set_trace()
48
+ return torch.cat((x, x[indices]), dim=0)
49
+ elif x.shape[0] == max_len:
50
+ return x
51
+ else:
52
+ rand_idx = torch.randperm(x.shape[0])[:max_len]
53
+ return x[rand_idx, :]
54
+
55
+
56
+ def point_sample(input, point_coords, return_dtype, **kwargs):
57
+ """
58
+ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
59
+ Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
60
+ [0, 1] x [0, 1] square.
61
+ Args:
62
+ input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
63
+ point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
64
+ [0, 1] x [0, 1] normalized point coordinates.
65
+ Returns:
66
+ output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
67
+ features for points in `point_coords`. The features are obtained via bilinear
68
+ interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
69
+ """
70
+ add_dim = False
71
+ if point_coords.dim() == 3:
72
+ add_dim = True
73
+ point_coords = point_coords.unsqueeze(2)
74
+ # output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
75
+ output = F.grid_sample(input.float(), (2.0 * point_coords - 1.0).float(), **kwargs)
76
+ output = output.to(return_dtype)
77
+ if add_dim:
78
+ output = output.squeeze(3)
79
+ return output
80
+
81
+
82
+ def farthest_point_sample(xyz, npoint):
83
+ """
84
+ Input:
85
+ xyz: pointcloud data, [B, N, 2]
86
+ npoint: number of samples
87
+ Return:
88
+ centroids: sampled pointcloud index, [B, npoint]
89
+ """
90
+ device = xyz.device
91
+ B, N, C = xyz.shape
92
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
93
+ distance = torch.ones(B, N).to(device) * 1e10
94
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
95
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
96
+ for i in range(npoint):
97
+ centroids[:, i] = farthest
98
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 2)
99
+ dist = torch.sum((xyz - centroid) ** 2, -1)
100
+ distance = torch.min(distance, dist)
101
+ farthest = torch.max(distance, -1)[1]
102
+ return centroids
103
+
104
+
105
+ def index_points(points, idx):
106
+ """
107
+ Input:
108
+ points: input points data, [B, N, C]
109
+ idx: sample index data, [B, S]
110
+ Return:
111
+ new_points:, indexed points data, [B, S, C]
112
+ """
113
+ device = points.device
114
+ B = points.shape[0]
115
+ view_shape = list(idx.shape)
116
+ view_shape[1:] = [1] * (len(view_shape) - 1)
117
+ repeat_shape = list(idx.shape)
118
+ repeat_shape[0] = 1
119
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
120
+ new_points = points[batch_indices, idx, :]
121
+ return new_points
122
+
123
+
124
+ def square_distance(src, dst):
125
+ """
126
+ Calculate Euclid distance between each two points.
127
+ src^T * dst = xn * xm + yn * ym + zn * zm;
128
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
129
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
130
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
131
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
132
+ Input:
133
+ src: source points, [B, N, C]
134
+ dst: target points, [B, M, C]
135
+ Output:
136
+ dist: per-point square distance, [B, N, M]
137
+ """
138
+ B, N, _ = src.shape
139
+ _, M, _ = dst.shape
140
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
141
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
142
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
143
+ return dist
144
+
145
+
146
+ def knn_point(nsample, xyz, new_xyz):
147
+ """
148
+ Input:
149
+ nsample: max sample number in local region
150
+ xyz: all points, [B, N, C]
151
+ new_xyz: query points, [B, S, C]
152
+ Return:
153
+ group_idx: grouped points index, [B, S, nsample]
154
+ """
155
+ sqrdists = square_distance(new_xyz, xyz)
156
+ _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
157
+ return group_idx
158
+
159
+
160
+ class ConvReLULN1D(nn.Module):
161
+ def __init__(self, in_channels, out_channels, kernel_size=1, bias=True):
162
+ super(ConvReLULN1D, self).__init__()
163
+ self.act = nn.ReLU(inplace=True)
164
+ self.net = nn.Sequential(
165
+ nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
166
+ self.act
167
+ )
168
+ self.norm = nn.LayerNorm(out_channels)
169
+
170
+ def forward(self, x):
171
+ # (B, C, N) -> (B, C_1, N)
172
+ x = self.net(x)
173
+ x = x.permute(0, 2, 1)
174
+ x = self.norm(x)
175
+ x = x.permute(0, 2, 1)
176
+
177
+ return x
178
+
179
+
180
+ def normal_init(module, mean=0, std=1, bias=0):
181
+ if hasattr(module, 'weight') and module.weight is not None:
182
+ nn.init.normal_(module.weight, mean, std)
183
+ if hasattr(module, 'bias') and module.bias is not None:
184
+ nn.init.constant_(module.bias, bias)
185
+
186
+
187
+ class GeoRegionSampler(nn.Module):
188
+ def __init__(self,
189
+ input_dim,
190
+ output_dim,
191
+ num_init_point,
192
+ num_sub_point,
193
+ num_neighbor,
194
+ pooler_mode='mean'):
195
+ super(GeoRegionSampler, self).__init__()
196
+ self.input_dim = input_dim
197
+ self.output_dim = output_dim
198
+ self.num_init_point = num_init_point
199
+ self.num_sub_point = num_sub_point
200
+ self.num_neighbor = num_neighbor
201
+
202
+ self.diff_projector_list = nn.ModuleList()
203
+ self.agg_projector_list = nn.ModuleList()
204
+ self.pooler_list = nn.ModuleList()
205
+
206
+ for ii in range(len(num_sub_point)):
207
+ self.diff_projector_list.append(nn.Linear(self.input_dim + 2, self.input_dim + 2))
208
+ self.agg_projector_list.append(ConvReLULN1D(in_channels=2*(self.input_dim + 2),
209
+ out_channels=self.input_dim,
210
+ ))
211
+ if pooler_mode == 'mean':
212
+ self.pooler_list.append(nn.AvgPool1d(kernel_size=num_neighbor[ii]))
213
+ elif pooler_mode =='max':
214
+ self.pooler_list.append(nn.AdaptiveMaxPool1d(output_size=1))
215
+ else:
216
+ raise NotImplementedError(f'{self.pooler_mode} is not supported.')
217
+
218
+ self.flatten_projector = nn.Linear(self.input_dim * num_sub_point[-1], self.input_dim)
219
+ self.dim_projector = nn.Linear(self.input_dim, self.output_dim)
220
+ # self.dim_projector = nn.Sequential(*[
221
+ # nn.Linear(self.input_dim, self.output_dim),
222
+ # nn.GELU(),
223
+ # nn.Linear(self.output_dim, self.output_dim)
224
+ # ])
225
+
226
+ self.norm_init_weights()
227
+
228
+ # self.dtype = torch.float32
229
+ def norm_init_weights(self):
230
+ for m in self.modules():
231
+ if isinstance(m, nn.Conv2d):
232
+ normal_init(m, 0, 0.01)
233
+
234
+
235
+ def forward(self,
236
+ feature_map,
237
+ region_masks,
238
+ original_dtype,
239
+ return_dtype):
240
+
241
+ assert len(feature_map) == len(region_masks)
242
+
243
+ all_points = []
244
+ all_points_fea = []
245
+ all_points_img_ids = []
246
+
247
+ # Sample points and their features
248
+ for img_idx, (region_feature_map_i, region_masks_list_i) in enumerate(zip(feature_map, region_masks)):
249
+ if len(region_masks_list_i) != 0:
250
+ # (w, h)
251
+ ori_image_wh = torch.tensor([region_masks_list_i[0].shape[0], region_masks_list_i[0].shape[1]], device=region_masks_list_i[0].device)[None,]
252
+ # list of elements of shape [num_sample_point, 2]
253
+ cur_non_zero_pos = [rand_sample_repeat((m.nonzero()/ori_image_wh), self.num_init_point) for m in region_masks_list_i]
254
+ # list -> [num_mask, num_sample_point, 2]
255
+ cur_non_zero_pos = torch.stack(cur_non_zero_pos)
256
+ # [HxW, C] -> [H, W, C] -> [C, H, W] -> [N, C, H, W]
257
+ if region_feature_map_i.ndim == 2:
258
+ h = w = int(math.sqrt(region_feature_map_i.shape[0]))
259
+ c = region_feature_map_i.shape[-1]
260
+ region_feature_map_i = region_feature_map_i.reshape(h, w, c)
261
+ else:
262
+ assert region_feature_map_i.ndim == 3
263
+ dup_region_feature_map_i = region_feature_map_i.permute(2, 0, 1)
264
+ dup_region_feature_map_i = dup_region_feature_map_i.unsqueeze(0).repeat(cur_non_zero_pos.shape[0], 1, 1, 1)
265
+ # [num_mask, C, H, W] x [num_mask, num_sample_point, 2] -> [num_mask, C, num_sample_point] -> [num_mask, num_sample_point, C]
266
+ # F.grid_sample doesn't support BF16. Need to tranform into float32 then transform back.
267
+ dup_region_feature_map_i_ori_type = dup_region_feature_map_i.to(original_dtype)
268
+ region_feature_i = point_sample(dup_region_feature_map_i_ori_type,
269
+ cur_non_zero_pos.flip(dims=(2,)).type(original_dtype),
270
+ return_dtype,
271
+ align_corners=True,
272
+ )
273
+ # region_feature_i = region_feature_i.to(dup_region_feature_map_i.dtype)
274
+ region_feature_i = region_feature_i.transpose(-2, -1)
275
+
276
+ cur_img_ids = [img_idx] * len(cur_non_zero_pos)
277
+ # save to global list
278
+ all_points.append(cur_non_zero_pos)
279
+ all_points_fea.append(region_feature_i)
280
+ all_points_img_ids.extend(cur_img_ids)
281
+
282
+ # No region found, return list of None.
283
+ if len(all_points) == 0:
284
+ return [None] * len(region_masks)
285
+
286
+ all_points = torch.cat(all_points, dim=0).to(return_dtype) # [B*num_mask, num_sample_point, 2]
287
+ all_points_fea = torch.cat(all_points_fea, dim=0) # [B*num_mask, num_sample_point, C]
288
+ all_points_img_ids = torch.tensor(all_points_img_ids, device=all_points_fea.device)
289
+
290
+ assert all_points_fea.shape[:-1] == all_points_fea.shape[:-1]
291
+
292
+ # Processing.
293
+ for stage_i in range(len(self.num_sub_point)):
294
+ cur_num_sub_point = self.num_sub_point[stage_i]
295
+ cur_num_neighbor = self.num_neighbor[stage_i]
296
+
297
+ all_points = all_points.contiguous() # xy [btach, points, xy]
298
+ fps_idx = farthest_point_sample(all_points, cur_num_sub_point).long()
299
+
300
+ new_points = index_points(all_points, fps_idx) # [B, npoint, 2]
301
+ new_points_fea = index_points(all_points_fea, fps_idx) # [B, npoint, d]
302
+
303
+ idx = knn_point(cur_num_neighbor, all_points, new_points)
304
+ grouped_points = index_points(all_points, idx) # [B, npoint, k, 2]
305
+ grouped_points_fea = index_points(all_points_fea, idx) # [B, npoint, k, d]
306
+
307
+ local_points_fea = torch.cat([grouped_points_fea, grouped_points],dim=-1) # [B, npoint, k, d+2]
308
+ anchor_points_fea = torch.cat([new_points_fea, new_points],dim=-1).unsqueeze(-2)
309
+ diff_points_fea = local_points_fea-anchor_points_fea
310
+
311
+ diff_points_fea = self.diff_projector_list[stage_i](diff_points_fea)
312
+ gather_points_fea = torch.cat([diff_points_fea, anchor_points_fea.repeat(1, 1, cur_num_neighbor, 1)], dim=-1) # [B, npoint, k, 2(d+2)]
313
+
314
+ b, n, s, d = gather_points_fea.size()
315
+ gather_points_fea = gather_points_fea.permute(0, 1, 3, 2) # [B, npoint, 2(d+2), k]
316
+ gather_points_fea = gather_points_fea.reshape(-1, d, s) # [B*npoint, 2(d+2), k]
317
+ gather_points_fea = self.agg_projector_list[stage_i](gather_points_fea) # [B*npoint, d, k]
318
+
319
+ batch_size, new_dim, _ = gather_points_fea.size()
320
+ gather_points_fea = self.pooler_list[stage_i](gather_points_fea).view(batch_size, new_dim) # [B*npoint, d]
321
+
322
+ gather_points_fea = gather_points_fea.reshape(b, n, -1) # [B, npoint, d]
323
+
324
+ all_points = new_points
325
+ all_points_fea = gather_points_fea
326
+
327
+ x = all_points_fea.flatten(1, -1) # [B, npoint x d]
328
+ x = self.flatten_projector(x)
329
+ all_region_fea = self.dim_projector(x) # [B, d]
330
+
331
+ output_region_fea = []
332
+ for img_idx in range(len(region_masks)):
333
+ cur_mask = all_points_img_ids == img_idx
334
+
335
+ if not cur_mask.any():
336
+ output_region_fea.append(None)
337
+ else:
338
+ output_region_fea.append(all_region_fea[cur_mask])
339
+
340
+ return output_region_fea
341
+
342
+
343
+ class FerretMetaModel:
344
+
345
+ def __init__(self, config):
346
+ super(FerretMetaModel, self).__init__(config)
347
+ self.max_sample_point = 512
348
+ if hasattr(config, "mm_vision_tower"):
349
+ self.vision_tower = build_vision_tower(config, delay_load=True)
350
+ self.mm_projector = build_vision_projector(config)
351
+
352
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
353
+ self.image_newline = nn.Parameter(
354
+ torch.empty(config.hidden_size, dtype=self.dtype)
355
+ )
356
+
357
+ if hasattr(config, "region_fea_adapter"):
358
+ self.region_fea_adapter = nn.Linear(config.mm_hidden_size, config.hidden_size)
359
+
360
+ if hasattr(config, "region_geo_sampler"):
361
+ if getattr(config, 'mm_patch_merge_type', 'flat').startswith('spatial'):
362
+ self.region_geo_sampler = GeoRegionSampler(input_dim=config.mm_hidden_size,
363
+ output_dim=config.hidden_size,
364
+ num_init_point=self.max_sample_point,
365
+ num_sub_point=[128, 32],
366
+ num_neighbor=[24, 24],
367
+ pooler_mode=config.sampler_pooler_mode
368
+ )
369
+ else:
370
+ self.region_geo_sampler = GeoRegionSampler(input_dim=config.mm_hidden_size,
371
+ output_dim=config.hidden_size,
372
+ num_init_point=self.max_sample_point,
373
+ num_sub_point=[128, 32],
374
+ num_neighbor=[24, 24],
375
+ pooler_mode=config.sampler_pooler_mode
376
+ )
377
+
378
+ def get_vision_tower(self):
379
+ vision_tower = getattr(self, 'vision_tower', None)
380
+ if type(vision_tower) is list:
381
+ vision_tower = vision_tower[0]
382
+ return vision_tower
383
+
384
+ def initialize_vision_modules(self, model_args, fsdp=None,
385
+ add_region_feature=False,
386
+ region_geo_sampler=False,
387
+ sampler_pooler_mode='mean',
388
+ ):
389
+ vision_tower = model_args.vision_tower
390
+ mm_vision_select_layer = model_args.mm_vision_select_layer
391
+ mm_vision_select_feature = model_args.mm_vision_select_feature
392
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
393
+ mm_patch_merge_type = model_args.mm_patch_merge_type
394
+
395
+ self.config.mm_vision_tower = vision_tower
396
+
397
+ if self.get_vision_tower() is None:
398
+ vision_tower = build_vision_tower(model_args)
399
+
400
+ if fsdp is not None and len(fsdp) > 0:
401
+ self.vision_tower = [vision_tower]
402
+ else:
403
+ self.vision_tower = vision_tower
404
+ else:
405
+ if fsdp is not None and len(fsdp) > 0:
406
+ vision_tower = self.vision_tower[0]
407
+ else:
408
+ vision_tower = self.vision_tower
409
+ vision_tower.load_model()
410
+
411
+ self.config.use_mm_proj = True
412
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
413
+ self.config.mm_hidden_size = vision_tower.hidden_size
414
+ self.config.mm_vision_select_layer = mm_vision_select_layer
415
+ self.config.mm_vision_select_feature = mm_vision_select_feature
416
+ self.config.mm_patch_merge_type = mm_patch_merge_type
417
+
418
+ if getattr(self, 'mm_projector', None) is None:
419
+ self.mm_projector = build_vision_projector(self.config)
420
+
421
+ if 'unpad' in mm_patch_merge_type:
422
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
423
+ self.image_newline = nn.Parameter(
424
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
425
+ )
426
+
427
+ if add_region_feature:
428
+ if region_geo_sampler:
429
+ self.config.region_geo_sampler = True
430
+ self.config.sampler_pooler_mode = sampler_pooler_mode
431
+
432
+ if not hasattr(self, 'region_geo_sampler'):
433
+ if mm_patch_merge_type.startswith('spatial'):
434
+ # === if feature is concated ===
435
+ # self.region_geo_sampler = GeoRegionSampler(input_dim=self.config.mm_hidden_size * 2,
436
+ # output_dim=self.config.hidden_size,
437
+ # num_init_point=self.max_sample_point,
438
+ # num_sub_point=[128, 32],
439
+ # num_neighbor=[24, 24],
440
+ # pooler_mode=sampler_pooler_mode
441
+ # )
442
+ # === if feature is added ===
443
+ self.region_geo_sampler = GeoRegionSampler(input_dim=self.config.mm_hidden_size,
444
+ output_dim=self.config.hidden_size,
445
+ num_init_point=self.max_sample_point,
446
+ num_sub_point=[128, 32],
447
+ num_neighbor=[24, 24],
448
+ pooler_mode=sampler_pooler_mode
449
+ )
450
+ else:
451
+ self.region_geo_sampler = GeoRegionSampler(input_dim=self.config.mm_hidden_size,
452
+ output_dim=self.config.hidden_size,
453
+ num_init_point=self.max_sample_point,
454
+ num_sub_point=[128, 32],
455
+ num_neighbor=[24, 24],
456
+ pooler_mode=sampler_pooler_mode
457
+ )
458
+ else:
459
+ self.config.region_fea_adapter = True
460
+ if not hasattr(self, 'region_fea_adapter'):
461
+ self.region_fea_adapter = nn.Linear(self.config.mm_hidden_size, self.config.hidden_size)
462
+
463
+ else:
464
+ # In case it is frozen by LoRA
465
+ for p in self.mm_projector.parameters():
466
+ p.requires_grad = True
467
+
468
+ # print(f"pretrain mm mlp adapter: {type(pretrain_mm_mlp_adapter)}") # String
469
+ if pretrain_mm_mlp_adapter is not None and pretrain_mm_mlp_adapter != "None":
470
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
471
+ def get_w(weights, keyword):
472
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
473
+
474
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
475
+
476
+
477
+ def unpad_image(tensor, original_size):
478
+ """
479
+ Unpads a PyTorch tensor of a padded and resized image.
480
+
481
+ Args:
482
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
483
+ original_size (tuple): The original size of PIL image (width, height).
484
+
485
+ Returns:
486
+ torch.Tensor: The unpadded image tensor.
487
+ """
488
+ original_width, original_height = original_size
489
+ current_height, current_width = tensor.shape[1:]
490
+
491
+ original_aspect_ratio = original_width / original_height
492
+ current_aspect_ratio = current_width / current_height
493
+
494
+ if original_aspect_ratio > current_aspect_ratio:
495
+ scale_factor = current_width / original_width
496
+ new_height = int(original_height * scale_factor)
497
+ padding = (current_height - new_height) // 2
498
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
499
+ else:
500
+ scale_factor = current_height / original_height
501
+ new_width = int(original_width * scale_factor)
502
+ padding = (current_width - new_width) // 2
503
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
504
+
505
+ return unpadded_tensor
506
+
507
+
508
+ class FerretMetaForCausalLM(ABC):
509
+
510
+ @abstractmethod
511
+ def get_model(self):
512
+ pass
513
+
514
+ def get_vision_tower(self):
515
+ return self.get_model().get_vision_tower()
516
+
517
+ def encode_images(self, images, region_flag=False, region_geo_sampler=False):
518
+ image_features = self.get_model().get_vision_tower()(images)
519
+ projected_image_features = self.get_model().mm_projector(image_features)
520
+ if region_flag:
521
+ if region_geo_sampler:
522
+ new_region_feature_map = image_features
523
+ else:
524
+ new_region_feature_map = self.get_model().region_fea_adapter(image_features)
525
+ else:
526
+ new_region_feature_map = None
527
+
528
+ return image_features, projected_image_features, new_region_feature_map
529
+
530
+ def extract_region_feature(self, region_feature_map, region_masks, original_dtype, return_dtype):
531
+ all_region_features = []
532
+ assert len(region_feature_map) == len(region_masks)
533
+ for region_feature_map_i, region_masks_list_i in zip(region_feature_map, region_masks):
534
+ if len(region_masks_list_i) == 0:
535
+ all_region_features.append(None)
536
+ else:
537
+ # (w, h)
538
+ ori_image_wh = torch.tensor([region_masks_list_i[0].shape[0], region_masks_list_i[0].shape[1]], device=region_masks_list_i[0].device)[None,]
539
+ # list of elements of shape [num_sample_point, 2]
540
+ non_zero_pos = [rand_sample((m.nonzero()/ori_image_wh), self.get_model().max_sample_point) for m in region_masks_list_i]
541
+ # [num_mask, num_sample_point(padded), 2]
542
+ non_zero_pos = nn.utils.rnn.pad_sequence(non_zero_pos, padding_value=-1, batch_first=True)
543
+ non_zero_pos_mask = ~(non_zero_pos.sum(dim=-1) < 0)
544
+ # [HxW, C] -> [H, W, C] -> [C, H, W] -> [N, C, H, W]
545
+ h = w = int(math.sqrt(region_feature_map_i.shape[0]))
546
+ c = region_feature_map_i.shape[-1]
547
+ dup_region_feature_map_i = region_feature_map_i.reshape(h, w, c).permute(2, 0, 1)
548
+ dup_region_feature_map_i = dup_region_feature_map_i.unsqueeze(0).repeat(non_zero_pos.shape[0], 1, 1, 1)
549
+ # [num_mask, C, H, W] x [num_mask, num_sample_point(padded), 2] -> [num_mask, C, num_sample_point(padded)]
550
+ # F.grid_sample doesn't support BF16. Need to tranform into float32 then transform back.
551
+ dup_region_feature_map_i_ori_type = dup_region_feature_map_i.to(original_dtype)
552
+ # pdb.set_trace()
553
+ region_feature_i = point_sample(dup_region_feature_map_i_ori_type,
554
+ non_zero_pos.flip(dims=(2,)).type(original_dtype),
555
+ return_dtype,
556
+ align_corners=True
557
+ )
558
+ region_feature_i = region_feature_i.to(dup_region_feature_map_i.dtype)
559
+ # [num_mask, C]
560
+ region_feature_i = torch.stack([x[m].mean(dim=0) for x, m in zip(region_feature_i.transpose(1,2), non_zero_pos_mask)]).nan_to_num()
561
+ all_region_features.append(region_feature_i)
562
+
563
+ return all_region_features
564
+
565
+ def prepare_inputs_labels_for_multimodal(
566
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
567
+ images, image_sizes=None, region_masks=None
568
+ ):
569
+ if region_masks is not None:
570
+ region_flag = True
571
+ else:
572
+ region_flag = False
573
+ region_geo_sampler = region_flag and getattr(self.config, 'region_geo_sampler', False)
574
+
575
+ vision_tower = self.get_vision_tower()
576
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
577
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
578
+
579
+ if type(images) is list or images.ndim == 5:
580
+ if type(images) is list:
581
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
582
+
583
+ concat_images = torch.cat([image for image in images], dim=0)
584
+ raw_image_features, image_features, region_feature_map = self.encode_images(concat_images, region_flag=region_flag, region_geo_sampler=region_geo_sampler)
585
+ split_sizes = [image.shape[0] for image in images]
586
+ image_features = torch.split(image_features, split_sizes, dim=0)
587
+
588
+ if region_flag:
589
+ region_feature_maps = torch.split(region_feature_map, split_sizes, dim=0) # (#images, #patches, h*w, c)
590
+ # ======== This is for only taking the global image feature map for referring ======
591
+ # region_feature_map = torch.split(region_feature_map, split_sizes, dim=0)
592
+ # first_region_feature_map = [x[0:1] for x in region_feature_map]
593
+ # region_feature_map = torch.cat(first_region_feature_map, dim=0)
594
+
595
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
596
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square_nocrop')
597
+
598
+ if mm_patch_merge_type == 'flat':
599
+ image_features = [x.flatten(0, 1) for x in image_features]
600
+ # TODO: here we use the first feature map default for each batch (global feaure map) for referring
601
+ first_region_feature_map = [x[0:1] for x in region_feature_map]
602
+ region_feature_map = torch.cat(first_region_feature_map, dim=0) # (#images, h, w, c)
603
+ elif mm_patch_merge_type.startswith('spatial'):
604
+ new_image_features = []
605
+ new_region_features = []
606
+ for image_idx, image_feature in enumerate(image_features):
607
+ if image_feature.shape[0] > 1:
608
+ base_image_feature = image_feature[0]
609
+ image_feature = image_feature[1:]
610
+ height = width = self.get_vision_tower().num_patches_per_side
611
+ assert height * width == base_image_feature.shape[0]
612
+ if region_flag:
613
+ cur_region_feature_map = region_feature_maps[image_idx] # (#patches, h*w, c)
614
+ cur_region_feature_map = cur_region_feature_map.view(cur_region_feature_map.shape[0], height, width, cur_region_feature_map.shape[-1]) # (#patches, h, w, c)
615
+ base_region_feature = cur_region_feature_map[0]
616
+ region_feature = cur_region_feature_map[1:]
617
+ # pdb.set_trace()
618
+ if image_aspect_ratio == 'anyres':
619
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
620
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
621
+ if region_flag:
622
+ region_feature = region_feature.view(num_patch_height, num_patch_width, height, width, -1)
623
+ else:
624
+ raise NotImplementedError
625
+
626
+ if 'unpad' in mm_patch_merge_type:
627
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
628
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
629
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
630
+ image_feature = torch.cat((
631
+ image_feature,
632
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
633
+ ), dim=-1)
634
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
635
+ else:
636
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
637
+ image_feature = image_feature.flatten(0, 3)
638
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
639
+ if region_flag:
640
+ region_feature = region_feature.permute(0, 2, 1, 3, 4).contiguous() # (patch_h, patch_w, h, w, c) -> (patch_h, h, patch_w, w, c)
641
+ region_feature = region_feature.flatten(0, 1).flatten(1, 2) # (patch_h, h, patch_w, w, c) -> (all_h, all_w, c)
642
+ # Tranform dtype, if using pytorch2.1+, no need to do this.
643
+ base_region_feature = base_region_feature.to(dtype=torch.float32)
644
+ base_region_feature_resized = F.interpolate(base_region_feature.unsqueeze(0).permute(0, 3, 1, 2), (region_feature.shape[0], region_feature.shape[1])) # (1, c, all_h, all_w)
645
+ base_region_feature_resized = base_region_feature_resized.to(region_feature.dtype)
646
+ base_region_feature_resized = base_region_feature_resized.squeeze(0).permute(1, 2, 0) # (all_h, all_w, c)
647
+ # === Add:
648
+ new_region_feature = base_region_feature_resized + region_feature
649
+ # === Concat: A bit lower, 1/3 more GPU memory consumption.
650
+ # new_region_feature = torch.cat((base_region_feature_resized, region_feature), dim=2) # (all_h, all_w, 2c)
651
+ else:
652
+ image_feature = image_feature[0]
653
+ if 'unpad' in mm_patch_merge_type:
654
+ image_feature = torch.cat((
655
+ image_feature,
656
+ self.model.image_newline[None].to(image_feature.device)
657
+ ), dim=0)
658
+ if region_flag:
659
+ new_region_feature = region_feature_maps[image_idx][0] # (h, w, c)
660
+ new_image_features.append(image_feature)
661
+ if region_flag:
662
+ new_region_features.append(new_region_feature)
663
+ # pdb.set_trace()
664
+ image_features = new_image_features
665
+ if region_flag:
666
+ # region_feature_map = torch.stack(new_region_features, dim=0) # (#images, h, w, c or 2c)
667
+ region_feature_map = new_region_features
668
+ # pdb.set_trace()
669
+ else:
670
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
671
+ else:
672
+ raw_image_features, image_features, region_feature_map = self.encode_images(images, region_flag=region_flag, region_geo_sampler=region_geo_sampler)
673
+
674
+ if region_flag:
675
+ assert len(region_masks) == len(input_ids)
676
+ for img_idx, (cur_input_id, cur_region_mask) in enumerate(zip(input_ids, region_masks)):
677
+ cur_region_token_num = (cur_input_id == self.config.im_region_fea_token).sum()
678
+ if cur_region_token_num != len(cur_region_mask):
679
+ print('Found regions cropped because of text beyond max_len, removed them.')
680
+ region_masks[img_idx] = cur_region_mask[:cur_region_token_num]
681
+
682
+ # dump_region_mask = torch.zeros(100, 100).to(device='cuda')
683
+ dump_region_mask = torch.zeros(100, 100, device='cuda')
684
+ dump_region_mask[10:20, 10:20] = 1
685
+ dump_region_masks = [[dump_region_mask.clone()]]
686
+ for _ in range(len(region_feature_map)-1):
687
+ dump_region_masks.append([])
688
+
689
+ if region_geo_sampler:
690
+ if type(image_features) is list:
691
+ region_features = self.get_model().region_geo_sampler(region_feature_map, region_masks,
692
+ original_dtype=raw_image_features.dtype,
693
+ return_dtype=image_features[0].dtype)
694
+ dump_region_features = self.get_model().region_geo_sampler(region_feature_map, dump_region_masks,
695
+ original_dtype=raw_image_features.dtype,
696
+ return_dtype=image_features[0].dtype)
697
+ else:
698
+ region_features = self.get_model().region_geo_sampler(region_feature_map, region_masks,
699
+ original_dtype=raw_image_features.dtype,
700
+ return_dtype=image_features.dtype)
701
+ dump_region_features = self.get_model().region_geo_sampler(region_feature_map, dump_region_masks,
702
+ original_dtype=raw_image_features.dtype,
703
+ return_dtype=image_features.dtype)
704
+ else:
705
+ if type(image_features) is list:
706
+ region_features = self.extract_region_feature(region_feature_map, region_masks,
707
+ original_dtype=raw_image_features.dtype,
708
+ return_dtype=image_features[0].dtype)
709
+ dump_region_features = self.extract_region_feature(region_feature_map, dump_region_masks,
710
+ original_dtype=raw_image_features.dtype,
711
+ return_dtype=image_features[0].dtype)
712
+ else:
713
+ region_features = self.extract_region_feature(region_feature_map, region_masks,
714
+ original_dtype=raw_image_features.dtype,
715
+ return_dtype=image_features.dtype)
716
+ dump_region_features = self.extract_region_feature(region_feature_map, dump_region_masks,
717
+ original_dtype=raw_image_features.dtype,
718
+ return_dtype=image_features.dtype)
719
+ # assert len(dump_region_features) == 1
720
+ assert len([df for df in dump_region_features if df is not None]) == 1
721
+ assert len(dump_region_features[0]) == 1
722
+ assert len(region_features) == len(input_ids)
723
+
724
+ # TODO: image start / end is not implemented here to support pretraining.
725
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
726
+ raise NotImplementedError
727
+
728
+ # Let's just add dummy tensors if they do not exist,
729
+ # it is a headache to deal with None all the time.
730
+ # But it is not ideal, and if you have a better idea,
731
+ # please open an issue / submit a PR, thanks.
732
+ _labels = labels
733
+ _position_ids = position_ids
734
+ _attention_mask = attention_mask
735
+ if attention_mask is None:
736
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
737
+ else:
738
+ attention_mask = attention_mask.bool()
739
+ if position_ids is None:
740
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
741
+ if labels is None:
742
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
743
+
744
+ # remove the padding using attention_mask -- FIXME
745
+ _input_ids = input_ids
746
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
747
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
748
+
749
+ new_input_embeds = []
750
+ new_labels = []
751
+ cur_image_idx = 0
752
+ for batch_idx, cur_input_ids in enumerate(input_ids):
753
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
754
+ if num_images == 0:
755
+ cur_image_features = image_features[cur_image_idx]
756
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
757
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
758
+ new_input_embeds.append(cur_input_embeds)
759
+ new_labels.append(labels[batch_idx])
760
+ cur_image_idx += 1
761
+ continue
762
+
763
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
764
+ cur_input_id_with_im = []
765
+ cur_input_ids_noim = []
766
+ cur_labels = labels[batch_idx]
767
+ cur_labels_noim = []
768
+ for i in range(len(image_token_indices) - 1):
769
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
770
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
771
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
772
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
773
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
774
+ cur_new_input_embeds = []
775
+ cur_new_labels = []
776
+ assert len(cur_input_ids_noim) == len(cur_input_embeds_no_im)
777
+ for i in range(num_images + 1):
778
+ cur_input_id_with_im.append(cur_input_ids_noim[i])
779
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
780
+ cur_new_labels.append(cur_labels_noim[i])
781
+ if i < num_images:
782
+ cur_image_features = image_features[cur_image_idx]
783
+ cur_image_idx += 1
784
+ cur_input_id_with_im.append(torch.full((cur_image_features.shape[0],), IMAGE_TOKEN_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
785
+ cur_new_input_embeds.append(cur_image_features)
786
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
787
+
788
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
789
+
790
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
791
+ cur_new_labels = torch.cat(cur_new_labels)
792
+ cur_input_id_with_im = torch.cat(cur_input_id_with_im)
793
+
794
+ assert len(cur_input_id_with_im) == len(cur_new_input_embeds)
795
+ # Add region feature into text feature embeddings.
796
+ # Currently only support one image in each input.
797
+ assert batch_idx+1 == cur_image_idx
798
+ if region_flag and region_features[batch_idx] is not None:
799
+ region_embs = torch.zeros_like(cur_new_input_embeds)
800
+ region_replace_mask = (cur_input_id_with_im == self.config.im_region_fea_token)
801
+ # region_embs[region_replace_mask] = region_features[batch_idx].to(cur_new_input_embeds.dtype)
802
+ if len(region_embs[region_replace_mask]) != len(region_features[batch_idx]):
803
+ # ("Found a region cropped in text")
804
+ region_embs[region_replace_mask] = region_features[batch_idx][:len(region_embs[region_replace_mask])].to(cur_new_input_embeds.dtype)
805
+ else:
806
+ region_embs[region_replace_mask] = region_features[batch_idx].to(cur_new_input_embeds.dtype)
807
+ cur_new_input_embeds = cur_new_input_embeds * (~region_replace_mask).to(cur_new_input_embeds.dtype)[:, None] + region_embs
808
+ else:
809
+ if hasattr(self.config, 'im_region_fea_token'):
810
+ assert (cur_input_id_with_im == self.config.im_region_fea_token).sum() == 0
811
+
812
+ # Add dump region feature to input embedding, to make sure the gradient for region sampler always exist when open region_flag.
813
+ if region_flag:
814
+ # cur_new_input_embeds[0] = cur_new_input_embeds[0] + 0 * dump_region_features[0, 0].to(cur_new_input_embeds.dtype)
815
+ cur_new_input_embeds[0] = cur_new_input_embeds[0] + 0.0 * dump_region_features[0][0].to(cur_new_input_embeds.dtype)
816
+
817
+ new_input_embeds.append(cur_new_input_embeds)
818
+ new_labels.append(cur_new_labels)
819
+
820
+ # Truncate sequences to max length as image embeddings can make the sequence longer
821
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
822
+ if tokenizer_model_max_length is not None:
823
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
824
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
825
+
826
+ # Combine them
827
+ max_len = max(x.shape[0] for x in new_input_embeds)
828
+ batch_size = len(new_input_embeds)
829
+
830
+ new_input_embeds_padded = []
831
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
832
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
833
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
834
+
835
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
836
+ cur_len = cur_new_embed.shape[0]
837
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
838
+ new_input_embeds_padded.append(torch.cat((
839
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
840
+ cur_new_embed
841
+ ), dim=0))
842
+ if cur_len > 0:
843
+ new_labels_padded[i, -cur_len:] = cur_new_labels
844
+ attention_mask[i, -cur_len:] = True
845
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
846
+ else:
847
+ new_input_embeds_padded.append(torch.cat((
848
+ cur_new_embed,
849
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
850
+ ), dim=0))
851
+ if cur_len > 0:
852
+ new_labels_padded[i, :cur_len] = cur_new_labels
853
+ attention_mask[i, :cur_len] = True
854
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
855
+
856
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
857
+
858
+ if _labels is None:
859
+ new_labels = None
860
+ else:
861
+ new_labels = new_labels_padded
862
+
863
+ if _attention_mask is None:
864
+ attention_mask = None
865
+ else:
866
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
867
+
868
+ if _position_ids is None:
869
+ position_ids = None
870
+
871
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
872
+
873
+ def initialize_vision_tokenizer(self, model_args, tokenizer, add_region_feature=False):
874
+ if model_args.mm_use_im_patch_token:
875
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
876
+ self.resize_token_embeddings(len(tokenizer))
877
+
878
+ if add_region_feature:
879
+ region_token_id = tokenizer.convert_tokens_to_ids([DEFAULT_REGION_FEA_TOKEN])[0]
880
+ # If region_token doesn't exist, add it.
881
+ if region_token_id == tokenizer.unk_token_id:
882
+ num_region_fea_tokens = tokenizer.add_tokens([DEFAULT_REGION_FEA_TOKEN], special_tokens=True)
883
+ self.config.im_region_fea_token = tokenizer.convert_tokens_to_ids([DEFAULT_REGION_FEA_TOKEN])[0]
884
+ self.resize_token_embeddings(len(tokenizer))
885
+
886
+ if model_args.mm_use_im_start_end:
887
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
888
+ self.resize_token_embeddings(len(tokenizer))
889
+
890
+ if add_region_feature:
891
+ num_new_tokens = num_new_tokens + num_region_fea_tokens
892
+
893
+ if num_new_tokens > 0:
894
+ input_embeddings = self.get_input_embeddings().weight.data
895
+ output_embeddings = self.get_output_embeddings().weight.data
896
+
897
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
898
+ dim=0, keepdim=True)
899
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
900
+ dim=0, keepdim=True)
901
+
902
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
903
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
904
+
905
+ if model_args.tune_mm_mlp_adapter:
906
+ for p in self.get_input_embeddings().parameters():
907
+ p.requires_grad = True
908
+ for p in self.get_output_embeddings().parameters():
909
+ p.requires_grad = False
910
+
911
+ if model_args.pretrain_mm_mlp_adapter:
912
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
913
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
914
+ assert num_new_tokens == 2
915
+ if input_embeddings.shape == embed_tokens_weight.shape:
916
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
917
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
918
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
919
+ else:
920
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
921
+ elif model_args.mm_use_im_patch_token:
922
+ if model_args.tune_mm_mlp_adapter:
923
+ for p in self.get_input_embeddings().parameters():
924
+ p.requires_grad = False
925
+ for p in self.get_output_embeddings().parameters():
926
+ p.requires_grad = False