dveranieto commited on
Commit
036a350
1 Parent(s): 39bb931

Added usage helpers

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. model.py +109 -0
  3. test.py +23 -0
  4. vision_transformer.py +824 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .models/
2
+ __pycache__
model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * SPDX-License-Identifier: Apache-2.0
2
+ # * © 2023 ETH Zurich and other contributors, see AUTHORS.txt for details
3
+
4
+ import os
5
+ from typing import List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from fastapi import HTTPException
10
+ from mtc_api_utils.base_model import MLBaseModel
11
+ from mtc_api_utils.init_api import download_if_not_exists
12
+ from torchvision import transforms
13
+
14
+ from aesthetics_model.api_types import AestheticsModelResponse
15
+ from aesthetics_model.config import AestheticsModelConfig
16
+ from aesthetics_model.model.vision_transformer import vit_large_patch16_224_in21k
17
+
18
+
19
+ class AestheticModel(MLBaseModel):
20
+ model = None
21
+ transform = None
22
+ device = None
23
+ model_checkpoint_path = None
24
+
25
+ def init_model(self):
26
+ print("Initializing Aesthetic Model...")
27
+ # Download image archive
28
+ self.model_checkpoint_path = download_if_not_exists(
29
+ artifact_url=os.path.join(AestheticsModelConfig.model_base_url, AestheticsModelConfig.model_url),
30
+ download_dir=AestheticsModelConfig.model_checkpoint_path,
31
+ polybox_auth=AestheticsModelConfig.polybox_credentials,
32
+ is_tar=False,
33
+ )
34
+ print("Model downloaded, loading to device..")
35
+
36
+ self.model, self.transform, self.device = self.load_model()
37
+ print("Model initialized")
38
+
39
+ def is_ready(self):
40
+ return self.model is not None
41
+
42
+ def inference(self, images: List[Image.Image]) -> List[AestheticsModelResponse]:
43
+ if not self.is_ready():
44
+ raise HTTPException(status_code=503, detail="Model is not ready")
45
+
46
+ results = []
47
+ for image in images:
48
+ im = self.validate_and_resize_image(image, max_size=AestheticsModelConfig.max_image_size)
49
+ image_tensor = self.transform(im).unsqueeze(0).to(self.device)
50
+ with torch.no_grad():
51
+ image_embedding = self.model.forward_features(image_tensor)
52
+ score = self.model.head(image_embedding).squeeze().to("cpu")
53
+ result: AestheticsModelResponse = AestheticsModelResponse(
54
+ aesthetics_score=float(score),
55
+ aesthetics_embedding=image_embedding.to("cpu").tolist(),
56
+ # TODO: success is going to be always true
57
+ success=True
58
+ )
59
+ results.append(result)
60
+ return results
61
+
62
+ def load_model(self):
63
+ # Check if a gpu and cuda is available
64
+ use_cuda = torch.cuda.is_available()
65
+ device = torch.device("cuda:0" if use_cuda else "cpu")
66
+ print(f"Running on {device}")
67
+
68
+ model = vit_large_patch16_224_in21k()
69
+ model.reset_classifier(num_classes=1)
70
+ model.load_state_dict(torch.load(self.model_checkpoint_path, map_location=device))
71
+
72
+ model.eval()
73
+ model.to(device)
74
+
75
+ transform = transforms.Compose(
76
+ [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
77
+ )
78
+
79
+ return model, transform, device
80
+
81
+ def validate_and_resize_image(self, image: Image.Image, max_size: int):
82
+ """Check if an image is to big, if it is too big resize it. Transform to RGB if B&W image.
83
+
84
+ Args:
85
+ img : PIL Image
86
+ max_size : int
87
+ """
88
+ if not image.mode == "RGB":
89
+ image = image.convert(mode="RGB")
90
+
91
+ size_factor = (max_size + 0.0) / max(image.size)
92
+ if size_factor < 1:
93
+ new_h = round(image.size[0] * size_factor)
94
+ new_w = round(image.size[1] * size_factor)
95
+ image = image.resize((new_w, new_h), resample=Image.HAMMING)
96
+ return image
97
+
98
+
99
+ if __name__ == '__main__':
100
+ model = AestheticModel()
101
+ model.__wait_until_ready__()
102
+
103
+ example_images = ["../tests/test_image.png"] # model.get_examples()
104
+
105
+ # Perform inference with underlying model
106
+ images = [model.model(Image.open(image)) for image in example_images]
107
+
108
+ model_result = model.inference(images=images)
109
+ print(model_result)
test.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import AutoModel
2
+ from huggingface_hub import hf_hub_download
3
+ from vision_transformer import vit_large_patch16_224_in21k
4
+ import torch
5
+ import numpy as np
6
+
7
+ REPO_ID = "ethz-mtc/aesthetics_vit"
8
+ FILENAME="pytorch_model.bin"
9
+
10
+ path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
11
+ print(path)
12
+ REPO_ID = "ethz-mtc/shot_scale_classifier-resnet50"
13
+ path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir=".models")
14
+ print(path)
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
+ model = vit_large_patch16_224_in21k()
18
+ model.reset_classifier(num_classes=1)
19
+ model.load_state_dict(torch.load(path, map_location=device))
20
+
21
+ print(
22
+ f"Model has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
23
+ )
vision_transformer.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # * SPDX-License-Identifier: Apache-2.0
2
+ # * © 2023 ETH Zurich and other contributors, see AUTHORS.txt for details
3
+
4
+ # Based on timm library: git+https://github.com/rwightman/pytorch-image-models.git@95feb1da41c1fe95ce9634b83db343e08224a8c5
5
+ """ Vision Transformer (ViT) in PyTorch
6
+
7
+ A PyTorch implement of Vision Transformers as described in
8
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
9
+
10
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
11
+
12
+ Acknowledgments:
13
+ * The paper authors for releasing code and weights, thanks!
14
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
15
+ for some einops/einsum fun
16
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
17
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
18
+
19
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
20
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
21
+
22
+ Hacked together by / Copyright 2020 Ross Wightman
23
+ """
24
+ import math
25
+ import logging
26
+ from functools import partial
27
+ from collections import OrderedDict
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ from timm.models.layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
34
+ from timm.models.resnet import resnet26d, resnet50d
35
+ from timm.models.resnetv2 import ResNetV2
36
+ from timm.models.registry import register_model
37
+
38
+ _logger = logging.getLogger(__name__)
39
+
40
+
41
+ def _cfg(url='', **kwargs):
42
+ return {
43
+ 'url': url,
44
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
45
+ 'crop_pct': .9, 'interpolation': 'bicubic',
46
+ 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
47
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
48
+ **kwargs
49
+ }
50
+
51
+
52
+ default_cfgs = {
53
+ # patch models (my experiments)
54
+ 'vit_small_patch16_224': _cfg(
55
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
56
+ ),
57
+
58
+ # patch models (weights ported from official Google JAX impl)
59
+ 'vit_base_patch16_224': _cfg(
60
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
61
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
62
+ ),
63
+ 'vit_base_patch32_224': _cfg(
64
+ url='', # no official model weights for this combo, only for in21k
65
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
66
+ 'vit_base_patch16_384': _cfg(
67
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
68
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
69
+ 'vit_base_patch32_384': _cfg(
70
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
71
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
72
+ 'vit_large_patch16_224': _cfg(
73
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
74
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
75
+ 'vit_large_patch32_224': _cfg(
76
+ url='', # no official model weights for this combo, only for in21k
77
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
78
+ 'vit_large_patch16_384': _cfg(
79
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
80
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
81
+ 'vit_large_patch32_384': _cfg(
82
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
83
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
84
+
85
+ # patch models, imagenet21k (weights ported from official Google JAX impl)
86
+ 'vit_base_patch16_224_in21k': _cfg(
87
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
88
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
89
+ 'vit_base_patch32_224_in21k': _cfg(
90
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
91
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
92
+ 'vit_large_patch16_224_in21k': _cfg(
93
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
94
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
95
+ 'vit_large_patch32_224_in21k': _cfg(
96
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
97
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
98
+ 'vit_huge_patch14_224_in21k': _cfg(
99
+ url='', # FIXME I have weights for this but > 2GB limit for github release binaries
100
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
101
+
102
+ # hybrid models (weights ported from official Google JAX impl)
103
+ 'vit_base_resnet50_224_in21k': _cfg(
104
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
105
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
106
+ 'vit_base_resnet50_384': _cfg(
107
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
108
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
109
+
110
+ # hybrid models (my experiments)
111
+ 'vit_small_resnet26d_224': _cfg(),
112
+ 'vit_small_resnet50d_s3_224': _cfg(),
113
+ 'vit_base_resnet26d_224': _cfg(),
114
+ 'vit_base_resnet50d_224': _cfg(),
115
+
116
+ # deit models (FB weights)
117
+ 'vit_deit_tiny_patch16_224': _cfg(
118
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
119
+ 'vit_deit_small_patch16_224': _cfg(
120
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
121
+ 'vit_deit_base_patch16_224': _cfg(
122
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
123
+ 'vit_deit_base_patch16_384': _cfg(
124
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
125
+ input_size=(3, 384, 384), crop_pct=1.0),
126
+ 'vit_deit_tiny_distilled_patch16_224': _cfg(
127
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
128
+ 'vit_deit_small_distilled_patch16_224': _cfg(
129
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
130
+ 'vit_deit_base_distilled_patch16_224': _cfg(
131
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
132
+ 'vit_deit_base_distilled_patch16_384': _cfg(
133
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
134
+ input_size=(3, 384, 384), crop_pct=1.0),
135
+ }
136
+
137
+
138
+ class Mlp(nn.Module):
139
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
140
+ super().__init__()
141
+ out_features = out_features or in_features
142
+ hidden_features = hidden_features or in_features
143
+ self.fc1 = nn.Linear(in_features, hidden_features)
144
+ self.act = act_layer()
145
+ self.fc2 = nn.Linear(hidden_features, out_features)
146
+ self.drop = nn.Dropout(drop)
147
+
148
+ def forward(self, x):
149
+ x = self.fc1(x)
150
+ x = self.act(x)
151
+ x = self.drop(x)
152
+ x = self.fc2(x)
153
+ x = self.drop(x)
154
+ return x
155
+
156
+
157
+ class Attention(nn.Module):
158
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
159
+ super().__init__()
160
+ self.num_heads = num_heads
161
+ head_dim = dim // num_heads
162
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
163
+ self.scale = qk_scale or head_dim ** -0.5
164
+
165
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
166
+ self.attn_drop = nn.Dropout(attn_drop)
167
+ self.proj = nn.Linear(dim, dim)
168
+ self.proj_drop = nn.Dropout(proj_drop)
169
+
170
+ def forward(self, x):
171
+ B, N, C = x.shape
172
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
173
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
174
+
175
+ attn = (q @ k.transpose(-2, -1)) * self.scale
176
+ # TODO: this is where the masking of the pads should happen
177
+ #if mask is not None:
178
+ # attn = attn.masked_fill(mask == 0, -1e9) or float('-1e20')
179
+
180
+ attn = attn.softmax(dim=-1)
181
+ attn = self.attn_drop(attn)
182
+
183
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
184
+ x = self.proj(x)
185
+ x = self.proj_drop(x)
186
+ return x
187
+
188
+
189
+ class Block(nn.Module):
190
+
191
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
192
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
193
+ super().__init__()
194
+ self.norm1 = norm_layer(dim)
195
+ self.attn = Attention(
196
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
197
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
198
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
199
+ self.norm2 = norm_layer(dim)
200
+ mlp_hidden_dim = int(dim * mlp_ratio)
201
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
202
+
203
+ def forward(self, x):
204
+ x = x + self.drop_path(self.attn(self.norm1(x)))
205
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
206
+ return x
207
+
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ Image to Patch Embedding
211
+ """
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
217
+ self.img_size = img_size
218
+ self.patch_size = patch_size
219
+ self.num_patches = num_patches
220
+
221
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
222
+
223
+ def forward(self, x):
224
+ B, C, H, W = x.shape
225
+ # FIXME look at relaxing size constraints
226
+ #assert H == self.img_size[0] and W == self.img_size[1], \
227
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
228
+ x = self.proj(x).flatten(2).transpose(1, 2)
229
+ #print("Patch embedding output shape {} for input image {}".format(x.shape, [B,C,H,W]))
230
+ #print("Number of patches: {}".format(self.num_patches))
231
+ return x
232
+
233
+
234
+ class HybridEmbed(nn.Module):
235
+ """ CNN Feature Map Embedding
236
+ Extract feature map from CNN, flatten, project to embedding dim.
237
+ """
238
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
239
+ super().__init__()
240
+ assert isinstance(backbone, nn.Module)
241
+ img_size = to_2tuple(img_size)
242
+ self.img_size = img_size
243
+ self.backbone = backbone
244
+ if feature_size is None:
245
+ with torch.no_grad():
246
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
247
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
248
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
249
+ training = backbone.training
250
+ if training:
251
+ backbone.eval()
252
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
253
+ if isinstance(o, (list, tuple)):
254
+ o = o[-1] # last feature if backbone outputs list/tuple of features
255
+ feature_size = o.shape[-2:]
256
+ feature_dim = o.shape[1]
257
+ backbone.train(training)
258
+ else:
259
+ feature_size = to_2tuple(feature_size)
260
+ if hasattr(self.backbone, 'feature_info'):
261
+ feature_dim = self.backbone.feature_info.channels()[-1]
262
+ else:
263
+ feature_dim = self.backbone.num_features
264
+ self.num_patches = feature_size[0] * feature_size[1]
265
+ self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
266
+
267
+ def forward(self, x):
268
+ x = self.backbone(x)
269
+ if isinstance(x, (list, tuple)):
270
+ x = x[-1] # last feature if backbone outputs list/tuple of features
271
+ x = self.proj(x).flatten(2).transpose(1, 2)
272
+ return x
273
+
274
+
275
+ class VisionTransformer(nn.Module):
276
+ """ Vision Transformer
277
+
278
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
279
+ https://arxiv.org/abs/2010.11929
280
+ """
281
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
282
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
283
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, variable_input_len=True):
284
+ """
285
+ Args:
286
+ img_size (int, tuple): input image size
287
+ patch_size (int, tuple): patch size
288
+ in_chans (int): number of input channels
289
+ num_classes (int): number of classes for classification head
290
+ embed_dim (int): embedding dimension
291
+ depth (int): depth of transformer
292
+ num_heads (int): number of attention heads
293
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
294
+ qkv_bias (bool): enable bias for qkv if True
295
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
296
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
297
+ drop_rate (float): dropout rate
298
+ attn_drop_rate (float): attention dropout rate
299
+ drop_path_rate (float): stochastic depth rate
300
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
301
+ norm_layer: (nn.Module): normalization layer
302
+ """
303
+ super().__init__()
304
+ self.num_classes = num_classes
305
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
306
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
307
+
308
+ if hybrid_backbone is not None:
309
+ self.patch_embed = HybridEmbed(
310
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
311
+ else:
312
+ self.patch_embed = PatchEmbed(
313
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
314
+ num_patches = self.patch_embed.num_patches
315
+
316
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
317
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
318
+ self.pos_drop = nn.Dropout(p=drop_rate)
319
+
320
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
321
+ self.blocks = nn.ModuleList([
322
+ Block(
323
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
324
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
325
+ for i in range(depth)])
326
+ self.norm = norm_layer(embed_dim)
327
+
328
+ # Representation layer
329
+ if representation_size:
330
+ self.num_features = representation_size
331
+ self.pre_logits = nn.Sequential(OrderedDict([
332
+ ('fc', nn.Linear(embed_dim, representation_size)),
333
+ ('act', nn.Tanh())
334
+ ]))
335
+ else:
336
+ self.pre_logits = nn.Identity()
337
+
338
+ # Classifier head
339
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
340
+
341
+ trunc_normal_(self.pos_embed, std=.02)
342
+ trunc_normal_(self.cls_token, std=.02)
343
+ self.apply(self._init_weights)
344
+
345
+ self.variable_input_len = variable_input_len
346
+ self.patch_size = patch_size
347
+
348
+ def _init_weights(self, m):
349
+ if isinstance(m, nn.Linear):
350
+ trunc_normal_(m.weight, std=.02)
351
+ if isinstance(m, nn.Linear) and m.bias is not None:
352
+ nn.init.constant_(m.bias, 0)
353
+ elif isinstance(m, nn.LayerNorm):
354
+ nn.init.constant_(m.bias, 0)
355
+ nn.init.constant_(m.weight, 1.0)
356
+
357
+ @torch.jit.ignore
358
+ def no_weight_decay(self):
359
+ return {'pos_embed', 'cls_token'}
360
+
361
+ def get_classifier(self):
362
+ return self.head
363
+
364
+ def reset_classifier(self, num_classes, global_pool=''):
365
+ self.num_classes = num_classes
366
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
367
+
368
+ def forward_features(self, x) -> torch.Tensor:
369
+ #B = x.shape[0]
370
+ B, C, H, W = x.shape
371
+ x = self.patch_embed(x)
372
+
373
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
374
+ x = torch.cat((cls_tokens, x), dim=1)
375
+ #print("ViT shapes of x and pos_embed: {} | {}".format(x.shape, self.pos_embed.shape))
376
+ #for an image of size 710,541: ViT shapes of x and pos_embed: torch.Size([1, 1453, 768]) | torch.Size([1, 197, 768])
377
+ pos_embed = None
378
+ if self.variable_input_len:
379
+ patches_height = int(H/self.patch_size)
380
+ patches_width = int(W/self.patch_size)
381
+
382
+ pos_embed = resize_pos_embed(self.pos_embed, patches_height*patches_width, patches_height, patches_width)
383
+
384
+ else:
385
+ pos_embed = self.pos_embed
386
+
387
+ x = x + pos_embed
388
+
389
+ x = self.pos_drop(x)
390
+
391
+ for blk in self.blocks:
392
+ x = blk(x)
393
+
394
+ #print("Shapes after block: {}".format(x.shape)) # Shapes after block: torch.Size([1, 1453, 768])
395
+ #temp = self.norm(x)
396
+ #print("Shapes after norm original: {}".format(temp.shape)) # Shapes after norm original: torch.Size([1, 1453, 768])
397
+
398
+ x = self.norm(x)[:, 0] # Important! The state of the class token at the output of the transformer encoder serves as the image representation
399
+
400
+ #print("Shapes after norm: {}".format(x.shape)) # Shapes after norm: torch.Size([1, 768])
401
+ x = self.pre_logits(x)
402
+
403
+ return x
404
+
405
+ def forward(self, x):
406
+ x = self.forward_features(x)
407
+ #print("Shapes before head: {}".format(x.shape)) # Shapes before head: torch.Size([1, 768])
408
+ x = self.head(x)
409
+ return x
410
+
411
+
412
+ class DistilledVisionTransformer(VisionTransformer):
413
+ """ Vision Transformer with distillation token.
414
+
415
+ Paper: `Training data-efficient image transformers & distillation through attention` -
416
+ https://arxiv.org/abs/2012.12877
417
+
418
+ This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
419
+ """
420
+ def __init__(self, *args, **kwargs):
421
+ super().__init__(*args, **kwargs)
422
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
423
+ num_patches = self.patch_embed.num_patches
424
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
425
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
426
+
427
+ trunc_normal_(self.dist_token, std=.02)
428
+ trunc_normal_(self.pos_embed, std=.02)
429
+ self.head_dist.apply(self._init_weights)
430
+
431
+ def forward_features(self, x):
432
+ B = x.shape[0]
433
+ x = self.patch_embed(x)
434
+
435
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
436
+ dist_token = self.dist_token.expand(B, -1, -1)
437
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
438
+
439
+ x = x + self.pos_embed
440
+ x = self.pos_drop(x)
441
+
442
+ for blk in self.blocks:
443
+ x = blk(x)
444
+
445
+ x = self.norm(x)
446
+ return x[:, 0], x[:, 1]
447
+
448
+ def forward(self, x):
449
+ x, x_dist = self.forward_features(x)
450
+ x = self.head(x)
451
+ x_dist = self.head_dist(x_dist)
452
+ if self.training:
453
+ return x, x_dist
454
+ else:
455
+ # during inference, return the average of both classifier predictions
456
+ return (x + x_dist) / 2
457
+
458
+
459
+ def resize_pos_embed(posemb, ntok_new, new_height, new_width):
460
+ #uzpaka: make it work for non-square images
461
+
462
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
463
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
464
+ #_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
465
+ #ntok_new = posemb_new.shape[1]
466
+ if True:
467
+ posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
468
+ ntok_new -= 1
469
+ else:
470
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
471
+ gs_old = int(math.sqrt(len(posemb_grid)))
472
+
473
+ gs_new_h = new_height
474
+ gs_new_w = new_width
475
+ if new_height is None or new_width is None:
476
+ gs_new = int(math.sqrt(ntok_new))
477
+ gs_new_h = gs_new
478
+ gs_new_w = gs_new
479
+
480
+ #_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
481
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
482
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_new_h, gs_new_w), mode='bilinear')
483
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new_h * gs_new_w, -1)
484
+ posemb_ret = torch.cat([posemb_tok, posemb_grid], dim=1)
485
+ return posemb_ret
486
+
487
+
488
+ def checkpoint_filter_fn(state_dict, model):
489
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
490
+ out_dict = {}
491
+ if 'model' in state_dict:
492
+ # For deit models
493
+ state_dict = state_dict['model']
494
+ for k, v in state_dict.items():
495
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
496
+ # For old models that I trained prior to conv based patchification
497
+ O, I, H, W = model.patch_embed.proj.weight.shape
498
+ v = v.reshape(O, -1, H, W)
499
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
500
+ # To resize pos embedding when using model at different size from pretrained weights
501
+ v = resize_pos_embed(v, model.pos_embed)
502
+ out_dict[k] = v
503
+ return out_dict
504
+
505
+
506
+ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
507
+ default_cfg = default_cfgs[variant]
508
+ default_num_classes = default_cfg['num_classes']
509
+ default_img_size = default_cfg['input_size'][-1]
510
+
511
+ num_classes = kwargs.pop('num_classes', default_num_classes)
512
+ img_size = kwargs.pop('img_size', default_img_size)
513
+ repr_size = kwargs.pop('representation_size', None)
514
+ if repr_size is not None and num_classes != default_num_classes:
515
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
516
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
517
+ _logger.warning("Removing representation layer for fine-tuning.")
518
+ repr_size = None
519
+
520
+ model_cls = DistilledVisionTransformer if distilled else VisionTransformer
521
+ model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
522
+ model.default_cfg = default_cfg
523
+ return model
524
+
525
+
526
+ @register_model
527
+ def vit_small_patch16_224(pretrained=False, **kwargs):
528
+ """ My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
529
+ model_kwargs = dict(
530
+ patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
531
+ qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
532
+ if pretrained:
533
+ # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
534
+ model_kwargs.setdefault('qk_scale', 768 ** -0.5)
535
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
536
+ return model
537
+
538
+
539
+ @register_model
540
+ def vit_base_patch16_224(pretrained=False, **kwargs):
541
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
542
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
543
+ """
544
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
545
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
546
+ return model
547
+
548
+
549
+ @register_model
550
+ def vit_base_patch32_224(pretrained=False, **kwargs):
551
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
552
+ """
553
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
554
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
555
+ return model
556
+
557
+
558
+ @register_model
559
+ def vit_base_patch16_384(pretrained=False, **kwargs):
560
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
561
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
562
+ """
563
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
564
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
565
+ return model
566
+
567
+
568
+ @register_model
569
+ def vit_base_patch32_384(pretrained=False, **kwargs):
570
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
571
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
572
+ """
573
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
574
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
575
+ return model
576
+
577
+
578
+ @register_model
579
+ def vit_large_patch16_224(pretrained=False, **kwargs):
580
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
581
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
582
+ """
583
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
584
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
585
+ return model
586
+
587
+
588
+ @register_model
589
+ def vit_large_patch32_224(pretrained=False, **kwargs):
590
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
591
+ """
592
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
593
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
594
+ return model
595
+
596
+
597
+ @register_model
598
+ def vit_large_patch16_384(pretrained=False, **kwargs):
599
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
600
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
601
+ """
602
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
603
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
604
+ return model
605
+
606
+
607
+ @register_model
608
+ def vit_large_patch32_384(pretrained=False, **kwargs):
609
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
610
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
611
+ """
612
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
613
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
614
+ return model
615
+
616
+
617
+ @register_model
618
+ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
619
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
620
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
621
+ """
622
+ model_kwargs = dict(
623
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
624
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
625
+ return model
626
+
627
+
628
+ @register_model
629
+ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
630
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
631
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
632
+ """
633
+ model_kwargs = dict(
634
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
635
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
636
+ return model
637
+
638
+
639
+ @register_model
640
+ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
641
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
642
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
643
+ """
644
+ model_kwargs = dict(
645
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
646
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
647
+ return model
648
+
649
+
650
+ @register_model
651
+ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
652
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
653
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
654
+ """
655
+ model_kwargs = dict(
656
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
657
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
658
+ return model
659
+
660
+
661
+ @register_model
662
+ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
663
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
664
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
665
+ NOTE: converted weights not currently available, too large for github release hosting.
666
+ """
667
+ model_kwargs = dict(
668
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
669
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
670
+ return model
671
+
672
+
673
+ @register_model
674
+ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
675
+ """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
676
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
677
+ """
678
+ # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
679
+ backbone = ResNetV2(
680
+ layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
681
+ preact=False, stem_type='same', conv_layer=StdConv2dSame)
682
+ model_kwargs = dict(
683
+ embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
684
+ representation_size=768, **kwargs)
685
+ model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
686
+ return model
687
+
688
+
689
+ @register_model
690
+ def vit_base_resnet50_384(pretrained=False, **kwargs):
691
+ """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
692
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
693
+ """
694
+ # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
695
+ backbone = ResNetV2(
696
+ layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
697
+ preact=False, stem_type='same', conv_layer=StdConv2dSame)
698
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
699
+ model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
700
+ return model
701
+
702
+
703
+ @register_model
704
+ def vit_small_resnet26d_224(pretrained=False, **kwargs):
705
+ """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
706
+ """
707
+ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
708
+ model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
709
+ model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
710
+ return model
711
+
712
+
713
+ @register_model
714
+ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
715
+ """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
716
+ """
717
+ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
718
+ model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
719
+ model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
720
+ return model
721
+
722
+
723
+ @register_model
724
+ def vit_base_resnet26d_224(pretrained=False, **kwargs):
725
+ """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
726
+ """
727
+ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
728
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
729
+ model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
730
+ return model
731
+
732
+
733
+ @register_model
734
+ def vit_base_resnet50d_224(pretrained=False, **kwargs):
735
+ """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
736
+ """
737
+ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
738
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
739
+ model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
740
+ return model
741
+
742
+
743
+ @register_model
744
+ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
745
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
746
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
747
+ """
748
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
749
+ model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
750
+ return model
751
+
752
+
753
+ @register_model
754
+ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
755
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
756
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
757
+ """
758
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
759
+ model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
760
+ return model
761
+
762
+
763
+ @register_model
764
+ def vit_deit_base_patch16_224(pretrained=False, **kwargs):
765
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
766
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
767
+ """
768
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
769
+ model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
770
+ return model
771
+
772
+
773
+ @register_model
774
+ def vit_deit_base_patch16_384(pretrained=False, **kwargs):
775
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
776
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
777
+ """
778
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
779
+ model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
780
+ return model
781
+
782
+
783
+ @register_model
784
+ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
785
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
786
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
787
+ """
788
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
789
+ model = _create_vision_transformer(
790
+ 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
791
+ return model
792
+
793
+
794
+ @register_model
795
+ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
796
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
797
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
798
+ """
799
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
800
+ model = _create_vision_transformer(
801
+ 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
802
+ return model
803
+
804
+
805
+ @register_model
806
+ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
807
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
808
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
809
+ """
810
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
811
+ model = _create_vision_transformer(
812
+ 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
813
+ return model
814
+
815
+
816
+ @register_model
817
+ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
818
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
819
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
820
+ """
821
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
822
+ model = _create_vision_transformer(
823
+ 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
824
+ return model