from layers import * from torch.nn import functional as F class CycleGenerator(nn.Module): def __init__(self, conv_dim=64): super(CycleGenerator, self).__init__() self.conv1 = conv(3, conv_dim, 4) self.conv2 = conv(conv_dim, conv_dim * 2, 4) # experiment with number of residual_blocks self.res_block1 = ResidualBlock(conv_dim * 2) self.res_block2 = ResidualBlock(conv_dim * 2) self.deconv1 = deconv(conv_dim * 2, conv_dim, 4) self.deconv2 = deconv(conv_dim, 3, 4, norm=False) def forward(self, x): """Generates an image conditioned on an input image. Input ----- x: BS x 3 x 32 x 32 Output ------ out: BS x 3 x 32 x 32 """ out = F.relu(self.conv1(x)) out = F.relu(self.conv2(out)) out = F.relu(self.res_block1(out)) out = F.relu(self.res_block2(out)) out = F.relu(self.deconv1(out)) out = F.tanh(self.deconv2(out)) return out