|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
class StylizationBlock(nn.Module): |
|
|
|
def __init__(self, latent_dim, time_embed_dim, dropout): |
|
super().__init__() |
|
self.emb_layers = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(time_embed_dim, 2 * latent_dim), |
|
) |
|
self.norm = nn.LayerNorm(latent_dim) |
|
self.out_layers = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Dropout(p=dropout), |
|
zero_module(nn.Linear(latent_dim, latent_dim)), |
|
) |
|
|
|
def forward(self, h, emb): |
|
""" |
|
h: B, T, D |
|
emb: B, D |
|
""" |
|
|
|
emb_out = self.emb_layers(emb).unsqueeze(1) |
|
|
|
scale, shift = torch.chunk(emb_out, 2, dim=2) |
|
h = self.norm(h) * (1 + scale) + shift |
|
h = self.out_layers(h) |
|
return h |
|
|