|
from torch import nn, Tensor |
|
from einops import rearrange |
|
from typing import Dict, Optional |
|
from config import AutoConfig |
|
|
|
|
|
from timm.layers.mlp import Mlp |
|
from timm.layers.norm import LayerNorm |
|
|
|
|
|
|
|
class SubjectBehaviorEmbed(nn.Module): |
|
def __init__( |
|
self, |
|
subject_list, |
|
in_dim, |
|
dim, |
|
dropout=0.2, |
|
): |
|
super().__init__() |
|
self.subject_list = subject_list |
|
|
|
self.embed = nn.ModuleDict() |
|
for subject in self.subject_list: |
|
block = nn.Sequential( |
|
nn.Linear(in_dim, dim), |
|
nn.GELU(), |
|
) |
|
self.embed[subject] = block |
|
|
|
self.mlp = Mlp(dim, out_features=dim) |
|
|
|
self.dropout = nn.Sequential( |
|
nn.Unflatten(1, (dim, 1)), |
|
nn.Dropout1d(dropout), |
|
nn.Flatten(1, -1), |
|
) |
|
def forward(self, c: Tensor, subject: str): |
|
if c is not None: |
|
c = self.embed[subject](c) |
|
c = self.mlp(c) |
|
c = self.dropout(c) |
|
|
|
return c |
|
|
|
|
|
def build_behavior_embed(cfg: AutoConfig, out_dim=None): |
|
out_dim = out_dim or cfg.MODEL.COND.DIM |
|
return SubjectBehaviorEmbed( |
|
subject_list=cfg.DATASET.SUBJECT_LIST, |
|
in_dim=cfg.MODEL.COND.IN_DIM, |
|
dim=out_dim, |
|
dropout=cfg.MODEL.COND.DROPOUT, |
|
) |
|
|