splatt3r / src /mast3r_src /mast3r /catmlp_dpt_head.py
brandonsmart's picture
Initial commit
5ed9923
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# MASt3R heads
# --------------------------------------------------------
import torch
import torch.nn.functional as F
from einops import rearrange
import mast3r.utils.path_to_dust3r # noqa
from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
import dust3r.utils.path_to_croco # noqa
from models.blocks import Mlp # noqa
def reg_desc(desc, mode):
if 'norm' in mode:
desc = desc / desc.norm(dim=-1, keepdim=True)
else:
raise ValueError(f"Unknown desc mode {mode}")
return desc
def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
if desc_conf_mode is None:
desc_conf_mode = conf_mode
fmap = out.permute(0, 2, 3, 1) # B,H,W,D
res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
if conf_mode is not None:
res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
if desc_dim is not None:
start = 3 + int(conf_mode is not None)
res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
if two_confs:
res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
else:
res['desc_conf'] = res['conf'].clone()
return res
class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
""" Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
The input for both heads is a concatenation of Encoder and Decoder outputs
"""
def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
self.local_feat_dim = local_feat_dim
patch_size = net.patch_embed.patch_size
if isinstance(patch_size, tuple):
assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
patch_size = patch_size[0]
self.patch_size = patch_size
self.desc_mode = net.desc_mode
self.has_conf = has_conf
self.two_confs = net.two_confs # independent confs for 3D regr and descs
self.desc_conf_mode = net.desc_conf_mode
idim = net.enc_embed_dim + net.dec_embed_dim
self.head_local_features = Mlp(in_features=idim,
hidden_features=int(hidden_dim_factor * idim),
out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
def forward(self, decout, img_shape):
# pass through the heads
pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
# recover encoder and decoder outputs
enc_output, dec_output = decout[0], decout[-1]
cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
H, W = img_shape
B, S, D = cat_output.shape
# extract local_features
local_features = self.head_local_features(cat_output) # B,S,D
local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
# post process 3D pts, descriptors and confidences
out = torch.cat([pts3d, local_features], dim=1)
if self.postprocess:
out = self.postprocess(out,
depth_mode=self.depth_mode,
conf_mode=self.conf_mode,
desc_dim=self.local_feat_dim,
desc_mode=self.desc_mode,
two_confs=self.two_confs,
desc_conf_mode=self.desc_conf_mode)
return out
# @MODIFIED
def reg_dense_offsets(xyz, shift=6.0):
"""
Apply an activation function to the offsets so that they are small at initialization
"""
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
offsets = xyz * (torch.exp(d - shift) - torch.exp(torch.zeros_like(d) - shift))
return offsets
# @MODIFIED
def reg_dense_scales(scales):
"""
Apply an activation function to the offsets so that they are small at initialization
"""
scales = scales.exp()
return scales
# @MODIFIED
def reg_dense_rotation(rotations, eps=1e-8):
"""
Apply PixelSplat's rotation normalization
"""
return rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
# @MODIFIED
def reg_dense_sh(sh):
"""
Apply PixelSplat's spherical harmonic postprocessing
"""
sh = rearrange(sh, '... (xyz d_sh) -> ... xyz d_sh', xyz=3)
return sh
# @MODIFIED
def reg_dense_opacities(opacities):
"""
Apply PixelSplat's opacity postprocessing
"""
return opacities.sigmoid()
# @MODIFIED
def gaussian_postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None, use_offsets=False, sh_degree=1):
if desc_conf_mode is None:
desc_conf_mode = conf_mode
fmap = out.permute(0, 2, 3, 1) # B,H,W,D
assert conf_mode is not None, "Confidence mode must be provided for Gaussian head"
assert desc_dim is not None, "Descriptor dimension must be provided for Gaussian head"
assert two_confs, "Two confidences must be provided for Gaussian head"
pts3d, conf, desc, desc_conf, offset, scales, rotations, sh, opacities = torch.split(fmap, [3, 1, desc_dim, 1, 3, 3, 4, 3 * sh_degree, 1], dim=-1)
pts3d = reg_dense_depth(pts3d, mode=depth_mode)
conf = reg_dense_conf(conf.squeeze(-1), mode=conf_mode)
desc = reg_desc(desc, mode=desc_mode)
desc_conf = reg_dense_conf(desc_conf.squeeze(-1), mode=desc_conf_mode)
offset = reg_dense_offsets(offset)
scales = reg_dense_scales(scales)
rotations = reg_dense_rotation(rotations)
sh = reg_dense_sh(sh)
opacities = reg_dense_opacities(opacities)
res = {
'pts3d': pts3d,
'conf': conf,
'desc': desc,
'desc_conf': desc_conf,
'scales': scales,
'rotations': rotations,
'sh': sh,
'opacities': opacities
}
if use_offsets:
res['means'] = pts3d.detach() + offset
else:
res['means'] = pts3d.detach()
return res
# @MODIFIED
class GaussianHead(PixelwiseTaskWithDPT):
"""Version of the above, modified to also output Gaussian parameters"""
def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", use_offsets=False, sh_degree=1, **kwargs):
super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
self.local_feat_dim = local_feat_dim
patch_size = net.patch_embed.patch_size
if isinstance(patch_size, tuple):
assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
patch_size = patch_size[0]
self.patch_size = patch_size
self.desc_mode = net.desc_mode
self.has_conf = has_conf
self.two_confs = net.two_confs # independent confs for 3D regr and descs
self.desc_conf_mode = net.desc_conf_mode
idim = net.enc_embed_dim + net.dec_embed_dim
self.head_local_features = Mlp(in_features=idim,
hidden_features=int(hidden_dim_factor * idim),
out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
# Gaussian Num Channels =
# 3D mean offsets (3) +
# Scales (3) +
# Rotations (4) +
# Spherical Harmonics (3 * sh_degree) +
# Opacity (1)
gaussian_num_channels = 3 + 3 + 4 + 3 * sh_degree + 1
self.gaussian_dpt = PixelwiseTaskWithDPT(
num_channels=gaussian_num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type
)
final_conv_layer = self.gaussian_dpt.dpt.head[-1]
splits_and_inits = [
(3, 0.001, 0.001), # 3D mean offsets
(3, 0.00003, -7.0), # Scales
(4, 1.0, 0.0), # Rotations
(3 * sh_degree, 1.0, 0.0), # Spherical Harmonics
(1, 1.0, -2.0) # Opacity
]
start_channels = 0
for out_channel, s, b in splits_and_inits:
torch.nn.init.xavier_uniform_(
final_conv_layer.weight[start_channels:start_channels+out_channel, :, :, :],
s
)
torch.nn.init.constant_(
final_conv_layer.bias[start_channels:start_channels+out_channel],
b
)
start_channels += out_channel
self.use_offsets = use_offsets
self.sh_degree = sh_degree
def forward(self, decout, img_shape):
# pass through the heads
pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
# recover encoder and decoder outputs
enc_output, dec_output = decout[0], decout[-1]
cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
H, W = img_shape
B, S, D = cat_output.shape
# extract local_features
local_features = self.head_local_features(cat_output) # B,S,D
local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
# extract gaussian_features
gaussian_features = self.gaussian_dpt.dpt(decout, image_size=(img_shape[0], img_shape[1]))
# gaussian_features = self.gaussian_local_features(cat_output) # B,S,D
# gaussian_features = gaussian_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
# gaussian_features = F.pixel_shuffle(gaussian_features, self.patch_size) # B,d,H,W
# post process 3D pts, descriptors and confidences
out = torch.cat([pts3d, local_features, gaussian_features], dim=1)
if self.postprocess:
out = gaussian_postprocess(out,
depth_mode=self.depth_mode,
conf_mode=self.conf_mode,
desc_dim=self.local_feat_dim,
desc_mode=self.desc_mode,
two_confs=self.two_confs,
desc_conf_mode=self.desc_conf_mode,
use_offsets=self.use_offsets,
sh_degree=self.sh_degree)
return out
def mast3r_head_factory(head_type, output_mode, net, has_conf=False, use_offsets=False, sh_degree=1):
"""" build a prediction head for the decoder
"""
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
local_feat_dim = int(output_mode[10:])
assert net.dec_depth > 9
l2 = net.dec_depth
feature_dim = 256
last_dim = feature_dim // 2
out_nchan = 3
ed = net.enc_embed_dim
dd = net.dec_embed_dim
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
num_channels=out_nchan + has_conf,
feature_dim=feature_dim,
last_dim=last_dim,
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
dim_tokens=[ed, dd, dd, dd],
postprocess=postprocess,
depth_mode=net.depth_mode,
conf_mode=net.conf_mode,
head_type='regression')
# @MODIFIED
elif head_type == 'gaussian_head' and output_mode.startswith('pts3d+gaussian+desc'):
local_feat_dim = int(output_mode[19:])
assert net.dec_depth > 9
l2 = net.dec_depth
feature_dim = 256
last_dim = feature_dim // 2
out_nchan = 3
ed = net.enc_embed_dim
dd = net.dec_embed_dim
return GaussianHead(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
num_channels=out_nchan + has_conf,
feature_dim=feature_dim,
last_dim=last_dim,
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
dim_tokens=[ed, dd, dd, dd],
postprocess=postprocess,
depth_mode=net.depth_mode,
conf_mode=net.conf_mode,
head_type='regression',
use_offsets=use_offsets,
sh_degree=sh_degree)
else:
raise NotImplementedError(
f"unexpected {head_type=} and {output_mode=}")