radna commited on
Commit
68c0edf
1 Parent(s): 456c716

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +54 -126
modeling_intern_vit.py CHANGED
@@ -12,18 +12,18 @@ from einops import rearrange
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_intern_vit import InternVisionConfig
20
 
21
  try:
22
- from .triton_flash_attn import _attention
23
-
24
  has_flash_attn = True
25
  except:
26
- print("attention is not installed.")
27
  has_flash_attn = False
28
 
29
 
@@ -49,16 +49,12 @@ try:
49
 
50
  InternRMSNorm = FusedRMSNorm # noqa
51
 
52
- logger.info(
53
- "Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm"
54
- )
55
  except ImportError:
56
  # using the normal InternRMSNorm
57
  pass
58
  except Exception:
59
- logger.warning(
60
- "discovered apex but it failed to load, falling back to InternRMSNorm"
61
- )
62
  pass
63
 
64
 
@@ -75,25 +71,18 @@ class InternVisionEmbeddings(nn.Module):
75
  )
76
 
77
  self.patch_embedding = nn.Conv2d(
78
- in_channels=3,
79
- out_channels=self.embed_dim,
80
- kernel_size=self.patch_size,
81
- stride=self.patch_size,
82
  )
83
 
84
  self.num_patches = (self.image_size // self.patch_size) ** 2
85
  self.num_positions = self.num_patches + 1
86
 
87
- self.position_embedding = nn.Parameter(
88
- torch.randn(1, self.num_positions, self.embed_dim)
89
- )
90
 
91
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
92
  batch_size = pixel_values.shape[0]
93
  target_dtype = self.patch_embedding.weight.dtype
94
- patch_embeds = self.patch_embedding(
95
- pixel_values
96
- ) # shape = [*, width, grid, grid]
97
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
98
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
99
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
@@ -111,17 +100,15 @@ class InternAttention(nn.Module):
111
  self.num_heads = config.num_attention_heads
112
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
113
  if config.use_flash_attn and not has_flash_attn:
114
- print(
115
- "Warning: Flash Attention is not available, use_flash_attn is set to False."
116
- )
117
  self.head_dim = self.embed_dim // self.num_heads
118
  if self.head_dim * self.num_heads != self.embed_dim:
119
  raise ValueError(
120
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
121
- f" {self.num_heads})."
122
  )
123
 
124
- self.scale = self.head_dim**-0.5
125
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
126
  self.attn_drop = nn.Dropout(config.attention_dropout)
127
  self.proj_drop = nn.Dropout(config.dropout)
@@ -133,32 +120,20 @@ class InternAttention(nn.Module):
133
  self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
134
 
135
  if self.use_flash_attn:
136
- self.inner_attn = _attention.apply(attention_dropout=config.attention_dropout)
137
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
138
 
139
  def _naive_attn(self, x):
140
  B, N, C = x.shape
141
- qkv = (
142
- self.qkv(x)
143
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
144
- .permute(2, 0, 3, 1, 4)
145
- )
146
  q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
147
 
148
  if self.qk_normalization:
149
  B_, H_, N_, D_ = q.shape
150
- q = (
151
- self.q_norm(q.transpose(1, 2).flatten(-2, -1))
152
- .view(B_, N_, H_, D_)
153
- .transpose(1, 2)
154
- )
155
- k = (
156
- self.k_norm(k.transpose(1, 2).flatten(-2, -1))
157
- .view(B_, N_, H_, D_)
158
- .transpose(1, 2)
159
- )
160
 
161
- attn = (q * self.scale) @ k.transpose(-2, -1)
162
  attn = attn.softmax(dim=-1)
163
  attn = self.attn_drop(attn)
164
 
@@ -169,9 +144,7 @@ class InternAttention(nn.Module):
169
 
170
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
171
  qkv = self.qkv(x)
172
- qkv = rearrange(
173
- qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
174
- )
175
 
176
  if self.qk_normalization:
177
  q, k, v = qkv.unbind(2)
@@ -180,21 +153,14 @@ class InternAttention(nn.Module):
180
  qkv = torch.stack([q, k, v], dim=2)
181
 
182
  context, _ = self.inner_attn(
183
- qkv,
184
- key_padding_mask=key_padding_mask,
185
- need_weights=need_weights,
186
- causal=False,
187
  )
188
- outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
189
  outs = self.proj_drop(outs)
190
  return outs
191
 
192
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
193
- x = (
194
- self._naive_attn(hidden_states)
195
- if not self.use_flash_attn
196
- else self._flash_attn(hidden_states)
197
- )
198
  return x
199
 
200
 
@@ -226,32 +192,20 @@ class InternVisionEncoderLayer(nn.Module):
226
 
227
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
228
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
229
- self.drop_path1 = (
230
- DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
231
- )
232
- self.drop_path2 = (
233
- DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
234
- )
235
 
236
  def forward(
237
- self,
238
- hidden_states: torch.Tensor,
239
- ) -> Tuple[
240
- torch.FloatTensor,
241
- Optional[torch.FloatTensor],
242
- Optional[Tuple[torch.FloatTensor]],
243
- ]:
244
  """
245
  Args:
246
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
247
  """
248
- hidden_states = hidden_states + self.drop_path1(
249
- self.attn(self.norm1(hidden_states)) * self.ls1
250
- )
251
 
252
- hidden_states = hidden_states + self.drop_path2(
253
- self.mlp(self.norm2(hidden_states)) * self.ls2
254
- )
255
 
256
  return hidden_states
257
 
@@ -270,23 +224,16 @@ class InternVisionEncoder(nn.Module):
270
  super().__init__()
271
  self.config = config
272
  # stochastic depth decay rule
273
- dpr = [
274
- x.item()
275
- for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
276
- ]
277
- self.layers = nn.ModuleList(
278
- [
279
- InternVisionEncoderLayer(config, dpr[idx])
280
- for idx in range(config.num_hidden_layers)
281
- ]
282
- )
283
  self.gradient_checkpointing = True
284
 
285
  def forward(
286
- self,
287
- inputs_embeds,
288
- output_hidden_states: Optional[bool] = None,
289
- return_dict: Optional[bool] = None,
290
  ) -> Union[Tuple, BaseModelOutput]:
291
  r"""
292
  Args:
@@ -299,13 +246,9 @@ class InternVisionEncoder(nn.Module):
299
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
300
  """
301
  output_hidden_states = (
302
- output_hidden_states
303
- if output_hidden_states is not None
304
- else self.config.output_hidden_states
305
- )
306
- return_dict = (
307
- return_dict if return_dict is not None else self.config.use_return_dict
308
  )
 
309
 
310
  encoder_states = () if output_hidden_states else None
311
  hidden_states = inputs_embeds
@@ -315,8 +258,8 @@ class InternVisionEncoder(nn.Module):
315
  encoder_states = encoder_states + (hidden_states,)
316
  if self.gradient_checkpointing and self.training:
317
  layer_outputs = torch.utils.checkpoint.checkpoint(
318
- encoder_layer, hidden_states
319
- )
320
  else:
321
  layer_outputs = encoder_layer(
322
  hidden_states,
@@ -334,9 +277,9 @@ class InternVisionEncoder(nn.Module):
334
 
335
 
336
  class InternVisionModel(PreTrainedModel):
337
- main_input_name = "pixel_values"
338
  config_class = InternVisionConfig
339
- _no_split_modules = ["InternVisionEncoderLayer"]
340
 
341
  def __init__(self, config: InternVisionConfig):
342
  super().__init__(config)
@@ -349,45 +292,30 @@ class InternVisionModel(PreTrainedModel):
349
  pos_emb = self.embeddings.position_embedding
350
  _, num_positions, embed_dim = pos_emb.shape
351
  cls_emb = pos_emb[:, :1, :]
352
- pos_emb = (
353
- pos_emb[:, 1:, :]
354
- .reshape(1, old_size // patch_size, old_size // patch_size, -1)
355
- .permute(0, 3, 1, 2)
356
- )
357
- pos_emb = F.interpolate(
358
- pos_emb.float(),
359
- size=new_size // patch_size,
360
- mode="bicubic",
361
- align_corners=False,
362
- )
363
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
364
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
365
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
366
- logger.info(
367
- "Resized position embeddings from {} to {}".format(old_size, new_size)
368
- )
369
 
370
  def get_input_embeddings(self):
371
  return self.embeddings
372
 
373
  def forward(
374
- self,
375
- pixel_values: Optional[torch.FloatTensor] = None,
376
- output_hidden_states: Optional[bool] = None,
377
- return_dict: Optional[bool] = None,
378
- pixel_embeds: Optional[torch.FloatTensor] = None,
379
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
380
  output_hidden_states = (
381
- output_hidden_states
382
- if output_hidden_states is not None
383
- else self.config.output_hidden_states
384
- )
385
- return_dict = (
386
- return_dict if return_dict is not None else self.config.use_return_dict
387
  )
 
388
 
389
  if pixel_values is None and pixel_embeds is None:
390
- raise ValueError("You have to specify pixel_values or pixel_embeds")
391
 
392
  if pixel_embeds is not None:
393
  hidden_states = pixel_embeds
@@ -395,7 +323,7 @@ class InternVisionModel(PreTrainedModel):
395
  if len(pixel_values.shape) == 4:
396
  hidden_states = self.embeddings(pixel_values)
397
  else:
398
- raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
399
  encoder_outputs = self.encoder(
400
  inputs_embeds=hidden_states,
401
  output_hidden_states=output_hidden_states,
 
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import (BaseModelOutput,
16
+ BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
22
  try:
23
+ from .flash_attention import FlashAttention
 
24
  has_flash_attn = True
25
  except:
26
+ print('FlashAttention is not installed.')
27
  has_flash_attn = False
28
 
29
 
 
49
 
50
  InternRMSNorm = FusedRMSNorm # noqa
51
 
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
 
 
53
  except ImportError:
54
  # using the normal InternRMSNorm
55
  pass
56
  except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
 
 
58
  pass
59
 
60
 
 
71
  )
72
 
73
  self.patch_embedding = nn.Conv2d(
74
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
 
 
 
75
  )
76
 
77
  self.num_patches = (self.image_size // self.patch_size) ** 2
78
  self.num_positions = self.num_patches + 1
79
 
80
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
 
 
81
 
82
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
83
  batch_size = pixel_values.shape[0]
84
  target_dtype = self.patch_embedding.weight.dtype
85
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
 
 
86
  patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
87
  class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
88
  embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
 
100
  self.num_heads = config.num_attention_heads
101
  self.use_flash_attn = config.use_flash_attn and has_flash_attn
102
  if config.use_flash_attn and not has_flash_attn:
103
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
 
 
104
  self.head_dim = self.embed_dim // self.num_heads
105
  if self.head_dim * self.num_heads != self.embed_dim:
106
  raise ValueError(
107
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
108
+ f' {self.num_heads}).'
109
  )
110
 
111
+ self.scale = self.head_dim ** -0.5
112
  self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
113
  self.attn_drop = nn.Dropout(config.attention_dropout)
114
  self.proj_drop = nn.Dropout(config.dropout)
 
120
  self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
121
 
122
  if self.use_flash_attn:
123
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
124
  self.proj = nn.Linear(self.embed_dim, self.embed_dim)
125
 
126
  def _naive_attn(self, x):
127
  B, N, C = x.shape
128
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
129
  q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
130
 
131
  if self.qk_normalization:
132
  B_, H_, N_, D_ = q.shape
133
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
134
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
 
 
 
 
 
 
 
 
135
 
136
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
137
  attn = attn.softmax(dim=-1)
138
  attn = self.attn_drop(attn)
139
 
 
144
 
145
  def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
146
  qkv = self.qkv(x)
147
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
 
 
148
 
149
  if self.qk_normalization:
150
  q, k, v = qkv.unbind(2)
 
153
  qkv = torch.stack([q, k, v], dim=2)
154
 
155
  context, _ = self.inner_attn(
156
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
 
 
 
157
  )
158
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
159
  outs = self.proj_drop(outs)
160
  return outs
161
 
162
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
 
 
 
 
164
  return x
165
 
166
 
 
192
 
193
  self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
194
  self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
195
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
196
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
 
 
 
197
 
198
  def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
 
 
 
 
202
  """
203
  Args:
204
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
205
  """
206
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
 
 
207
 
208
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
 
209
 
210
  return hidden_states
211
 
 
224
  super().__init__()
225
  self.config = config
226
  # stochastic depth decay rule
227
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
228
+ self.layers = nn.ModuleList([
229
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
 
 
 
 
 
 
 
230
  self.gradient_checkpointing = True
231
 
232
  def forward(
233
+ self,
234
+ inputs_embeds,
235
+ output_hidden_states: Optional[bool] = None,
236
+ return_dict: Optional[bool] = None,
237
  ) -> Union[Tuple, BaseModelOutput]:
238
  r"""
239
  Args:
 
246
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
  """
248
  output_hidden_states = (
249
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
250
  )
251
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
252
 
253
  encoder_states = () if output_hidden_states else None
254
  hidden_states = inputs_embeds
 
258
  encoder_states = encoder_states + (hidden_states,)
259
  if self.gradient_checkpointing and self.training:
260
  layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ encoder_layer,
262
+ hidden_states)
263
  else:
264
  layer_outputs = encoder_layer(
265
  hidden_states,
 
277
 
278
 
279
  class InternVisionModel(PreTrainedModel):
280
+ main_input_name = 'pixel_values'
281
  config_class = InternVisionConfig
282
+ _no_split_modules = ['InternVisionEncoderLayer']
283
 
284
  def __init__(self, config: InternVisionConfig):
285
  super().__init__(config)
 
292
  pos_emb = self.embeddings.position_embedding
293
  _, num_positions, embed_dim = pos_emb.shape
294
  cls_emb = pos_emb[:, :1, :]
295
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
296
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
 
 
 
 
 
 
 
 
 
297
  pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
298
  pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
299
  self.embeddings.position_embedding = nn.Parameter(pos_emb)
300
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
 
 
301
 
302
  def get_input_embeddings(self):
303
  return self.embeddings
304
 
305
  def forward(
306
+ self,
307
+ pixel_values: Optional[torch.FloatTensor] = None,
308
+ output_hidden_states: Optional[bool] = None,
309
+ return_dict: Optional[bool] = None,
310
+ pixel_embeds: Optional[torch.FloatTensor] = None,
311
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
312
  output_hidden_states = (
313
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
314
  )
315
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
316
 
317
  if pixel_values is None and pixel_embeds is None:
318
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
319
 
320
  if pixel_embeds is not None:
321
  hidden_states = pixel_embeds
 
323
  if len(pixel_values.shape) == 4:
324
  hidden_states = self.embeddings(pixel_values)
325
  else:
326
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
327
  encoder_outputs = self.encoder(
328
  inputs_embeds=hidden_states,
329
  output_hidden_states=output_hidden_states,