|
import torch.nn as nn
|
|
|
|
from .common import ResBlock, default_conv
|
|
|
|
def encoder(in_channels, n_feats):
|
|
"""RGB / IR feature encoder
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
|
|
nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
|
|
nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
|
|
)
|
|
|
|
def decoder(out_channels, n_feats):
|
|
"""RGB / IR / Depth decoder
|
|
"""
|
|
|
|
|
|
|
|
deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
|
|
|
|
return nn.Sequential(
|
|
nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
|
|
nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
|
|
nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
|
|
)
|
|
|
|
|
|
def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
|
|
"""sequential ResNet
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = []
|
|
|
|
if in_channels is not None:
|
|
m += [default_conv(in_channels, n_feats, kernel_size)]
|
|
|
|
m += [ResBlock(n_feats, 3)] * n_blocks
|
|
|
|
if out_channels is not None:
|
|
m += [default_conv(n_feats, out_channels, kernel_size)]
|
|
|
|
|
|
return nn.Sequential(*m)
|
|
|
|
|