radna commited on
Commit
4af4b92
1 Parent(s): a28e379

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +267 -59
modeling_intern_vit.py CHANGED
@@ -12,24 +12,125 @@ from einops import rearrange
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import (BaseModelOutput,
16
- BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
 
22
  try:
23
- from .flash_attention import FlashAttention
 
 
 
24
  has_flash_attn = True
25
  except:
26
- print('FlashAttention is not installed.')
27
  has_flash_attn = False
28
 
29
-
30
  logger = logging.get_logger(__name__)
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class InternRMSNorm(nn.Module):
34
  def __init__(self, hidden_size, eps=1e-6):
35
  super().__init__()
@@ -49,15 +150,25 @@ try:
49
 
50
  InternRMSNorm = FusedRMSNorm # noqa
51
 
52
- logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
 
 
53
  except ImportError:
54
  # using the normal InternRMSNorm
55
  pass
56
  except Exception:
57
- logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
 
 
58
  pass
59
 
60
 
 
 
 
 
 
 
61
  class InternVisionEmbeddings(nn.Module):
62
  def __init__(self, config: InternVisionConfig):
63
  super().__init__()
@@ -71,22 +182,55 @@ class InternVisionEmbeddings(nn.Module):
71
  )
72
 
73
  self.patch_embedding = nn.Conv2d(
74
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
 
 
 
75
  )
76
 
