|
""" Attention Factory |
|
|
|
Hacked together by / Copyright 2021 Ross Wightman |
|
""" |
|
import torch |
|
from functools import partial |
|
|
|
from .bottleneck_attn import BottleneckAttn |
|
from .cbam import CbamModule, LightCbamModule |
|
from .eca import EcaModule, CecaModule |
|
from .gather_excite import GatherExcite |
|
from .global_context import GlobalContext |
|
from .halo_attn import HaloAttn |
|
from .lambda_layer import LambdaLayer |
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn |
|
from .selective_kernel import SelectiveKernel |
|
from .split_attn import SplitAttn |
|
from .squeeze_excite import SEModule, EffectiveSEModule |
|
|
|
|
|
def get_attn(attn_type): |
|
if isinstance(attn_type, torch.nn.Module): |
|
return attn_type |
|
module_cls = None |
|
if attn_type: |
|
if isinstance(attn_type, str): |
|
attn_type = attn_type.lower() |
|
|
|
|
|
if attn_type == 'se': |
|
module_cls = SEModule |
|
elif attn_type == 'ese': |
|
module_cls = EffectiveSEModule |
|
elif attn_type == 'eca': |
|
module_cls = EcaModule |
|
elif attn_type == 'ecam': |
|
module_cls = partial(EcaModule, use_mlp=True) |
|
elif attn_type == 'ceca': |
|
module_cls = CecaModule |
|
elif attn_type == 'ge': |
|
module_cls = GatherExcite |
|
elif attn_type == 'gc': |
|
module_cls = GlobalContext |
|
elif attn_type == 'gca': |
|
module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) |
|
elif attn_type == 'cbam': |
|
module_cls = CbamModule |
|
elif attn_type == 'lcbam': |
|
module_cls = LightCbamModule |
|
|
|
|
|
|
|
|
|
elif attn_type == 'sk': |
|
module_cls = SelectiveKernel |
|
elif attn_type == 'splat': |
|
module_cls = SplitAttn |
|
|
|
|
|
|
|
|
|
elif attn_type == 'lambda': |
|
return LambdaLayer |
|
elif attn_type == 'bottleneck': |
|
return BottleneckAttn |
|
elif attn_type == 'halo': |
|
return HaloAttn |
|
elif attn_type == 'nl': |
|
module_cls = NonLocalAttn |
|
elif attn_type == 'bat': |
|
module_cls = BatNonLocalAttn |
|
|
|
|
|
else: |
|
assert False, "Invalid attn module (%s)" % attn_type |
|
elif isinstance(attn_type, bool): |
|
if attn_type: |
|
module_cls = SEModule |
|
else: |
|
module_cls = attn_type |
|
return module_cls |
|
|
|
|
|
def create_attn(attn_type, channels, **kwargs): |
|
module_cls = get_attn(attn_type) |
|
if module_cls is not None: |
|
|
|
return module_cls(channels, **kwargs) |
|
return None |
|
|