Muhammad Rama Nurimani commited on
Commit
dfaf93c
1 Parent(s): bed47c1

test deploy

Browse files
Files changed (1) hide show
  1. networks.py +616 -0
networks.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+
12
+
13
+ class Identity(nn.Module):
14
+ def forward(self, x):
15
+ return x
16
+
17
+
18
+ def get_norm_layer(norm_type='instance'):
19
+ """Return a normalization layer
20
+
21
+ Parameters:
22
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
23
+
24
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
25
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
26
+ """
27
+ if norm_type == 'batch':
28
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
29
+ elif norm_type == 'instance':
30
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
31
+ elif norm_type == 'none':
32
+ def norm_layer(x):
33
+ return Identity()
34
+ else:
35
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
36
+ return norm_layer
37
+
38
+
39
+ def get_scheduler(optimizer, opt):
40
+ """Return a learning rate scheduler
41
+
42
+ Parameters:
43
+ optimizer -- the optimizer of the network
44
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
45
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
46
+
47
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
48
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
49
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
50
+ See https://pytorch.org/docs/stable/optim.html for more details.
51
+ """
52
+ if opt.lr_policy == 'linear':
53
+ def lambda_rule(epoch):
54
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
55
+ return lr_l
56
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
57
+ elif opt.lr_policy == 'step':
58
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
59
+ elif opt.lr_policy == 'plateau':
60
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
61
+ elif opt.lr_policy == 'cosine':
62
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
63
+ else:
64
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
65
+ return scheduler
66
+
67
+
68
+ def init_weights(net, init_type='normal', init_gain=0.02):
69
+ """Initialize network weights.
70
+
71
+ Parameters:
72
+ net (network) -- network to be initialized
73
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
74
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
75
+
76
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
77
+ work better for some applications. Feel free to try yourself.
78
+ """
79
+ def init_func(m): # define the initialization function
80
+ classname = m.__class__.__name__
81
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
82
+ if init_type == 'normal':
83
+ init.normal_(m.weight.data, 0.0, init_gain)
84
+ elif init_type == 'xavier':
85
+ init.xavier_normal_(m.weight.data, gain=init_gain)
86
+ elif init_type == 'kaiming':
87
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
88
+ elif init_type == 'orthogonal':
89
+ init.orthogonal_(m.weight.data, gain=init_gain)
90
+ else:
91
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
92
+ if hasattr(m, 'bias') and m.bias is not None:
93
+ init.constant_(m.bias.data, 0.0)
94
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
95
+ init.normal_(m.weight.data, 1.0, init_gain)
96
+ init.constant_(m.bias.data, 0.0)
97
+
98
+ print('initialize network with %s' % init_type)
99
+ net.apply(init_func) # apply the initialization function <init_func>
100
+
101
+
102
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
103
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
104
+ Parameters:
105
+ net (network) -- the network to be initialized
106
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
107
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
108
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
109
+
110
+ Return an initialized network.
111
+ """
112
+ if len(gpu_ids) > 0:
113
+ assert(torch.cuda.is_available())
114
+ net.to(gpu_ids[0])
115
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
116
+ init_weights(net, init_type, init_gain=init_gain)
117
+ return net
118
+
119
+
120
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
121
+ """Create a generator
122
+
123
+ Parameters:
124
+ input_nc (int) -- the number of channels in input images
125
+ output_nc (int) -- the number of channels in output images
126
+ ngf (int) -- the number of filters in the last conv layer
127
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
128
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
129
+ use_dropout (bool) -- if use dropout layers.
130
+ init_type (str) -- the name of our initialization method.
131
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
132
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
133
+
134
+ Returns a generator
135
+
136
+ Our current implementation provides two types of generators:
137
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
138
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
139
+
140
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
141
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
142
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
143
+
144
+
145
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
146
+ """
147
+ net = None
148
+ norm_layer = get_norm_layer(norm_type=norm)
149
+
150
+ if netG == 'resnet_9blocks':
151
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
152
+ elif netG == 'resnet_6blocks':
153
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
154
+ elif netG == 'unet_128':
155
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
156
+ elif netG == 'unet_256':
157
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
158
+ else:
159
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
160
+ return init_net(net, init_type, init_gain, gpu_ids)
161
+
162
+
163
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
164
+ """Create a discriminator
165
+
166
+ Parameters:
167
+ input_nc (int) -- the number of channels in input images
168
+ ndf (int) -- the number of filters in the first conv layer
169
+ netD (str) -- the architecture's name: basic | n_layers | pixel
170
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
171
+ norm (str) -- the type of normalization layers used in the network.
172
+ init_type (str) -- the name of the initialization method.
173
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
174
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
175
+
176
+ Returns a discriminator
177
+
178
+ Our current implementation provides three types of discriminators:
179
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
180
+ It can classify whether 70×70 overlapping patches are real or fake.
181
+ Such a patch-level discriminator architecture has fewer parameters
182
+ than a full-image discriminator and can work on arbitrarily-sized images
183
+ in a fully convolutional fashion.
184
+
185
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
186
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
187
+
188
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
189
+ It encourages greater color diversity but has no effect on spatial statistics.
190
+
191
+ The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
192
+ """
193
+ net = None
194
+ norm_layer = get_norm_layer(norm_type=norm)
195
+
196
+ if netD == 'basic': # default PatchGAN classifier
197
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
198
+ elif netD == 'n_layers': # more options
199
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
200
+ elif netD == 'pixel': # classify if each pixel is real or fake
201
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
202
+ else:
203
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
204
+ return init_net(net, init_type, init_gain, gpu_ids)
205
+
206
+
207
+ ##############################################################################
208
+ # Classes
209
+ ##############################################################################
210
+ class GANLoss(nn.Module):
211
+ """Define different GAN objectives.
212
+
213
+ The GANLoss class abstracts away the need to create the target label tensor
214
+ that has the same size as the input.
215
+ """
216
+
217
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
218
+ """ Initialize the GANLoss class.
219
+
220
+ Parameters:
221
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
222
+ target_real_label (bool) - - label for a real image
223
+ target_fake_label (bool) - - label of a fake image
224
+
225
+ Note: Do not use sigmoid as the last layer of Discriminator.
226
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
227
+ """
228
+ super(GANLoss, self).__init__()
229
+ self.register_buffer('real_label', torch.tensor(target_real_label))
230
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
231
+ self.gan_mode = gan_mode
232
+ if gan_mode == 'lsgan':
233
+ self.loss = nn.MSELoss()
234
+ elif gan_mode == 'vanilla':
235
+ self.loss = nn.BCEWithLogitsLoss()
236
+ elif gan_mode in ['wgangp']:
237
+ self.loss = None
238
+ else:
239
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
240
+
241
+ def get_target_tensor(self, prediction, target_is_real):
242
+ """Create label tensors with the same size as the input.
243
+
244
+ Parameters:
245
+ prediction (tensor) - - tpyically the prediction from a discriminator
246
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
247
+
248
+ Returns:
249
+ A label tensor filled with ground truth label, and with the size of the input
250
+ """
251
+
252
+ if target_is_real:
253
+ target_tensor = self.real_label
254
+ else:
255
+ target_tensor = self.fake_label
256
+ return target_tensor.expand_as(prediction)
257
+
258
+ def __call__(self, prediction, target_is_real):
259
+ """Calculate loss given Discriminator's output and grount truth labels.
260
+
261
+ Parameters:
262
+ prediction (tensor) - - tpyically the prediction output from a discriminator
263
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
264
+
265
+ Returns:
266
+ the calculated loss.
267
+ """
268
+ if self.gan_mode in ['lsgan', 'vanilla']:
269
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
270
+ loss = self.loss(prediction, target_tensor)
271
+ elif self.gan_mode == 'wgangp':
272
+ if target_is_real:
273
+ loss = -prediction.mean()
274
+ else:
275
+ loss = prediction.mean()
276
+ return loss
277
+
278
+
279
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
280
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
281
+
282
+ Arguments:
283
+ netD (network) -- discriminator network
284
+ real_data (tensor array) -- real images
285
+ fake_data (tensor array) -- generated images from the generator
286
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
287
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
288
+ constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
289
+ lambda_gp (float) -- weight for this loss
290
+
291
+ Returns the gradient penalty loss
292
+ """
293
+ if lambda_gp > 0.0:
294
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
295
+ interpolatesv = real_data
296
+ elif type == 'fake':
297
+ interpolatesv = fake_data
298
+ elif type == 'mixed':
299
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
300
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
301
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
302
+ else:
303
+ raise NotImplementedError('{} not implemented'.format(type))
304
+ interpolatesv.requires_grad_(True)
305
+ disc_interpolates = netD(interpolatesv)
306
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
307
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
308
+ create_graph=True, retain_graph=True, only_inputs=True)
309
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
310
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
311
+ return gradient_penalty, gradients
312
+ else:
313
+ return 0.0, None
314
+
315
+
316
+ class ResnetGenerator(nn.Module):
317
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
318
+
319
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
320
+ """
321
+
322
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
323
+ """Construct a Resnet-based generator
324
+
325
+ Parameters:
326
+ input_nc (int) -- the number of channels in input images
327
+ output_nc (int) -- the number of channels in output images
328
+ ngf (int) -- the number of filters in the last conv layer
329
+ norm_layer -- normalization layer
330
+ use_dropout (bool) -- if use dropout layers
331
+ n_blocks (int) -- the number of ResNet blocks
332
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
333
+ """
334
+ assert(n_blocks >= 0)
335
+ super(ResnetGenerator, self).__init__()
336
+ if type(norm_layer) == functools.partial:
337
+ use_bias = norm_layer.func == nn.InstanceNorm2d
338
+ else:
339
+ use_bias = norm_layer == nn.InstanceNorm2d
340
+
341
+ model = [nn.ReflectionPad2d(3),
342
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
343
+ norm_layer(ngf),
344
+ nn.ReLU(True)]
345
+
346
+ n_downsampling = 2
347
+ for i in range(n_downsampling): # add downsampling layers
348
+ mult = 2 ** i
349
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
350
+ norm_layer(ngf * mult * 2),
351
+ nn.ReLU(True)]
352
+
353
+ mult = 2 ** n_downsampling
354
+ for i in range(n_blocks): # add ResNet blocks
355
+
356
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
357
+
358
+ for i in range(n_downsampling): # add upsampling layers
359
+ mult = 2 ** (n_downsampling - i)
360
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
361
+ kernel_size=3, stride=2,
362
+ padding=1, output_padding=1,
363
+ bias=use_bias),
364
+ norm_layer(int(ngf * mult / 2)),
365
+ nn.ReLU(True)]
366
+ model += [nn.ReflectionPad2d(3)]
367
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
368
+ model += [nn.Tanh()]
369
+
370
+ self.model = nn.Sequential(*model)
371
+
372
+ def forward(self, input):
373
+ """Standard forward"""
374
+ return self.model(input)
375
+
376
+
377
+ class ResnetBlock(nn.Module):
378
+ """Define a Resnet block"""
379
+
380
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
381
+ """Initialize the Resnet block
382
+
383
+ A resnet block is a conv block with skip connections
384
+ We construct a conv block with build_conv_block function,
385
+ and implement skip connections in <forward> function.
386
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
387
+ """
388
+ super(ResnetBlock, self).__init__()
389
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
390
+
391
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
392
+ """Construct a convolutional block.
393
+
394
+ Parameters:
395
+ dim (int) -- the number of channels in the conv layer.
396
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
397
+ norm_layer -- normalization layer
398
+ use_dropout (bool) -- if use dropout layers.
399
+ use_bias (bool) -- if the conv layer uses bias or not
400
+
401
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
402
+ """
403
+ conv_block = []
404
+ p = 0
405
+ if padding_type == 'reflect':
406
+ conv_block += [nn.ReflectionPad2d(1)]
407
+ elif padding_type == 'replicate':
408
+ conv_block += [nn.ReplicationPad2d(1)]
409
+ elif padding_type == 'zero':
410
+ p = 1
411
+ else:
412
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
413
+
414
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
415
+ if use_dropout:
416
+ conv_block += [nn.Dropout(0.5)]
417
+
418
+ p = 0
419
+ if padding_type == 'reflect':
420
+ conv_block += [nn.ReflectionPad2d(1)]
421
+ elif padding_type == 'replicate':
422
+ conv_block += [nn.ReplicationPad2d(1)]
423
+ elif padding_type == 'zero':
424
+ p = 1
425
+ else:
426
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
427
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
428
+
429
+ return nn.Sequential(*conv_block)
430
+
431
+ def forward(self, x):
432
+ """Forward function (with skip connections)"""
433
+ out = x + self.conv_block(x) # add skip connections
434
+ return out
435
+
436
+
437
+ class UnetGenerator(nn.Module):
438
+ """Create a Unet-based generator"""
439
+
440
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
441
+ """Construct a Unet generator
442
+ Parameters:
443
+ input_nc (int) -- the number of channels in input images
444
+ output_nc (int) -- the number of channels in output images
445
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
446
+ image of size 128x128 will become of size 1x1 # at the bottleneck
447
+ ngf (int) -- the number of filters in the last conv layer
448
+ norm_layer -- normalization layer
449
+
450
+ We construct the U-Net from the innermost layer to the outermost layer.
451
+ It is a recursive process.
452
+ """
453
+ super(UnetGenerator, self).__init__()
454
+ # construct unet structure
455
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
456
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
457
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
458
+ # gradually reduce the number of filters from ngf * 8 to ngf
459
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
460
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
461
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
462
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
463
+
464
+ def forward(self, input):
465
+ """Standard forward"""
466
+ return self.model(input)
467
+
468
+
469
+ class UnetSkipConnectionBlock(nn.Module):
470
+ """Defines the Unet submodule with skip connection.
471
+ X -------------------identity----------------------
472
+ |-- downsampling -- |submodule| -- upsampling --|
473
+ """
474
+
475
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
476
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
477
+ """Construct a Unet submodule with skip connections.
478
+
479
+ Parameters:
480
+ outer_nc (int) -- the number of filters in the outer conv layer
481
+ inner_nc (int) -- the number of filters in the inner conv layer
482
+ input_nc (int) -- the number of channels in input images/features
483
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
484
+ outermost (bool) -- if this module is the outermost module
485
+ innermost (bool) -- if this module is the innermost module
486
+ norm_layer -- normalization layer
487
+ use_dropout (bool) -- if use dropout layers.
488
+ """
489
+ super(UnetSkipConnectionBlock, self).__init__()
490
+ self.outermost = outermost
491
+ if type(norm_layer) == functools.partial:
492
+ use_bias = norm_layer.func == nn.InstanceNorm2d
493
+ else:
494
+ use_bias = norm_layer == nn.InstanceNorm2d
495
+ if input_nc is None:
496
+ input_nc = outer_nc
497
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
498
+ stride=2, padding=1, bias=use_bias)
499
+ downrelu = nn.LeakyReLU(0.2, True)
500
+ downnorm = norm_layer(inner_nc)
501
+ uprelu = nn.ReLU(True)
502
+ upnorm = norm_layer(outer_nc)
503
+
504
+ if outermost:
505
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
506
+ kernel_size=4, stride=2,
507
+ padding=1)
508
+ down = [downconv]
509
+ up = [uprelu, upconv, nn.Tanh()]
510
+ model = down + [submodule] + up
511
+ elif innermost:
512
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
513
+ kernel_size=4, stride=2,
514
+ padding=1, bias=use_bias)
515
+ down = [downrelu, downconv]
516
+ up = [uprelu, upconv, upnorm]
517
+ model = down + up
518
+ else:
519
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
520
+ kernel_size=4, stride=2,
521
+ padding=1, bias=use_bias)
522
+ down = [downrelu, downconv, downnorm]
523
+ up = [uprelu, upconv, upnorm]
524
+
525
+ if use_dropout:
526
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
527
+ else:
528
+ model = down + [submodule] + up
529
+
530
+ self.model = nn.Sequential(*model)
531
+
532
+ def forward(self, x):
533
+ if self.outermost:
534
+ return self.model(x)
535
+ else: # add skip connections
536
+ return torch.cat([x, self.model(x)], 1)
537
+
538
+
539
+ class NLayerDiscriminator(nn.Module):
540
+ """Defines a PatchGAN discriminator"""
541
+
542
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
543
+ """Construct a PatchGAN discriminator
544
+
545
+ Parameters:
546
+ input_nc (int) -- the number of channels in input images
547
+ ndf (int) -- the number of filters in the last conv layer
548
+ n_layers (int) -- the number of conv layers in the discriminator
549
+ norm_layer -- normalization layer
550
+ """
551
+ super(NLayerDiscriminator, self).__init__()
552
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
553
+ use_bias = norm_layer.func == nn.InstanceNorm2d
554
+ else:
555
+ use_bias = norm_layer == nn.InstanceNorm2d
556
+
557
+ kw = 4
558
+ padw = 1
559
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
560
+ nf_mult = 1
561
+ nf_mult_prev = 1
562
+ for n in range(1, n_layers): # gradually increase the number of filters
563
+ nf_mult_prev = nf_mult
564
+ nf_mult = min(2 ** n, 8)
565
+ sequence += [
566
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
567
+ norm_layer(ndf * nf_mult),
568
+ nn.LeakyReLU(0.2, True)
569
+ ]
570
+
571
+ nf_mult_prev = nf_mult
572
+ nf_mult = min(2 ** n_layers, 8)
573
+ sequence += [
574
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
575
+ norm_layer(ndf * nf_mult),
576
+ nn.LeakyReLU(0.2, True)
577
+ ]
578
+
579
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
580
+ self.model = nn.Sequential(*sequence)
581
+
582
+ def forward(self, input):
583
+ """Standard forward."""
584
+ return self.model(input)
585
+
586
+
587
+ class PixelDiscriminator(nn.Module):
588
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
589
+
590
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
591
+ """Construct a 1x1 PatchGAN discriminator
592
+
593
+ Parameters:
594
+ input_nc (int) -- the number of channels in input images
595
+ ndf (int) -- the number of filters in the last conv layer
596
+ norm_layer -- normalization layer
597
+ """
598
+ super(PixelDiscriminator, self).__init__()
599
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
600
+ use_bias = norm_layer.func == nn.InstanceNorm2d
601
+ else:
602
+ use_bias = norm_layer == nn.InstanceNorm2d
603
+
604
+ self.net = [
605
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
606
+ nn.LeakyReLU(0.2, True),
607
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
608
+ norm_layer(ndf * 2),
609
+ nn.LeakyReLU(0.2, True),
610
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
611
+
612
+ self.net = nn.Sequential(*self.net)
613
+
614
+ def forward(self, input):
615
+ """Standard forward."""
616
+ return self.net(input)