77
  self.num_patches = (self.image_size // self.patch_size) ** 2
78
  self.num_positions = self.num_patches + 1
79
 
80
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
- batch_size = pixel_values.shape[0]
84
  target_dtype = self.patch_embedding.weight.dtype
85
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
 
 
86
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
89
- embeddings = embeddings + self.position_embedding.to(target_dtype)
 
 
 
 
 
 
 
90
  return embeddings
91
 
92
 
@@ -100,15 +244,17 @@ class InternAttention(nn.Module):
100
  self.num_heads = config.num_attention_heads
101
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
  if config.use_flash_attn and not has_flash_attn:
103
- print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
 
 
104
  self.head_dim = self.embed_dim // self.num_heads
105
  if self.head_dim * self.num_heads != self.embed_dim:
106
  raise ValueError(
107
- f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
- f' {self.num_heads}).'
109
  )
110
 
111
- self.scale = self.head_dim ** -0.5
112
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
  self.attn_drop = nn.Dropout(config.attention_dropout)
114
  self.proj_drop = nn.Dropout(config.dropout)
@@ -125,15 +271,28 @@ class InternAttention(nn.Module):
125
 
126
  def _naive_attn(self, x):
127
  B, N, C = x.shape
128
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
129
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
 
 
 
 
 
130
 
131
  if self.qk_normalization:
132
  B_, H_, N_, D_ = q.shape
133
- q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
- k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 
 
 
 
 
 
 
 
135
 
136
- attn = ((q * self.scale) @ k.transpose(-2, -1))
137
  attn = attn.softmax(dim=-1)
138
  attn = self.attn_drop(attn)
139
 
@@ -144,7 +303,9 @@ class InternAttention(nn.Module):
144
 
145
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
  qkv = self.qkv(x)
147
- qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
 
 
148
 
149
  if self.qk_normalization:
150
  q, k, v = qkv.unbind(2)
@@ -153,14 +314,21 @@ class InternAttention(nn.Module):
153
  qkv = torch.stack([q, k, v], dim=2)
154
 
155
  context, _ = self.inner_attn(
156
- qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
 
 
 
157
  )
158
- outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
  outs = self.proj_drop(outs)
160
  return outs
161
 
162
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
- x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
 
 
 
164
  return x
165
 
166
 
@@ -184,28 +352,41 @@ class InternVisionEncoderLayer(nn.Module):
184
  super().__init__()
185
  self.embed_dim = config.hidden_size
186
  self.intermediate_size = config.intermediate_size
 
187
 
188
  self.attn = InternAttention(config)
189
  self.mlp = InternMLP(config)
190
- self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
191
- self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
192
 
193
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
- self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
- self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
 
 
 
197
 
198
  def forward(
199
- self,
200
- hidden_states: torch.Tensor,
201
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
 
 
 
 
202
  """
203
  Args:
204
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
  """
206
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
 
 
207
 
208
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
 
209
 
210
  return hidden_states
211
 
@@ -224,16 +405,23 @@ class InternVisionEncoder(nn.Module):
224
  super().__init__()
225
  self.config = config
226
  # stochastic depth decay rule
227
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
- self.layers = nn.ModuleList([
229
- InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
 
 
 
 
 
 
 
230
  self.gradient_checkpointing = True
231
 
232
  def forward(
233
- self,
234
- inputs_embeds,
235
- output_hidden_states: Optional[bool] = None,
236
- return_dict: Optional[bool] = None,
237
  ) -> Union[Tuple, BaseModelOutput]:
238
  r"""
239
  Args:
@@ -246,9 +434,13 @@ class InternVisionEncoder(nn.Module):
246
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
  """
248
  output_hidden_states = (
249
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
250
  )
251
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
 
253
  encoder_states = () if output_hidden_states else None
254
  hidden_states = inputs_embeds
@@ -258,8 +450,8 @@ class InternVisionEncoder(nn.Module):
258
  encoder_states = encoder_states + (hidden_states,)
259
  if self.gradient_checkpointing and self.training:
260
  layer_outputs = torch.utils.checkpoint.checkpoint(
261
- encoder_layer,
262
- hidden_states)
263
  else:
264
  layer_outputs = encoder_layer(
265
  hidden_states,
@@ -277,9 +469,9 @@ class InternVisionEncoder(nn.Module):
277
 
278
 
279
  class InternVisionModel(PreTrainedModel):
280
- main_input_name = 'pixel_values'
281
  config_class = InternVisionConfig
282
- _no_split_modules = ['InternVisionEncoderLayer']
283
 
284
  def __init__(self, config: InternVisionConfig):
285
  super().__init__(config)
@@ -292,30 +484,46 @@ class InternVisionModel(PreTrainedModel):
292
  pos_emb = self.embeddings.position_embedding
293
  _, num_positions, embed_dim = pos_emb.shape
294
  cls_emb = pos_emb[:, :1, :]
295
- pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
296
- pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
 
 
 
 
 
 
 
 
 
297
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
298
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
299
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
300
- logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
 
 
 
301
 
302
  def get_input_embeddings(self):
303
  return self.embeddings
304
 
305
  def forward(
306
- self,
307
- pixel_values: Optional[torch.FloatTensor] = None,
308
- output_hidden_states: Optional[bool] = None,
309
- return_dict: Optional[bool] = None,
310
- pixel_embeds: Optional[torch.FloatTensor] = None,
311
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
312
  output_hidden_states = (
313
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
314
  )
315
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
316
 
317
  if pixel_values is None and pixel_embeds is None:
318
- raise ValueError('You have to specify pixel_values or pixel_embeds')
319
 
320
  if pixel_embeds is not None:
321
  hidden_states = pixel_embeds
@@ -323,7 +531,7 @@ class InternVisionModel(PreTrainedModel):
323
  if len(pixel_values.shape) == 4:
324
  hidden_states = self.embeddings(pixel_values)
325
  else:
326
- raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
327
  encoder_outputs = self.encoder(
328
  inputs_embeds=hidden_states,
329
  output_hidden_states=output_hidden_states,
 
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_intern_vit import InternVisionConfig
20
 
21
+
22
  try:
23
+ from triton_flash_atn import _attention
24
+
25
+ from triton_bert_pading import pad_input, unpad_input
26
+
27
  has_flash_attn = True
28
  except:
29
+ print("FlashAttention is not installed.")
30
  has_flash_attn = False
31
 
 
32
  logger = logging.get_logger(__name__)
33
 
34
 
35
+ class FlashAttention(nn.Module):
36
+ """Implement the scaled dot product attention with softmax.
37
+ Arguments
38
+ ---------
39
+ softmax_scale: The temperature to use for the softmax attention.
40
+ (default: 1/sqrt(d_keys) where d_keys is computed at
41
+ runtime)
42
+ attention_dropout: The dropout rate to apply to the attention
43
+ (default: 0.0)
44
+ """
45
+
46
+ def __init__(
47
+ self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
48
+ ):
49
+ super().__init__()
50
+ self.softmax_scale = softmax_scale
51
+ self.dropout_p = attention_dropout
52
+
53
+ def forward(
54
+ self,
55
+ qkv,
56
+ key_padding_mask=None,
57
+ causal=False,
58
+ cu_seqlens=None,
59
+ max_s=None,
60
+ need_weights=False,
61
+ ):
62
+ """Implements the multihead softmax attention.
63
+ Arguments
64
+ ---------
65
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
66
+ if unpadded: (nnz, 3, h, d)
67
+ key_padding_mask: a bool tensor of shape (B, S)
68
+ """
69
+ assert not need_weights
70
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
71
+ assert qkv.is_cuda
72
+
73
+ if cu_seqlens is None:
74
+ batch_size = qkv.shape[0]
75
+ seqlen = qkv.shape[1]
76
+ if key_padding_mask is None:
77
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
78
+ max_s = seqlen
79
+ cu_seqlens = torch.arange(
80
+ 0,
81
+ (batch_size + 1) * seqlen,
82
+ step=seqlen,
83
+ dtype=torch.int32,
84
+ device=qkv.device,
85
+ )
86
+ output = _attention.apply(
87
+ qkv,
88
+ cu_seqlens,
89
+ max_s,
90
+ self.dropout_p if self.training else 0.0,
91
+ sm_scale=self.softmax_scale,
92
+ causal=causal,
93
+ )
94
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
95
+ else:
96
+ nheads = qkv.shape[-2]
97
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
98
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
99
+ x_unpad = rearrange(
100
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
101
+ )
102
+ output_unpad = _attention.apply(
103
+ x_unpad,
104
+ cu_seqlens,
105
+ max_s,
106
+ self.dropout_p if self.training else 0.0,
107
+ sm_scale=self.softmax_scale,
108
+ causal=causal,
109
+ )
110
+ output = rearrange(
111
+ pad_input(
112
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
113
+ indices,
114
+ batch_size,
115
+ seqlen,
116
+ ),
117
+ "b s (h d) -> b s h d",
118
+ h=nheads,
119
+ )
120
+ else:
121
+ assert max_s is not None
122
+ output = _attention.apply(
123
+ qkv,
124
+ cu_seqlens,
125
+ max_s,
126
+ self.dropout_p if self.training else 0.0,
127
+ sm_scale=self.softmax_scale,
128
+ causal=causal,
129
+ )
130
+
131
+ return output, None
132
+
133
+
134
  class InternRMSNorm(nn.Module):
135
  def __init__(self, hidden_size, eps=1e-6):
136
  super().__init__()
 
150
 
151
  InternRMSNorm = FusedRMSNorm # noqa
152
 
153
+ logger.info(
154
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm"
155
+ )
156
  except ImportError:
157
  # using the normal InternRMSNorm
158
  pass
159
  except Exception:
160
+ logger.warning(
161
+ "discovered apex but it failed to load, falling back to InternRMSNorm"
162
+ )
163
  pass
164
 
165
 
166
+ NORM2FN = {
167
+ "rms_norm": InternRMSNorm,
168
+ "layer_norm": nn.LayerNorm,
169
+ }
170
+
171
+
172
  class InternVisionEmbeddings(nn.Module):
173
  def __init__(self, config: InternVisionConfig):
174
  super().__init__()
 
182
  )
183
 
184
  self.patch_embedding = nn.Conv2d(
185
+ in_channels=3,
186
+ out_channels=self.embed_dim,
187
+ kernel_size=self.patch_size,
188
+ stride=self.patch_size,
189
  )
190
 
191
  self.num_patches = (self.image_size // self.patch_size) ** 2
192
  self.num_positions = self.num_patches + 1
193
 
194
+ self.position_embedding = nn.Parameter(
195
+ torch.randn(1, self.num_positions, self.embed_dim)
196
+ )
197
+
198
+ def _get_pos_embed(self, pos_embed, H, W):
199
+ target_dtype = pos_embed.dtype
200
+ pos_embed = (
201
+ pos_embed.float()
202
+ .reshape(
203
+ 1,
204
+ self.image_size // self.patch_size,
205
+ self.image_size // self.patch_size,
206
+ -1,
207
+ )
208
+ .permute(0, 3, 1, 2)
209
+ )
210
+ pos_embed = (
211
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
212
+ .reshape(1, -1, H * W)
213
+ .permute(0, 2, 1)
214
+ .to(target_dtype)
215
+ )
216
+ return pos_embed
217
 
218
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
 
219
  target_dtype = self.patch_embedding.weight.dtype
220
+ # shape = [*, channel, width, height]
221
+ patch_embeds = self.patch_embedding(pixel_values)
222
+ batch_size, _, height, width = patch_embeds.shape
223
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
224
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
225
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
226
+ position_embedding = torch.cat(
227
+ [
228
+ self.position_embedding[:, :1, :],
229
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
230
+ ],
231
+ dim=1,
232
+ )
233
+ embeddings = embeddings + position_embedding.to(target_dtype)
234
  return embeddings
235
 
236
 
 
244
  self.num_heads = config.num_attention_heads
245
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
246
  if config.use_flash_attn and not has_flash_attn:
247
+ print(
248
+ "Warning: Flash Attention is not available, use_flash_attn is set to False."
249
+ )
250
  self.head_dim = self.embed_dim // self.num_heads
251
  if self.head_dim * self.num_heads != self.embed_dim:
252
  raise ValueError(
253
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
254
+ f" {self.num_heads})."
255
  )
256
 
257
+ self.scale = self.head_dim**-0.5
258
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
259
  self.attn_drop = nn.Dropout(config.attention_dropout)
260
  self.proj_drop = nn.Dropout(config.dropout)
 
271
 
272
  def _naive_attn(self, x):
273
  B, N, C = x.shape
274
+ qkv = (
275
+ self.qkv(x)
276
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
277
+ .permute(2, 0, 3, 1, 4)
278
+ )
279
+ # make torchscript happy (cannot use tensor as tuple)
280
+ q, k, v = qkv.unbind(0)
281
 
282
  if self.qk_normalization:
283
  B_, H_, N_, D_ = q.shape
284
+ q = (
285
+ self.q_norm(q.transpose(1, 2).flatten(-2, -1))
286
+ .view(B_, N_, H_, D_)
287
+ .transpose(1, 2)
288
+ )
289
+ k = (
290
+ self.k_norm(k.transpose(1, 2).flatten(-2, -1))
291
+ .view(B_, N_, H_, D_)
292
+ .transpose(1, 2)
293
+ )
294
 
295
+ attn = (q * self.scale) @ k.transpose(-2, -1)
296
  attn = attn.softmax(dim=-1)
297
  attn = self.attn_drop(attn)
298
 
 
303
 
304
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
305
  qkv = self.qkv(x)
306
+ qkv = rearrange(
307
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
308
+ )
309
 
310
  if self.qk_normalization:
311
  q, k, v = qkv.unbind(2)
 
314
  qkv = torch.stack([q, k, v], dim=2)
315
 
316
  context, _ = self.inner_attn(
317
+ qkv,
318
+ key_padding_mask=key_padding_mask,
319
+ need_weights=need_weights,
320
+ causal=False,
321
  )
322
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
323
  outs = self.proj_drop(outs)
324
  return outs
325
 
326
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
327
+ x = (
328
+ self._naive_attn(hidden_states)
329
+ if not self.use_flash_attn
330
+ else self._flash_attn(hidden_states)
331
+ )
332
  return x
333
 
334
 
 
352
  super().__init__()
353
  self.embed_dim = config.hidden_size
354
  self.intermediate_size = config.intermediate_size
355
+ self.norm_type = config.norm_type
356
 
357
  self.attn = InternAttention(config)
358
  self.mlp = InternMLP(config)
359
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
360
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
361
 
362
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
363
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
364
+ self.drop_path1 = (
365
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
366
+ )
367
+ self.drop_path2 = (
368
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
369
+ )
370
 
371
  def forward(
372
+ self,
373
+ hidden_states: torch.Tensor,
374
+ ) -> Tuple[
375
+ torch.FloatTensor,
376
+ Optional[torch.FloatTensor],
377
+ Optional[Tuple[torch.FloatTensor]],
378
+ ]:
379
  """
380
  Args:
381
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
382
  """
383
+ hidden_states = hidden_states + self.drop_path1(
384
+ self.attn(self.norm1(hidden_states)) * self.ls1
385
+ )
386
 
387
+ hidden_states = hidden_states + self.drop_path2(
388
+ self.mlp(self.norm2(hidden_states)) * self.ls2
389
+ )
390
 
391
  return hidden_states
392
 
 
405
  super().__init__()
406
  self.config = config
407
  # stochastic depth decay rule
408
+ dpr = [
409
+ x.item()
410
+ for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
411
+ ]
412
+ self.layers = nn.ModuleList(
413
+ [
414
+ InternVisionEncoderLayer(config, dpr[idx])
415
+ for idx in range(config.num_hidden_layers)
416
+ ]
417
+ )
418
  self.gradient_checkpointing = True
419
 
420
  def forward(
421
+ self,
422
+ inputs_embeds,
423
+ output_hidden_states: Optional[bool] = None,
424
+ return_dict: Optional[bool] = None,
425
  ) -> Union[Tuple, BaseModelOutput]:
426
  r"""
427
  Args:
 
434
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
435
  """
436
  output_hidden_states = (
437
+ output_hidden_states
438
+ if output_hidden_states is not None
439
+ else self.config.output_hidden_states
440
+ )
441
+ return_dict = (
442
+ return_dict if return_dict is not None else self.config.use_return_dict
443
  )
 
444
 
445
  encoder_states = () if output_hidden_states else None
446
  hidden_states = inputs_embeds
 
450
  encoder_states = encoder_states + (hidden_states,)
451
  if self.gradient_checkpointing and self.training:
452
  layer_outputs = torch.utils.checkpoint.checkpoint(
453
+ encoder_layer, hidden_states
454
+ )
455
  else:
456
  layer_outputs = encoder_layer(
457
  hidden_states,
 
469
 
470
 
471
  class InternVisionModel(PreTrainedModel):
472
+ main_input_name = "pixel_values"
473
  config_class = InternVisionConfig
474
+ _no_split_modules = ["InternVisionEncoderLayer"]
475
 
476
  def __init__(self, config: InternVisionConfig):
477
  super().__init__(config)
 
484
  pos_emb = self.embeddings.position_embedding
485
  _, num_positions, embed_dim = pos_emb.shape
486
  cls_emb = pos_emb[:, :1, :]
487
+ pos_emb = (
488
+ pos_emb[:, 1:, :]
489
+ .reshape(1, old_size // patch_size, old_size // patch_size, -1)
490
+ .permute(0, 3, 1, 2)
491
+ )
492
+ pos_emb = F.interpolate(
493
+ pos_emb.float(),
494
+ size=new_size // patch_size,
495
+ mode="bicubic",
496
+ align_corners=False,
497
+ )
498
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
499
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
500
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
501
+ self.embeddings.image_size = new_size
502
+ logger.info(
503
+ "Resized position embeddings from {} to {}".format(old_size, new_size)
504
+ )
505
 
506
  def get_input_embeddings(self):
507
  return self.embeddings
508
 
509
  def forward(
510
+ self,
511
+ pixel_values: Optional[torch.FloatTensor] = None,
512
+ output_hidden_states: Optional[bool] = None,
513
+ return_dict: Optional[bool] = None,
514
+ pixel_embeds: Optional[torch.FloatTensor] = None,
515
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
516
  output_hidden_states = (
517
+ output_hidden_states
518
+ if output_hidden_states is not None
519
+ else self.config.output_hidden_states
520
+ )
521
+ return_dict = (
522
+ return_dict if return_dict is not None else self.config.use_return_dict
523
  )
 
524
 
525
  if pixel_values is None and pixel_embeds is None:
526
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
527
 
528
  if pixel_embeds is not None:
529
  hidden_states = pixel_embeds
 
531
  if len(pixel_values.shape) == 4:
532
  hidden_states = self.embeddings(pixel_values)
533
  else:
534
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
535
  encoder_outputs = self.encoder(
536
  inputs_embeds=hidden_states,
537
  output_hidden_states=output_hidden_states,