ZJF-Thunder
添加文件
e26e560
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.backbones.res2net import Bottle2neck
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
from mmdet.models.utils import SimplifiedBasicBlock
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
SimplifiedBasicBlock)):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True