HGB commited on
Commit
e065dd1
·
1 Parent(s): 9c8bb9e

remove formatting

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +207 -122
modeling_intern_vit.py CHANGED
@@ -12,13 +12,13 @@ from einops import rearrange
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import (BaseModelOutput,
16
- BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
 
22
  try:
23
  from triton_flash_atn import _attention
24
 
@@ -26,7 +26,7 @@ try:
26
 
27
  has_flash_attn = True
28
  except:
29
- print('FlashAttention is not installed.')
30
  has_flash_attn = False
31
 
32
  logger = logging.get_logger(__name__)
@@ -43,13 +43,22 @@ class FlashAttention(nn.Module):
43
  (default: 0.0)
44
  """
45
 
46
- def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
 
 
47
  super().__init__()
48
  self.softmax_scale = softmax_scale
49
  self.dropout_p = attention_dropout
50
 
51
- def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
52
- max_s=None, need_weights=False):
 
 
 
 
 
 
 
53
  """Implements the multihead softmax attention.
54
  Arguments
55
  ---------
@@ -65,35 +74,58 @@ class FlashAttention(nn.Module):
65
  batch_size = qkv.shape[0]
66
  seqlen = qkv.shape[1]
67
  if key_padding_mask is None:
68
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
69
  max_s = seqlen
70
- cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
71
- device=qkv.device)
 
 
 
 
 
72
  output = _attention.apply(
73
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
74
- sm_scale=self.softmax_scale, causal=causal
 
 
 
 
75
  )
76
- output = rearrange(
77
- output, '(b s) ... -> b s ...', b=batch_size)
78
  else:
79
  nheads = qkv.shape[-2]
80
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
81
- x_unpad, indices, cu_seqlens, max_s = unpad_input(
82
- x, key_padding_mask)
83
  x_unpad = rearrange(
84
- x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
 
85
  output_unpad = _attention.apply(
86
- x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
87
- sm_scale=self.softmax_scale, causal=causal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
90
- indices, batch_size, seqlen),
91
- 'b s (h d) -> b s h d', h=nheads)
92
  else:
93
  assert max_s is not None
94
  output = _attention.apply(
95
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
96
- sm_scale=self.softmax_scale, causal=causal
 
 
 
 
97
  )
98
 
99
  return output, None
@@ -109,8 +141,7 @@ class InternRMSNorm(nn.Module):
109
  input_dtype = hidden_states.dtype
110
  hidden_states = hidden_states.to(torch.float32)
111
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
112
- hidden_states = hidden_states * \
113
- torch.rsqrt(variance + self.variance_epsilon)
114
  return self.weight * hidden_states.to(input_dtype)
115
 
116
 
@@ -120,19 +151,21 @@ try:
120
  InternRMSNorm = FusedRMSNorm # noqa
121
 
122
  logger.info(
123
- 'Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
 
124
  except ImportError:
125
  # using the normal InternRMSNorm
126
  pass
127
  except Exception:
128
  logger.warning(
129
- 'discovered apex but it failed to load, falling back to InternRMSNorm')
 
130
  pass
131
 
132
 
133
  NORM2FN = {
134
- 'rms_norm': InternRMSNorm,
135
- 'layer_norm': nn.LayerNorm,
136
  }
137
 
138
 
@@ -149,21 +182,37 @@ class InternVisionEmbeddings(nn.Module):
149
  )
150
 
151
  self.patch_embedding = nn.Conv2d(
152
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
 
 
 
153
  )
154
 
