File size: 344 Bytes
373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch.nn as nn
def build_MLP(dim_list, latent_dim):
model_list = []
prev = dim_list[0]
for cur in dim_list[1:]:
model_list.append(nn.Linear(prev, cur))
model_list.append(nn.GELU())
prev = cur
model_list.append(nn.Linear(prev, latent_dim))
model = nn.Sequential(*model_list)
return model
|