File size: 5,857 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Facebook, Inc. and its affiliates.
import fvcore.nn.weight_init as weight_init
import torch.nn.functional as F
from annotator.oneformer.detectron2.layers import CNNBlockBase, Conv2d, get_norm
from annotator.oneformer.detectron2.modeling import BACKBONE_REGISTRY
from annotator.oneformer.detectron2.modeling.backbone.resnet import (
BasicStem,
BottleneckBlock,
DeformBottleneckBlock,
ResNet,
)
class DeepLabStem(CNNBlockBase):
"""
The DeepLab ResNet stem (layers before the first residual block).
"""
def __init__(self, in_channels=3, out_channels=128, norm="BN"):
"""
Args:
norm (str or callable): norm after the first conv layer.
See :func:`layers.get_norm` for supported format.
"""
super().__init__(in_channels, out_channels, 4)
self.in_channels = in_channels
self.conv1 = Conv2d(
in_channels,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False,
norm=get_norm(norm, out_channels // 2),
)
self.conv2 = Conv2d(
out_channels // 2,
out_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm=get_norm(norm, out_channels // 2),
)
self.conv3 = Conv2d(
out_channels // 2,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm=get_norm(norm, out_channels),
)
weight_init.c2_msra_fill(self.conv1)
weight_init.c2_msra_fill(self.conv2)
weight_init.c2_msra_fill(self.conv3)
def forward(self, x):
x = self.conv1(x)
x = F.relu_(x)
x = self.conv2(x)
x = F.relu_(x)
x = self.conv3(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x
@BACKBONE_REGISTRY.register()
def build_resnet_deeplab_backbone(cfg, input_shape):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
if cfg.MODEL.RESNETS.STEM_TYPE == "basic":
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab":
stem = DeepLabStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
else:
raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE))
# fmt: off
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID
# fmt: on
assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation)
assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation)
if res4_dilation == 2:
# Always dilate res5 if res4 is dilated.
assert res5_dilation == 4
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
stages = []
# Avoid creating variables without gradients
# It consumes extra memory and may cause allreduce to fail
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
max_stage_idx = max(out_stage_idx)
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
if stage_idx == 4:
dilation = res4_dilation
elif stage_idx == 5:
dilation = res5_dilation
else:
dilation = 1
first_stride = 1 if idx == 0 or dilation > 1 else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
"in_channels": in_channels,
"out_channels": out_channels,
"norm": norm,
}
stage_kargs["bottleneck_channels"] = bottleneck_channels
stage_kargs["stride_in_1x1"] = stride_in_1x1
stage_kargs["dilation"] = dilation
stage_kargs["num_groups"] = num_groups
if deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock
if stage_idx == 5:
stage_kargs.pop("dilation")
stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid]
blocks = ResNet.make_stage(**stage_kargs)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)
|