i3d / spar3d /models /camera.py
John6666's picture
Upload 77 files
b572032 verified
raw
history blame contribute delete
993 Bytes
from dataclasses import dataclass, field
from typing import List
import torch
import torch.nn as nn
from spar3d.models.utils import BaseModule
class LinearCameraEmbedder(BaseModule):
@dataclass
class Config(BaseModule.Config):
in_channels: int = 25
out_channels: int = 768
conditions: List[str] = field(default_factory=list)
cfg: Config
def configure(self) -> None:
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
def forward(self, **kwargs):
cond_tensors = []
for cond_name in self.cfg.conditions:
assert cond_name in kwargs
cond = kwargs[cond_name]
# cond in shape (B, Nv, ...)
cond_tensors.append(cond.view(*cond.shape[:2], -1))
cond_tensor = torch.cat(cond_tensors, dim=-1)
assert cond_tensor.shape[-1] == self.cfg.in_channels
embedding = self.linear(cond_tensor)
return embedding