Spaces:
Sleeping
Sleeping
File size: 1,807 Bytes
56cd6b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# post process function for all heads: extract 3D points/confidence from output
# --------------------------------------------------------
import torch
def postprocess(out, depth_mode, conf_mode):
"""
extract 3D points/confidence from prediction head output
""" # out的通道数为4,即分别表示三维点云的xyz坐标值和conf置信度
fmap = out.permute(0, 2, 3, 1) # B=1,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
if conf_mode is not None:
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
return res
def reg_dense_depth(xyz, mode):
"""
extract 3D points from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
assert no_bounds
if mode == 'linear':
if no_bounds:
return xyz # [-inf, +inf]
return xyz.clip(min=vmin, max=vmax)
# distance to origin
d = xyz.norm(dim=-1, keepdim=True) # 对channel维度,即对x,y,z三个坐标值求第二范式
xyz = xyz / d.clip(min=1e-8) # 除以上面的norm,即归一化
if mode == 'square':
return xyz * d.square()
if mode == 'exp':
return xyz * torch.expm1(d)
raise ValueError(f'bad {mode=}')
def reg_dense_conf(x, mode):
"""
extract confidence from prediction head output
"""
mode, vmin, vmax = mode
if mode == 'exp':
return vmin + x.exp().clip(max=vmax-vmin)
if mode == 'sigmoid':
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f'bad {mode=}')
|