File size: 1,036 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class StylizationBlock(nn.Module):
def __init__(self, latent_dim, time_embed_dim, dropout):
super().__init__()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, 2 * latent_dim),
)
self.norm = nn.LayerNorm(latent_dim)
self.out_layers = nn.Sequential(
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Linear(latent_dim, latent_dim)),
)
def forward(self, h, emb):
"""
h: B, T, D
emb: B, D
"""
# B, 1, 2D
emb_out = self.emb_layers(emb).unsqueeze(1)
# scale: B, 1, D / shift: B, 1, D
scale, shift = torch.chunk(emb_out, 2, dim=2)
h = self.norm(h) * (1 + scale) + shift
h = self.out_layers(h)
return h
|