File size: 344 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch.nn as nn


def build_MLP(dim_list, latent_dim):
    model_list = []
    prev = dim_list[0]
    for cur in dim_list[1:]:
        model_list.append(nn.Linear(prev, cur))
        model_list.append(nn.GELU())
        prev = cur
    model_list.append(nn.Linear(prev, latent_dim))
    model = nn.Sequential(*model_list)
    return model