Minuano's picture
Add applications files and generator saved models.
a847ff6
raw
history blame
1.07 kB
from layers import *
from torch.nn import functional as F
class CycleGenerator(nn.Module):
def __init__(self, conv_dim=64):
super(CycleGenerator, self).__init__()
self.conv1 = conv(3, conv_dim, 4)
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
# experiment with number of residual_blocks
self.res_block1 = ResidualBlock(conv_dim * 2)
self.res_block2 = ResidualBlock(conv_dim * 2)
self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
self.deconv2 = deconv(conv_dim, 3, 4, norm=False)
def forward(self, x):
"""Generates an image conditioned on an input image.
Input
-----
x: BS x 3 x 32 x 32
Output
------
out: BS x 3 x 32 x 32
"""
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = F.relu(self.res_block1(out))
out = F.relu(self.res_block2(out))
out = F.relu(self.deconv1(out))
out = F.tanh(self.deconv2(out))
return out