Spaces:
Runtime error
Runtime error
from layers import * | |
from torch.nn import functional as F | |
class CycleGenerator(nn.Module): | |
def __init__(self, conv_dim=64): | |
super(CycleGenerator, self).__init__() | |
self.conv1 = conv(3, conv_dim, 4) | |
self.conv2 = conv(conv_dim, conv_dim * 2, 4) | |
# experiment with number of residual_blocks | |
self.res_block1 = ResidualBlock(conv_dim * 2) | |
self.res_block2 = ResidualBlock(conv_dim * 2) | |
self.deconv1 = deconv(conv_dim * 2, conv_dim, 4) | |
self.deconv2 = deconv(conv_dim, 3, 4, norm=False) | |
def forward(self, x): | |
"""Generates an image conditioned on an input image. | |
Input | |
----- | |
x: BS x 3 x 32 x 32 | |
Output | |
------ | |
out: BS x 3 x 32 x 32 | |
""" | |
out = F.relu(self.conv1(x)) | |
out = F.relu(self.conv2(out)) | |
out = F.relu(self.res_block1(out)) | |
out = F.relu(self.res_block2(out)) | |
out = F.relu(self.deconv1(out)) | |
out = F.tanh(self.deconv2(out)) | |
return out |