hpwang's picture
[Init]
fd5e0f7
"""Copyright (C) 2024 Apple Inc. All Rights Reserved.
Dense Prediction Transformer Decoder architecture.
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch import nn
class MultiresConvDecoder(nn.Module):
"""Decoder for multi-resolution encodings."""
def __init__(
self,
dims_encoder: Iterable[int],
dim_decoder: int,
):
"""Initialize multiresolution convolutional decoder.
Args:
----
dims_encoder: Expected dims at each level from the encoder.
dim_decoder: Dim of decoder features.
"""
super().__init__()
self.dims_encoder = list(dims_encoder)
self.dim_decoder = dim_decoder
self.dim_out = dim_decoder
num_encoders = len(self.dims_encoder)
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
# when the dimensions mismatch. Otherwise we do not do anything, which is
# the default behavior of monodepth.
conv0 = (
nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
if self.dims_encoder[0] != dim_decoder
else nn.Identity()
)
convs = [conv0]
for i in range(1, num_encoders):
convs.append(
nn.Conv2d(
self.dims_encoder[i],
dim_decoder,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
)
self.convs = nn.ModuleList(convs)
fusions = []
for i in range(num_encoders):
fusions.append(
FeatureFusionBlock2d(
num_features=dim_decoder,
deconv=(i != 0),
batch_norm=False,
)
)
self.fusions = nn.ModuleList(fusions)
def forward(self, encodings: torch.Tensor) -> torch.Tensor:
"""Decode the multi-resolution encodings."""
num_levels = len(encodings)
num_encoders = len(self.dims_encoder)
if num_levels != num_encoders:
raise ValueError(
f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
)
# Project features of different encoder dims to the same decoder dim.
# Fuse features from the lowest resolution (num_levels-1)
# to the highest (0).
features = self.convs[-1](encodings[-1])
lowres_features = features
features = self.fusions[-1](features)
for i in range(num_levels - 2, -1, -1):
features_i = self.convs[i](encodings[i])
features = self.fusions[i](features, features_i)
return features, lowres_features
class ResidualBlock(nn.Module):
"""Generic implementation of residual blocks.
This implements a generic residual block from
He et al. - Identity Mappings in Deep Residual Networks (2016),
https://arxiv.org/abs/1603.05027
which can be further customized via factory functions.
"""
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
"""Initialize ResidualBlock."""
super().__init__()
self.residual = residual
self.shortcut = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply residual block."""
delta_x = self.residual(x)
if self.shortcut is not None:
x = self.shortcut(x)
return x + delta_x
class FeatureFusionBlock2d(nn.Module):
"""Feature fusion for DPT."""
def __init__(
self,
num_features: int,
deconv: bool = False,
batch_norm: bool = False,
):
"""Initialize feature fusion block.
Args:
----
num_features: Input and output dimensions.
deconv: Whether to use deconv before the final output conv.
batch_norm: Whether to use batch normalization in resnet blocks.
"""
super().__init__()
self.resnet1 = self._residual_block(num_features, batch_norm)
self.resnet2 = self._residual_block(num_features, batch_norm)
self.use_deconv = deconv
if deconv:
self.deconv = nn.ConvTranspose2d(
in_channels=num_features,
out_channels=num_features,
kernel_size=2,
stride=2,
padding=0,
bias=False,
)
self.out_conv = nn.Conv2d(
num_features,
num_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
"""Process and fuse input features."""
x = x0
if x1 is not None:
res = self.resnet1(x1)
x = self.skip_add.add(x, res)
x = self.resnet2(x)
if self.use_deconv:
x = self.deconv(x)
x = self.out_conv(x)
return x
@staticmethod
def _residual_block(num_features: int, batch_norm: bool):
"""Create a residual block."""
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
layers = [
nn.ReLU(False),
nn.Conv2d(
num_features,
num_features,
kernel_size=3,
stride=1,
padding=1,
bias=not batch_norm,
),
]
if batch_norm:
layers.append(nn.BatchNorm2d(dim))
return layers
residual = nn.Sequential(
*_create_block(dim=num_features, batch_norm=batch_norm),
*_create_block(dim=num_features, batch_norm=batch_norm),
)
return ResidualBlock(residual)