File size: 1,643 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from typing import Callable, Optional, Tuple, Union
from timm.layers import Mlp, PatchEmbed
from timm.models.vision_transformer import Block, VisionTransformer
from .head import DecoderLinear
class ViTForSeg(VisionTransformer):
def __init__(self, img_size: int | Tuple[int, int] = 224, patch_size: int | Tuple[int, int] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, qkv_bias: bool = True, qk_norm: bool = False, init_values: float | None = None, class_token: bool = True, no_embed_class: bool = False, pre_norm: bool = False, fc_norm: bool | None = None, drop_rate: float = 0, pos_drop_rate: float = 0, patch_drop_rate: float = 0, proj_drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, weight_init: str = '', embed_layer: Callable[..., Any] = ..., norm_layer: Callable[..., Any] | None = None, act_layer: Callable[..., Any] | None = None, block_fn: Callable[..., Any] = ..., mlp_layer: Callable[..., Any] = ...):
super().__init__(img_size, patch_size, in_chans, num_classes, global_pool, embed_dim, depth, num_heads, mlp_ratio, qkv_bias, qk_norm, init_values, class_token, no_embed_class, pre_norm, fc_norm, drop_rate, pos_drop_rate, patch_drop_rate, proj_drop_rate, attn_drop_rate, drop_path_rate, weight_init, embed_layer, norm_layer, act_layer, block_fn, mlp_layer)
self.head = DecoderLinear(num)
def forward_head(self, x, pre_logits: bool = False):
return self.head(x)
def init_from_vit(self, vit):
self.load_state_dict(vit.state_dict(), strict=False)
|