|
from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from spar3d.models.utils import BaseModule
|
|
|
|
|
|
class LinearCameraEmbedder(BaseModule):
|
|
@dataclass
|
|
class Config(BaseModule.Config):
|
|
in_channels: int = 25
|
|
out_channels: int = 768
|
|
conditions: List[str] = field(default_factory=list)
|
|
|
|
cfg: Config
|
|
|
|
def configure(self) -> None:
|
|
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
|
|
|
def forward(self, **kwargs):
|
|
cond_tensors = []
|
|
for cond_name in self.cfg.conditions:
|
|
assert cond_name in kwargs
|
|
cond = kwargs[cond_name]
|
|
|
|
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
|
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
|
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
|
embedding = self.linear(cond_tensor)
|
|
return embedding
|
|
|