File size: 870 Bytes
d2e7940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from torch import nn
class SimpleDenseNet(nn.Module):
def __init__(self, hparams: dict):
super().__init__()
self.model = nn.Sequential(
nn.Linear(hparams["input_size"], hparams["lin1_size"]),
nn.BatchNorm1d(hparams["lin1_size"]),
nn.ReLU(),
nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
nn.BatchNorm1d(hparams["lin2_size"]),
nn.ReLU(),
nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
nn.BatchNorm1d(hparams["lin3_size"]),
nn.ReLU(),
nn.Linear(hparams["lin3_size"], hparams["output_size"]),
)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (batch, 1, width, height) -> (batch, 1*width*height)
x = x.view(batch_size, -1)
return self.model(x)
|