import torch import torch.nn as nn def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class StylizationBlock(nn.Module): def __init__(self, latent_dim, time_embed_dim, dropout): super().__init__() self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear(time_embed_dim, 2 * latent_dim), ) self.norm = nn.LayerNorm(latent_dim) self.out_layers = nn.Sequential( nn.SiLU(), nn.Dropout(p=dropout), zero_module(nn.Linear(latent_dim, latent_dim)), ) def forward(self, h, emb): """ h: B, T, D emb: B, D """ # B, 1, 2D emb_out = self.emb_layers(emb).unsqueeze(1) # scale: B, 1, D / shift: B, 1, D scale, shift = torch.chunk(emb_out, 2, dim=2) h = self.norm(h) * (1 + scale) + shift h = self.out_layers(h) return h