155
  self.num_patches = (self.image_size // self.patch_size) ** 2
156
  self.num_positions = self.num_patches + 1
157
 
158
  self.position_embedding = nn.Parameter(
159
- torch.randn(1, self.num_positions, self.embed_dim))
 
160
 
161
  def _get_pos_embed(self, pos_embed, H, W):
162
  target_dtype = pos_embed.dtype
163
- pos_embed = pos_embed.float().reshape(
164
- 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
165
- pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
166
- reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
167
  return pos_embed
168
 
169
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
@@ -172,14 +221,15 @@ class InternVisionEmbeddings(nn.Module):
172
  patch_embeds = self.patch_embedding(pixel_values)
173
  batch_size, _, height, width = patch_embeds.shape
174
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
175
- class_embeds = self.class_embedding.expand(
176
- batch_size, 1, -1).to(target_dtype)
177
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
178
- position_embedding = torch.cat([
179
- self.position_embedding[:, :1, :],
180
- self._get_pos_embed(
181
- self.position_embedding[:, 1:, :], height, width)
182
- ], dim=1)
 
 
183
  embeddings = embeddings + position_embedding.to(target_dtype)
184
  return embeddings
185
 
@@ -195,49 +245,54 @@ class InternAttention(nn.Module):
195
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
196
  if config.use_flash_attn and not has_flash_attn:
197
  print(
198
- 'Warning: Flash Attention is not available, use_flash_attn is set to False.')
 
199
  self.head_dim = self.embed_dim // self.num_heads
200
  if self.head_dim * self.num_heads != self.embed_dim:
201
  raise ValueError(
202
- f'embed_dim must be divisible by num_heads (got `embed_dim`: {
203
- self.embed_dim} and `num_heads`:'
204
- f' {self.num_heads}).'
205
  )
206
 
207
- self.scale = self.head_dim ** -0.5
208
- self.qkv = nn.Linear(self.embed_dim, 3 *
209
- self.embed_dim, bias=config.qkv_bias)
210
  self.attn_drop = nn.Dropout(config.attention_dropout)
211
  self.proj_drop = nn.Dropout(config.dropout)
212
 
213
  self.qk_normalization = config.qk_normalization
214
 
215
  if self.qk_normalization:
216
- self.q_norm = InternRMSNorm(
217
- self.embed_dim, eps=config.layer_norm_eps)
218
- self.k_norm = InternRMSNorm(
219
- self.embed_dim, eps=config.layer_norm_eps)
220
 
221
  if self.use_flash_attn:
222
- self.inner_attn = FlashAttention(
223
- attention_dropout=config.attention_dropout)
224
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
225
 
226
  def _naive_attn(self, x):
227
  B, N, C = x.shape
228
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
229
- self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
230
  # make torchscript happy (cannot use tensor as tuple)
231
  q, k, v = qkv.unbind(0)
232
 
233
  if self.qk_normalization:
234
  B_, H_, N_, D_ = q.shape
235
- q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)
236
- ).view(B_, N_, H_, D_).transpose(1, 2)
237
- k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)
238
- ).view(B_, N_, H_, D_).transpose(1, 2)
 
 
 
 
 
 
239
 
240
- attn = ((q * self.scale) @ k.transpose(-2, -1))
241
  attn = attn.softmax(dim=-1)
242
  attn = self.attn_drop(attn)
243
 
@@ -248,8 +303,9 @@ class InternAttention(nn.Module):
248
 
249
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
250
  qkv = self.qkv(x)
251
- qkv = rearrange(qkv, 'b s (three h d) -> b s three h d',
252
- three=3, h=self.num_heads)
 
253
 
254
  if self.qk_normalization:
255
  q, k, v = qkv.unbind(2)
@@ -258,15 +314,21 @@ class InternAttention(nn.Module):
258
  qkv = torch.stack([q, k, v], dim=2)
259
 
260
  context, _ = self.inner_attn(
261
- qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
 
 
 
262
  )
263
- outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
264
  outs = self.proj_drop(outs)
265
  return outs
266
 
