File size: 1,491 Bytes
ec378c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from torch import nn, Tensor
from einops import rearrange
from typing import Dict, Optional
from config import AutoConfig
from timm.layers.mlp import Mlp
from timm.layers.norm import LayerNorm
class SubjectBehaviorEmbed(nn.Module):
def __init__(
self,
subject_list,
in_dim,
dim,
dropout=0.2, # dropout for handle behavior data free case
):
super().__init__()
self.subject_list = subject_list
self.embed = nn.ModuleDict()
for subject in self.subject_list:
block = nn.Sequential(
nn.Linear(in_dim, dim),
nn.GELU(),
)
self.embed[subject] = block
self.mlp = Mlp(dim, out_features=dim)
self.dropout = nn.Sequential(
nn.Unflatten(1, (dim, 1)), # [B, D, 1]
nn.Dropout1d(dropout), # dropout on the entire D
nn.Flatten(1, -1), # [B, D]
)
def forward(self, c: Tensor, subject: str):
if c is not None:
c = self.embed[subject](c)
c = self.mlp(c)
c = self.dropout(c)
# dropout in training but not validation
return c
def build_behavior_embed(cfg: AutoConfig, out_dim=None):
out_dim = out_dim or cfg.MODEL.COND.DIM
return SubjectBehaviorEmbed(
subject_list=cfg.DATASET.SUBJECT_LIST,
in_dim=cfg.MODEL.COND.IN_DIM,
dim=out_dim,
dropout=cfg.MODEL.COND.DROPOUT,
)
|