hpwang's picture
[Init]
fd5e0f7
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
# Field of View network architecture.
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
class FOVNetwork(nn.Module):
"""Field of View estimation network."""
def __init__(
self,
num_features: int,
fov_encoder: Optional[nn.Module] = None,
):
"""Initialize the Field of View estimation block.
Args:
----
num_features: Number of features used.
fov_encoder: Optional encoder to bring additional network capacity.
"""
super().__init__()
# Create FOV head.
fov_head0 = [
nn.Conv2d(
num_features, num_features // 2, kernel_size=3, stride=2, padding=1
), # 128 x 24 x 24
nn.ReLU(True),
]
fov_head = [
nn.Conv2d(
num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1
), # 64 x 12 x 12
nn.ReLU(True),
nn.Conv2d(
num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1
), # 32 x 6 x 6
nn.ReLU(True),
nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0),
]
if fov_encoder is not None:
self.encoder = nn.Sequential(
fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2)
)
self.downsample = nn.Sequential(*fov_head0)
else:
fov_head = fov_head0 + fov_head
self.head = nn.Sequential(*fov_head)
def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor:
"""Forward the fov network.
Args:
----
x (torch.Tensor): Input image.
lowres_feature (torch.Tensor): Low resolution feature.
Returns:
-------
The field of view tensor.
"""
if hasattr(self, "encoder"):
x = F.interpolate(
x,
size=None,
scale_factor=0.25,
mode="bilinear",
align_corners=False,
)
x = self.encoder(x)[:, 1:].permute(0, 2, 1)
lowres_feature = self.downsample(lowres_feature)
x = x.reshape_as(lowres_feature) + lowres_feature
else:
x = lowres_feature
return self.head(x)