File size: 48,399 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 |
# FastViT for PyTorch
#
# Original implementation and weights from https://github.com/apple/ml-fastvit
#
# For licensing see accompanying LICENSE file at https://github.com/apple/ml-fastvit/tree/main
# Original work is copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import os
from functools import partial
from typing import Tuple, Optional, Union
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
def num_groups(group_size, channels):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
assert channels % group_size == 0
return channels // group_size
class MobileOneBlock(nn.Module):
"""MobileOne building block.
This block has a multi-branched architecture at train-time
and plain-CNN style architecture at inference time
For more details, please refer to our paper:
`An Improved One millisecond Mobile Backbone` -
https://arxiv.org/pdf/2206.04040.pdf
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
) -> None:
"""Construct a MobileOneBlock module.
Args:
in_chs: Number of channels in the input.
out_chs: Number of channels produced by the block.
kernel_size: Size of the convolution kernel.
stride: Stride size.
dilation: Kernel dilation factor.
group_size: Convolution group size.
inference_mode: If True, instantiates model in inference mode.
use_se: Whether to use SE-ReLU activations.
use_act: Whether to use activation. Default: ``True``
use_scale_branch: Whether to use scale branch. Default: ``True``
num_conv_branches: Number of linear conv branches.
"""
super(MobileOneBlock, self).__init__()
self.inference_mode = inference_mode
self.groups = num_groups(group_size, in_chs)
self.stride = stride
self.dilation = dilation
self.kernel_size = kernel_size
self.in_chs = in_chs
self.out_chs = out_chs
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity()
if inference_mode:
self.reparam_conv = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=self.groups,
bias=True,
)
else:
# Re-parameterizable skip connection
self.reparam_conv = None
self.identity = (
nn.BatchNorm2d(num_features=in_chs)
if out_chs == in_chs and stride == 1
else None
)
# Re-parameterizable conv branches
if num_conv_branches > 0:
self.conv_kxk = nn.ModuleList([
ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=kernel_size,
stride=self.stride,
groups=self.groups,
apply_act=False,
) for _ in range(self.num_conv_branches)
])
else:
self.conv_kxk = None
# Re-parameterizable scale branch
self.conv_scale = None
if kernel_size > 1 and use_scale_branch:
self.conv_scale = ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=1,
stride=self.stride,
groups=self.groups,
apply_act=False
)
self.act = act_layer() if use_act else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
# Inference mode forward pass.
if self.reparam_conv is not None:
return self.act(self.se(self.reparam_conv(x)))
# Multi-branched train-time forward pass.
# Identity branch output
identity_out = 0
if self.identity is not None:
identity_out = self.identity(x)
# Scale branch output
scale_out = 0
if self.conv_scale is not None:
scale_out = self.conv_scale(x)
# Other kxk conv branches
out = scale_out + identity_out
if self.conv_kxk is not None:
for rc in self.conv_kxk:
out += rc(x)
return self.act(self.se(out))
def reparameterize(self):
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.reparam_conv is not None:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = create_conv2d(
in_channels=self.in_chs,
out_channels=self.out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = kernel
self.reparam_conv.bias.data = bias
# Delete un-used branches
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("conv_kxk")
self.__delattr__("conv_scale")
if hasattr(self, "identity"):
self.__delattr__("identity")
self.inference_mode = True
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.conv_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
# Pad scale branch kernel to match conv branch kernel size.
pad = self.kernel_size // 2
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.identity is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
if self.conv_kxk is not None:
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
kernel_conv += _kernel
bias_conv += _bias
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(
self, branch: Union[nn.Sequential, nn.BatchNorm2d]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with preceeding conv layer.
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
Args:
branch: Sequence of ops to be fused.
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
if isinstance(branch, ConvNormAct):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, "id_tensor"):
input_dim = self.in_chs // self.groups
kernel_value = torch.zeros(
(self.in_chs, input_dim, self.kernel_size, self.kernel_size),
dtype=branch.weight.dtype,
device=branch.weight.device,
)
for i in range(self.in_chs):
kernel_value[
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
] = 1
self.id_tensor = kernel_value
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
class ReparamLargeKernelConv(nn.Module):
"""Building Block of RepLKNet
This class defines overparameterized large kernel conv block
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int,
group_size: int,
small_kernel: Optional[int] = None,
inference_mode: bool = False,
act_layer: Optional[nn.Module] = None,
) -> None:
"""Construct a ReparamLargeKernelConv module.
Args:
in_chs: Number of input channels.
out_chs: Number of output channels.
kernel_size: Kernel size of the large kernel conv branch.
stride: Stride size. Default: 1
group_size: Group size. Default: 1
small_kernel: Kernel size of small kernel conv branch.
inference_mode: If True, instantiates model in inference mode. Default: ``False``
act_layer: Activation module. Default: ``nn.GELU``
"""
super(ReparamLargeKernelConv, self).__init__()
self.stride = stride
self.groups = num_groups(group_size, in_chs)
self.in_chs = in_chs
self.out_chs = out_chs
self.kernel_size = kernel_size
self.small_kernel = small_kernel
if inference_mode:
self.reparam_conv = create_conv2d(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=1,
groups=self.groups,
bias=True,
)
else:
self.reparam_conv = None
self.large_conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=kernel_size,
stride=self.stride,
groups=self.groups,
apply_act=False,
)
if small_kernel is not None:
assert (
small_kernel <= kernel_size
), "The kernel size for re-param cannot be larger than the large kernel!"
self.small_conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=small_kernel,
stride=self.stride,
groups=self.groups,
apply_act=False,
)
# FIXME output of this act was not used in original impl, likely due to bug
self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
out = self.reparam_conv(x)
else:
out = self.large_conv(x)
if self.small_conv is not None:
out = out + self.small_conv(x)
out = self.act(out)
return out
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to obtain re-parameterized kernel and bias.
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
Returns:
Tuple of (kernel, bias) after fusing branches.
"""
eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn)
if hasattr(self, "small_conv"):
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
eq_b += small_b
eq_k += nn.functional.pad(
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
)
return eq_k, eq_b
def reparameterize(self) -> None:
"""
Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
eq_k, eq_b = self.get_kernel_bias()
self.reparam_conv = create_conv2d(
self.in_chs,
self.out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = eq_k
self.reparam_conv.bias.data = eq_b
self.__delattr__("large_conv")
if hasattr(self, "small_conv"):
self.__delattr__("small_conv")
@staticmethod
def _fuse_bn(
conv: torch.Tensor, bn: nn.BatchNorm2d
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to fuse batchnorm layer with conv layer.
Args:
conv: Convolutional kernel weights.
bn: Batchnorm 2d layer.
Returns:
Tuple of (kernel, bias) after fusing batchnorm.
"""
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def convolutional_stem(
in_chs: int,
out_chs: int,
act_layer: nn.Module = nn.GELU,
inference_mode: bool = False
) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks.
Args:
in_chs: Number of input channels.
out_chs: Number of output channels.
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
Returns:
nn.Sequential object with stem elements.
"""
return nn.Sequential(
MobileOneBlock(
in_chs=in_chs,
out_chs=out_chs,
kernel_size=3,
stride=2,
act_layer=act_layer,
inference_mode=inference_mode,
),
MobileOneBlock(
in_chs=out_chs,
out_chs=out_chs,
kernel_size=3,
stride=2,
group_size=1,
act_layer=act_layer,
inference_mode=inference_mode,
),
MobileOneBlock(
in_chs=out_chs,
out_chs=out_chs,
kernel_size=1,
stride=1,
act_layer=act_layer,
inference_mode=inference_mode,
),
)
class Attention(nn.Module):
"""Multi-headed Self Attention module.
Source modified from:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Build MHSA module that can handle 3D or 4D input tensors.
Args:
dim: Number of embedding dimensions.
head_dim: Number of hidden dimensions per head. Default: ``32``
qkv_bias: Use bias or not. Default: ``False``
attn_drop: Dropout rate for attention tensor.
proj_drop: Dropout rate for projection tensor.
"""
super().__init__()
assert dim % head_dim == 0, "dim should be divisible by head_dim"
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(-2, -1) # (B, N, C)
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x.transpose(-2, -1).reshape(B, C, H, W)
return x
class PatchEmbed(nn.Module):
"""Convolutional patch embedding layer."""
def __init__(
self,
patch_size: int,
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
) -> None:
"""Build patch embedding layer.
Args:
patch_size: Patch size for embedding computation.
stride: Stride for convolutional embedding layer.
in_chs: Number of channels of input tensor.
embed_dim: Number of embedding dimensions.
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
"""
super().__init__()
self.proj = nn.Sequential(
ReparamLargeKernelConv(
in_chs=in_chs,
out_chs=embed_dim,
kernel_size=patch_size,
stride=stride,
group_size=1,
small_kernel=3,
inference_mode=inference_mode,
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
),
MobileOneBlock(
in_chs=embed_dim,
out_chs=embed_dim,
kernel_size=1,
stride=1,
act_layer=act_layer,
inference_mode=inference_mode,
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class RepMixer(nn.Module):
"""Reparameterizable token mixer.
For more details, please refer to our paper:
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
"""
def __init__(
self,
dim,
kernel_size=3,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Module.
Args:
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
kernel_size: Kernel size for spatial mixing. Default: 3
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.inference_mode = inference_mode
if inference_mode:
self.reparam_conv = nn.Conv2d(
self.dim,
self.dim,
kernel_size=self.kernel_size,
stride=1,
padding=self.kernel_size // 2,
groups=self.dim,
bias=True,
)
else:
self.reparam_conv = None
self.norm = MobileOneBlock(
dim,
dim,
kernel_size,
group_size=1,
use_act=False,
use_scale_branch=False,
num_conv_branches=0,
)
self.mixer = MobileOneBlock(
dim,
dim,
kernel_size,
group_size=1,
use_act=False,
)
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
x = self.reparam_conv(x)
else:
x = x + self.layer_scale(self.mixer(x) - self.norm(x))
return x
def reparameterize(self) -> None:
"""Reparameterize mixer and norm into a single
convolutional layer for efficient inference.
"""
if self.inference_mode:
return
self.mixer.reparameterize()
self.norm.reparameterize()
if isinstance(self.layer_scale, LayerScale2d):
w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * (
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
)
b = torch.squeeze(self.layer_scale.gamma) * (
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
)
else:
w = (
self.mixer.id_tensor
+ self.mixer.reparam_conv.weight
- self.norm.reparam_conv.weight
)
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
self.reparam_conv = create_conv2d(
self.dim,
self.dim,
kernel_size=self.kernel_size,
stride=1,
groups=self.dim,
bias=True,
)
self.reparam_conv.weight.data = w
self.reparam_conv.bias.data = b
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("mixer")
self.__delattr__("norm")
self.__delattr__("layer_scale")
class ConvMlp(nn.Module):
"""Convolutional FFN Module."""
def __init__(
self,
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
Args:
in_chs: Number of input channels.
hidden_channels: Number of channels after expansion. Default: None
out_chs: Number of output channels. Default: None
act_layer: Activation layer. Default: ``GELU``
drop: Dropout rate. Default: ``0.0``.
"""
super().__init__()
out_chs = out_chs or in_chs
hidden_channels = hidden_channels or in_chs
self.conv = ConvNormAct(
in_chs,
out_chs,
kernel_size=7,
groups=in_chs,
apply_act=False,
)
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class RepConditionalPosEnc(nn.Module):
"""Implementation of conditional positional encoding.
For more details refer to paper:
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
In our implementation, we can reparameterize this module to eliminate a skip connection.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
Args:
dim: Number of input channels.
dim_out: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super(RepConditionalPosEnc, self).__init__()
if isinstance(spatial_shape, int):
spatial_shape = tuple([spatial_shape] * 2)
assert isinstance(spatial_shape, Tuple), (
f'"spatial_shape" must by a sequence or int, '
f"get {type(spatial_shape)} instead."
)
assert len(spatial_shape) == 2, (
f'Length of "spatial_shape" should be 2, '
f"got {len(spatial_shape)} instead."
)
self.spatial_shape = spatial_shape
self.dim = dim
self.dim_out = dim_out or dim
self.groups = dim
if inference_mode:
self.reparam_conv = nn.Conv2d(
self.dim,
self.dim_out,
kernel_size=self.spatial_shape,
stride=1,
padding=spatial_shape[0] // 2,
groups=self.groups,
bias=True,
)
else:
self.reparam_conv = None
self.pos_enc = nn.Conv2d(
self.dim,
self.dim_out,
spatial_shape,
1,
int(spatial_shape[0] // 2),
groups=self.groups,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
x = self.reparam_conv(x)
else:
x = self.pos_enc(x) + x
return x
def reparameterize(self) -> None:
# Build equivalent Id tensor
input_dim = self.dim // self.groups
kernel_value = torch.zeros(
(
self.dim,
input_dim,
self.spatial_shape[0],
self.spatial_shape[1],
),
dtype=self.pos_enc.weight.dtype,
device=self.pos_enc.weight.device,
)
for i in range(self.dim):
kernel_value[
i,
i % input_dim,
self.spatial_shape[0] // 2,
self.spatial_shape[1] // 2,
] = 1
id_tensor = kernel_value
# Reparameterize Id tensor and conv
w_final = id_tensor + self.pos_enc.weight
b_final = self.pos_enc.bias
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
self.dim,
self.dim_out,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = w_final
self.reparam_conv.bias.data = b_final
for name, para in self.named_parameters():
if 'reparam_conv' in name:
continue
para.detach_()
self.__delattr__("pos_enc")
class RepMixerBlock(nn.Module):
"""Implementation of Metaformer block with RepMixer as token mixer.
For more details on Metaformer structure, please refer to:
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Block.
Args:
dim: Number of embedding dimensions.
kernel_size: Kernel size for repmixer. Default: 3
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
super().__init__()
self.token_mixer = RepMixer(
dim,
kernel_size=kernel_size,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
self.mlp = ConvMlp(
in_chs=dim,
hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
x = self.token_mixer(x)
x = x + self.drop_path(self.layer_scale(self.mlp(x)))
return x
class AttentionBlock(nn.Module):
"""Implementation of metaformer block with MHSA as token mixer.
For more details on Metaformer structure, please refer to:
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
):
"""Build Attention Block.
Args:
dim: Number of embedding dimensions.
mlp_ratio: MLP expansion ratio. Default: 4.0
act_layer: Activation layer. Default: ``nn.GELU``
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
proj_drop: Dropout rate. Default: 0.0
drop_path: Drop path rate. Default: 0.0
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
"""
super().__init__()
self.norm = norm_layer(dim)
self.token_mixer = Attention(dim=dim)
if layer_scale_init_value is not None:
self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale_1 = nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.mlp = ConvMlp(
in_chs=dim,
hidden_channels=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
if layer_scale_init_value is not None:
self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale_2 = nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x))))
x = x + self.drop_path2(self.layer_scale_2(self.mlp(x)))
return x
class FastVitStage(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
depth: int,
token_mixer_type: str,
downsample: bool = True,
down_patch_size: int = 7,
down_stride: int = 2,
pos_emb_layer: Optional[nn.Module] = None,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: Optional[float] = 1e-5,
lkc_use_act=False,
inference_mode=False,
):
"""FastViT stage.
Args:
dim: Number of embedding dimensions.
depth: Number of blocks in stage
token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio.
act_layer: Activation layer.
norm_layer: Normalization layer.
proj_drop_rate: Dropout rate.
drop_path_rate: Drop path rate.
layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode.
"""
super().__init__()
self.grad_checkpointing = False
if downsample:
self.downsample = PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=dim,
embed_dim=dim_out,
act_layer=act_layer,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
)
else:
assert dim == dim_out
self.downsample = nn.Identity()
if pos_emb_layer is not None:
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode)
else:
self.pos_emb = nn.Identity()
blocks = []
for block_idx in range(depth):
if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock(
dim_out,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx],
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(AttentionBlock(
dim_out,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
proj_drop=proj_drop_rate,
drop_path=drop_path_rate[block_idx],
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.pos_emb(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class FastVit(nn.Module):
fork_feat: torch.jit.Final[bool]
"""
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
"""
def __init__(
self,
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
) -> None:
super().__init__()
self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat
self.global_pool = global_pool
self.feature_info = []
# Convolutional stem
self.stem = convolutional_stem(
in_chans,
embed_dims[0],
act_layer,
inference_mode,
)
# Build the main stages of the network architecture
prev_dim = embed_dims[0]
scale = 1
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
stages = []
for i in range(len(layers)):
downsample = downsamples[i] or prev_dim != embed_dims[i]
stage = FastVitStage(
dim=prev_dim,
dim_out=embed_dims[i],
depth=layers[i],
downsample=downsample,
down_patch_size=down_patch_size,
down_stride=down_stride,
pos_emb_layer=pos_embs[i],
token_mixer_type=token_mixers[i],
kernel_size=repmixer_kernel_size,
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
proj_drop_rate=proj_drop_rate,
drop_path_rate=dpr[i],
layer_scale_init_value=layer_scale_init_value,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
)
stages.append(stage)
prev_dim = embed_dims[i]
if downsample:
scale *= 2
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_dim
# For segmentation and detection, extract intermediate output
if self.fork_feat:
# Add a norm layer for each output. self.stages is slightly different than self.network
# in the original code, the PatchEmbed layer is part of self.stages in this code where
# it was part of self.network in the original code. So we do not need to skip out indices.
self.out_indices = [0, 1, 2, 3]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get("FORK_LAST3", None):
"""For RetinaNet, `start_level=1`. The first norm layer will not used.
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
"""
layer = nn.Identity()
else:
layer = norm_layer(embed_dims[i_emb])
layer_name = f"norm{i_layer}"
self.add_module(layer_name, layer)
else:
# Classifier head
self.num_features = final_features = int(embed_dims[-1] * cls_ratio)
self.final_conv = MobileOneBlock(
in_chs=embed_dims[-1],
out_chs=final_features,
kernel_size=3,
stride=1,
group_size=1,
inference_mode=inference_mode,
use_se=True,
act_layer=act_layer,
num_conv_branches=1,
)
self.head = ClassifierHead(
final_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
"""Init. for classification"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return set()
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+).pos_emb', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.stem(x)
outs = []
for idx, block in enumerate(self.stages):
x = block(x)
if self.fork_feat:
if idx in self.out_indices:
norm_layer = getattr(self, f"norm{idx}")
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
x = self.final_conv(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
if self.fork_feat:
return x
x = self.forward_head(x)
return x
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 256, 256),
"pool_size": (8, 8),
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'),
"classifier": "head.fc",
**kwargs,
}
default_cfgs = generate_default_cfgs({
"fastvit_t8.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_t12.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_s12.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_sa12.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_sa24.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_sa36.apple_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_ma36.apple_in1k": _cfg(
hf_hub_id='timm/',
crop_pct=0.95
),
"fastvit_t8.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_t12.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
"fastvit_s12.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa12.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa24.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_sa36.apple_dist_in1k": _cfg(
hf_hub_id='timm/',),
"fastvit_ma36.apple_dist_in1k": _cfg(
hf_hub_id='timm/',
crop_pct=0.95
),
})
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
@register_model
def fastvit_t8(pretrained=False, **kwargs):
"""Instantiate FastViT-T8 model variant."""
model_args = dict(
layers=(2, 2, 4, 2),
embed_dims=(48, 96, 192, 384),
mlp_ratios=(3, 3, 3, 3),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
)
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_t12(pretrained=False, **kwargs):
"""Instantiate FastViT-T12 model variant."""
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
)
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_s12(pretrained=False, **kwargs):
"""Instantiate FastViT-S12 model variant."""
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
)
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa12(pretrained=False, **kwargs):
"""Instantiate FastViT-SA12 model variant."""
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa24(pretrained=False, **kwargs):
"""Instantiate FastViT-SA24 model variant."""
model_args = dict(
layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa36(pretrained=False, **kwargs):
"""Instantiate FastViT-SA36 model variant."""
model_args = dict(
layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_ma36(pretrained=False, **kwargs):
"""Instantiate FastViT-MA36 model variant."""
model_args = dict(
layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
|