hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
raw
history blame contribute delete
1.76 kB
import torch.nn as nn
from .common import ResBlock, default_conv
def encoder(in_channels, n_feats):
"""RGB / IR feature encoder
"""
# in_channels == 1 or 3 or 4 or ....
# After 1st conv, B x n_feats x H x W
# After 2nd conv, B x 2n_feats x H/2 x W/2
# After 3rd conv, B x 3n_feats x H/4 x W/4
return nn.Sequential(
nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
)
def decoder(out_channels, n_feats):
"""RGB / IR / Depth decoder
"""
# After 1st deconv, B x 2n_feats x H/2 x W/2
# After 2nd deconv, B x n_feats x H x W
# After 3rd conv, B x out_channels x H x W
deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
return nn.Sequential(
nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
)
# def ResNet(n_feats, in_channels=None, out_channels=None):
def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
"""sequential ResNet
"""
# if in_channels is None:
# in_channels = n_feats
# if out_channels is None:
# out_channels = n_feats
# # currently not implemented
m = []
if in_channels is not None:
m += [default_conv(in_channels, n_feats, kernel_size)]
m += [ResBlock(n_feats, 3)] * n_blocks
if out_channels is not None:
m += [default_conv(n_feats, out_channels, kernel_size)]
return nn.Sequential(*m)