LKCell / models /utils /residual.py
qingke1's picture
initial commit
aea73e2
# -*- coding: utf-8 -*-
# Residual block as defined in:
# He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
# for image recognition." In Proceedings of the IEEE conference on computer vision
# and pattern recognition, pp. 770-778. 2016.
#
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import torch
import torch.nn as nn
from collections import OrderedDict
from models.utils.tf_utils import TFSamepaddingLayer
class ResidualBlock(nn.Module):
"""Residual block as defined in:
He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
for image recognition." In Proceedings of the IEEE conference on computer vision
and pattern recognition, pp. 770-778. 2016.
"""
def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1):
super(ResidualBlock, self).__init__()
assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info"
self.nr_unit = unit_count
self.in_ch = in_ch
self.unit_ch = unit_ch
# ! For inference only so init values for batchnorm may not match tensorflow
unit_in_ch = in_ch
self.units = nn.ModuleList()
for idx in range(unit_count):
unit_layer = [
("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("preact/relu", nn.ReLU(inplace=True)),
(
"conv1",
nn.Conv2d(
unit_in_ch,
unit_ch[0],
unit_ksize[0],
stride=1,
padding=0,
bias=False,
),
),
("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)),
("conv1/relu", nn.ReLU(inplace=True)),
(
"conv2/pad",
TFSamepaddingLayer(
ksize=unit_ksize[1], stride=stride if idx == 0 else 1
),
),
(
"conv2",
nn.Conv2d(
unit_ch[0],
unit_ch[1],
unit_ksize[1],
stride=stride if idx == 0 else 1,
padding=0,
bias=False,
),
),
("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)),
("conv2/relu", nn.ReLU(inplace=True)),
(
"conv3",
nn.Conv2d(
unit_ch[1],
unit_ch[2],
unit_ksize[2],
stride=1,
padding=0,
bias=False,
),
),
]
# * has bna to conclude each previous block so
# * must not put preact for the first unit of this block
unit_layer = unit_layer if idx != 0 else unit_layer[2:]
self.units.append(nn.Sequential(OrderedDict(unit_layer)))
unit_in_ch = unit_ch[-1]
if in_ch != unit_ch[-1] or stride != 1:
self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False)
else:
self.shortcut = None
self.blk_bna = nn.Sequential(
OrderedDict(
[
("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
)
)
def out_ch(self):
return self.unit_ch[-1]
def init_weights(self):
"""Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers"""
for m in self.modules():
classname = m.__class__.__name__
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if "norm" in classname.lower():
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if "linear" in classname.lower():
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, prev_feat, freeze=False):
if self.shortcut is None:
shortcut = prev_feat
else:
shortcut = self.shortcut(prev_feat)
for idx in range(0, len(self.units)):
new_feat = prev_feat
if self.training:
with torch.set_grad_enabled(not freeze):
new_feat = self.units[idx](new_feat)
else:
new_feat = self.units[idx](new_feat)
prev_feat = new_feat + shortcut
shortcut = prev_feat
feat = self.blk_bna(prev_feat)
return feat