import torch import torch.nn as nn import torch.nn.functional as F from .BasePIFuNet import BasePIFuNet import functools from .SurfaceClassifier import SurfaceClassifier from .DepthNormalizer import DepthNormalizer from ..net_util import * class ResBlkPIFuNet(BasePIFuNet): def __init__(self, opt, projection_mode='orthogonal'): if opt.color_loss_type == 'l1': error_term = nn.L1Loss() elif opt.color_loss_type == 'mse': error_term = nn.MSELoss() super(ResBlkPIFuNet, self).__init__( projection_mode=projection_mode, error_term=error_term) self.name = 'respifu' self.opt = opt norm_type = get_norm_layer(norm_type=opt.norm_color) self.image_filter = ResnetFilter(opt, norm_layer=norm_type) self.surface_classifier = SurfaceClassifier( filter_channels=self.opt.mlp_dim_color, num_views=self.opt.num_views, no_residual=self.opt.no_residual, last_op=nn.Tanh()) self.normalizer = DepthNormalizer(opt) init_net(self) def filter(self, images): ''' Filter the input images store all intermediate features. :param images: [B, C, H, W] input images ''' self.im_feat = self.image_filter(images) def attach(self, im_feat): self.im_feat = torch.cat([im_feat, self.im_feat], 1) def query(self, points, calibs, transforms=None, labels=None): ''' Given 3D points, query the network predictions for each point. Image features should be pre-computed before this call. store all intermediate features. query() function may behave differently during training/testing. :param points: [B, 3, N] world space coordinates of points :param calibs: [B, 3, 4] calibration matrices for each image :param transforms: Optional [B, 2, 3] image space coordinate transforms :param labels: Optional [B, Res, N] gt labeling :return: [B, Res, N] predictions for each point ''' if labels is not None: self.labels = labels xyz = self.projection(points, calibs, transforms) xy = xyz[:, :2, :] z = xyz[:, 2:3, :] z_feat = self.normalizer(z) # This is a list of [B, Feat_i, N] features point_local_feat_list = [self.index(self.im_feat, xy), z_feat] # [B, Feat_all, N] point_local_feat = torch.cat(point_local_feat_list, 1) self.preds = self.surface_classifier(point_local_feat) def forward(self, images, im_feat, points, calibs, transforms=None, labels=None): self.filter(images) self.attach(im_feat) self.query(points, calibs, transforms, labels) res = self.get_preds() error = self.get_error() return res, error class ResnetBlock(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. padding_type (str) -- the name of padding layer: reflect | replicate | zero norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. use_bias (bool) -- if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) if last: conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] else: conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): """Forward function (with skip connections)""" out = x + self.conv_block(x) # add skip connections return out class ResnetFilter(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ assert (n_blocks >= 0) super(ResnetFilter, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks if i == n_blocks - 1: model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, last=True)] else: model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] if opt.use_tanh: model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): """Standard forward""" return self.model(input)