File size: 8,700 Bytes
a7299bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import torch
import numpy as np
import torch.nn as nn
def init_image_coor(height, width):
x_row = np.arange(0, width)
x = np.tile(x_row, (height, 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()).cuda()
u_u0 = x - width/2.0
y_col = np.arange(0, height) # y_col = np.arange(0, height)
y = np.tile(y_col, (width, 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()).cuda()
v_v0 = y - height/2.0
return u_u0, v_v0
def depth_to_xyz(depth, focal_length):
b, c, h, w = depth.shape
u_u0, v_v0 = init_image_coor(h, w)
x = u_u0 * depth / focal_length
y = v_v0 * depth / focal_length
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def get_surface_normal(xyz, patch_size=3):
# xyz: [1, h, w, 3]
x, y, z = torch.unbind(xyz, dim=3)
x = torch.unsqueeze(x, 0)
y = torch.unsqueeze(y, 0)
z = torch.unsqueeze(z, 0)
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
dim=4)
ATA = torch.squeeze(ATA)
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
ATA = ATA + eps_identity
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
AT1 = torch.squeeze(AT1)
AT1 = torch.unsqueeze(AT1, 3)
patch_num = 4
patch_x = int(AT1.size(1) / patch_num)
patch_y = int(AT1.size(0) / patch_num)
n_img = torch.randn(AT1.shape).cuda()
overlap = patch_size // 2 + 1
for x in range(int(patch_num)):
for y in range(int(patch_num)):
left_flg = 0 if x == 0 else 1
right_flg = 0 if x == patch_num -1 else 1
top_flg = 0 if y == 0 else 1
btm_flg = 0 if y == patch_num - 1 else 1
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
n_img_tmp, _ = torch.solve(at1, ata)
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
n_img_norm = n_img / n_img_L2
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
n_img_norm[orient_mask] *= -1
return n_img_norm
def get_surface_normalv2(xyz, patch_size=3):
"""
xyz: xyz coordinates
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
# para depth: depth map, [b, c, h, w]
b, c, h, w = depth.shape
focal_length = focal_length[:, None, None, None]
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
xyz = depth_to_xyz(depth_filter, focal_length)
sn_batch = []
for i in range(b):
xyz_i = xyz[i, :][None, :, :, :]
normal = get_surface_normalv2(xyz_i)
sn_batch.append(normal)
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
sn_batch[mask_invalid] = 0.0
return sn_batch
def vis_normal(normal):
"""
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
@para normal: surface normal, [h, w, 3], numpy.array
"""
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
n_img_norm = normal / (n_img_L2 + 1e-8)
normal_vis = n_img_norm * 127
normal_vis += 128
normal_vis = normal_vis.astype(np.uint8)
return normal_vis
def vis_normal2(normals):
'''
Montage of normal maps. Vectors are unit length and backfaces thresholded.
'''
x = normals[:, :, 0] # horizontal; pos right
y = normals[:, :, 1] # depth; pos far
z = normals[:, :, 2] # vertical; pos up
backfacing = (z > 0)
norm = np.sqrt(np.sum(normals**2, axis=2))
zero = (norm < 1e-5)
x += 1.0; x *= 0.5
y += 1.0; y *= 0.5
z = np.abs(z)
x[zero] = 0.0
y[zero] = 0.0
z[zero] = 0.0
normals[:, :, 0] = x # horizontal; pos right
normals[:, :, 1] = y # depth; pos far
normals[:, :, 2] = z # vertical; pos up
return normals
if __name__ == '__main__':
import cv2, os |