hwonheo commited on
Commit
f83098b
·
verified ·
1 Parent(s): eb0787f

Upload 3 files

Browse files
network/__init__.py ADDED
File without changes
network/discriminator.py CHANGED
@@ -10,12 +10,12 @@ def conv_block(ndf, in_channels, out_channels, kernel_size, stride, padding):
10
  )
11
 
12
  class PatchDiscriminator(nn.Module):
13
- def __init__(self, input_nc=1, ndf=16):
14
  """Initializes the Patch Discriminator model.
15
 
16
  Args:
17
  input_nc (int): Number of input channels. Default is 1 (e.g., for grayscale images).
18
- ndf (int): Number of filters in the first convolution layer. Default is 16.
19
  """
20
  super(PatchDiscriminator, self).__init__()
21
 
@@ -25,27 +25,25 @@ class PatchDiscriminator(nn.Module):
25
  self.conv3 = conv_block(ndf * 2, ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
26
  self.conv4 = conv_block(ndf * 4, ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
27
 
28
- # Final convolution layer to reduce to a single channel output
29
- self.conv5 = nn.Conv3d(ndf * 8, 1, kernel_size=4, padding=1)
30
 
31
  # Flatten layer
32
  self.flatten = nn.Flatten()
33
 
34
- # Fully connected layer to adjust output size
35
- self.fc = nn.Linear(539, 1) # Adjust '539' based on the flattened output size
36
 
37
- # Sigmoid activation to obtain a probability
38
  self.sigmoid = nn.Sigmoid()
39
 
40
  def forward(self, x):
41
- """Defines the forward pass of the discriminator."""
42
  x = self.conv1(x)
43
  x = self.conv2(x)
44
  x = self.conv3(x)
45
  x = self.conv4(x)
46
- x = self.conv5(x)
47
  x = self.flatten(x)
48
  x = self.fc(x)
49
  x = self.sigmoid(x)
50
  return x
51
-
 
10
  )
11
 
12
  class PatchDiscriminator(nn.Module):
13
+ def __init__(self, input_nc=1, ndf=3, output_size=1):
14
  """Initializes the Patch Discriminator model.
15
 
16
  Args:
17
  input_nc (int): Number of input channels. Default is 1 (e.g., for grayscale images).
18
+ ndf (int): Number of filters in the first convolution layer. Default is 64.
19
  """
20
  super(PatchDiscriminator, self).__init__()
21
 
 
25
  self.conv3 = conv_block(ndf * 2, ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
26
  self.conv4 = conv_block(ndf * 4, ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
27
 
28
+ # Final convolution layer
29
+ self.conv5 = nn.Conv3d(ndf * 8, ndf * 8, kernel_size=4, padding=1)
30
 
31
  # Flatten layer
32
  self.flatten = nn.Flatten()
33
 
34
+ # Fully connected layer
35
+ self.fc = nn.Linear(ndf * 8 * 7 * 11 * 7, output_size)
36
 
37
+ # Sigmoid activation
38
  self.sigmoid = nn.Sigmoid()
39
 
40
  def forward(self, x):
 
41
  x = self.conv1(x)
42
  x = self.conv2(x)
43
  x = self.conv3(x)
44
  x = self.conv4(x)
45
+ x = self.conv5(x)
46
  x = self.flatten(x)
47
  x = self.fc(x)
48
  x = self.sigmoid(x)
49
  return x
 
network/generator.py CHANGED
@@ -27,9 +27,8 @@ class ResnetBlock(nn.Module):
27
  class DeUpBlock(nn.Module):
28
  def __init__(self, inf, onf):
29
  super(DeUpBlock, self).__init__()
30
- # Upsampling only in the width dimension
31
  self.deupblock = nn.Sequential(
32
- nn.ConvTranspose3d(inf, onf, kernel_size=(1, 6, 1), stride=(1, 6, 1), padding=(0, 0, 0)),
33
  nn.LeakyReLU(0.2)
34
  )
35
 
@@ -38,7 +37,7 @@ class DeUpBlock(nn.Module):
38
 
39
  # Resnet Generator
40
  class ResnetGenerator(nn.Module):
41
- def __init__(self, input_nc=1, output_nc=1, ngf=16, n_residual_blocks=4):
42
  super(ResnetGenerator, self).__init__()
43
  self.n_residual_blocks = n_residual_blocks
44
 
@@ -66,4 +65,3 @@ class ResnetGenerator(nn.Module):
66
  x = self.conv_block2(y) + x
67
  x = self.deup(x)
68
  return self.conv3(x)
69
-
 
27
  class DeUpBlock(nn.Module):
28
  def __init__(self, inf, onf):
29
  super(DeUpBlock, self).__init__()
 
30
  self.deupblock = nn.Sequential(
31
+ nn.ConvTranspose3d(inf, onf, kernel_size=(1, 3, 1), stride=(1, 3, 1), padding=(0, 0, 0)),
32
  nn.LeakyReLU(0.2)
33
  )
34
 
 
37
 
38
  # Resnet Generator
39
  class ResnetGenerator(nn.Module):
40
+ def __init__(self, input_nc=1, output_nc=1, ngf=32, n_residual_blocks=2):
41
  super(ResnetGenerator, self).__init__()
42
  self.n_residual_blocks = n_residual_blocks
43
 
 
65
  x = self.conv_block2(y) + x
66
  x = self.deup(x)
67
  return self.conv3(x)