Upload 3 files
Browse files- network/__init__.py +0 -0
- network/discriminator.py +8 -10
- network/generator.py +2 -4
network/__init__.py
ADDED
File without changes
|
network/discriminator.py
CHANGED
@@ -10,12 +10,12 @@ def conv_block(ndf, in_channels, out_channels, kernel_size, stride, padding):
|
|
10 |
)
|
11 |
|
12 |
class PatchDiscriminator(nn.Module):
|
13 |
-
def __init__(self, input_nc=1, ndf=
|
14 |
"""Initializes the Patch Discriminator model.
|
15 |
|
16 |
Args:
|
17 |
input_nc (int): Number of input channels. Default is 1 (e.g., for grayscale images).
|
18 |
-
ndf (int): Number of filters in the first convolution layer. Default is
|
19 |
"""
|
20 |
super(PatchDiscriminator, self).__init__()
|
21 |
|
@@ -25,27 +25,25 @@ class PatchDiscriminator(nn.Module):
|
|
25 |
self.conv3 = conv_block(ndf * 2, ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
|
26 |
self.conv4 = conv_block(ndf * 4, ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
|
27 |
|
28 |
-
# Final convolution layer
|
29 |
-
self.conv5 = nn.Conv3d(ndf * 8,
|
30 |
|
31 |
# Flatten layer
|
32 |
self.flatten = nn.Flatten()
|
33 |
|
34 |
-
# Fully connected layer
|
35 |
-
self.fc = nn.Linear(
|
36 |
|
37 |
-
# Sigmoid activation
|
38 |
self.sigmoid = nn.Sigmoid()
|
39 |
|
40 |
def forward(self, x):
|
41 |
-
"""Defines the forward pass of the discriminator."""
|
42 |
x = self.conv1(x)
|
43 |
x = self.conv2(x)
|
44 |
x = self.conv3(x)
|
45 |
x = self.conv4(x)
|
46 |
-
x = self.conv5(x)
|
47 |
x = self.flatten(x)
|
48 |
x = self.fc(x)
|
49 |
x = self.sigmoid(x)
|
50 |
return x
|
51 |
-
|
|
|
10 |
)
|
11 |
|
12 |
class PatchDiscriminator(nn.Module):
|
13 |
+
def __init__(self, input_nc=1, ndf=3, output_size=1):
|
14 |
"""Initializes the Patch Discriminator model.
|
15 |
|
16 |
Args:
|
17 |
input_nc (int): Number of input channels. Default is 1 (e.g., for grayscale images).
|
18 |
+
ndf (int): Number of filters in the first convolution layer. Default is 64.
|
19 |
"""
|
20 |
super(PatchDiscriminator, self).__init__()
|
21 |
|
|
|
25 |
self.conv3 = conv_block(ndf * 2, ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
|
26 |
self.conv4 = conv_block(ndf * 4, ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
|
27 |
|
28 |
+
# Final convolution layer
|
29 |
+
self.conv5 = nn.Conv3d(ndf * 8, ndf * 8, kernel_size=4, padding=1)
|
30 |
|
31 |
# Flatten layer
|
32 |
self.flatten = nn.Flatten()
|
33 |
|
34 |
+
# Fully connected layer
|
35 |
+
self.fc = nn.Linear(ndf * 8 * 7 * 11 * 7, output_size)
|
36 |
|
37 |
+
# Sigmoid activation
|
38 |
self.sigmoid = nn.Sigmoid()
|
39 |
|
40 |
def forward(self, x):
|
|
|
41 |
x = self.conv1(x)
|
42 |
x = self.conv2(x)
|
43 |
x = self.conv3(x)
|
44 |
x = self.conv4(x)
|
45 |
+
x = self.conv5(x)
|
46 |
x = self.flatten(x)
|
47 |
x = self.fc(x)
|
48 |
x = self.sigmoid(x)
|
49 |
return x
|
|
network/generator.py
CHANGED
@@ -27,9 +27,8 @@ class ResnetBlock(nn.Module):
|
|
27 |
class DeUpBlock(nn.Module):
|
28 |
def __init__(self, inf, onf):
|
29 |
super(DeUpBlock, self).__init__()
|
30 |
-
# Upsampling only in the width dimension
|
31 |
self.deupblock = nn.Sequential(
|
32 |
-
nn.ConvTranspose3d(inf, onf, kernel_size=(1,
|
33 |
nn.LeakyReLU(0.2)
|
34 |
)
|
35 |
|
@@ -38,7 +37,7 @@ class DeUpBlock(nn.Module):
|
|
38 |
|
39 |
# Resnet Generator
|
40 |
class ResnetGenerator(nn.Module):
|
41 |
-
def __init__(self, input_nc=1, output_nc=1, ngf=
|
42 |
super(ResnetGenerator, self).__init__()
|
43 |
self.n_residual_blocks = n_residual_blocks
|
44 |
|
@@ -66,4 +65,3 @@ class ResnetGenerator(nn.Module):
|
|
66 |
x = self.conv_block2(y) + x
|
67 |
x = self.deup(x)
|
68 |
return self.conv3(x)
|
69 |
-
|
|
|
27 |
class DeUpBlock(nn.Module):
|
28 |
def __init__(self, inf, onf):
|
29 |
super(DeUpBlock, self).__init__()
|
|
|
30 |
self.deupblock = nn.Sequential(
|
31 |
+
nn.ConvTranspose3d(inf, onf, kernel_size=(1, 3, 1), stride=(1, 3, 1), padding=(0, 0, 0)),
|
32 |
nn.LeakyReLU(0.2)
|
33 |
)
|
34 |
|
|
|
37 |
|
38 |
# Resnet Generator
|
39 |
class ResnetGenerator(nn.Module):
|
40 |
+
def __init__(self, input_nc=1, output_nc=1, ngf=32, n_residual_blocks=2):
|
41 |
super(ResnetGenerator, self).__init__()
|
42 |
self.n_residual_blocks = n_residual_blocks
|
43 |
|
|
|
65 |
x = self.conv_block2(y) + x
|
66 |
x = self.deup(x)
|
67 |
return self.conv3(x)
|
|