import torch.nn as nn | |
def build_MLP(dim_list, latent_dim): | |
model_list = [] | |
prev = dim_list[0] | |
for cur in dim_list[1:]: | |
model_list.append(nn.Linear(prev, cur)) | |
model_list.append(nn.GELU()) | |
prev = cur | |
model_list.append(nn.Linear(prev, latent_dim)) | |
model = nn.Sequential(*model_list) | |
return model | |