Spaces:
Sleeping
Sleeping
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# post process function for all heads: extract 3D points/confidence from output | |
# -------------------------------------------------------- | |
import torch | |
def postprocess(out, depth_mode, conf_mode): | |
""" | |
extract 3D points/confidence from prediction head output | |
""" # out的通道数为4,即分别表示三维点云的xyz坐标值和conf置信度 | |
fmap = out.permute(0, 2, 3, 1) # B=1,H,W,3 | |
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) | |
if conf_mode is not None: | |
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
return res | |
def reg_dense_depth(xyz, mode): | |
""" | |
extract 3D points from prediction head output | |
""" | |
mode, vmin, vmax = mode | |
no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) | |
assert no_bounds | |
if mode == 'linear': | |
if no_bounds: | |
return xyz # [-inf, +inf] | |
return xyz.clip(min=vmin, max=vmax) | |
# distance to origin | |
d = xyz.norm(dim=-1, keepdim=True) # 对channel维度,即对x,y,z三个坐标值求第二范式 | |
xyz = xyz / d.clip(min=1e-8) # 除以上面的norm,即归一化 | |
if mode == 'square': | |
return xyz * d.square() | |
if mode == 'exp': | |
return xyz * torch.expm1(d) | |
raise ValueError(f'bad {mode=}') | |
def reg_dense_conf(x, mode): | |
""" | |
extract confidence from prediction head output | |
""" | |
mode, vmin, vmax = mode | |
if mode == 'exp': | |
return vmin + x.exp().clip(max=vmax-vmin) | |
if mode == 'sigmoid': | |
return (vmax - vmin) * torch.sigmoid(x) + vmin | |
raise ValueError(f'bad {mode=}') | |