File size: 1,074 Bytes
a847ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from layers import *
from torch.nn import functional as F

class CycleGenerator(nn.Module):
    def __init__(self, conv_dim=64):
        super(CycleGenerator, self).__init__()
        
        self.conv1 = conv(3, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        # experiment with number of residual_blocks
        self.res_block1 = ResidualBlock(conv_dim * 2)
        self.res_block2 = ResidualBlock(conv_dim * 2)
        self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 3, 4, norm=False)
    
    def forward(self, x):
        """Generates an image conditioned on an input image.
            Input
            -----
                x: BS x 3 x 32 x 32
            Output
            ------
                out: BS x 3 x 32 x 32
        """
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.res_block1(out))
        out = F.relu(self.res_block2(out))
        out = F.relu(self.deconv1(out))
        out = F.tanh(self.deconv2(out))
        
        return out