prajwalsahu5's picture
Upload 141 files
b3c2eb7
raw
history blame
2.53 kB
import torch
from torch import nn
def make_encoder(input_dim, enc_dec_dims):
encoder_layers = []
decoder_layers = []
output_dim = input_dim
enc_shape = enc_dec_dims[-1]
for enc_dim in enc_dec_dims[:-1]:
encoder_layers.extend([nn.Linear(input_dim, enc_dim), nn.SELU()])
input_dim = enc_dim
encoder_layers.append(nn.Linear(input_dim, enc_shape))
enc_dec_dims = list(reversed(enc_dec_dims))
for dec_dim in enc_dec_dims[1:]:
decoder_layers.extend([nn.Linear(enc_shape, dec_dim), nn.SELU()])
enc_shape = dec_dim
decoder_layers.append(nn.Linear(enc_shape, output_dim))
return nn.Sequential(*encoder_layers), nn.Sequential(*decoder_layers)
class FsrFgModel(nn.Module):
def __init__(self, fg_input_dim, mfg_input_dim, num_input_dim, enc_dec_dims, output_dims,
num_tasks, dropout, method):
super(FsrFgModel, self).__init__()
self.method = method
if self.method == 'FG':
input_dim = fg_input_dim
elif self.method == 'MFG':
input_dim = mfg_input_dim
elif self.method == 'FGR':
input_dim = fg_input_dim + mfg_input_dim
else:
input_dim = fg_input_dim + mfg_input_dim
if self.method != 'FGR_desc':
fcn_input_dim = enc_dec_dims[-1]
else:
fcn_input_dim = num_input_dim + enc_dec_dims[-1]
self.encoder, self.decoder = make_encoder(input_dim, enc_dec_dims)
self.dropout = nn.Dropout(dropout)
self.predict_out_dim = num_tasks
self.batch_norm = nn.BatchNorm1d(fcn_input_dim)
layers = []
for output_dim in output_dims:
layers.extend([nn.Linear(fcn_input_dim, output_dim), nn.SELU(), nn.BatchNorm1d(output_dim)])
fcn_input_dim = output_dim
layers.extend([self.dropout, nn.Linear(fcn_input_dim, num_tasks)])
self.predictor = nn.Sequential(*layers)
def forward(self, fg=None, mfg=None, num_features=None):
if self.method == 'FG':
z_d = self.encoder(fg)
elif self.method == 'MFG':
z_d = self.encoder(mfg)
elif self.method == 'FGR':
z_d = self.encoder(torch.cat([fg, mfg], dim=1))
else:
z_d = self.encoder(torch.cat([fg, mfg], dim=1))
v_d_hat = self.decoder(z_d)
if self.method == 'FGR_desc':
z_d = torch.cat([z_d, num_features], dim=1)
output = self.predictor(z_d)
return output, v_d_hat