267
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
268
- x = self._naive_attn(
269
- hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
 
 
270
  return x
271
 
272
 
@@ -294,33 +356,37 @@ class InternVisionEncoderLayer(nn.Module):
294
 
295
  self.attn = InternAttention(config)
296
  self.mlp = InternMLP(config)
297
- self.norm1 = NORM2FN[self.norm_type](
298
- self.embed_dim, eps=config.layer_norm_eps)
299
- self.norm2 = NORM2FN[self.norm_type](
300
- self.embed_dim, eps=config.layer_norm_eps)
301
-
302
- self.ls1 = nn.Parameter(
303
- config.initializer_factor * torch.ones(self.embed_dim))
304
- self.ls2 = nn.Parameter(
305
- config.initializer_factor * torch.ones(self.embed_dim))
306
- self.drop_path1 = DropPath(
307
- drop_path_rate) if drop_path_rate > 0. else nn.Identity()
308
- self.drop_path2 = DropPath(
309
- drop_path_rate) if drop_path_rate > 0. else nn.Identity()
310
 
311
  def forward(
312
- self,
313
- hidden_states: torch.Tensor,
314
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
 
 
 
 
315
  """
316
  Args:
317
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
318
  """
319
- hidden_states = hidden_states + \
320
- self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
 
321
 
322
- hidden_states = hidden_states + \
323
- self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
324
 
325
  return hidden_states
326
 
@@ -339,17 +405,23 @@ class InternVisionEncoder(nn.Module):
339
  super().__init__()
340
  self.config = config
341
  # stochastic depth decay rule
342
- dpr = [x.item() for x in torch.linspace(
343
- 0, config.drop_path_rate, config.num_hidden_layers)]
344
- self.layers = nn.ModuleList([
345
- InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
 
 
 
 
 
 
346
  self.gradient_checkpointing = True
347
 
348
  def forward(
349
- self,
350
- inputs_embeds,
351
- output_hidden_states: Optional[bool] = None,
352
- return_dict: Optional[bool] = None,
353
  ) -> Union[Tuple, BaseModelOutput]:
354
  r"""
355
  Args:
@@ -362,9 +434,13 @@ class InternVisionEncoder(nn.Module):
362
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
363
  """
364
  output_hidden_states = (
365
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
366
  )
367
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
368
 
369
  encoder_states = () if output_hidden_states else None
370
  hidden_states = inputs_embeds
@@ -374,8 +450,8 @@ class InternVisionEncoder(nn.Module):
374
  encoder_states = encoder_states + (hidden_states,)
375
  if self.gradient_checkpointing and self.training:
376
  layer_outputs = torch.utils.checkpoint.checkpoint(
377
- encoder_layer,
378
- hidden_states)
379
  else:
380
  layer_outputs = encoder_layer(
381
  hidden_states,
@@ -393,9 +469,9 @@ class InternVisionEncoder(nn.Module):
393
 
394
 
395
  class InternVisionModel(PreTrainedModel):
396
- main_input_name = 'pixel_values'
397
  config_class = InternVisionConfig
398
- _no_split_modules = ['InternVisionEncoderLayer']
399
 
400
  def __init__(self, config: InternVisionConfig):
401
  super().__init__(config)
@@ -408,36 +484,46 @@ class InternVisionModel(PreTrainedModel):
408
  pos_emb = self.embeddings.position_embedding
409
  _, num_positions, embed_dim = pos_emb.shape
410
  cls_emb = pos_emb[:, :1, :]
411
- pos_emb = pos_emb[:, 1:, :].reshape(
412
- 1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
413
- pos_emb = F.interpolate(pos_emb.float(
414
- ), size=new_size // patch_size, mode='bicubic', align_corners=False)
415
- pos_emb = pos_emb.to(cls_emb.dtype).reshape(
416
- 1, embed_dim, -1).permute(0, 2, 1)
 
 
 
 
 
 
417
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
418
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
419
  self.embeddings.image_size = new_size
420
- logger.info('Resized position embeddings from {} to {}'.format(
421
- old_size, new_size))
 
422
 
423
  def get_input_embeddings(self):
424
  return self.embeddings
425
 
426
  def forward(
427
- self,
428
- pixel_values: Optional[torch.FloatTensor] = None,
429
- output_hidden_states: Optional[bool] = None,
430
- return_dict: Optional[bool] = None,
431
- pixel_embeds: Optional[torch.FloatTensor] = None,
432
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
433
  output_hidden_states = (
434
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
435
  )
436
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
437
 
438
  if pixel_values is None and pixel_embeds is None:
439
- raise ValueError(
440
- 'You have to specify pixel_values or pixel_embeds')
441
 
442
  if pixel_embeds is not None:
443
  hidden_states = pixel_embeds
@@ -445,8 +531,7 @@ class InternVisionModel(PreTrainedModel):
445
  if len(pixel_values.shape) == 4:
446
  hidden_states = self.embeddings(pixel_values)
447
  else:
448
- raise ValueError(f'wrong pixel_values size: {
449
- pixel_values.shape}')
450
  encoder_outputs = self.encoder(
451
  inputs_embeds=hidden_states,
452
  output_hidden_states=output_hidden_states,
 
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_intern_vit import InternVisionConfig
20
 
21
+
22
  try:
23
  from triton_flash_atn import _attention
24
 
 
26
 
27
  has_flash_attn = True
28
  except:
29
+ print("FlashAttention is not installed.")
30
  has_flash_attn = False
31
 
32
  logger = logging.get_logger(__name__)
 
43
  (default: 0.0)
44
  """
45
 
46
+ def __init__(
47
+ self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
48
+ ):
49
  super().__init__()
50
  self.softmax_scale = softmax_scale
51
  self.dropout_p = attention_dropout
52
 
53
+ def forward(
54
+ self,
55
+ qkv,
56
+ key_padding_mask=None,
57
+ causal=False,
58
+ cu_seqlens=None,
59
+ max_s=None,
60
+ need_weights=False,
61
+ ):
62
  """Implements the multihead softmax attention.
63
  Arguments
64
  ---------
 
74
  batch_size = qkv.shape[0]
75
  seqlen = qkv.shape[1]
76
  if key_padding_mask is None:
77
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
78
  max_s = seqlen
79
+ cu_seqlens = torch.arange(
80
+ 0,
81
+ (batch_size + 1) * seqlen,
82
+ step=seqlen,
83
+ dtype=torch.int32,
84
+ device=qkv.device,
85
+ )
86
  output = _attention.apply(
87
+ qkv,
88
+ cu_seqlens,
89
+ max_s,
90
+ self.dropout_p if self.training else 0.0,
91
+ sm_scale=self.softmax_scale,
92
+ causal=causal,
93
  )
94
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
 
95
  else:
96
  nheads = qkv.shape[-2]
97
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
98
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
 
99
  x_unpad = rearrange(
100
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
101
+ )
102
  output_unpad = _attention.apply(
103
+ x_unpad,
104
+ cu_seqlens,
105
+ max_s,
106
+ self.dropout_p if self.training else 0.0,
107
+ sm_scale=self.softmax_scale,
108
+ causal=causal,
109
+ )
110
+ output = rearrange(
111
+ pad_input(
112
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
113
+ indices,
114
+ batch_size,
115
+ seqlen,
116
+ ),
117
+ "b s (h d) -> b s h d",
118
+ h=nheads,
119
  )
 
 
 
120
  else:
121
  assert max_s is not None
122
  output = _attention.apply(
123
+ qkv,
124
+ cu_seqlens,
125
+ max_s,
126
+ self.dropout_p if self.training else 0.0,
127
+ sm_scale=self.softmax_scale,
128
+ causal=causal,
129
  )
130
 
131
  return output, None
 
141
  input_dtype = hidden_states.dtype
142
  hidden_states = hidden_states.to(torch.float32)
143
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
144
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
145
  return self.weight * hidden_states.to(input_dtype)
146
 
147
 
 
151
  InternRMSNorm = FusedRMSNorm # noqa
152
 
153
  logger.info(
154
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm"
155
+ )
156
  except ImportError:
157
  # using the normal InternRMSNorm
158
  pass
159
  except Exception:
160
  logger.warning(
161
+ "discovered apex but it failed to load, falling back to InternRMSNorm"
162
+ )
163
  pass
164
 
165
 
166
  NORM2FN = {
167
+ "rms_norm": InternRMSNorm,
168
+ "layer_norm": nn.LayerNorm,
169
  }
170
 
171
 
 
182
  )
183
 
184
  self.patch_embedding = nn.Conv2d(
185
+ in_channels=3,
186
+ out_channels=self.embed_dim,
187
+ kernel_size=self.patch_size,
188
+ stride=self.patch_size,
189
  )
190
 
191
  self.num_patches = (self.image_size // self.patch_size) ** 2
192
  self.num_positions = self.num_patches + 1
193
 
194
  self.position_embedding = nn.Parameter(
195
+ torch.randn(1, self.num_positions, self.embed_dim)
196
+ )
197
 
198
  def _get_pos_embed(self, pos_embed, H, W):
199
  target_dtype = pos_embed.dtype
200
+ pos_embed = (
201
+ pos_embed.float()
202
+ .reshape(
203
+ 1,
204
+ self.image_size // self.patch_size,
205
+ self.image_size // self.patch_size,
206
+ -1,
207
+ )
208
+ .permute(0, 3, 1, 2)
209
+ )
210
+ pos_embed = (
211
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
212
+ .reshape(1, -1, H * W)
213
+ .permute(0, 2, 1)
214
+ .to(target_dtype)
215
+ )
216
  return pos_embed
217
 
218
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
 
221
  patch_embeds = self.patch_embedding(pixel_values)
222
  batch_size, _, height, width = patch_embeds.shape
223
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
224
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
 
225
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
226
+ position_embedding = torch.cat(
227
+ [
228
+ self.position_embedding[:, :1, :],
229
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
230
+ ],
231
+ dim=1,
232
+ )
233
  embeddings = embeddings + position_embedding.to(target_dtype)
234
  return embeddings
235
 
 
245
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
246
  if config.use_flash_attn and not has_flash_attn:
247
  print(
248
+ "Warning: Flash Attention is not available, use_flash_attn is set to False."
249
+ )
250
  self.head_dim = self.embed_dim // self.num_heads
251
  if self.head_dim * self.num_heads != self.embed_dim:
252
  raise ValueError(
253
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
254
+ f" {self.num_heads})."
 
255
  )
256
 
257
+ self.scale = self.head_dim**-0.5
258
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
 
259
  self.attn_drop = nn.Dropout(config.attention_dropout)
260
  self.proj_drop = nn.Dropout(config.dropout)
261
 
262
  self.qk_normalization = config.qk_normalization
263
 
264
  if self.qk_normalization:
265
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
266
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
 
 
267
 
268
  if self.use_flash_attn:
269
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
 
270
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
271
 
272
  def _naive_attn(self, x):
273
  B, N, C = x.shape
274
+ qkv = (
275
+ self.qkv(x)
276
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
277
+ .permute(2, 0, 3, 1, 4)
278
+ )
279
  # make torchscript happy (cannot use tensor as tuple)
280
  q, k, v = qkv.unbind(0)
281
 
282
  if self.qk_normalization:
283
  B_, H_, N_, D_ = q.shape
284
+ q = (
285
+ self.q_norm(q.transpose(1, 2).flatten(-2, -1))
286
+ .view(B_, N_, H_, D_)
287
+ .transpose(1, 2)
288
+ )
289
+ k = (
290
+ self.k_norm(k.transpose(1, 2).flatten(-2, -1))
291
+ .view(B_, N_, H_, D_)
292
+ .transpose(1, 2)
293
+ )
294
 
295
+ attn = (q * self.scale) @ k.transpose(-2, -1)
296
  attn = attn.softmax(dim=-1)
297
  attn = self.attn_drop(attn)
298
 
 
303
 
304
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
305
  qkv = self.qkv(x)
306
+ qkv = rearrange(
307
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
308
+ )
309
 
310
  if self.qk_normalization:
311
  q, k, v = qkv.unbind(2)
 
314
  qkv = torch.stack([q, k, v], dim=2)
315
 
316
  context, _ = self.inner_attn(
317
+ qkv,
318
+ key_padding_mask=key_padding_mask,
319
+ need_weights=need_weights,
320
+ causal=False,
321
  )
322
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
323
  outs = self.proj_drop(outs)
324
  return outs
325
 
326
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
327
+ x = (
328
+ self._naive_attn(hidden_states)
329
+ if not self.use_flash_attn
330
+ else self._flash_attn(hidden_states)
331
+ )
332
  return x
333
 
334
 
 
356
 
357
  self.attn = InternAttention(config)
358
  self.mlp = InternMLP(config)
359
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
360
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
361
+
362
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
363
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
364
+ self.drop_path1 = (
365
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
366
+ )
367
+ self.drop_path2 = (
368
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
369
+ )
 
 
370
 
371
  def forward(
372
+ self,
373
+ hidden_states: torch.Tensor,
374
+ ) -> Tuple[
375
+ torch.FloatTensor,
376
+ Optional[torch.FloatTensor],
377
+ Optional[Tuple[torch.FloatTensor]],
378
+ ]:
379
  """
380
  Args:
381
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
382
  """
383
+ hidden_states = hidden_states + self.drop_path1(
384
+ self.attn(self.norm1(hidden_states)) * self.ls1
385
+ )
386
 
387
+ hidden_states = hidden_states + self.drop_path2(
388
+ self.mlp(self.norm2(hidden_states)) * self.ls2
389
+ )
390
 
391
  return hidden_states
392
 
 
405
  super().__init__()
406
  self.config = config
407
  # stochastic depth decay rule
408
+ dpr = [
409
+ x.item()
410
+ for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
411
+ ]
412
+ self.layers = nn.ModuleList(
413
+ [
414
+ InternVisionEncoderLayer(config, dpr[idx])
415
+ for idx in range(config.num_hidden_layers)
416
+ ]
417
+ )
418
  self.gradient_checkpointing = True
419
 
420
  def forward(
421
+ self,
422
+ inputs_embeds,
423
+ output_hidden_states: Optional[bool] = None,
424
+ return_dict: Optional[bool] = None,
425
  ) -> Union[Tuple, BaseModelOutput]:
426
  r"""
427
  Args:
 
434
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
435
  """
436
  output_hidden_states = (
437
+ output_hidden_states
438
+ if output_hidden_states is not None
439
+ else self.config.output_hidden_states
440
+ )
441
+ return_dict = (
442
+ return_dict if return_dict is not None else self.config.use_return_dict
443
  )
 
444
 
445
  encoder_states = () if output_hidden_states else None
446
  hidden_states = inputs_embeds
 
450
  encoder_states = encoder_states + (hidden_states,)
451
  if self.gradient_checkpointing and self.training:
452
  layer_outputs = torch.utils.checkpoint.checkpoint(
453
+ encoder_layer, hidden_states
454
+ )
455
  else:
456
  layer_outputs = encoder_layer(
457
  hidden_states,
 
469
 
470
 
471
  class InternVisionModel(PreTrainedModel):
472
+ main_input_name = "pixel_values"
473
  config_class = InternVisionConfig
474
+ _no_split_modules = ["InternVisionEncoderLayer"]
475
 
476
  def __init__(self, config: InternVisionConfig):
477
  super().__init__(config)
 
484
  pos_emb = self.embeddings.position_embedding
485
  _, num_positions, embed_dim = pos_emb.shape
486
  cls_emb = pos_emb[:, :1, :]
487
+ pos_emb = (
488
+ pos_emb[:, 1:, :]
489
+ .reshape(1, old_size // patch_size, old_size // patch_size, -1)
490
+ .permute(0, 3, 1, 2)
491
+ )
492
+ pos_emb = F.interpolate(
493
+ pos_emb.float(),
494
+ size=new_size // patch_size,
495
+ mode="bicubic",
496
+ align_corners=False,
497
+ )
498
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
499
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
500
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
501
  self.embeddings.image_size = new_size
502
+ logger.info(
503
+ "Resized position embeddings from {} to {}".format(old_size, new_size)
504
+ )
505
 
506
  def get_input_embeddings(self):
507
  return self.embeddings
508
 
509
  def forward(
510
+ self,
511
+ pixel_values: Optional[torch.FloatTensor] = None,
512
+ output_hidden_states: Optional[bool] = None,
513
+ return_dict: Optional[bool] = None,
514
+ pixel_embeds: Optional[torch.FloatTensor] = None,
515
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
516
  output_hidden_states = (
517
+ output_hidden_states
518
+ if output_hidden_states is not None
519
+ else self.config.output_hidden_states
520
+ )
521
+ return_dict = (
522
+ return_dict if return_dict is not None else self.config.use_return_dict
523
  )
 
524
 
525
  if pixel_values is None and pixel_embeds is None:
526
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
 
527
 
528
  if pixel_embeds is not None:
529
  hidden_states = pixel_embeds
 
531
  if len(pixel_values.shape) == 4:
532
  hidden_states = self.embeddings(pixel_values)
533
  else:
534
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
 
535
  encoder_outputs = self.encoder(
536
  inputs_embeds=hidden_states,
537
  output_hidden_states=output_hidden_states,