gent commited on
Commit
ed01e82
·
1 Parent(s): 18ecb91

Add model files

Browse files
Files changed (2) hide show
  1. configuration_sit.py +132 -0
  2. modeling_sit.py +921 -0
configuration_sit.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ ViT SiT model configuration"""
16
+
17
+ from transformers import PretrainedConfig
18
+ from transformers import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ VIT_SiT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "erow/vit-SiT-base": "https://huggingface.co/erow/SiT/resolve/main/config.json",
25
+ }
26
+
27
+
28
+ class ViTSiTConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`ViTSiTModel`]. It is used to instantiate an ViT
31
+ SiT model according to the specified arguments, defining the model architecture. Instantiating a configuration with
32
+ the defaults will yield a similar configuration to that of the ViT
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
51
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
53
+ The dropout ratio for the attention probabilities.
54
+ initializer_range (`float`, *optional*, defaults to 0.02):
55
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
56
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
57
+ The epsilon used by the layer normalization layers.
58
+ image_size (`int`, *optional*, defaults to 224):
59
+ The size (resolution) of each image.
60
+ patch_size (`int`, *optional*, defaults to 16):
61
+ The size (resolution) of each patch.
62
+ num_channels (`int`, *optional*, defaults to 3):
63
+ The number of input channels.
64
+ qkv_bias (`bool`, *optional*, defaults to `True`):
65
+ Whether to add a bias to the queries, keys and values.
66
+ decoder_num_attention_heads (`int`, *optional*, defaults to 16):
67
+ Number of attention heads for each attention layer in the decoder.
68
+ decoder_hidden_size (`int`, *optional*, defaults to 512):
69
+ Dimensionality of the decoder.
70
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
71
+ Number of hidden layers in the decoder.
72
+ decoder_intermediate_size (`int`, *optional*, defaults to 2048):
73
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
74
+ mask_ratio (`float`, *optional*, defaults to 0.75):
75
+ The ratio of the number of masked tokens in the input sequence.
76
+ norm_pix_loss (`bool`, *optional*, defaults to `False`):
77
+ Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
78
+ representation quality in the experiments of the authors.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import ViTSiTConfig, ViTSiTModel
84
+
85
+ >>> # Initializing a ViT SiT vit-SiT-base style configuration
86
+ >>> configuration = ViTSiTConfig()
87
+
88
+ >>> # Initializing a model (with random weights) from the vit-SiT-base style configuration
89
+ >>> model = ViTSiTModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "vit_sit"
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_size=768,
100
+ out_dim = 256,
101
+ num_hidden_layers=12,
102
+ num_attention_heads=12,
103
+ intermediate_size=3072,
104
+ hidden_act="gelu",
105
+ hidden_dropout_prob=0.0,
106
+ attention_probs_dropout_prob=0.0,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-12,
109
+ image_size=224,
110
+ patch_size=16,
111
+ num_channels=3,
112
+ qkv_bias=True,
113
+ mask_ratio=0.75,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ self.hidden_size = hidden_size
119
+ self.out_dim = out_dim
120
+ self.num_hidden_layers = num_hidden_layers
121
+ self.num_attention_heads = num_attention_heads
122
+ self.intermediate_size = intermediate_size
123
+ self.hidden_act = hidden_act
124
+ self.hidden_dropout_prob = hidden_dropout_prob
125
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
126
+ self.initializer_range = initializer_range
127
+ self.layer_norm_eps = layer_norm_eps
128
+ self.image_size = image_size
129
+ self.patch_size = patch_size
130
+ self.num_channels = num_channels
131
+ self.qkv_bias = qkv_bias
132
+ self.mask_ratio = mask_ratio
modeling_sit.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch ViT SiT (masked autoencoder) model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from copy import deepcopy
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Set, Tuple, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import BaseModelOutput
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
33
+ from transformers.utils import (
34
+ ModelOutput,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from configuration_sit import ViTSiTConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CONFIG_FOR_DOC = "ViTSiTConfig"
46
+ _CHECKPOINT_FOR_DOC = "erow/vit-sit-base"
47
+
48
+ VIT_SiT_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "erow/vit-sit-base",
50
+ # See all ViTSiT models at https://huggingface.co/models?filter=vit_sit
51
+ ]
52
+
53
+
54
+ @dataclass
55
+ class ViTSiTModelOutput(ModelOutput):
56
+ """
57
+ Class for ViTSiTModel's outputs, with potential hidden states and attentions.
58
+
59
+ Args:
60
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
61
+ Sequence of hidden-states at the output of the last layer of the model.
62
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
63
+ Tensor indicating which patches are masked (1) and which are not (0).
64
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
65
+ Tensor containing the original index of the (shuffled) masked patches.
66
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
67
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
68
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
69
+ plus the initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
73
+ the self-attention heads.
74
+ """
75
+
76
+ last_hidden_state: torch.FloatTensor = None
77
+ noise: torch.LongTensor = None
78
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
79
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
80
+
81
+
82
+
83
+ @dataclass
84
+ class ViTSiTForPreTrainingOutput(ModelOutput):
85
+ """
86
+ Class for ViTSiTForPreTraining's outputs, with potential hidden states and attentions.
87
+
88
+ Args:
89
+ loss (`torch.FloatTensor` of shape `(1,)`):
90
+ Pixel reconstruction loss.
91
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
92
+ Pixel reconstruction logits.
93
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
94
+ Tensor indicating which patches are masked (1) and which are not (0).
95
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
96
+ Tensor containing the original index of the (shuffled) masked patches.
97
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
98
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
99
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
100
+ plus the initial embedding outputs.
101
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
102
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
103
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
104
+ the self-attention heads.
105
+ """
106
+
107
+ loss: Optional[torch.FloatTensor] = None
108
+ logits: torch.FloatTensor = None
109
+ noise: torch.LongTensor = None
110
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
112
+
113
+
114
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
115
+ """
116
+ Create 2D sin/cos positional embeddings.
117
+
118
+ Args:
119
+ embed_dim (`int`):
120
+ Embedding dimension.
121
+ grid_size (`int`):
122
+ The grid height and width.
123
+ add_cls_token (`bool`, *optional*, defaults to `False`):
124
+ Whether or not to add a classification (CLS) token.
125
+
126
+ Returns:
127
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
128
+ position embeddings (with or without classification token)
129
+ """
130
+ grid_h = np.arange(grid_size, dtype=np.float32)
131
+ grid_w = np.arange(grid_size, dtype=np.float32)
132
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
133
+ grid = np.stack(grid, axis=0)
134
+
135
+ grid = grid.reshape([2, 1, grid_size, grid_size])
136
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
137
+ if add_cls_token:
138
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
139
+ return pos_embed
140
+
141
+
142
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
143
+ if embed_dim % 2 != 0:
144
+ raise ValueError("embed_dim must be even")
145
+
146
+ # use half of dimensions to encode grid_h
147
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
148
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
149
+
150
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
151
+ return emb
152
+
153
+
154
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
155
+ """
156
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
157
+ """
158
+ if embed_dim % 2 != 0:
159
+ raise ValueError("embed_dim must be even")
160
+
161
+ omega = np.arange(embed_dim // 2, dtype=float)
162
+ omega /= embed_dim / 2.0
163
+ omega = 1.0 / 10000**omega # (D/2,)
164
+
165
+ pos = pos.reshape(-1) # (M,)
166
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
167
+
168
+ emb_sin = np.sin(out) # (M, D/2)
169
+ emb_cos = np.cos(out) # (M, D/2)
170
+
171
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
172
+ return emb
173
+
174
+
175
+ class ViTSiTEmbeddings(nn.Module):
176
+ """
177
+ Construct the CLS token, position and patch embeddings.
178
+
179
+ """
180
+
181
+ def __init__(self, config):
182
+ super().__init__()
183
+
184
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
185
+ self.patch_embeddings = ViTSiTPatchEmbeddings(config)
186
+ self.num_patches = self.patch_embeddings.num_patches
187
+ # fixed sin-cos embedding
188
+ self.position_embeddings = nn.Parameter(
189
+ torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
190
+ )
191
+ self.config = config
192
+ self.initialize_weights()
193
+
194
+ def initialize_weights(self):
195
+ # initialize (and freeze) position embeddings by sin-cos embedding
196
+ pos_embed = get_2d_sincos_pos_embed(
197
+ self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
198
+ )
199
+ self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
200
+
201
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
202
+ w = self.patch_embeddings.projection.weight.data
203
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
204
+
205
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
206
+ torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
207
+
208
+
209
+ def forward(self, pixel_values, noise=None):
210
+ batch_size, num_channels, height, width = pixel_values.shape
211
+ embeddings = self.patch_embeddings(pixel_values)
212
+
213
+ # add position embeddings w/o cls token
214
+ embeddings = embeddings + self.position_embeddings[:, 1:, :]
215
+
216
+ # append cls token
217
+ cls_token = self.cls_token + self.position_embeddings[:, :1, :]
218
+ cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
219
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
220
+
221
+ return embeddings
222
+
223
+ class ViTSiTPatchEmbeddings(nn.Module):
224
+ """
225
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
226
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
227
+ Transformer.
228
+ """
229
+
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ image_size, patch_size = config.image_size, config.patch_size
233
+ num_channels, hidden_size = config.num_channels, config.hidden_size
234
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
235
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
236
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
237
+ self.image_size = image_size
238
+ self.patch_size = patch_size
239
+ self.num_channels = num_channels
240
+ self.num_patches = num_patches
241
+
242
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
243
+
244
+ def forward(self, pixel_values):
245
+ batch_size, num_channels, height, width = pixel_values.shape
246
+ if num_channels != self.num_channels:
247
+ raise ValueError(
248
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
249
+ )
250
+ if height != self.image_size[0] or width != self.image_size[1]:
251
+ raise ValueError(
252
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
253
+ )
254
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
255
+ return x
256
+
257
+
258
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTSiT
259
+ class ViTSiTSelfAttention(nn.Module):
260
+ def __init__(self, config: ViTSiTConfig) -> None:
261
+ super().__init__()
262
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
263
+ raise ValueError(
264
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
265
+ f"heads {config.num_attention_heads}."
266
+ )
267
+
268
+ self.num_attention_heads = config.num_attention_heads
269
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
270
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
271
+
272
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
273
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
274
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
275
+
276
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
277
+
278
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
279
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
280
+ x = x.view(new_x_shape)
281
+ return x.permute(0, 2, 1, 3)
282
+
283
+ def forward(
284
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
285
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
286
+ mixed_query_layer = self.query(hidden_states)
287
+
288
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
289
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
290
+ query_layer = self.transpose_for_scores(mixed_query_layer)
291
+
292
+ # Take the dot product between "query" and "key" to get the raw attention scores.
293
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
294
+
295
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
296
+
297
+ # Normalize the attention scores to probabilities.
298
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
299
+
300
+ # This is actually dropping out entire tokens to attend to, which might
301
+ # seem a bit unusual, but is taken from the original Transformer paper.
302
+ attention_probs = self.dropout(attention_probs)
303
+
304
+ # Mask heads if we want to
305
+ if head_mask is not None:
306
+ attention_probs = attention_probs * head_mask
307
+
308
+ context_layer = torch.matmul(attention_probs, value_layer)
309
+
310
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
311
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
312
+ context_layer = context_layer.view(new_context_layer_shape)
313
+
314
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
315
+
316
+ return outputs
317
+
318
+
319
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTSiT
320
+ class ViTSiTSelfOutput(nn.Module):
321
+ """
322
+ The residual connection is defined in ViTSiTLayer instead of here (as is the case with other models), due to the
323
+ layernorm applied before each block.
324
+ """
325
+
326
+ def __init__(self, config: ViTSiTConfig) -> None:
327
+ super().__init__()
328
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
329
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
330
+
331
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.dropout(hidden_states)
334
+
335
+ return hidden_states
336
+
337
+
338
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTSiT
339
+ class ViTSiTAttention(nn.Module):
340
+ def __init__(self, config: ViTSiTConfig) -> None:
341
+ super().__init__()
342
+ self.attention = ViTSiTSelfAttention(config)
343
+ self.output = ViTSiTSelfOutput(config)
344
+ self.pruned_heads = set()
345
+
346
+ def prune_heads(self, heads: Set[int]) -> None:
347
+ if len(heads) == 0:
348
+ return
349
+ heads, index = find_pruneable_heads_and_indices(
350
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
351
+ )
352
+
353
+ # Prune linear layers
354
+ self.attention.query = prune_linear_layer(self.attention.query, index)
355
+ self.attention.key = prune_linear_layer(self.attention.key, index)
356
+ self.attention.value = prune_linear_layer(self.attention.value, index)
357
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
358
+
359
+ # Update hyper params and store pruned heads
360
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
361
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
375
+ return outputs
376
+
377
+
378
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTSiT
379
+ class ViTSiTIntermediate(nn.Module):
380
+ def __init__(self, config: ViTSiTConfig) -> None:
381
+ super().__init__()
382
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
383
+ if isinstance(config.hidden_act, str):
384
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
385
+ else:
386
+ self.intermediate_act_fn = config.hidden_act
387
+
388
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
389
+ hidden_states = self.dense(hidden_states)
390
+ hidden_states = self.intermediate_act_fn(hidden_states)
391
+
392
+ return hidden_states
393
+
394
+
395
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTSiT
396
+ class ViTSiTOutput(nn.Module):
397
+ def __init__(self, config: ViTSiTConfig) -> None:
398
+ super().__init__()
399
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
400
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
401
+
402
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
403
+ hidden_states = self.dense(hidden_states)
404
+ hidden_states = self.dropout(hidden_states)
405
+
406
+ hidden_states = hidden_states + input_tensor
407
+
408
+ return hidden_states
409
+
410
+
411
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTSiT
412
+ class ViTSiTLayer(nn.Module):
413
+ """This corresponds to the Block class in the timm implementation."""
414
+
415
+ def __init__(self, config: ViTSiTConfig) -> None:
416
+ super().__init__()
417
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
418
+ self.seq_len_dim = 1
419
+ self.attention = ViTSiTAttention(config)
420
+ self.intermediate = ViTSiTIntermediate(config)
421
+ self.output = ViTSiTOutput(config)
422
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
423
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+
425
+ def forward(
426
+ self,
427
+ hidden_states: torch.Tensor,
428
+ head_mask: Optional[torch.Tensor] = None,
429
+ output_attentions: bool = False,
430
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
431
+ self_attention_outputs = self.attention(
432
+ self.layernorm_before(hidden_states), # in ViTSiT, layernorm is applied before self-attention
433
+ head_mask,
434
+ output_attentions=output_attentions,
435
+ )
436
+ attention_output = self_attention_outputs[0]
437
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
438
+
439
+ # first residual connection
440
+ hidden_states = attention_output + hidden_states
441
+
442
+ # in ViTSiT, layernorm is also applied after self-attention
443
+ layer_output = self.layernorm_after(hidden_states)
444
+ layer_output = self.intermediate(layer_output)
445
+
446
+ # second residual connection is done here
447
+ layer_output = self.output(layer_output, hidden_states)
448
+
449
+ outputs = (layer_output,) + outputs
450
+
451
+ return outputs
452
+
453
+
454
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTSiT
455
+ class ViTSiTEncoder(nn.Module):
456
+ def __init__(self, config: ViTSiTConfig) -> None:
457
+ super().__init__()
458
+ self.config = config
459
+ self.layer = nn.ModuleList([ViTSiTLayer(config) for _ in range(config.num_hidden_layers)])
460
+ self.gradient_checkpointing = False
461
+
462
+ def forward(
463
+ self,
464
+ hidden_states: torch.Tensor,
465
+ head_mask: Optional[torch.Tensor] = None,
466
+ output_attentions: bool = False,
467
+ output_hidden_states: bool = False,
468
+ return_dict: bool = True,
469
+ ) -> Union[tuple, BaseModelOutput]:
470
+ all_hidden_states = () if output_hidden_states else None
471
+ all_self_attentions = () if output_attentions else None
472
+
473
+ for i, layer_module in enumerate(self.layer):
474
+ if output_hidden_states:
475
+ all_hidden_states = all_hidden_states + (hidden_states,)
476
+
477
+ layer_head_mask = head_mask[i] if head_mask is not None else None
478
+
479
+ if self.gradient_checkpointing and self.training:
480
+ layer_outputs = self._gradient_checkpointing_func(
481
+ layer_module.__call__,
482
+ hidden_states,
483
+ layer_head_mask,
484
+ output_attentions,
485
+ )
486
+ else:
487
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
488
+
489
+ hidden_states = layer_outputs[0]
490
+
491
+ if output_attentions:
492
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
493
+
494
+ if output_hidden_states:
495
+ all_hidden_states = all_hidden_states + (hidden_states,)
496
+
497
+ if not return_dict:
498
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
499
+ return BaseModelOutput(
500
+ last_hidden_state=hidden_states,
501
+ hidden_states=all_hidden_states,
502
+ attentions=all_self_attentions,
503
+ )
504
+
505
+
506
+ class ViTSiTPreTrainedModel(PreTrainedModel):
507
+ """
508
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
509
+ models.
510
+ """
511
+
512
+ config_class = ViTSiTConfig
513
+ base_model_prefix = "vit"
514
+ main_input_name = "pixel_values"
515
+ supports_gradient_checkpointing = True
516
+
517
+ def _init_weights(self, module):
518
+ """Initialize the weights"""
519
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
520
+ # Slightly different from the TF version which uses truncated_normal for initialization
521
+ # cf https://github.com/pytorch/pytorch/pull/5617
522
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
523
+ if module.bias is not None:
524
+ module.bias.data.zero_()
525
+ elif isinstance(module, nn.LayerNorm):
526
+ module.bias.data.zero_()
527
+ module.weight.data.fill_(1.0)
528
+
529
+
530
+ VIT_SiT_START_DOCSTRING = r"""
531
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
532
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
533
+ behavior.
534
+
535
+ Parameters:
536
+ config ([`ViTSiTConfig`]): Model configuration class with all the parameters of the model.
537
+ Initializing with a config file does not load the weights associated with the model, only the
538
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
539
+ """
540
+
541
+ VIT_SiT_INPUTS_DOCSTRING = r"""
542
+ Args:
543
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
544
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
545
+ for details.
546
+
547
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
548
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
549
+
550
+ - 1 indicates the head is **not masked**,
551
+ - 0 indicates the head is **masked**.
552
+
553
+ output_attentions (`bool`, *optional*):
554
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
555
+ tensors for more detail.
556
+ output_hidden_states (`bool`, *optional*):
557
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
558
+ more detail.
559
+ return_dict (`bool`, *optional*):
560
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
561
+ """
562
+
563
+
564
+ @add_start_docstrings(
565
+ "The bare ViTSiT Model transformer outputting raw hidden-states without any specific head on top.",
566
+ VIT_SiT_START_DOCSTRING,
567
+ )
568
+ class ViTSiTModel(ViTSiTPreTrainedModel):
569
+ def __init__(self, config):
570
+ super().__init__(config)
571
+ self.config = config
572
+
573
+ self.embeddings = ViTSiTEmbeddings(config)
574
+ self.encoder = ViTSiTEncoder(config)
575
+
576
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
577
+
578
+ # Initialize weights and apply final processing
579
+ self.post_init()
580
+
581
+ def get_input_embeddings(self):
582
+ return self.embeddings.patch_embeddings
583
+
584
+ def _prune_heads(self, heads_to_prune):
585
+ """
586
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
587
+ class PreTrainedModel
588
+ """
589
+ for layer, heads in heads_to_prune.items():
590
+ self.encoder.layer[layer].attention.prune_heads(heads)
591
+
592
+ @add_start_docstrings_to_model_forward(VIT_SiT_INPUTS_DOCSTRING)
593
+ @replace_return_docstrings(output_type=ViTSiTModelOutput, config_class=_CONFIG_FOR_DOC)
594
+ def forward(
595
+ self,
596
+ pixel_values: Optional[torch.FloatTensor] = None,
597
+ noise: Optional[torch.FloatTensor] = None,
598
+ head_mask: Optional[torch.FloatTensor] = None,
599
+ output_attentions: Optional[bool] = None,
600
+ output_hidden_states: Optional[bool] = None,
601
+ return_dict: Optional[bool] = None,
602
+ ) -> Union[Tuple, ViTSiTModelOutput]:
603
+ r"""
604
+ Returns:
605
+
606
+ Examples:
607
+
608
+ ```python
609
+ >>> from transformers import AutoImageProcessor, ViTSiTModel
610
+ >>> from PIL import Image
611
+ >>> import requests
612
+
613
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
614
+ >>> image = Image.open(requests.get(url, stream=True).raw)
615
+
616
+ >>> image_processor = AutoImageProcessor.from_pretrained("erow/vit-sit-base")
617
+ >>> model = ViTSiTModel.from_pretrained("erow/vit-sit-base")
618
+
619
+ >>> inputs = image_processor(images=image, return_tensors="pt")
620
+ >>> outputs = model(**inputs)
621
+ >>> last_hidden_states = outputs.last_hidden_state
622
+ ```"""
623
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
624
+ output_hidden_states = (
625
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
626
+ )
627
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
628
+
629
+ if pixel_values is None:
630
+ raise ValueError("You have to specify pixel_values")
631
+
632
+ # Prepare head mask if needed
633
+ # 1.0 in head_mask indicate we keep the head
634
+ # attention_probs has shape bsz x n_heads x N x N
635
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
636
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
637
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
638
+
639
+ embedding_output = self.embeddings(pixel_values, noise=noise)
640
+
641
+ encoder_outputs = self.encoder(
642
+ embedding_output,
643
+ head_mask=head_mask,
644
+ output_attentions=output_attentions,
645
+ output_hidden_states=output_hidden_states,
646
+ return_dict=return_dict,
647
+ )
648
+ sequence_output = encoder_outputs[0]
649
+ sequence_output = self.layernorm(sequence_output)
650
+
651
+ if not return_dict:
652
+ return (sequence_output, ) + encoder_outputs[1:]
653
+
654
+ return ViTSiTModelOutput(
655
+ last_hidden_state=sequence_output,
656
+ hidden_states=encoder_outputs.hidden_states,
657
+ attentions=encoder_outputs.attentions,
658
+ )
659
+
660
+
661
+ class CLSHead(nn.Module):
662
+ def __init__(self, in_dim, bottleneck_dim, nlayers=3, hidden_dim=4096):
663
+ super().__init__()
664
+ nlayers = max(nlayers, 1)
665
+ if nlayers == 1:
666
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
667
+ else:
668
+ layers = [nn.Linear(in_dim, hidden_dim)]
669
+ layers.append(nn.BatchNorm1d(hidden_dim))
670
+ layers.append(nn.ReLU(inplace=True))
671
+ for _ in range(nlayers - 2):
672
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
673
+ layers.append(nn.BatchNorm1d(hidden_dim))
674
+ layers.append(nn.GELU())
675
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
676
+ layers.append(nn.BatchNorm1d(bottleneck_dim, affine=False))
677
+
678
+ self.mlp = nn.Sequential(*layers)
679
+ self.apply(self._init_weights)
680
+
681
+ def _init_weights(self, m):
682
+ if isinstance(m, nn.Linear):
683
+ nn.init.normal_(m.weight, std=.02)
684
+ if isinstance(m, nn.Linear) and m.bias is not None:
685
+ nn.init.constant_(m.bias, 0)
686
+
687
+ def forward(self, x):
688
+ x = self.mlp(x)
689
+ return x
690
+
691
+ class RECHead(nn.Module):
692
+ def __init__(self, in_dim, in_chans=3, patch_size=16):
693
+ super().__init__()
694
+
695
+ layers = [nn.Linear(in_dim, in_dim)]
696
+ layers.append(nn.GELU())
697
+ layers.append(nn.Linear(in_dim, in_dim))
698
+ layers.append(nn.GELU())
699
+ layers.append(nn.Linear(in_dim, in_dim))
700
+ layers.append(nn.GELU())
701
+
702
+ self.mlp = nn.Sequential(*layers)
703
+ self.apply(self._init_weights)
704
+
705
+ self.convTrans = nn.ConvTranspose2d(in_dim, in_chans, kernel_size=(patch_size, patch_size),
706
+ stride=(patch_size, patch_size))
707
+
708
+
709
+ def _init_weights(self, m):
710
+ if isinstance(m, nn.Linear):
711
+ torch.nn.init.normal_(m.weight, std=.02)
712
+ if isinstance(m, nn.Linear) and m.bias is not None:
713
+ nn.init.constant_(m.bias, 0)
714
+
715
+ def forward(self, x):
716
+ x = self.mlp(x)
717
+
718
+ x_rec = x.transpose(1, 2)
719
+ out_sz = tuple( ( int(math.sqrt(x_rec.size()[2])) , int(math.sqrt(x_rec.size()[2])) ) )
720
+ x_rec = self.convTrans(x_rec.unflatten(2, out_sz))
721
+ return x_rec
722
+
723
+
724
+ @add_start_docstrings(
725
+ """The ViTSiT Model transformer with the decoder on top for self-supervised pre-training.
726
+
727
+ <Tip>
728
+
729
+ Note that we provide a script to pre-train this model on custom data in our [examples
730
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
731
+
732
+ </Tip>
733
+
734
+ """,
735
+ VIT_SiT_START_DOCSTRING,
736
+ )
737
+ class ViTSiTForPreTraining(ViTSiTPreTrainedModel):
738
+ def __init__(self, config):
739
+ super().__init__(config)
740
+ self.config = config
741
+
742
+ self.vit = ViTSiTModel(config)
743
+ self.head_recons = RECHead(config.hidden_size, config.num_channels, config.patch_size)
744
+ self.head = CLSHead(config.hidden_size, config.out_dim)
745
+
746
+ # Initialize weights and apply final processing
747
+ self.post_init()
748
+
749
+ def get_input_embeddings(self):
750
+ return self.vit.embeddings.patch_embeddings
751
+
752
+ def _prune_heads(self, heads_to_prune):
753
+ """
754
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
755
+ class PreTrainedModel
756
+ """
757
+ for layer, heads in heads_to_prune.items():
758
+ self.encoder.layer[layer].attention.prune_heads(heads)
759
+
760
+ def patchify(self, pixel_values):
761
+ """
762
+ Args:
763
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
764
+ Pixel values.
765
+
766
+ Returns:
767
+ `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
768
+ Patchified pixel values.
769
+ """
770
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
771
+ # sanity checks
772
+ if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
773
+ raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
774
+ if pixel_values.shape[1] != num_channels:
775
+ raise ValueError(
776
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
777
+ )
778
+
779
+ # patchify
780
+ batch_size = pixel_values.shape[0]
781
+ num_patches_one_direction = pixel_values.shape[2] // patch_size
782
+ patchified_pixel_values = pixel_values.reshape(
783
+ batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
784
+ )
785
+ patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
786
+ patchified_pixel_values = patchified_pixel_values.reshape(
787
+ batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
788
+ )
789
+ return patchified_pixel_values
790
+
791
+ def unpatchify(self, patchified_pixel_values):
792
+ """
793
+ Args:
794
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
795
+ Patchified pixel values.
796
+
797
+ Returns:
798
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
799
+ Pixel values.
800
+ """
801
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
802
+ num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
803
+ # sanity check
804
+ if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
805
+ raise ValueError("Make sure that the number of patches can be squared")
806
+
807
+ # unpatchify
808
+ batch_size = patchified_pixel_values.shape[0]
809
+ patchified_pixel_values = patchified_pixel_values.reshape(
810
+ batch_size,
811
+ num_patches_one_direction,
812
+ num_patches_one_direction,
813
+ patch_size,
814
+ patch_size,
815
+ num_channels,
816
+ )
817
+ patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
818
+ pixel_values = patchified_pixel_values.reshape(
819
+ batch_size,
820
+ num_channels,
821
+ num_patches_one_direction * patch_size,
822
+ num_patches_one_direction * patch_size,
823
+ )
824
+ return pixel_values
825
+
826
+ def forward_loss(self, pixel_values, pred, mask):
827
+ """
828
+ Args:
829
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
830
+ Pixel values.
831
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
832
+ Predicted pixel values.
833
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
834
+ Tensor indicating which patches are masked (1) and which are not (0).
835
+
836
+ Returns:
837
+ `torch.FloatTensor`: Pixel reconstruction loss.
838
+ """
839
+ target = pixel_values
840
+ pred = self.unpatchify(pred)
841
+
842
+ loss = (pred - target) ** 2
843
+
844
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
845
+ return loss
846
+
847
+ @add_start_docstrings_to_model_forward(VIT_SiT_INPUTS_DOCSTRING)
848
+ @replace_return_docstrings(output_type=ViTSiTForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
849
+ def forward(
850
+ self,
851
+ pixel_values: Optional[torch.FloatTensor] = None,
852
+ noise: Optional[torch.FloatTensor] = None,
853
+ head_mask: Optional[torch.FloatTensor] = None,
854
+ output_attentions: Optional[bool] = None,
855
+ output_hidden_states: Optional[bool] = None,
856
+ return_dict: Optional[bool] = None,
857
+ ) -> Union[Tuple, ViTSiTForPreTrainingOutput]:
858
+ r"""
859
+ Returns:
860
+
861
+ Examples:
862
+
863
+ ```python
864
+ >>> from transformers import AutoImageProcessor, ViTSiTForPreTraining
865
+ >>> from PIL import Image
866
+ >>> import requests
867
+
868
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
869
+ >>> image = Image.open(requests.get(url, stream=True).raw)
870
+
871
+ >>> image_processor = AutoImageProcessor.from_pretrained("erow/vit-sit-base")
872
+ >>> model = ViTSiTForPreTraining.from_pretrained("erow/vit-sit-base")
873
+
874
+ >>> inputs = image_processor(images=image, return_tensors="pt")
875
+ >>> outputs = model(**inputs)
876
+ >>> loss = outputs.loss
877
+ >>> mask = outputs.mask
878
+ >>> ids_restore = outputs.ids_restore
879
+ ```"""
880
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
881
+
882
+ outputs = self.vit(
883
+ pixel_values,
884
+ head_mask=head_mask,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ )
889
+
890
+ latent = outputs.last_hidden_state
891
+
892
+ logits = self.decoder(latent) # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
893
+
894
+ loss = self.forward_loss(pixel_values, logits, noise)
895
+
896
+ if not return_dict:
897
+ output = (logits, ) + outputs[2:]
898
+ return ((loss,) + output) if loss is not None else output
899
+
900
+ return ViTSiTForPreTrainingOutput(
901
+ loss=loss,
902
+ logits=logits,
903
+ noise=noise,
904
+ hidden_states=outputs.hidden_states,
905
+ attentions=outputs.attentions,
906
+ )
907
+
908
+
909
+ if __name__=="__main__":
910
+
911
+ # Initializing a ViT MAE vit-mae-base style configuration
912
+ configuration = ViTSiTConfig()
913
+
914
+ # Initializing a model (with random weights) from the vit-mae-base style configuration
915
+ model = ViTSiTModel(configuration)
916
+
917
+ # Accessing the model configuration
918
+ configuration = model.config
919
+
920
+ x = torch.randn(1, 3, 224, 224)
921
+ output = model(x)