|
""" NormAct (Normalizaiton + Activation Layer) Factory |
|
|
|
Create norm + act combo modules that attempt to be backwards compatible with separate norm + act |
|
isntances in models. Where these are used it will be possible to swap separate BN + act layers with |
|
combined modules like IABN or EvoNorms. |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import types |
|
import functools |
|
|
|
from .evo_norm import * |
|
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d |
|
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d |
|
from .inplace_abn import InplaceAbn |
|
|
|
_NORM_ACT_MAP = dict( |
|
batchnorm=BatchNormAct2d, |
|
batchnorm2d=BatchNormAct2d, |
|
groupnorm=GroupNormAct, |
|
groupnorm1=functools.partial(GroupNormAct, num_groups=1), |
|
layernorm=LayerNormAct, |
|
layernorm2d=LayerNormAct2d, |
|
evonormb0=EvoNorm2dB0, |
|
evonormb1=EvoNorm2dB1, |
|
evonormb2=EvoNorm2dB2, |
|
evonorms0=EvoNorm2dS0, |
|
evonorms0a=EvoNorm2dS0a, |
|
evonorms1=EvoNorm2dS1, |
|
evonorms1a=EvoNorm2dS1a, |
|
evonorms2=EvoNorm2dS2, |
|
evonorms2a=EvoNorm2dS2a, |
|
frn=FilterResponseNormAct2d, |
|
frntlu=FilterResponseNormTlu2d, |
|
inplaceabn=InplaceAbn, |
|
iabn=InplaceAbn, |
|
) |
|
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} |
|
|
|
_NORM_ACT_REQUIRES_ARG = { |
|
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} |
|
|
|
|
|
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): |
|
layer = get_norm_act_layer(layer_name, act_layer=act_layer) |
|
layer_instance = layer(num_features, apply_act=apply_act, **kwargs) |
|
if jit: |
|
layer_instance = torch.jit.script(layer_instance) |
|
return layer_instance |
|
|
|
|
|
def get_norm_act_layer(norm_layer, act_layer=None): |
|
if norm_layer is None: |
|
return None |
|
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) |
|
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) |
|
norm_act_kwargs = {} |
|
|
|
|
|
if isinstance(norm_layer, functools.partial): |
|
norm_act_kwargs.update(norm_layer.keywords) |
|
norm_layer = norm_layer.func |
|
|
|
if isinstance(norm_layer, str): |
|
if not norm_layer: |
|
return None |
|
layer_name = norm_layer.replace('_', '').lower().split('-')[0] |
|
norm_act_layer = _NORM_ACT_MAP[layer_name] |
|
elif norm_layer in _NORM_ACT_TYPES: |
|
norm_act_layer = norm_layer |
|
elif isinstance(norm_layer, types.FunctionType): |
|
|
|
norm_act_layer = norm_layer |
|
else: |
|
type_name = norm_layer.__name__.lower() |
|
if type_name.startswith('batchnorm'): |
|
norm_act_layer = BatchNormAct2d |
|
elif type_name.startswith('groupnorm'): |
|
norm_act_layer = GroupNormAct |
|
elif type_name.startswith('groupnorm1'): |
|
norm_act_layer = functools.partial(GroupNormAct, num_groups=1) |
|
elif type_name.startswith('layernorm2d'): |
|
norm_act_layer = LayerNormAct2d |
|
elif type_name.startswith('layernorm'): |
|
norm_act_layer = LayerNormAct |
|
else: |
|
assert False, f"No equivalent norm_act layer for {type_name}" |
|
|
|
if norm_act_layer in _NORM_ACT_REQUIRES_ARG: |
|
|
|
|
|
norm_act_kwargs.setdefault('act_layer', act_layer) |
|
if norm_act_kwargs: |
|
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) |
|
return norm_act_layer |
|
|