Spaces:
Runtime error
Runtime error
File size: 4,684 Bytes
c7f097c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet
from .SurfaceClassifier import SurfaceClassifier
from .DepthNormalizer import DepthNormalizer
from .HGFilters import *
from ..net_util import init_net
class HGPIFuNet(BasePIFuNet):
'''
HG PIFu network uses Hourglass stacks as the image filter.
It does the following:
1. Compute image feature stacks and store it in self.im_feat_list
self.im_feat_list[-1] is the last stack (output stack)
2. Calculate calibration
3. If training, it index on every intermediate stacks,
If testing, it index on the last stack.
4. Classification.
5. During training, error is calculated on all stacks.
'''
def __init__(self,
opt,
projection_mode='orthogonal',
error_term=nn.MSELoss(),
):
super(HGPIFuNet, self).__init__(
projection_mode=projection_mode,
error_term=error_term)
self.name = 'hgpifu'
self.opt = opt
self.num_views = self.opt.num_views
self.image_filter = HGFilter(opt)
self.surface_classifier = SurfaceClassifier(
filter_channels=self.opt.mlp_dim,
num_views=self.opt.num_views,
no_residual=self.opt.no_residual,
last_op=nn.Sigmoid())
self.normalizer = DepthNormalizer(opt)
# This is a list of [B x Feat_i x H x W] features
self.im_feat_list = []
self.tmpx = None
self.normx = None
self.intermediate_preds_list = []
init_net(self)
def filter(self, images):
'''
Filter the input images
store all intermediate features.
:param images: [B, C, H, W] input images
'''
self.im_feat_list, self.tmpx, self.normx = self.image_filter(images)
# If it is not in training, only produce the last im_feat
if not self.training:
self.im_feat_list = [self.im_feat_list[-1]]
def query(self, points, calibs, transforms=None, labels=None):
'''
Given 3D points, query the network predictions for each point.
Image features should be pre-computed before this call.
store all intermediate features.
query() function may behave differently during training/testing.
:param points: [B, 3, N] world space coordinates of points
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:param labels: Optional [B, Res, N] gt labeling
:return: [B, Res, N] predictions for each point
'''
if labels is not None:
self.labels = labels
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
z = xyz[:, 2:3, :]
in_img = (xy[:, 0] >= -1.0) & (xy[:, 0] <= 1.0) & (xy[:, 1] >= -1.0) & (xy[:, 1] <= 1.0)
z_feat = self.normalizer(z, calibs=calibs)
if self.opt.skip_hourglass:
tmpx_local_feature = self.index(self.tmpx, xy)
self.intermediate_preds_list = []
for im_feat in self.im_feat_list:
# [B, Feat_i + z, N]
point_local_feat_list = [self.index(im_feat, xy), z_feat]
if self.opt.skip_hourglass:
point_local_feat_list.append(tmpx_local_feature)
point_local_feat = torch.cat(point_local_feat_list, 1)
# out of image plane is always set to 0
pred = in_img[:,None].float() * self.surface_classifier(point_local_feat)
self.intermediate_preds_list.append(pred)
self.preds = self.intermediate_preds_list[-1]
def get_im_feat(self):
'''
Get the image filter
:return: [B, C_feat, H, W] image feature after filtering
'''
return self.im_feat_list[-1]
def get_error(self):
'''
Hourglass has its own intermediate supervision scheme
'''
error = 0
for preds in self.intermediate_preds_list:
error += self.error_term(preds, self.labels)
error /= len(self.intermediate_preds_list)
return error
def forward(self, images, points, calibs, transforms=None, labels=None):
# Get image feature
self.filter(images)
# Phase 2: point query
self.query(points=points, calibs=calibs, transforms=transforms, labels=labels)
# get the prediction
res = self.get_preds()
# get the error
error = self.get_error()
return res, error |