hr16 commited on
Commit
480bfbc
·
1 Parent(s): 689d937

Fork adriansahlman's stylegan2_pytorch

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+
5
+ # Jupyter Notebook
6
+ .ipynb_checkpoints
7
+
8
+ results/
LICENSE.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2020 Adrian Sahlman
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ pyyaml
4
+ requests
5
+ scipy
6
+ tensorboard
7
+ torch
8
+ torchvision
9
+ tqdm
run_convert_from_tf.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pickle
4
+ import argparse
5
+ import io
6
+ import requests
7
+ import torch
8
+ import stylegan2
9
+ from stylegan2 import utils
10
+
11
+
12
+ pretrained_model_urls = {
13
+ 'car-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl',
14
+ 'car-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl',
15
+ 'cat-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl',
16
+ 'church-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl',
17
+ 'ffhq-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
18
+ 'ffhq-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
19
+ 'horse-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl',
20
+ 'car-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl',
21
+ 'car-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl',
22
+ 'car-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl',
23
+ 'car-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl',
24
+ 'car-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl',
25
+ 'car-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl',
26
+ 'car-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl',
27
+ 'car-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl',
28
+ 'car-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl',
29
+ 'ffhq-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl',
30
+ 'ffhq-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl',
31
+ 'ffhq-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl',
32
+ 'ffhq-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl',
33
+ 'ffhq-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl',
34
+ 'ffhq-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl',
35
+ 'ffhq-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl',
36
+ 'ffhq-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl',
37
+ 'ffhq-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
38
+ }
39
+
40
+
41
+ class Unpickler(pickle.Unpickler):
42
+ def find_class(self, module, name):
43
+ if module == 'dnnlib.tflib.network' and name == 'Network':
44
+ return utils.AttributeDict
45
+ return super(Unpickler, self).find_class(module, name)
46
+
47
+
48
+ def load_tf_models_file(fpath):
49
+ with open(fpath, 'rb') as fp:
50
+ return Unpickler(fp).load()
51
+
52
+
53
+ def load_tf_models_url(url):
54
+ print('Downloading file {}...'.format(url))
55
+ with requests.Session() as session:
56
+ with session.get(url) as ret:
57
+ fp = io.BytesIO(ret.content)
58
+ return Unpickler(fp).load()
59
+
60
+
61
+ def convert_kwargs(static_kwargs, kwargs_mapping):
62
+ kwargs = utils.AttributeDict()
63
+ for key, value in static_kwargs.items():
64
+ if key in kwargs_mapping:
65
+ if value == 'lrelu':
66
+ value = 'leaky:0.2'
67
+ for k in utils.to_list(kwargs_mapping[key]):
68
+ kwargs[k] = value
69
+ return kwargs
70
+
71
+
72
+ _PERMITTED_MODELS = ['G_main', 'G_mapping', 'G_synthesis_stylegan2', 'D_stylegan2', 'D_main', 'G_synthesis']
73
+ def convert_from_tf(tf_state):
74
+ tf_state = utils.AttributeDict.convert_dict_recursive(tf_state)
75
+ model_type = tf_state.build_func_name
76
+ assert model_type in _PERMITTED_MODELS, \
77
+ 'Found model type {}. '.format(model_type) + \
78
+ 'Allowed model types are: {}'.format(_PERMITTED_MODELS)
79
+
80
+ if model_type == 'G_main':
81
+ kwargs = convert_kwargs(
82
+ static_kwargs=tf_state.static_kwargs,
83
+ kwargs_mapping={
84
+ 'dlatent_avg_beta': 'dlatent_avg_beta'
85
+ }
86
+ )
87
+ kwargs.G_mapping = convert_from_tf(tf_state.components.mapping)
88
+ kwargs.G_synthesis = convert_from_tf(tf_state.components.synthesis)
89
+ G = stylegan2.models.Generator(**kwargs)
90
+ for name, var in tf_state.variables:
91
+ if name == 'dlatent_avg':
92
+ G.dlatent_avg.data.copy_(torch.from_numpy(var))
93
+ kwargs = convert_kwargs(
94
+ static_kwargs=tf_state.static_kwargs,
95
+ kwargs_mapping={
96
+ 'truncation_psi': 'truncation_psi',
97
+ 'truncation_cutoff': 'truncation_cutoff',
98
+ 'truncation_psi_val': 'truncation_psi',
99
+ 'truncation_cutoff_val': 'truncation_cutoff'
100
+ }
101
+ )
102
+ G.set_truncation(**kwargs)
103
+ return G
104
+
105
+ if model_type == 'G_mapping':
106
+ kwargs = convert_kwargs(
107
+ static_kwargs=tf_state.static_kwargs,
108
+ kwargs_mapping={
109
+ 'mapping_nonlinearity': 'activation',
110
+ 'normalize_latents': 'normalize_input',
111
+ 'mapping_lr_mul': 'lr_mul'
112
+ }
113
+ )
114
+ kwargs.num_layers = sum(
115
+ 1 for var_name, _ in tf_state.variables
116
+ if re.match('Dense[0-9]+/weight', var_name)
117
+ )
118
+ for var_name, var in tf_state.variables:
119
+ if var_name == 'LabelConcat/weight':
120
+ kwargs.label_size = var.shape[0]
121
+ if var_name == 'Dense0/weight':
122
+ kwargs.latent_size = var.shape[0]
123
+ kwargs.hidden = var.shape[1]
124
+ if var_name == 'Dense{}/bias'.format(kwargs.num_layers - 1):
125
+ kwargs.out_size = var.shape[0]
126
+ G_mapping = stylegan2.models.GeneratorMapping(**kwargs)
127
+ for var_name, var in tf_state.variables:
128
+ if re.match('Dense[0-9]+/[a-zA-Z]*', var_name):
129
+ layer_idx = int(re.search('Dense(\d+)/[a-zA-Z]*', var_name).groups()[0])
130
+ if var_name.endswith('weight'):
131
+ G_mapping.main[layer_idx].layer.weight.data.copy_(
132
+ torch.from_numpy(var.T).contiguous())
133
+ elif var_name.endswith('bias'):
134
+ G_mapping.main[layer_idx].bias.data.copy_(torch.from_numpy(var))
135
+ if var_name == 'LabelConcat/weight':
136
+ G_mapping.embedding.weight.data.copy_(torch.from_numpy(var))
137
+ return G_mapping
138
+
139
+ if model_type == 'G_synthesis_stylegan2' or model_type == 'G_synthesis':
140
+ assert tf_state.static_kwargs.get('fused_modconv', True), \
141
+ 'Can not load TF networks that use `fused_modconv=False`'
142
+ noise_tensors = []
143
+ conv_vars = {}
144
+ for var_name, var in tf_state.variables:
145
+ if var_name.startswith('noise'):
146
+ noise_tensors.append(torch.from_numpy(var))
147
+ else:
148
+ layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
149
+ if layer_size not in conv_vars:
150
+ conv_vars[layer_size] = {}
151
+ var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
152
+ conv_vars[layer_size][var_name] = var
153
+ noise_tensors = sorted(noise_tensors, key=lambda x:x.size(-1))
154
+ kwargs = convert_kwargs(
155
+ static_kwargs=tf_state.static_kwargs,
156
+ kwargs_mapping={
157
+ 'nonlinearity': 'activation',
158
+ 'resample_filter': ['conv_filter', 'skip_filter']
159
+ }
160
+ )
161
+ kwargs.skip = False
162
+ kwargs.resnet = True
163
+ kwargs.channels = []
164
+ for size in sorted(conv_vars.keys(), reverse=True):
165
+ if size == 4:
166
+ if 'ToRGB/weight' in conv_vars[size]:
167
+ kwargs.skip = True
168
+ kwargs.resnet = False
169
+ kwargs.latent_size = conv_vars[size]['Conv/mod_weight'].shape[0]
170
+ kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
171
+ else:
172
+ kwargs.channels.append(conv_vars[size]['Conv1/bias'].shape[0])
173
+ if 'ToRGB/bias' in conv_vars[size]:
174
+ kwargs.data_channels = conv_vars[size]['ToRGB/bias'].shape[0]
175
+ G_synthesis = stylegan2.models.GeneratorSynthesis(**kwargs)
176
+ G_synthesis.const.data.copy_(torch.from_numpy(conv_vars[4]['Const/const']).squeeze(0))
177
+ def assign_weights(layer, weight, bias, mod_weight, mod_bias, noise_strength, transposed=False):
178
+ layer.bias.data.copy_(torch.from_numpy(bias))
179
+ layer.layer.weight.data.copy_(torch.tensor(noise_strength))
180
+ layer.layer.layer.dense.layer.weight.data.copy_(
181
+ torch.from_numpy(mod_weight.T).contiguous())
182
+ layer.layer.layer.dense.bias.data.copy_(torch.from_numpy(mod_bias + 1))
183
+ weight = torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous()
184
+ if transposed:
185
+ weight = weight.flip(dims=[2,3])
186
+ layer.layer.layer.weight.data.copy_(weight)
187
+ conv_blocks = G_synthesis.conv_blocks
188
+ for i, size in enumerate(sorted(conv_vars.keys())):
189
+ block = conv_blocks[i]
190
+ if size == 4:
191
+ assign_weights(
192
+ layer=block.conv_block[0],
193
+ weight=conv_vars[size]['Conv/weight'],
194
+ bias=conv_vars[size]['Conv/bias'],
195
+ mod_weight=conv_vars[size]['Conv/mod_weight'],
196
+ mod_bias=conv_vars[size]['Conv/mod_bias'],
197
+ noise_strength=conv_vars[size]['Conv/noise_strength'],
198
+ )
199
+ else:
200
+ assign_weights(
201
+ layer=block.conv_block[0],
202
+ weight=conv_vars[size]['Conv0_up/weight'],
203
+ bias=conv_vars[size]['Conv0_up/bias'],
204
+ mod_weight=conv_vars[size]['Conv0_up/mod_weight'],
205
+ mod_bias=conv_vars[size]['Conv0_up/mod_bias'],
206
+ noise_strength=conv_vars[size]['Conv0_up/noise_strength'],
207
+ transposed=True
208
+ )
209
+ assign_weights(
210
+ layer=block.conv_block[1],
211
+ weight=conv_vars[size]['Conv1/weight'],
212
+ bias=conv_vars[size]['Conv1/bias'],
213
+ mod_weight=conv_vars[size]['Conv1/mod_weight'],
214
+ mod_bias=conv_vars[size]['Conv1/mod_bias'],
215
+ noise_strength=conv_vars[size]['Conv1/noise_strength'],
216
+ )
217
+ if 'Skip/weight' in conv_vars[size]:
218
+ block.projection.weight.data.copy_(torch.from_numpy(
219
+ conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
220
+ to_RGB = G_synthesis.to_data_layers[i]
221
+ if to_RGB is not None:
222
+ to_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['ToRGB/bias']))
223
+ to_RGB.layer.weight.data.copy_(torch.from_numpy(
224
+ conv_vars[size]['ToRGB/weight']).permute((3, 2, 0, 1)).contiguous())
225
+ to_RGB.layer.dense.bias.data.copy_(
226
+ torch.from_numpy(conv_vars[size]['ToRGB/mod_bias'] + 1))
227
+ to_RGB.layer.dense.layer.weight.data.copy_(
228
+ torch.from_numpy(conv_vars[size]['ToRGB/mod_weight'].T).contiguous())
229
+ if not tf_state.static_kwargs.get('randomize_noise', True):
230
+ G_synthesis.static_noise(noise_tensors=noise_tensors)
231
+ return G_synthesis
232
+
233
+ if model_type == 'D_stylegan2' or model_type == 'D_main':
234
+ output_vars = {}
235
+ conv_vars = {}
236
+ for var_name, var in tf_state.variables:
237
+ if var_name.startswith('Output'):
238
+ output_vars[var_name.replace('Output/', '')] = var
239
+ else:
240
+ layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
241
+ if layer_size not in conv_vars:
242
+ conv_vars[layer_size] = {}
243
+ var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
244
+ conv_vars[layer_size][var_name] = var
245
+ kwargs = convert_kwargs(
246
+ static_kwargs=tf_state.static_kwargs,
247
+ kwargs_mapping={
248
+ 'nonlinearity': 'activation',
249
+ 'resample_filter': ['conv_filter', 'skip_filter'],
250
+ 'mbstd_group_size': 'mbstd_group_size'
251
+ }
252
+ )
253
+ kwargs.skip = False
254
+ kwargs.resnet = True
255
+ kwargs.channels = []
256
+ for size in sorted(conv_vars.keys(), reverse=True):
257
+ if size == 4:
258
+ if 'FromRGB/weight' in conv_vars[size]:
259
+ kwargs.skip = True
260
+ kwargs.resnet = False
261
+ kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
262
+ kwargs.dense_hidden = conv_vars[size]['Dense0/bias'].shape[0]
263
+ else:
264
+ kwargs.channels.append(conv_vars[size]['Conv0/bias'].shape[0])
265
+ if 'FromRGB/weight' in conv_vars[size]:
266
+ kwargs.data_channels = conv_vars[size]['FromRGB/weight'].shape[-2]
267
+ output_size = output_vars['bias'].shape[0]
268
+ if output_size > 1:
269
+ kwargs.label_size = output_size
270
+ D = stylegan2.models.Discriminator(**kwargs)
271
+ def assign_weights(layer, weight, bias):
272
+ layer.bias.data.copy_(torch.from_numpy(bias))
273
+ layer.layer.weight.data.copy_(
274
+ torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous())
275
+ conv_blocks = D.conv_blocks
276
+ for i, size in enumerate(sorted(conv_vars.keys())):
277
+ block = conv_blocks[-i - 1]
278
+ if size == 4:
279
+ assign_weights(
280
+ layer=block[-1].conv_block[0],
281
+ weight=conv_vars[size]['Conv/weight'],
282
+ bias=conv_vars[size]['Conv/bias'],
283
+ )
284
+ else:
285
+ assign_weights(
286
+ layer=block.conv_block[0],
287
+ weight=conv_vars[size]['Conv0/weight'],
288
+ bias=conv_vars[size]['Conv0/bias'],
289
+ )
290
+ assign_weights(
291
+ layer=block.conv_block[1],
292
+ weight=conv_vars[size]['Conv1_down/weight'],
293
+ bias=conv_vars[size]['Conv1_down/bias'],
294
+ )
295
+ if 'Skip/weight' in conv_vars[size]:
296
+ block.projection.weight.data.copy_(torch.from_numpy(
297
+ conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
298
+ from_RGB = D.from_data_layers[-i - 1]
299
+ if from_RGB is not None:
300
+ from_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['FromRGB/bias']))
301
+ from_RGB.layer.weight.data.copy_(torch.from_numpy(
302
+ conv_vars[size]['FromRGB/weight']).permute((3, 2, 0, 1)).contiguous())
303
+ return D
304
+
305
+
306
+ def get_arg_parser():
307
+ parser = argparse.ArgumentParser(
308
+ description='Convert tensorflow stylegan2 model to pytorch.',
309
+ epilog='Pretrained models that can be downloaded:\n{}'.format(
310
+ '\n'.join(pretrained_model_urls.keys()))
311
+ )
312
+
313
+ parser.add_argument(
314
+ '-i',
315
+ '--input',
316
+ help='File path to pickled tensorflow models.',
317
+ type=str,
318
+ default=None,
319
+ )
320
+
321
+ parser.add_argument(
322
+ '-d',
323
+ '--download',
324
+ help='Download the specified pretrained model. Use --help for info on available models.',
325
+ type=str,
326
+ default=None,
327
+ )
328
+
329
+ parser.add_argument(
330
+ '-o',
331
+ '--output',
332
+ help='One or more output file paths. Alternatively a directory path ' + \
333
+ 'where all models will be saved. Default: current directory',
334
+ type=str,
335
+ nargs='*',
336
+ default=['.'],
337
+ )
338
+
339
+ return parser
340
+
341
+
342
+ def main():
343
+ args = get_arg_parser().parse_args()
344
+ assert bool(args.input) != bool(args.download), \
345
+ 'Incorrect input format. Can only take either one ' + \
346
+ 'input filepath to a pickled tensorflow model or ' + \
347
+ 'a model name to download, but not both at the same ' + \
348
+ 'time or none at all.'
349
+ if args.input:
350
+ unpickled = load_tf_models_file(args.input)
351
+ else:
352
+ assert args.download in pretrained_model_urls.keys(), \
353
+ 'Unknown model {}. Use --help for list of models.'.format(args.download)
354
+ unpickled = load_tf_models_url(pretrained_model_urls[args.download])
355
+ if not isinstance(unpickled, (tuple, list)):
356
+ unpickled = [unpickled]
357
+ print('Converting tensorflow models and saving them...')
358
+ converted = [convert_from_tf(tf_state) for tf_state in unpickled]
359
+ if len(args.output) == 1 and (os.path.isdir(args.output[0]) or not os.path.splitext(args.output[0])[-1]):
360
+ if not os.path.exists(args.output[0]):
361
+ os.makedirs(args.output[0])
362
+ for tf_state, torch_model in zip(unpickled, converted):
363
+ torch_model.save(os.path.join(args.output[0], tf_state['name'] + '.pth'))
364
+ else:
365
+ assert len(args.output) == len(converted), 'Found {} models '.format(len(converted)) + \
366
+ 'in pickled file but only {} output paths were given.'.format(len(args.output))
367
+ for out_path, torch_model in zip(args.output, converted):
368
+ torch_model.save(out_path)
369
+ print('Done!')
370
+
371
+
372
+ if __name__ == '__main__':
373
+ main()
run_generator.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import argparse
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+
8
+ import stylegan2
9
+ from stylegan2 import utils
10
+
11
+ #----------------------------------------------------------------------------
12
+
13
+ _description = """StyleGAN2 generator.
14
+ Run 'python %(prog)s <subcommand> --help' for subcommand help."""
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ _examples = """examples:
19
+ # Train a network or convert a pretrained one.
20
+ # Example of converting pretrained ffhq model:
21
+ python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
22
+
23
+ # Generate ffhq uncurated images (matches paper Figure 12)
24
+ python %(prog)s generate_images --network=Gs.pth --seeds=6600-6625 --truncation_psi=0.5
25
+
26
+ # Generate ffhq curated images (matches paper Figure 11)
27
+ python %(prog)s generate_images --network=Gs.pth --seeds=66,230,389,1518 --truncation_psi=1.0
28
+
29
+ # Example of converting pretrained car model:
30
+ python run_convert_from_tf --download car-config-f --output G_car.pth D_car.pth Gs_car.pth
31
+
32
+ # Generate uncurated car images (matches paper Figure 12)
33
+ python %(prog)s generate_images --network=Gs_car.pth --seeds=6000-6025 --truncation_psi=0.5
34
+
35
+ # Generate style mixing example (matches style mixing video clip)
36
+ python %(prog)s style_mixing_example --network=Gs.pth --row_seeds=85,100,75,458,1500 --col_seeds=55,821,1789,293 --truncation_psi=1.0
37
+ """
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ def _add_shared_arguments(parser):
42
+
43
+ parser.add_argument(
44
+ '--network',
45
+ help='Network file path',
46
+ required=True,
47
+ metavar='FILE'
48
+ )
49
+
50
+ parser.add_argument(
51
+ '--output',
52
+ help='Root directory for run results. Default: %(default)s',
53
+ type=str,
54
+ default='./results',
55
+ metavar='DIR'
56
+ )
57
+
58
+ parser.add_argument(
59
+ '--pixel_min',
60
+ help='Minumum of the value range of pixels in generated images. ' + \
61
+ 'Default: %(default)s',
62
+ default=-1,
63
+ type=float,
64
+ metavar='VALUE'
65
+ )
66
+
67
+ parser.add_argument(
68
+ '--pixel_max',
69
+ help='Maximum of the value range of pixels in generated images. ' + \
70
+ 'Default: %(default)s',
71
+ default=1,
72
+ type=float,
73
+ metavar='VALUE'
74
+ )
75
+
76
+ parser.add_argument(
77
+ '--gpu',
78
+ help='CUDA device indices (given as separate ' + \
79
+ 'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
80
+ type=int,
81
+ default=[],
82
+ nargs='*',
83
+ metavar='INDEX'
84
+ )
85
+
86
+ parser.add_argument(
87
+ '--truncation_psi',
88
+ help='Truncation psi. Default: %(default)s',
89
+ type=float,
90
+ default=0.5,
91
+ metavar='VALUE'
92
+ )
93
+
94
+
95
+ def get_arg_parser():
96
+ parser = argparse.ArgumentParser(
97
+ description=_description,
98
+ epilog=_examples,
99
+ formatter_class=argparse.RawDescriptionHelpFormatter
100
+ )
101
+
102
+ range_desc = 'NOTE: This is a single argument, where list ' + \
103
+ 'elements are separated by "," and ranges are defined as "a-b". Only integers are allowed.'
104
+
105
+ subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
106
+
107
+ generate_images_parser = subparsers.add_parser(
108
+ 'generate_images', help='Generate images')
109
+
110
+ generate_images_parser.add_argument(
111
+ '--batch_size',
112
+ help='Batch size for generator. Default: %(default)s',
113
+ type=int,
114
+ default=1,
115
+ metavar='VALUE'
116
+ )
117
+
118
+ generate_images_parser.add_argument(
119
+ '--seeds',
120
+ help='List of random seeds for generating images. ' + range_desc,
121
+ type=utils.range_type,
122
+ required=True,
123
+ metavar='RANGE'
124
+ )
125
+
126
+ _add_shared_arguments(generate_images_parser)
127
+
128
+ style_mixing_example_parser = subparsers.add_parser(
129
+ 'style_mixing_example', help='Generate style mixing video')
130
+
131
+ style_mixing_example_parser.add_argument(
132
+ '--row_seeds',
133
+ help='List of random seeds for image rows. ' + range_desc,
134
+ type=utils.range_type,
135
+ required=True,
136
+ metavar='RANGE'
137
+ )
138
+
139
+ style_mixing_example_parser.add_argument(
140
+ '--col_seeds',
141
+ help='List of random seeds for image columns. ' + range_desc,
142
+ type=utils.range_type,
143
+ required=True,
144
+ metavar='RANGE'
145
+ )
146
+
147
+ style_mixing_example_parser.add_argument(
148
+ '--style_layers',
149
+ help='Indices of layers to mix style for. ' + \
150
+ 'Default: %(default)s. ' + range_desc,
151
+ type=utils.range_type,
152
+ default='0-6',
153
+ metavar='RANGE'
154
+ )
155
+
156
+ style_mixing_example_parser.add_argument(
157
+ '--grid',
158
+ help='Save a grid as well of the style mix. Default: %(default)s',
159
+ type=utils.bool_type,
160
+ default=True,
161
+ const=True,
162
+ nargs='?',
163
+ metavar='BOOL'
164
+ )
165
+
166
+ _add_shared_arguments(style_mixing_example_parser)
167
+
168
+ return parser
169
+
170
+ #----------------------------------------------------------------------------
171
+
172
+ def style_mixing_example(G, args):
173
+ assert max(args.style_layers) < len(G), \
174
+ 'Style layer indices can not be larger than ' + \
175
+ 'number of style layers ({}) of the generator.'.format(len(G))
176
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
177
+ if device.index is not None:
178
+ torch.cuda.set_device(device.index)
179
+ if len(args.gpu) > 1:
180
+ warnings.warn('Multi GPU is not available for style mixing example. Using device {}'.format(device))
181
+ G.to(device)
182
+ G.static_noise()
183
+ latent_size, label_size = G.latent_size, G.label_size
184
+ G_mapping, G_synthesis = G.G_mapping, G.G_synthesis
185
+
186
+ all_seeds = list(set(args.row_seeds + args.col_seeds))
187
+ all_z = torch.stack([torch.from_numpy(np.random.RandomState(seed).randn(latent_size)) for seed in all_seeds])
188
+ all_z = all_z.to(device=device, dtype=torch.float32)
189
+ if label_size:
190
+ labels = torch.zeros(len(all_z), dtype=torch.int64, device=device)
191
+ else:
192
+ labels = None
193
+
194
+ print('Generating disentangled latents...')
195
+ with torch.no_grad():
196
+ all_w = G_mapping(latents=all_z, labels=labels)
197
+
198
+ all_w = all_w.unsqueeze(1).repeat(1, len(G_synthesis), 1)
199
+
200
+ w_avg = G.dlatent_avg
201
+
202
+ if args.truncation_psi != 1:
203
+ all_w = w_avg + args.truncation_psi * (all_w - w_avg)
204
+
205
+ w_dict = {seed: w for seed, w in zip(all_seeds, all_w)}
206
+
207
+ all_images = []
208
+
209
+ progress = utils.ProgressWriter(len(all_w))
210
+ progress.write('Generating images...', step=False)
211
+
212
+ with torch.no_grad():
213
+ for w in all_w:
214
+ all_images.append(G_synthesis(w.unsqueeze(0)))
215
+ progress.step()
216
+
217
+ progress.write('Done!', step=False)
218
+ progress.close()
219
+
220
+ all_images = torch.cat(all_images, dim=0)
221
+
222
+ image_dict = {(seed, seed): image for seed, image in zip(all_seeds, all_images)}
223
+
224
+ progress = utils.ProgressWriter(len(args.row_seeds) * len(args.col_seeds))
225
+ progress.write('Generating style-mixed images...', step=False)
226
+
227
+ for row_seed in args.row_seeds:
228
+ for col_seed in args.col_seeds:
229
+ w = w_dict[row_seed].clone()
230
+ w[args.style_layers] = w_dict[col_seed][args.style_layers]
231
+ with torch.no_grad():
232
+ image_dict[(row_seed, col_seed)] = G_synthesis(w.unsqueeze(0)).squeeze(0)
233
+ progress.step()
234
+
235
+ progress.write('Done!', step=False)
236
+ progress.close()
237
+
238
+ progress = utils.ProgressWriter(len(image_dict))
239
+ progress.write('Saving images...', step=False)
240
+
241
+ for (row_seed, col_seed), image in list(image_dict.items()):
242
+ image = utils.tensor_to_PIL(
243
+ image, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
244
+ image_dict[(row_seed, col_seed)] = image
245
+ image.save(os.path.join(args.output, '%d-%d.png' % (row_seed, col_seed)))
246
+ progress.step()
247
+
248
+ progress.write('Done!', step=False)
249
+ progress.close()
250
+
251
+ if args.grid:
252
+ print('\n\nSaving style-mixed grid...')
253
+ H, W = all_images.size()[2:]
254
+ canvas = Image.new(
255
+ 'RGB', (W * (len(args.col_seeds) + 1), H * (len(args.row_seeds) + 1)), 'black')
256
+ for row_idx, row_seed in enumerate([None] + args.row_seeds):
257
+ for col_idx, col_seed in enumerate([None] + args.col_seeds):
258
+ if row_seed is None and col_seed is None:
259
+ continue
260
+ key = (row_seed, col_seed)
261
+ if row_seed is None:
262
+ key = (col_seed, col_seed)
263
+ if col_seed is None:
264
+ key = (row_seed, row_seed)
265
+ canvas.paste(image_dict[key], (W * col_idx, H * row_idx))
266
+ canvas.save(os.path.join(args.output, 'grid.png'))
267
+ print('Done!')
268
+
269
+ #----------------------------------------------------------------------------
270
+
271
+ def generate_images(G, args):
272
+ latent_size, label_size = G.latent_size, G.label_size
273
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
274
+ if device.index is not None:
275
+ torch.cuda.set_device(device.index)
276
+ G.to(device)
277
+ if args.truncation_psi != 1:
278
+ G.set_truncation(truncation_psi=args.truncation_psi)
279
+ if len(args.gpu) > 1:
280
+ warnings.warn(
281
+ 'Noise can not be randomized based on the seed ' + \
282
+ 'when using more than 1 GPU device. Noise will ' + \
283
+ 'now be randomized from default random state.'
284
+ )
285
+ G.random_noise()
286
+ G = torch.nn.DataParallel(G, device_ids=args.gpu)
287
+ else:
288
+ noise_reference = G.static_noise()
289
+
290
+ def get_batch(seeds):
291
+ latents = []
292
+ labels = []
293
+ if len(args.gpu) <= 1:
294
+ noise_tensors = [[] for _ in noise_reference]
295
+ for seed in seeds:
296
+ rnd = np.random.RandomState(seed)
297
+ latents.append(torch.from_numpy(rnd.randn(latent_size)))
298
+ if len(args.gpu) <= 1:
299
+ for i, ref in enumerate(noise_reference):
300
+ noise_tensors[i].append(torch.from_numpy(rnd.randn(*ref.size()[1:])))
301
+ if label_size:
302
+ labels.append(torch.tensor([rnd.randint(0, label_size)]))
303
+ latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32)
304
+ if labels:
305
+ labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64)
306
+ else:
307
+ labels = None
308
+ if len(args.gpu) <= 1:
309
+ noise_tensors = [
310
+ torch.stack(noise, dim=0).to(device=device, dtype=torch.float32)
311
+ for noise in noise_tensors
312
+ ]
313
+ else:
314
+ noise_tensors = None
315
+ return latents, labels, noise_tensors
316
+
317
+ progress = utils.ProgressWriter(len(args.seeds))
318
+ progress.write('Generating images...', step=False)
319
+
320
+ for i in range(0, len(args.seeds), args.batch_size):
321
+ latents, labels, noise_tensors = get_batch(args.seeds[i: i + args.batch_size])
322
+ if noise_tensors is not None:
323
+ G.static_noise(noise_tensors=noise_tensors)
324
+ with torch.no_grad():
325
+ generated = G(latents, labels=labels)
326
+ images = utils.tensor_to_PIL(
327
+ generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
328
+ for seed, img in zip(args.seeds[i: i + args.batch_size], images):
329
+ img.save(os.path.join(args.output, 'seed%04d.png' % seed))
330
+ progress.step()
331
+
332
+ progress.write('Done!', step=False)
333
+ progress.close()
334
+
335
+ #----------------------------------------------------------------------------
336
+
337
+ def main():
338
+ args = get_arg_parser().parse_args()
339
+ assert args.command, 'Missing subcommand.'
340
+ assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
341
+ '--output argument should specify a directory, not a file.'
342
+ if not os.path.exists(args.output):
343
+ os.makedirs(args.output)
344
+
345
+ G = stylegan2.models.load(args.network)
346
+ G.eval()
347
+
348
+ assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
349
+ 'stylegan2.models.Generator. Found {}.'.format(type(G))
350
+
351
+ if args.command == 'generate_images':
352
+ generate_images(G, args)
353
+ elif args.command == 'style_mixing_example':
354
+ style_mixing_example(G, args)
355
+ else:
356
+ raise TypeError('Unkown command {}'.format(args.command))
357
+
358
+ #----------------------------------------------------------------------------
359
+
360
+ if __name__ == '__main__':
361
+ main()
run_metrics.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+
7
+ import stylegan2
8
+ from stylegan2 import utils
9
+
10
+ #----------------------------------------------------------------------------
11
+
12
+ _description = """Metrics evaluation.
13
+ Run 'python %(prog)s <subcommand> --help' for subcommand help."""
14
+
15
+ #----------------------------------------------------------------------------
16
+
17
+ _examples = """examples:
18
+ # Train a network or convert a pretrained one. In this example we first convert a pretrained one.
19
+ python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
20
+
21
+ # Project generated images
22
+ python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5
23
+
24
+ # Project real images
25
+ python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder
26
+ """
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ def _add_shared_arguments(parser):
31
+
32
+ parser.add_argument(
33
+ '--network',
34
+ help='Network file path',
35
+ required=True,
36
+ metavar='FILE'
37
+ )
38
+
39
+ parser.add_argument(
40
+ '--num_samples',
41
+ type=int,
42
+ help='Number of samples to gather for evaluating ' + \
43
+ 'this metric. Default: %(default)s',
44
+ default=50000,
45
+ metavar='VALUE'
46
+ )
47
+
48
+ parser.add_argument(
49
+ '--size',
50
+ type=int,
51
+ help='Rescale images so that this is the size of their ' + \
52
+ 'smallest side in pixels. Default: Unscaled',
53
+ default=None,
54
+ metavar='VALUE'
55
+ )
56
+
57
+ parser.add_argument(
58
+ '--batch_size',
59
+ help='Batch size for generator. Default: %(default)s',
60
+ type=int,
61
+ default=1,
62
+ metavar='VALUE'
63
+ )
64
+
65
+ parser.add_argument(
66
+ '--output',
67
+ help='Root directory for run results. Default: %(default)s',
68
+ type=str,
69
+ default='./results',
70
+ metavar='DIR'
71
+ )
72
+
73
+ parser.add_argument(
74
+ '--pixel_min',
75
+ help='Minumum of the value range of pixels in generated images. ' + \
76
+ 'Default: %(default)s',
77
+ default=-1,
78
+ type=float,
79
+ metavar='VALUE'
80
+ )
81
+
82
+ parser.add_argument(
83
+ '--pixel_max',
84
+ help='Maximum of the value range of pixels in generated images. ' + \
85
+ 'Default: %(default)s',
86
+ default=1,
87
+ type=float,
88
+ metavar='VALUE'
89
+ )
90
+
91
+ parser.add_argument(
92
+ '--gpu',
93
+ help='CUDA device indices (given as separate ' + \
94
+ 'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
95
+ type=int,
96
+ default=[],
97
+ nargs='*',
98
+ metavar='INDEX'
99
+ )
100
+
101
+
102
+ def get_arg_parser():
103
+ parser = argparse.ArgumentParser(
104
+ description=_description,
105
+ epilog=_examples,
106
+ formatter_class=argparse.RawDescriptionHelpFormatter
107
+ )
108
+
109
+ subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
110
+
111
+ fid_parser = subparsers.add_parser('fid', help='Calculate FID')
112
+
113
+ fid_parser.add_argument(
114
+ '--data_dir',
115
+ help='Dataset root directory',
116
+ required=True,
117
+ metavar='DIR'
118
+ )
119
+
120
+ fid_parser.add_argument(
121
+ '--reals_batch_size',
122
+ help='Batch size for gathering statistics of reals. Default: %(default)s',
123
+ type=int,
124
+ default=1,
125
+ metavar='VALUE'
126
+ )
127
+
128
+ fid_parser.add_argument(
129
+ '--reals_data_workers',
130
+ help='Data workers for fetching real data samples. Default: %(default)s',
131
+ type=int,
132
+ default=4,
133
+ metavar='VALUE'
134
+ )
135
+
136
+ fid_parser.add_argument(
137
+ '--truncation_psi',
138
+ help='Truncation psi. Default: %(default)s',
139
+ type=float,
140
+ default=1.0,
141
+ metavar='VALUE'
142
+ )
143
+
144
+ _add_shared_arguments(fid_parser)
145
+
146
+ ppl_parser = subparsers.add_parser('ppl', help='Calculate PPL')
147
+
148
+ ppl_parser.add_argument(
149
+ '--epsilon',
150
+ type=float,
151
+ help='Perturbation value. Default: %(default)s',
152
+ default=1e-4,
153
+ metavar='VALUE'
154
+ )
155
+
156
+ ppl_parser.add_argument(
157
+ '--use_dlatent',
158
+ type=utils.bool_type,
159
+ help='Measure on perturbations of disentangled latents ' + \
160
+ 'instead of raw latents. Default: %(default)s',
161
+ default=True,
162
+ const=True,
163
+ nargs='?',
164
+ metavar='BOOL'
165
+ )
166
+
167
+ ppl_parser.add_argument(
168
+ '--full_sampling',
169
+ type=utils.bool_type,
170
+ help='Measure on random interpolation between two inputs ' + \
171
+ 'instead of directly on one input. Default: %(default)s',
172
+ default=False,
173
+ const=True,
174
+ nargs='?',
175
+ metavar='BOOL'
176
+ )
177
+
178
+ parser.add_argument(
179
+ '--ppl_ffhq_crop',
180
+ help='Crop images evaluated for PPL with crop values ' + \
181
+ 'for FFHQ. Default: False',
182
+ type=utils.bool_type,
183
+ const=True,
184
+ nargs='?',
185
+ default=False,
186
+ metavar='BOOL'
187
+ )
188
+
189
+ _add_shared_arguments(ppl_parser)
190
+
191
+ return parser
192
+
193
+ #----------------------------------------------------------------------------
194
+
195
+ def _report_metric(value, name, args):
196
+ fpath = os.path.join(args.output, 'metrics.json')
197
+ metrics = {}
198
+ if os.path.exists(fpath):
199
+ with open(fpath, 'r') as fp:
200
+ try:
201
+ metrics = json.load(fp)
202
+ except Exception:
203
+ pass
204
+ metrics[name] = value
205
+ with open(fpath, 'w') as fp:
206
+ json.dump(metrics, fp)
207
+ print('\n\nMetric evaluated!:')
208
+ print('{}: {}'.format(name, value))
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ def eval_fid(G, prior_generator, args):
213
+ assert args.data_dir, '--data_dir has to be specified.'
214
+ dataset = utils.ImageFolder(
215
+ args.data_dir,
216
+ pixel_min=args.pixel_min,
217
+ pixel_max=args.pixel_max
218
+ )
219
+ assert len(dataset), 'No images found at {}'.format(args.data_dir)
220
+
221
+ inception = stylegan2.external_models.inception.InceptionV3FeatureExtractor(
222
+ pixel_min=args.pixel_min, pixel_max=args.pixel_max)
223
+
224
+ if len(args.gpu) > 1:
225
+ inception = torch.nn.DataParallel(inception, device_ids=args.gpu)
226
+
227
+ args.reals_batch_size = max(args.reals_batch_size, len(args.gpu))
228
+
229
+ fid = stylegan2.metrics.fid.FID(
230
+ G=G,
231
+ prior_generator=prior_generator,
232
+ dataset=dataset,
233
+ num_samples=args.num_samples,
234
+ fid_model=inception,
235
+ fid_size=args.size,
236
+ truncation_psi=args.truncation_psi,
237
+ reals_batch_size=args.reals_batch_size,
238
+ reals_data_workers=args.reals_data_workers
239
+ )
240
+
241
+ value = fid.evaluate()
242
+
243
+ name = 'FID'
244
+ if args.size:
245
+ name += '({})'.format(args.size)
246
+ if args.truncation_psi != 1:
247
+ name +='trunc{}'.format(args.truncation_psi)
248
+ name += ':{}k'.format(args.num_samples // 1000)
249
+
250
+ _report_metric(value, name, args)
251
+
252
+ #----------------------------------------------------------------------------
253
+
254
+ def eval_ppl(G, prior_generator, args):
255
+
256
+ lpips = stylegan2.external_models.lpips.LPIPS_VGG16(
257
+ pixel_min=args.pixel_min, pixel_max=args.pixel_max)
258
+
259
+ if len(args.gpu) > 1:
260
+ lpips = torch.nn.DataParallel(lpips, device_ids=args.gpu)
261
+
262
+ crop = None
263
+ if args.ppl_ffhq_crop:
264
+ crop = stylegan2.metrics.ppl.PPL.FFHQ_CROP
265
+
266
+ ppl = stylegan2.metrics.ppl.PPL(
267
+ G=G,
268
+ prior_generator=prior_generator,
269
+ num_samples=args.num_samples,
270
+ epsilon=args.epsilon,
271
+ use_dlatent=args.use_dlatent,
272
+ full_sampling=args.full_sampling,
273
+ crop=crop,
274
+ lpips_model=lpips,
275
+ lpips_size=args.size,
276
+ )
277
+
278
+ value = ppl.evaluate()
279
+
280
+ name = 'PPL'
281
+ if args.size:
282
+ name += '({})'.format(args.size)
283
+ if args.use_dlatent:
284
+ name += 'W'
285
+ else:
286
+ name += 'Z'
287
+ if args.full_sampling:
288
+ name += '-full'
289
+ else:
290
+ name += '-end'
291
+ name += ':{}k'.format(args.num_samples // 1000)
292
+
293
+ _report_metric(value, name, args)
294
+
295
+ #----------------------------------------------------------------------------
296
+
297
+ def main():
298
+ args = get_arg_parser().parse_args()
299
+ assert args.command, 'Missing subcommand.'
300
+ assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
301
+ '--output argument should specify a directory, not a file.'
302
+ if not os.path.exists(args.output):
303
+ os.makedirs(args.output)
304
+
305
+ G = stylegan2.models.load(args.network)
306
+ assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
307
+ 'stylegan2.models.Generator. Found {}.'.format(type(G))
308
+
309
+ latent_size, label_size = G.latent_size, G.label_size
310
+
311
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
312
+ if device.index is not None:
313
+ torch.cuda.set_device(device.index)
314
+
315
+ G.to(device).eval().requires_grad_(False)
316
+
317
+ if len(args.gpu) > 1:
318
+ G = torch.nn.DataParallel(G, device_ids=args.gpu)
319
+
320
+ args.batch_size = max(args.batch_size, len(args.gpu))
321
+
322
+ prior_generator = utils.PriorGenerator(
323
+ latent_size=latent_size,
324
+ label_size=label_size,
325
+ batch_size=args.batch_size,
326
+ device=device
327
+ )
328
+
329
+ if args.command == 'fid':
330
+ eval_fid(G, prior_generator, args)
331
+ elif args.command == 'ppl':
332
+ eval_ppl(G, prior_generator, args)
333
+ else:
334
+ raise TypeError('Unkown command {}'.format(args.command))
335
+
336
+
337
+ if __name__ == '__main__':
338
+ main()
run_projector.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+
6
+ import stylegan2
7
+ from stylegan2 import utils
8
+
9
+ #----------------------------------------------------------------------------
10
+
11
+ _description = """StyleGAN2 projector.
12
+ Run 'python %(prog)s <subcommand> --help' for subcommand help."""
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ _examples = """examples:
17
+ # Train a network or convert a pretrained one.
18
+ # Example of converting pretrained ffhq model:
19
+ python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
20
+
21
+ # Project generated images
22
+ python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5
23
+
24
+ # Project real images
25
+ python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder
26
+
27
+ """
28
+
29
+ #----------------------------------------------------------------------------
30
+
31
+ def _add_shared_arguments(parser):
32
+
33
+ parser.add_argument(
34
+ '--network',
35
+ help='Network file path',
36
+ required=True,
37
+ metavar='FILE'
38
+ )
39
+
40
+ parser.add_argument(
41
+ '--num_steps',
42
+ type=int,
43
+ help='Number of steps to use for projection. ' + \
44
+ 'Default: %(default)s',
45
+ default=1000,
46
+ metavar='VALUE'
47
+ )
48
+
49
+ parser.add_argument(
50
+ '--batch_size',
51
+ help='Batch size. Default: %(default)s',
52
+ type=int,
53
+ default=1,
54
+ metavar='VALUE'
55
+ )
56
+
57
+ parser.add_argument(
58
+ '--label',
59
+ help='Label to use for dlatent statistics gathering ' + \
60
+ '(should be integer index of class). Default: no label.',
61
+ type=int,
62
+ default=None,
63
+ metavar='CLASS_INDEX'
64
+ )
65
+
66
+ parser.add_argument(
67
+ '--initial_learning_rate',
68
+ help='Initial learning rate of projection. Default: %(default)s',
69
+ default=0.1,
70
+ type=float,
71
+ metavar='VALUE'
72
+ )
73
+
74
+ parser.add_argument(
75
+ '--initial_noise_factor',
76
+ help='Initial noise factor of projection. Default: %(default)s',
77
+ default=0.05,
78
+ type=float,
79
+ metavar='VALUE'
80
+ )
81
+
82
+ parser.add_argument(
83
+ '--lr_rampdown_length',
84
+ help='Learning rate rampdown length for projection. ' + \
85
+ 'Should be in range [0, 1]. Default: %(default)s',
86
+ default=0.25,
87
+ type=float,
88
+ metavar='VALUE'
89
+ )
90
+
91
+ parser.add_argument(
92
+ '--lr_rampup_length',
93
+ help='Learning rate rampup length for projection. ' + \
94
+ 'Should be in range [0, 1]. Default: %(default)s',
95
+ default=0.05,
96
+ type=float,
97
+ metavar='VALUE'
98
+ )
99
+
100
+ parser.add_argument(
101
+ '--noise_ramp_length',
102
+ help='Learning rate rampdown length for projection. ' + \
103
+ 'Should be in range [0, 1]. Default: %(default)s',
104
+ default=0.75,
105
+ type=float,
106
+ metavar='VALUE'
107
+ )
108
+
109
+ parser.add_argument(
110
+ '--regularize_noise_weight',
111
+ help='The weight for noise regularization. Default: %(default)s',
112
+ default=1e5,
113
+ type=float,
114
+ metavar='VALUE'
115
+ )
116
+
117
+ parser.add_argument(
118
+ '--output',
119
+ help='Root directory for run results. Default: %(default)s',
120
+ type=str,
121
+ default='./results',
122
+ metavar='DIR'
123
+ )
124
+
125
+ parser.add_argument(
126
+ '--num_snapshots',
127
+ help='Number of snapshots. Default: %(default)s',
128
+ type=int,
129
+ default=5,
130
+ metavar='VALUE'
131
+ )
132
+
133
+ parser.add_argument(
134
+ '--pixel_min',
135
+ help='Minumum of the value range of pixels in generated images. ' + \
136
+ 'Default: %(default)s',
137
+ default=-1,
138
+ type=float,
139
+ metavar='VALUE'
140
+ )
141
+
142
+ parser.add_argument(
143
+ '--pixel_max',
144
+ help='Maximum of the value range of pixels in generated images. ' + \
145
+ 'Default: %(default)s',
146
+ default=1,
147
+ type=float,
148
+ metavar='VALUE'
149
+ )
150
+
151
+ parser.add_argument(
152
+ '--gpu',
153
+ help='CUDA device indices (given as separate ' + \
154
+ 'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
155
+ type=int,
156
+ default=[],
157
+ nargs='*',
158
+ metavar='INDEX'
159
+ )
160
+
161
+ #----------------------------------------------------------------------------
162
+
163
+ def get_arg_parser():
164
+ parser = argparse.ArgumentParser(
165
+ description=_description,
166
+ epilog=_examples,
167
+ formatter_class=argparse.RawDescriptionHelpFormatter
168
+ )
169
+
170
+ range_desc = 'NOTE: This is a single argument, where list ' + \
171
+ 'elements are separated by "," and ranges are defined as "a-b". ' + \
172
+ 'Only integers are allowed.'
173
+
174
+ subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
175
+
176
+ project_generated_images_parser = subparsers.add_parser(
177
+ 'project_generated_images', help='Project generated images')
178
+
179
+ project_generated_images_parser.add_argument(
180
+ '--seeds',
181
+ help='List of random seeds for generating images. ' + \
182
+ 'Default: 66,230,389,1518. ' + range_desc,
183
+ type=utils.range_type,
184
+ default=[66, 230, 389, 1518],
185
+ metavar='RANGE'
186
+ )
187
+
188
+ project_generated_images_parser.add_argument(
189
+ '--truncation_psi',
190
+ help='Truncation psi. Default: %(default)s',
191
+ type=float,
192
+ default=1.0,
193
+ metavar='VALUE'
194
+ )
195
+
196
+ _add_shared_arguments(project_generated_images_parser)
197
+
198
+ project_real_images_parser = subparsers.add_parser(
199
+ 'project_real_images', help='Project real images')
200
+
201
+ project_real_images_parser.add_argument(
202
+ '--data_dir',
203
+ help='Dataset root directory',
204
+ type=str,
205
+ required=True,
206
+ metavar='DIR'
207
+ )
208
+
209
+ project_real_images_parser.add_argument(
210
+ '--seed',
211
+ help='When there are more images available than ' + \
212
+ 'the number that is going to be projected this ' + \
213
+ 'seed is used for picking samples. Default: %(default)s',
214
+ type=int,
215
+ default=1234,
216
+ metavar='VALUE'
217
+ )
218
+
219
+ project_real_images_parser.add_argument(
220
+ '--num_images',
221
+ type=int,
222
+ help='Number of images to project. Default: %(default)s',
223
+ default=3,
224
+ metavar='VALUE'
225
+ )
226
+
227
+ _add_shared_arguments(project_real_images_parser)
228
+
229
+ return parser
230
+
231
+ #----------------------------------------------------------------------------
232
+
233
+ def project_images(G, images, name_prefix, args):
234
+
235
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
236
+ if device.index is not None:
237
+ torch.cuda.set_device(device.index)
238
+ if len(args.gpu) > 1:
239
+ warnings.warn(
240
+ 'Multi GPU is not available for projection. ' + \
241
+ 'Using device {}'.format(device)
242
+ )
243
+ G = utils.unwrap_module(G).to(device)
244
+
245
+ lpips_model = stylegan2.external_models.lpips.LPIPS_VGG16(
246
+ pixel_min=args.pixel_min, pixel_max=args.pixel_max)
247
+
248
+ proj = stylegan2.project.Projector(
249
+ G=G,
250
+ dlatent_avg_samples=10000,
251
+ dlatent_avg_label=args.label,
252
+ dlatent_device=device,
253
+ dlatent_batch_size=1024,
254
+ lpips_model=lpips_model,
255
+ lpips_size=256
256
+ )
257
+
258
+ for i in range(0, len(images), args.batch_size):
259
+ target = images[i: i + args.batch_size]
260
+ proj.start(
261
+ target=target,
262
+ num_steps=args.num_steps,
263
+ initial_learning_rate=args.initial_learning_rate,
264
+ initial_noise_factor=args.initial_noise_factor,
265
+ lr_rampdown_length=args.lr_rampdown_length,
266
+ lr_rampup_length=args.lr_rampup_length,
267
+ noise_ramp_length=args.noise_ramp_length,
268
+ regularize_noise_weight=args.regularize_noise_weight,
269
+ verbose=True,
270
+ verbose_prefix='Projecting image(s) {}/{}'.format(
271
+ i * args.batch_size + len(target), len(images))
272
+ )
273
+ snapshot_steps = set(
274
+ args.num_steps - np.linspace(
275
+ 0, args.num_steps, args.num_snapshots, endpoint=False, dtype=int))
276
+ for k, image in enumerate(
277
+ utils.tensor_to_PIL(target, pixel_min=args.pixel_min, pixel_max=args.pixel_max)):
278
+ image.save(os.path.join(args.output, name_prefix[i + k] + 'target.png'))
279
+ for j in range(args.num_steps):
280
+ proj.step()
281
+ if j in snapshot_steps:
282
+ generated = utils.tensor_to_PIL(
283
+ proj.generate(), pixel_min=args.pixel_min, pixel_max=args.pixel_max)
284
+ for k, image in enumerate(generated):
285
+ image.save(os.path.join(
286
+ args.output, name_prefix[i + k] + 'step%04d.png' % (j + 1)))
287
+
288
+ #----------------------------------------------------------------------------
289
+
290
+ def project_generated_images(G, args):
291
+ latent_size, label_size = G.latent_size, G.label_size
292
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
293
+ if device.index is not None:
294
+ torch.cuda.set_device(device.index)
295
+ G.to(device)
296
+ if len(args.gpu) > 1:
297
+ warnings.warn(
298
+ 'Noise can not be randomized based on the seed ' + \
299
+ 'when using more than 1 GPU device. Noise will ' + \
300
+ 'now be randomized from default random state.'
301
+ )
302
+ G.random_noise()
303
+ G = torch.nn.DataParallel(G, device_ids=args.gpu)
304
+ else:
305
+ noise_reference = G.static_noise()
306
+
307
+ def get_batch(seeds):
308
+ latents = []
309
+ labels = []
310
+ if len(args.gpu) <= 1:
311
+ noise_tensors = [[] for _ in noise_reference]
312
+ for seed in seeds:
313
+ rnd = np.random.RandomState(seed)
314
+ latents.append(torch.from_numpy(rnd.randn(latent_size)))
315
+ if len(args.gpu) <= 1:
316
+ for i, ref in enumerate(noise_reference):
317
+ noise_tensors[i].append(
318
+ torch.from_numpy(rnd.randn(*ref.size()[1:])))
319
+ if label_size:
320
+ labels.append(torch.tensor([rnd.randint(0, label_size)]))
321
+ latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32)
322
+ if labels:
323
+ labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64)
324
+ else:
325
+ labels = None
326
+ if len(args.gpu) <= 1:
327
+ noise_tensors = [
328
+ torch.stack(noise, dim=0).to(device=device, dtype=torch.float32)
329
+ for noise in noise_tensors
330
+ ]
331
+ else:
332
+ noise_tensors = None
333
+ return latents, labels, noise_tensors
334
+
335
+ images = []
336
+
337
+ progress = utils.ProgressWriter(len(args.seeds))
338
+ progress.write('Generating images...', step=False)
339
+
340
+ for i in range(0, len(args.seeds), args.batch_size):
341
+ latents, labels, noise_tensors = get_batch(args.seeds[i: i + args.batch_size])
342
+ if noise_tensors is not None:
343
+ G.static_noise(noise_tensors=noise_tensors)
344
+ with torch.no_grad():
345
+ images.append(G(latents, labels=labels))
346
+ progress.step()
347
+
348
+ images = torch.cat(images, dim=0)
349
+
350
+ progress.write('Done!', step=False)
351
+ progress.close()
352
+
353
+ name_prefix = ['seed%04d-' % seed for seed in args.seeds]
354
+ project_images(G, images, name_prefix, args)
355
+
356
+ #----------------------------------------------------------------------------
357
+
358
+ def project_real_images(G, args):
359
+ device = torch.device(args.gpu[0] if args.gpu else 'cpu')
360
+ print('Loading images from "%s"...' % args.data_dir)
361
+ dataset = utils.ImageFolder(
362
+ args.data_dir, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
363
+
364
+ rnd = np.random.RandomState(args.seed)
365
+ indices = rnd.choice(
366
+ len(dataset), size=min(args.num_images, len(dataset)), replace=False)
367
+ images = []
368
+ for i in indices:
369
+ data = dataset[i]
370
+ if isinstance(data, (tuple, list)):
371
+ data = data[0]
372
+ images.append(data)
373
+ images = torch.stack(images).to(device)
374
+ name_prefix = ['image%04d-' % i for i in indices]
375
+ print('Done!')
376
+ project_images(G, images, name_prefix, args)
377
+
378
+ #----------------------------------------------------------------------------
379
+
380
+ def main():
381
+ args = get_arg_parser().parse_args()
382
+ assert args.command, 'Missing subcommand.'
383
+ assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
384
+ '--output argument should specify a directory, not a file.'
385
+ if not os.path.exists(args.output):
386
+ os.makedirs(args.output)
387
+
388
+ G = stylegan2.models.load(args.network)
389
+ assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
390
+ 'stylegan2.models.Generator. Found {}.'.format(type(G))
391
+
392
+ if args.command == 'project_generated_images':
393
+ project_generated_images(G, args)
394
+ elif args.command == 'project_real_images':
395
+ project_real_images(G, args)
396
+ else:
397
+ raise TypeError('Unkown command {}'.format(args.command))
398
+
399
+
400
+ if __name__ == '__main__':
401
+ main()
run_training.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import os
3
+ import torch
4
+ from torch import multiprocessing as mp
5
+ import stylegan2
6
+ from stylegan2 import utils
7
+ from stylegan2.external_models import inception, lpips
8
+ from stylegan2.metrics import fid, ppl
9
+
10
+ #----------------------------------------------------------------------------
11
+
12
+ def get_arg_parser():
13
+ parser = utils.ConfigArgumentParser()
14
+
15
+ parser.add_argument(
16
+ '--output',
17
+ help='Output directory for model weights.',
18
+ type=str,
19
+ default=None,
20
+ metavar='DIR'
21
+ )
22
+
23
+ #----------------------------------------------------------------------------
24
+ # Model options
25
+
26
+ parser.add_argument(
27
+ '--channels',
28
+ help='Specify the channels for each layer (can be overriden for individual ' + \
29
+ 'networks with "--g_channels" and "--d_channels". ' + \
30
+ 'Default: %(default)s',
31
+ nargs='*',
32
+ type=int,
33
+ default=[32, 32, 64, 128, 256, 512, 512, 512, 512],
34
+ metavar='CHANNELS'
35
+ )
36
+
37
+ parser.add_argument(
38
+ '--latent',
39
+ help='Size of the prior (noise vector). Default: %(default)s',
40
+ type=int,
41
+ default=512,
42
+ metavar='VALUE'
43
+ )
44
+
45
+ parser.add_argument(
46
+ '--label',
47
+ help='Number of unique labels. Unused if not specified.',
48
+ type=int,
49
+ default=0,
50
+ metavar='VALUE'
51
+ )
52
+
53
+ parser.add_argument(
54
+ '--base_shape',
55
+ help='Data shape of first layer in generator or ' + \
56
+ 'last layer in discriminator. Default: %(default)s',
57
+ nargs=2,
58
+ type=int,
59
+ default=(4, 4),
60
+ metavar='SIZE'
61
+ )
62
+
63
+ parser.add_argument(
64
+ '--kernel_size',
65
+ help='Size of conv kernel. Default: %(default)s',
66
+ type=int,
67
+ default=3,
68
+ metavar='SIZE'
69
+ )
70
+
71
+ parser.add_argument(
72
+ '--pad_once',
73
+ help='Pad filtered convs only once before filter instead ' + \
74
+ 'of twice. Default: %(default)s',
75
+ type=utils.bool_type,
76
+ const=True,
77
+ nargs='?',
78
+ default=True,
79
+ metavar='BOOL'
80
+ )
81
+
82
+ parser.add_argument(
83
+ '--pad_mode',
84
+ help='Padding mode for conv layers. Default: %(default)s',
85
+ type=str,
86
+ default='constant',
87
+ metavar='MODE'
88
+ )
89
+
90
+ parser.add_argument(
91
+ '--pad_constant',
92
+ help='Padding constant for conv layers when `pad_mode` is ' + \
93
+ '\'constant\'. Default: %(default)s',
94
+ type=float,
95
+ default=0,
96
+ metavar='VALUE'
97
+ )
98
+
99
+ parser.add_argument(
100
+ '--filter_pad_mode',
101
+ help='Padding mode for filter layers. Default: %(default)s',
102
+ type=str,
103
+ default='constant',
104
+ metavar='MODE'
105
+ )
106
+
107
+ parser.add_argument(
108
+ '--filter_pad_constant',
109
+ help='Padding constant for filter layers when `filter_pad_mode` ' + \
110
+ 'is \'constant\'. Default: %(default)s',
111
+ type=float,
112
+ default=0,
113
+ metavar='VALUE'
114
+ )
115
+
116
+ parser.add_argument(
117
+ '--filter',
118
+ help='Filter to use whenever FIR is applied. Default: %(default)s',
119
+ nargs='*',
120
+ type=float,
121
+ default=[1, 3, 3, 1],
122
+ metavar='VALUE'
123
+ )
124
+
125
+ parser.add_argument(
126
+ '--weight_scale',
127
+ help='Use weight scaling for equalized learning rate. Default: %(default)s',
128
+ type=utils.bool_type,
129
+ const=True,
130
+ nargs='?',
131
+ default=True,
132
+ metavar='BOOL'
133
+ )
134
+
135
+ #----------------------------------------------------------------------------
136
+ # Generator options
137
+
138
+ parser.add_argument(
139
+ '--g_file',
140
+ help='Load a generator model from a file instead of constructing a new one. Disabled unless a file is specified.',
141
+ type=str,
142
+ default=None,
143
+ metavar='FILE'
144
+ )
145
+
146
+ parser.add_argument(
147
+ '--g_channels',
148
+ help='Instead of the values of "--channels", ' + \
149
+ 'use these for the generator instead.',
150
+ nargs='*',
151
+ type=int,
152
+ default=[],
153
+ metavar='CHANNELS'
154
+ )
155
+
156
+ parser.add_argument(
157
+ '--g_skip',
158
+ help='Use skip connections for the generator. Default: %(default)s',
159
+ type=utils.bool_type,
160
+ const=True,
161
+ nargs='?',
162
+ default=True,
163
+ metavar='BOOL'
164
+ )
165
+
166
+ parser.add_argument(
167
+ '--g_resnet',
168
+ help='Use resnet connections for the generator. Default: %(default)s',
169
+ type=utils.bool_type,
170
+ const=True,
171
+ nargs='?',
172
+ default=False,
173
+ metavar='BOOL'
174
+ )
175
+
176
+ parser.add_argument(
177
+ '--g_conv_block_size',
178
+ help='Number of layers in a conv block in the generator. Default: %(default)s',
179
+ type=int,
180
+ default=2,
181
+ metavar='VALUE'
182
+ )
183
+
184
+ parser.add_argument(
185
+ '--g_normalize',
186
+ help='Normalize conv features for generator. Default: %(default)s',
187
+ type=utils.bool_type,
188
+ const=True,
189
+ nargs='?',
190
+ default=True,
191
+ metavar='BOOL'
192
+ )
193
+
194
+ parser.add_argument(
195
+ '--g_fused_conv',
196
+ help='Fuse conv & upsample into a transposed ' + \
197
+ 'conv for the generator. Default: %(default)s',
198
+ type=utils.bool_type,
199
+ const=True,
200
+ nargs='?',
201
+ default=True,
202
+ metavar='BOOL'
203
+ )
204
+
205
+ parser.add_argument(
206
+ '--g_activation',
207
+ help='The non-linear activaiton function for ' + \
208
+ 'the generator. Default: %(default)s',
209
+ default='leaky:0.2',
210
+ type=str,
211
+ metavar='ACTIVATION'
212
+ )
213
+
214
+ parser.add_argument(
215
+ '--g_conv_resample_mode',
216
+ help='Resample mode for upsampling conv ' + \
217
+ 'layers for generator. Default: %(default)s',
218
+ type=str,
219
+ default='FIR',
220
+ metavar='MODE'
221
+ )
222
+
223
+ parser.add_argument(
224
+ '--g_skip_resample_mode',
225
+ help='Resample mode for skip connection ' + \
226
+ 'upsamples for the generator. Default: %(default)s',
227
+ type=str,
228
+ default='FIR',
229
+ metavar='MODE'
230
+ )
231
+
232
+ parser.add_argument(
233
+ '--g_lr',
234
+ help='The learning rate for the generator. Default: %(default)s',
235
+ default=2e-3,
236
+ type=float,
237
+ metavar='VALUE'
238
+ )
239
+
240
+ parser.add_argument(
241
+ '--g_betas',
242
+ help='Beta values for the generator Adam optimizer. Default: %(default)s',
243
+ type=float,
244
+ nargs=2,
245
+ default=(0, 0.99),
246
+ metavar='VALUE'
247
+ )
248
+
249
+ parser.add_argument(
250
+ '--g_loss',
251
+ help='Loss function for the generator. Default: %(default)s',
252
+ default='logistic_ns',
253
+ type=str,
254
+ metavar='LOSS'
255
+ )
256
+
257
+ parser.add_argument(
258
+ '--g_reg',
259
+ help='Regularization function for the generator with an optional weight (:?). Default: %(default)s',
260
+ default='pathreg:2',
261
+ type=str,
262
+ metavar='REG'
263
+ )
264
+
265
+ parser.add_argument(
266
+ '--g_reg_interval',
267
+ help='Interval at which to regularize the generator. Default: %(default)s',
268
+ default=4,
269
+ type=int,
270
+ metavar='INTERVAL'
271
+ )
272
+
273
+ parser.add_argument(
274
+ '--g_iter',
275
+ help='Number of generator iterations per training iteration. Default: %(default)s',
276
+ default=1,
277
+ type=int,
278
+ metavar='ITER'
279
+ )
280
+
281
+ parser.add_argument(
282
+ '--style_mix',
283
+ help='The probability of passing more than one ' + \
284
+ 'latent to the generator. Default: %(default)s',
285
+ type=float,
286
+ default=0.9,
287
+ metavar='PROBABILITY'
288
+ )
289
+
290
+ parser.add_argument(
291
+ '--latent_mapping_layers',
292
+ help='The number of layers of the latent mapping network. Default: %(default)s',
293
+ type=int,
294
+ default=8,
295
+ metavar='LAYERS'
296
+ )
297
+
298
+ parser.add_argument(
299
+ '--latent_mapping_lr_mul',
300
+ help='The learning rate multiplier for the latent ' + \
301
+ 'mapping network. Default: %(default)s',
302
+ type=float,
303
+ default=0.01,
304
+ metavar='LR_MUL'
305
+ )
306
+
307
+ parser.add_argument(
308
+ '--normalize_latent',
309
+ help='Normalize latent inputs. Default: %(default)s',
310
+ type=utils.bool_type,
311
+ const=True,
312
+ nargs='?',
313
+ default=True,
314
+ metavar='BOOL'
315
+ )
316
+
317
+ parser.add_argument(
318
+ '--modulate_rgb',
319
+ help='Modulate RGB layers (use style for output ' + \
320
+ 'layers of generator). Default: %(default)s',
321
+ type=utils.bool_type,
322
+ const=True,
323
+ nargs='?',
324
+ default=True,
325
+ metavar='BOOL'
326
+ )
327
+
328
+ #----------------------------------------------------------------------------
329
+ # Discriminator options
330
+
331
+ parser.add_argument(
332
+ '--d_file',
333
+ help='Load a discriminator model from a file instead of constructing a new one. Disabled unless a file is specified.',
334
+ type=str,
335
+ default=None,
336
+ metavar='FILE'
337
+ )
338
+
339
+ parser.add_argument(
340
+ '--d_channels',
341
+ help='Instead of the values of "--channels", ' + \
342
+ 'use these for the discriminator instead.',
343
+ nargs='*',
344
+ type=int,
345
+ default=[],
346
+ metavar='CHANNELS'
347
+ )
348
+
349
+ parser.add_argument(
350
+ '--d_skip',
351
+ help='Use skip connections for the discriminator. Default: %(default)s',
352
+ type=utils.bool_type,
353
+ const=True,
354
+ nargs='?',
355
+ default=False,
356
+ metavar='BOOL'
357
+ )
358
+
359
+ parser.add_argument(
360
+ '--d_resnet',
361
+ help='Use resnet connections for the discriminator. Default: %(default)s',
362
+ type=utils.bool_type,
363
+ const=True,
364
+ nargs='?',
365
+ default=True,
366
+ metavar='BOOL'
367
+ )
368
+
369
+ parser.add_argument(
370
+ '--d_conv_block_size',
371
+ help='Number of layers in a conv block in the discriminator. Default: %(default)s',
372
+ type=int,
373
+ default=2,
374
+ metavar='VALUE'
375
+ )
376
+
377
+ parser.add_argument(
378
+ '--d_fused_conv',
379
+ help='Fuse conv & downsample into a strided ' + \
380
+ 'conv for the discriminator. Default: %(default)s',
381
+ type=utils.bool_type,
382
+ const=True,
383
+ nargs='?',
384
+ default=True,
385
+ metavar='BOOL'
386
+ )
387
+
388
+ parser.add_argument(
389
+ '--group_size',
390
+ help='Size of the groups in batch std layer. Default: %(default)s',
391
+ type=int,
392
+ default=4,
393
+ metavar='VALUE'
394
+ )
395
+
396
+ parser.add_argument(
397
+ '--d_activation',
398
+ help='The non-linear activaiton function for the discriminator. Default: %(default)s',
399
+ default='leaky:0.2',
400
+ type=str,
401
+ metavar='ACTIVATION'
402
+ )
403
+
404
+ parser.add_argument(
405
+ '--d_conv_resample_mode',
406
+ help='Resample mode for downsampling conv ' + \
407
+ 'layers for discriminator. Default: %(default)s',
408
+ type=str,
409
+ default='FIR',
410
+ metavar='MODE'
411
+ )
412
+
413
+ parser.add_argument(
414
+ '--d_skip_resample_mode',
415
+ help='Resample mode for skip connection ' + \
416
+ 'downsamples for the discriminator. Default: %(default)s',
417
+ type=str,
418
+ default='FIR',
419
+ metavar='MODE'
420
+ )
421
+
422
+ parser.add_argument(
423
+ '--d_loss',
424
+ help='Loss function for the disriminator. Default: %(default)s',
425
+ default='logistic',
426
+ type=str,
427
+ metavar='LOSS'
428
+ )
429
+
430
+ parser.add_argument(
431
+ '--d_reg',
432
+ help='Regularization function for the discriminator ' + \
433
+ 'with an optional weight (:?). Default: %(default)s',
434
+ default='r1:10',
435
+ type=str,
436
+ metavar='REG'
437
+ )
438
+
439
+ parser.add_argument(
440
+ '--d_reg_interval',
441
+ help='Interval at which to regularize the discriminator. Default: %(default)s',
442
+ default=16,
443
+ type=int,
444
+ metavar='INTERVAL'
445
+ )
446
+
447
+ parser.add_argument(
448
+ '--d_iter',
449
+ help='Number of discriminator iterations per training iteration. Default: %(default)s',
450
+ default=1,
451
+ type=int,
452
+ metavar='ITER'
453
+ )
454
+
455
+ parser.add_argument(
456
+ '--d_lr',
457
+ help='The learning rate for the discriminator. Default: %(default)s',
458
+ default=2e-3,
459
+ type=float,
460
+ metavar='VALUE'
461
+ )
462
+
463
+ parser.add_argument(
464
+ '--d_betas',
465
+ help='Beta values for the discriminator Adam optimizer. Default: %(default)s',
466
+ type=float,
467
+ nargs=2,
468
+ default=(0, 0.99),
469
+ metavar='VALUE'
470
+ )
471
+
472
+ #----------------------------------------------------------------------------
473
+ # Training options
474
+
475
+ parser.add_argument(
476
+ '--iterations',
477
+ help='Number of iterations to train for. Default: %(default)s',
478
+ type=int,
479
+ default=1000000,
480
+ metavar='ITERATIONS'
481
+ )
482
+
483
+ parser.add_argument(
484
+ '--gpu',
485
+ help='The cuda device(s) to use. Example: ""--gpu 0 1" will train ' + \
486
+ 'on GPU 0 and GPU 1. Default: Only use CPU',
487
+ type=int,
488
+ default=[],
489
+ nargs='*',
490
+ metavar='DEVICE_ID'
491
+ )
492
+
493
+ parser.add_argument(
494
+ '--distributed',
495
+ help='When more than one gpu device is passed, automatically ' + \
496
+ 'start one process for each device and give it the correct ' + \
497
+ 'distributed args (rank, world_size etc). Disable this if ' + \
498
+ 'you want training to be performed with only one process ' + \
499
+ 'using the DataParallel module. Default: %(default)s',
500
+ type=utils.bool_type,
501
+ const=True,
502
+ nargs='?',
503
+ default=True,
504
+ metavar='BOOL'
505
+ )
506
+
507
+ parser.add_argument(
508
+ '--rank',
509
+ help='Rank for distributed training.',
510
+ type=int,
511
+ default=None,
512
+ )
513
+
514
+ parser.add_argument(
515
+ '--world_size',
516
+ help='World size for distributed training.',
517
+ type=int,
518
+ default=None,
519
+ )
520
+
521
+ parser.add_argument(
522
+ '--master_addr',
523
+ help='Address for distributed training.',
524
+ type=str,
525
+ default=None,
526
+ )
527
+
528
+ parser.add_argument(
529
+ '--master_port',
530
+ help='Port for distributed training.',
531
+ type=str,
532
+ default=None,
533
+ )
534
+
535
+ parser.add_argument(
536
+ '--batch_size',
537
+ help='Size of each batch. Default: %(default)s',
538
+ default=32,
539
+ type=int,
540
+ metavar='VALUE'
541
+ )
542
+
543
+ parser.add_argument(
544
+ '--device_batch_size',
545
+ help='Maximum number of items to fit on single device at a time. Default: %(default)s',
546
+ default=4,
547
+ type=int,
548
+ metavar='VALUE'
549
+ )
550
+
551
+ parser.add_argument(
552
+ '--g_reg_batch_size',
553
+ help='Size of each batch used to regularize the generator. Default: %(default)s',
554
+ default=16,
555
+ type=int,
556
+ metavar='VALUE'
557
+ )
558
+
559
+ parser.add_argument(
560
+ '--g_reg_device_batch_size',
561
+ help='Maximum number of items to fit on single device when ' + \
562
+ 'regularizing the generator. Default: %(default)s',
563
+ default=2,
564
+ type=int,
565
+ metavar='VALUE'
566
+ )
567
+
568
+ parser.add_argument(
569
+ '--half',
570
+ help='Use mixed precision training. Default: %(default)s',
571
+ type=utils.bool_type,
572
+ const=True,
573
+ nargs='?',
574
+ default=False,
575
+ metavar='BOOL'
576
+ )
577
+
578
+ parser.add_argument(
579
+ '--resume',
580
+ help='Resume from the latest saved checkpoint in the checkpoint_dir. ' + \
581
+ 'This loads all previous training settings except for the dataset options, ' + \
582
+ 'device args (--gpu ...) and distributed training args (--rank, --world_size e.t.c) ' + \
583
+ 'as well as metrics and logging.',
584
+ type=utils.bool_type,
585
+ const=True,
586
+ nargs='?',
587
+ default=False,
588
+ metavar='BOOL'
589
+ )
590
+
591
+ #----------------------------------------------------------------------------
592
+ # Extra metric options
593
+
594
+ parser.add_argument(
595
+ '--fid_interval',
596
+ help='If specified, evaluate the FID metric with this interval.',
597
+ default=None,
598
+ type=int,
599
+ metavar='INTERVAL'
600
+ )
601
+
602
+ parser.add_argument(
603
+ '--ppl_interval',
604
+ help='If specified, evaluate the PPL metric with this interval.',
605
+ default=None,
606
+ type=int,
607
+ metavar='INTERVAL'
608
+ )
609
+
610
+ parser.add_argument(
611
+ '--ppl_ffhq_crop',
612
+ help='Crop images evaluated for PPL with crop values for FFHQ. Default: %(default)s',
613
+ type=utils.bool_type,
614
+ const=True,
615
+ nargs='?',
616
+ default=False,
617
+ metavar='BOOL'
618
+ )
619
+
620
+ #----------------------------------------------------------------------------
621
+ # Data options
622
+
623
+ parser.add_argument(
624
+ '--pixel_min',
625
+ help='Minimum of the value range of pixels in generated images. Default: %(default)s',
626
+ default=-1,
627
+ type=float,
628
+ metavar='VALUE'
629
+ )
630
+
631
+ parser.add_argument(
632
+ '--pixel_max',
633
+ help='Maximum of the value range of pixels in generated images. Default: %(default)s',
634
+ default=1,
635
+ type=float,
636
+ metavar='VALUE'
637
+ )
638
+
639
+ parser.add_argument(
640
+ '--data_channels',
641
+ help='Number of channels in the data. Default: 3 (RGB)',
642
+ default=3,
643
+ type=int,
644
+ choices=[1, 3],
645
+ metavar='CHANNELS'
646
+ )
647
+
648
+ parser.add_argument(
649
+ '--data_dir',
650
+ help='The root directory of the dataset. This argument is required!',
651
+ type=str,
652
+ default=None
653
+ )
654
+
655
+ parser.add_argument(
656
+ '--data_resize',
657
+ help='Resize data to fit input size of discriminator. Default: %(default)s',
658
+ type=utils.bool_type,
659
+ const=True,
660
+ nargs='?',
661
+ default=False,
662
+ metavar='BOOL'
663
+ )
664
+
665
+ parser.add_argument(
666
+ '--mirror_augment',
667
+ help='Use random horizontal flipping for data images. Default: %(default)s',
668
+ type=utils.bool_type,
669
+ const=True,
670
+ nargs='?',
671
+ default=False,
672
+ metavar='BOOL'
673
+ )
674
+
675
+ parser.add_argument(
676
+ '--data_workers',
677
+ help='Number of worker processes that handles dataloading. Default: %(default)s',
678
+ default=4,
679
+ type=int,
680
+ metavar='WORKERS'
681
+ )
682
+
683
+ #----------------------------------------------------------------------------
684
+ # Logging options
685
+
686
+ parser.add_argument(
687
+ '--checkpoint_dir',
688
+ help='If specified, save checkpoints to this directory.',
689
+ default=None,
690
+ type=str,
691
+ metavar='DIR'
692
+ )
693
+
694
+ parser.add_argument(
695
+ '--checkpoint_interval',
696
+ help='Save checkpoints with this interval. Default: %(default)s',
697
+ default=10000,
698
+ type=int,
699
+ metavar='INTERVAL'
700
+ )
701
+
702
+ parser.add_argument(
703
+ '--tensorboard_log_dir',
704
+ help='Log to this tensorboard directory if specified.',
705
+ default=None,
706
+ type=str,
707
+ metavar='DIR'
708
+ )
709
+
710
+ parser.add_argument(
711
+ '--tensorboard_image_interval',
712
+ help='Log images to tensorboard with this interval if specified.',
713
+ default=None,
714
+ type=int,
715
+ metavar='INTERVAL'
716
+ )
717
+
718
+ parser.add_argument(
719
+ '--tensorboard_image_size',
720
+ help='Size of images logged to tensorboard. Default: %(default)s',
721
+ default=256,
722
+ type=int,
723
+ metavar='VALUE'
724
+ )
725
+
726
+ return parser
727
+
728
+ #----------------------------------------------------------------------------
729
+
730
+ def get_dataset(args):
731
+ assert args.data_dir, '--data_dir has to be specified.'
732
+ height, width = [
733
+ shape * 2 ** (len(args.d_channels or args.channels) - 1)
734
+ for shape in args.base_shape
735
+ ]
736
+ dataset = utils.ImageFolder(
737
+ args.data_dir,
738
+ mirror=args.mirror_augment,
739
+ pixel_min=args.pixel_min,
740
+ pixel_max=args.pixel_max,
741
+ height=height,
742
+ width=width,
743
+ resize=args.data_resize,
744
+ grayscale=args.data_channels == 1
745
+ )
746
+ assert len(dataset), 'No images found at {}'.format(args.data_dir)
747
+ return dataset
748
+
749
+ #----------------------------------------------------------------------------
750
+
751
+ def get_models(args):
752
+ common_kwargs = dict(
753
+ data_channels=args.data_channels,
754
+ base_shape=args.base_shape,
755
+ conv_filter=args.filter,
756
+ skip_filter=args.filter,
757
+ kernel_size=args.kernel_size,
758
+ conv_pad_mode=args.pad_mode,
759
+ conv_pad_constant=args.pad_constant,
760
+ filter_pad_mode=args.filter_pad_mode,
761
+ filter_pad_constant=args.filter_pad_constant,
762
+ pad_once=args.pad_once,
763
+ weight_scale=args.weight_scale
764
+ )
765
+
766
+ if args.g_file:
767
+ G = stylegan2.models.load(args.g_file)
768
+ assert isinstance(G, stylegan2.models.Generator), \
769
+ '`--g_file` should specify a generator model, found {}'.format(type(G))
770
+ else:
771
+
772
+ G_M = stylegan2.models.GeneratorMapping(
773
+ latent_size=args.latent,
774
+ label_size=args.label,
775
+ num_layers=args.latent_mapping_layers,
776
+ hidden=args.latent,
777
+ activation=args.g_activation,
778
+ normalize_input=args.normalize_latent,
779
+ lr_mul=args.latent_mapping_lr_mul,
780
+ weight_scale=args.weight_scale
781
+ )
782
+
783
+ G_S = stylegan2.models.GeneratorSynthesis(
784
+ channels=args.g_channels or args.channels,
785
+ latent_size=args.latent,
786
+ demodulate=args.g_normalize,
787
+ modulate_data_out=args.modulate_rgb,
788
+ conv_block_size=args.g_conv_block_size,
789
+ activation=args.g_activation,
790
+ conv_resample_mode=args.g_conv_resample_mode,
791
+ skip_resample_mode=args.g_skip_resample_mode,
792
+ resnet=args.g_resnet,
793
+ skip=args.g_skip,
794
+ fused_resample=args.g_fused_conv,
795
+ **common_kwargs
796
+ )
797
+
798
+ G = stylegan2.models.Generator(G_mapping=G_M, G_synthesis=G_S)
799
+
800
+ if args.d_file:
801
+ D = stylegan2.models.load(args.d_file)
802
+ assert isinstance(D, stylegan2.models.Discriminator), \
803
+ '`--d_file` should specify a discriminator model, found {}'.format(type(D))
804
+ else:
805
+ D = stylegan2.models.Discriminator(
806
+ channels=args.d_channels or args.channels,
807
+ label_size=args.label,
808
+ conv_block_size=args.d_conv_block_size,
809
+ activation=args.d_activation,
810
+ conv_resample_mode=args.d_conv_resample_mode,
811
+ skip_resample_mode=args.d_skip_resample_mode,
812
+ mbstd_group_size=args.group_size,
813
+ resnet=args.d_resnet,
814
+ skip=args.d_skip,
815
+ fused_resample=args.d_fused_conv,
816
+ **common_kwargs
817
+ )
818
+ assert len(G.G_synthesis.channels) == len(D.channels), \
819
+ 'While the number of channels for each layer can ' + \
820
+ 'differ between generator and discriminator, the ' + \
821
+ 'number of layers have to be the same. Received ' + \
822
+ '{} generator layers and {} discriminator layers.'.format(
823
+ len(G.G_synthesis.channels), len(D.channels))
824
+
825
+ return G, D
826
+
827
+ #----------------------------------------------------------------------------
828
+
829
+ def get_trainer(args):
830
+ dataset = get_dataset(args)
831
+ if args.resume and stylegan2.train._find_checkpoint(args.checkpoint_dir):
832
+ trainer = stylegan2.train.Trainer.load_checkpoint(
833
+ args.checkpoint_dir,
834
+ dataset,
835
+ device=args.gpu,
836
+ rank=args.rank,
837
+ world_size=args.world_size,
838
+ master_addr=args.master_addr,
839
+ master_port=args.master_port,
840
+ tensorboard_log_dir=args.tensorboard_log_dir
841
+ )
842
+ else:
843
+ G, D = get_models(args)
844
+ trainer = stylegan2.train.Trainer(
845
+ G=G,
846
+ D=D,
847
+ latent_size=args.latent,
848
+ dataset=dataset,
849
+ device=args.gpu,
850
+ batch_size=args.batch_size,
851
+ device_batch_size=args.device_batch_size,
852
+ label_size=args.label,
853
+ data_workers=args.data_workers,
854
+ G_loss=args.g_loss,
855
+ D_loss=args.d_loss,
856
+ G_reg=args.g_reg,
857
+ G_reg_interval=args.g_reg_interval,
858
+ G_opt_kwargs={'lr': args.g_lr, 'betas': args.g_betas},
859
+ G_reg_batch_size=args.g_reg_batch_size,
860
+ G_reg_device_batch_size=args.g_reg_device_batch_size,
861
+ D_reg=args.d_reg,
862
+ D_reg_interval=args.d_reg_interval,
863
+ D_opt_kwargs={'lr': args.d_lr, 'betas': args.d_betas},
864
+ style_mix_prob=args.style_mix,
865
+ G_iter=args.g_iter,
866
+ D_iter=args.d_iter,
867
+ tensorboard_log_dir=args.tensorboard_log_dir,
868
+ checkpoint_dir=args.checkpoint_dir,
869
+ checkpoint_interval=args.checkpoint_interval,
870
+ half=args.half,
871
+ rank=args.rank,
872
+ world_size=args.world_size,
873
+ master_addr=args.master_addr,
874
+ master_port=args.master_port
875
+ )
876
+ if args.fid_interval and not args.rank:
877
+ fid_model = inception.InceptionV3FeatureExtractor(
878
+ pixel_min=args.pixel_min, pixel_max=args.pixel_max)
879
+ trainer.register_metric(
880
+ name='FID (299x299)',
881
+ eval_fn=fid.FID(
882
+ trainer.Gs,
883
+ trainer.prior_generator,
884
+ dataset=dataset,
885
+ fid_model=fid_model,
886
+ fid_size=299,
887
+ reals_batch_size=64
888
+ ),
889
+ interval=args.fid_interval
890
+ )
891
+ trainer.register_metric(
892
+ name='FID',
893
+ eval_fn=fid.FID(
894
+ trainer.Gs,
895
+ trainer.prior_generator,
896
+ dataset=dataset,
897
+ fid_model=fid_model,
898
+ fid_size=None
899
+ ),
900
+ interval=args.fid_interval
901
+ )
902
+ if args.ppl_interval and not args.rank:
903
+ lpips_model = lpips.LPIPS_VGG16(
904
+ pixel_min=args.pixel_min, pixel_max=args.pixel_max)
905
+ crop = None
906
+ if args.ppl_ffhq_crop:
907
+ crop = ppl.PPL.FFHQ_CROP
908
+ trainer.register_metric(
909
+ name='PPL_end',
910
+ eval_fn=ppl.PPL(
911
+ trainer.Gs,
912
+ trainer.prior_generator,
913
+ full_sampling=False,
914
+ crop=crop,
915
+ lpips_model=lpips_model,
916
+ lpips_size=256
917
+ ),
918
+ interval=args.ppl_interval
919
+ )
920
+ trainer.register_metric(
921
+ name='PPL_full',
922
+ eval_fn=ppl.PPL(
923
+ trainer.Gs,
924
+ trainer.prior_generator,
925
+ full_sampling=True,
926
+ crop=crop,
927
+ lpips_model=lpips_model,
928
+ lpips_size=256
929
+ ),
930
+ interval=args.ppl_interval
931
+ )
932
+ if args.tensorboard_image_interval:
933
+ for static in [True, False]:
934
+ for trunc in [0.5, 0.7, 1.0]:
935
+ if static:
936
+ name = 'static'
937
+ else:
938
+ name = 'random'
939
+ name += '/trunc_{:.1f}'.format(trunc)
940
+ trainer.add_tensorboard_image_logging(
941
+ name=name,
942
+ num_images=4,
943
+ interval=args.tensorboard_image_interval,
944
+ resize=args.tensorboard_image_size,
945
+ seed=1234567890 if static else None,
946
+ truncation_psi=trunc,
947
+ pixel_min=args.pixel_min,
948
+ pixel_max=args.pixel_max
949
+ )
950
+ return trainer
951
+
952
+ #----------------------------------------------------------------------------
953
+
954
+ def run(args):
955
+ if not args.rank:
956
+ if not (args.checkpoint_dir or args.output):
957
+ warnings.warn(
958
+ 'Neither an output path or checkpoint dir has been ' + \
959
+ 'given. Weights from this training run will never ' + \
960
+ 'be saved.'
961
+ )
962
+ if args.output:
963
+ assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
964
+ '--output argument should specify a directory, not a file.'
965
+ trainer = get_trainer(args)
966
+ trainer.train(iterations=args.iterations)
967
+ if not args.rank and args.output:
968
+ print('Saving models to {}'.format(args.output))
969
+ if not os.path.exists(args.output):
970
+ os.makedirs(args.output)
971
+ for model_name in ['G', 'D', 'Gs']:
972
+ getattr(trainer, model_name).save(
973
+ os.path.join(args.output_dir, model_name + '.pth'))
974
+
975
+ #----------------------------------------------------------------------------
976
+
977
+ def run_distributed(rank, args):
978
+ args.rank = rank
979
+ args.world_size = len(args.gpu)
980
+ args.gpu = args.gpu[rank]
981
+ args.master_addr = args.master_addr or '127.0.0.1'
982
+ args.master_port = args.master_port or '23456'
983
+ run(args)
984
+
985
+ #----------------------------------------------------------------------------
986
+
987
+ def main():
988
+ parser = get_arg_parser()
989
+ args = parser.parse_args()
990
+ if len(args.gpu) > 1 and args.distributed:
991
+ assert args.rank is None and args.world_size is None, \
992
+ 'When --distributed is enabled (default) the rank and ' + \
993
+ 'world size can not be given as this is set up automatically. ' + \
994
+ 'Use --distributed 0 to disable automatic setup of distributed training.'
995
+ mp.spawn(run_distributed, nprocs=len(args.gpu), args=(args,))
996
+ else:
997
+ run(args)
998
+
999
+ #----------------------------------------------------------------------------
1000
+
1001
+ if __name__ == '__main__':
1002
+ main()
stylegan2/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import external_models
2
+ from . import metrics
3
+ from . import models
4
+ from . import project
5
+ from . import train
stylegan2/external_models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import inception
2
+ from . import lpips
stylegan2/external_models/inception.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from https://github.com/mseitzer/pytorch-fid/
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+ Unless required by applicable law or agreed to in writing, software
9
+ distributed under the License is distributed on an "AS IS" BASIS,
10
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ See the License for the specific language governing permissions and
12
+ limitations under the License.
13
+ """
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torchvision import models
18
+
19
+ try:
20
+ from torchvision.models.utils import load_state_dict_from_url
21
+ except ImportError:
22
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
23
+
24
+ # Inception weights ported to Pytorch from
25
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
26
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
27
+
28
+
29
+ class InceptionV3FeatureExtractor(nn.Module):
30
+ """Pretrained InceptionV3 network returning feature maps"""
31
+
32
+ # Index of default block of inception to return,
33
+ # corresponds to output of final average pooling
34
+ DEFAULT_BLOCK_INDEX = 3
35
+
36
+ # Maps feature dimensionality to their output blocks indices
37
+ BLOCK_INDEX_BY_DIM = {
38
+ 64: 0, # First max pooling features
39
+ 192: 1, # Second max pooling featurs
40
+ 768: 2, # Pre-aux classifier features
41
+ 2048: 3 # Final average pooling features
42
+ }
43
+
44
+ def __init__(self,
45
+ output_block=DEFAULT_BLOCK_INDEX,
46
+ pixel_min=-1,
47
+ pixel_max=1):
48
+ """
49
+ Build pretrained InceptionV3
50
+ Arguments:
51
+ output_block (int): Index of block to return features of.
52
+ Possible values are:
53
+ - 0: corresponds to output of first max pooling
54
+ - 1: corresponds to output of second max pooling
55
+ - 2: corresponds to output which is fed to aux classifier
56
+ - 3: corresponds to output of final average pooling
57
+ pixel_min (float): Min value for inputs. Default value is -1.
58
+ pixel_max (float): Max value for inputs. Default value is 1.
59
+ """
60
+ super(InceptionV3FeatureExtractor, self).__init__()
61
+
62
+ assert 0 <= output_block <= 3, '`output_block` can only be ' + \
63
+ '0 <= `output_block` <= 3.'
64
+
65
+ inception = fid_inception_v3()
66
+
67
+ blocks = []
68
+
69
+ # Block 0: input to maxpool1
70
+ block0 = [
71
+ inception.Conv2d_1a_3x3,
72
+ inception.Conv2d_2a_3x3,
73
+ inception.Conv2d_2b_3x3,
74
+ nn.MaxPool2d(kernel_size=3, stride=2)
75
+ ]
76
+ blocks.append(nn.Sequential(*block0))
77
+
78
+ # Block 1: maxpool1 to maxpool2
79
+ if output_block >= 1:
80
+ block1 = [
81
+ inception.Conv2d_3b_1x1,
82
+ inception.Conv2d_4a_3x3,
83
+ nn.MaxPool2d(kernel_size=3, stride=2)
84
+ ]
85
+ blocks.append(nn.Sequential(*block1))
86
+
87
+ # Block 2: maxpool2 to aux classifier
88
+ if output_block >= 2:
89
+ block2 = [
90
+ inception.Mixed_5b,
91
+ inception.Mixed_5c,
92
+ inception.Mixed_5d,
93
+ inception.Mixed_6a,
94
+ inception.Mixed_6b,
95
+ inception.Mixed_6c,
96
+ inception.Mixed_6d,
97
+ inception.Mixed_6e,
98
+ ]
99
+ blocks.append(nn.Sequential(*block2))
100
+
101
+ # Block 3: aux classifier to final avgpool
102
+ if output_block >= 3:
103
+ block3 = [
104
+ inception.Mixed_7a,
105
+ inception.Mixed_7b,
106
+ inception.Mixed_7c,
107
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
108
+ ]
109
+ blocks.append(nn.Sequential(*block3))
110
+
111
+ self.main = nn.Sequential(*blocks)
112
+ self.pixel_nin = pixel_min
113
+ self.pixel_max = pixel_max
114
+ self.requires_grad_(False)
115
+ self.eval()
116
+
117
+ def _scale(self, x):
118
+ if self.pixel_min != -1 or self.pixel_max != 1:
119
+ x = (2*x - self.pixel_min - self.pixel_max) \
120
+ / (self.pixel_max - self.pixel_min)
121
+ return x
122
+
123
+ def forward(self, input):
124
+ """
125
+ Get Inception feature maps.
126
+ Arguments:
127
+ input (torch.Tensor)
128
+ Returns:
129
+ feature_maps (torch.Tensor)
130
+ """
131
+ return self.main(input)
132
+
133
+
134
+ def fid_inception_v3():
135
+ """Build pretrained Inception model for FID computation
136
+ The Inception model for FID computation uses a different set of weights
137
+ and has a slightly different structure than torchvision's Inception.
138
+ This method first constructs torchvision's Inception and then patches the
139
+ necessary parts that are different in the FID Inception model.
140
+ """
141
+ inception = models.inception_v3(num_classes=1008,
142
+ aux_logits=False,
143
+ pretrained=False)
144
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
145
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
146
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
147
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
148
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
149
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
150
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
151
+ inception.Mixed_7b = FIDInceptionE_1(1280)
152
+ inception.Mixed_7c = FIDInceptionE_2(2048)
153
+
154
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
155
+ inception.load_state_dict(state_dict)
156
+ return inception
157
+
158
+
159
+ class FIDInceptionA(models.inception.InceptionA):
160
+ """InceptionA block patched for FID computation"""
161
+ def __init__(self, in_channels, pool_features):
162
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
163
+
164
+ def forward(self, x):
165
+ branch1x1 = self.branch1x1(x)
166
+
167
+ branch5x5 = self.branch5x5_1(x)
168
+ branch5x5 = self.branch5x5_2(branch5x5)
169
+
170
+ branch3x3dbl = self.branch3x3dbl_1(x)
171
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
172
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
173
+
174
+ # Patch: Tensorflow's average pool does not use the padded zero's in
175
+ # its average calculation
176
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
177
+ count_include_pad=False)
178
+ branch_pool = self.branch_pool(branch_pool)
179
+
180
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
181
+ return torch.cat(outputs, 1)
182
+
183
+
184
+ class FIDInceptionC(models.inception.InceptionC):
185
+ """InceptionC block patched for FID computation"""
186
+ def __init__(self, in_channels, channels_7x7):
187
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
188
+
189
+ def forward(self, x):
190
+ branch1x1 = self.branch1x1(x)
191
+
192
+ branch7x7 = self.branch7x7_1(x)
193
+ branch7x7 = self.branch7x7_2(branch7x7)
194
+ branch7x7 = self.branch7x7_3(branch7x7)
195
+
196
+ branch7x7dbl = self.branch7x7dbl_1(x)
197
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
198
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
199
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
200
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
201
+
202
+ # Patch: Tensorflow's average pool does not use the padded zero's in
203
+ # its average calculation
204
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
205
+ count_include_pad=False)
206
+ branch_pool = self.branch_pool(branch_pool)
207
+
208
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
209
+ return torch.cat(outputs, 1)
210
+
211
+
212
+ class FIDInceptionE_1(models.inception.InceptionE):
213
+ """First InceptionE block patched for FID computation"""
214
+ def __init__(self, in_channels):
215
+ super(FIDInceptionE_1, self).__init__(in_channels)
216
+
217
+ def forward(self, x):
218
+ branch1x1 = self.branch1x1(x)
219
+
220
+ branch3x3 = self.branch3x3_1(x)
221
+ branch3x3 = [
222
+ self.branch3x3_2a(branch3x3),
223
+ self.branch3x3_2b(branch3x3),
224
+ ]
225
+ branch3x3 = torch.cat(branch3x3, 1)
226
+
227
+ branch3x3dbl = self.branch3x3dbl_1(x)
228
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
229
+ branch3x3dbl = [
230
+ self.branch3x3dbl_3a(branch3x3dbl),
231
+ self.branch3x3dbl_3b(branch3x3dbl),
232
+ ]
233
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
234
+
235
+ # Patch: Tensorflow's average pool does not use the padded zero's in
236
+ # its average calculation
237
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
238
+ count_include_pad=False)
239
+ branch_pool = self.branch_pool(branch_pool)
240
+
241
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
242
+ return torch.cat(outputs, 1)
243
+
244
+
245
+ class FIDInceptionE_2(models.inception.InceptionE):
246
+ """Second InceptionE block patched for FID computation"""
247
+ def __init__(self, in_channels):
248
+ super(FIDInceptionE_2, self).__init__(in_channels)
249
+
250
+ def forward(self, x):
251
+ branch1x1 = self.branch1x1(x)
252
+
253
+ branch3x3 = self.branch3x3_1(x)
254
+ branch3x3 = [
255
+ self.branch3x3_2a(branch3x3),
256
+ self.branch3x3_2b(branch3x3),
257
+ ]
258
+ branch3x3 = torch.cat(branch3x3, 1)
259
+
260
+ branch3x3dbl = self.branch3x3dbl_1(x)
261
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
262
+ branch3x3dbl = [
263
+ self.branch3x3dbl_3a(branch3x3dbl),
264
+ self.branch3x3dbl_3b(branch3x3dbl),
265
+ ]
266
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
267
+
268
+ # Patch: The FID Inception model uses max pooling instead of average
269
+ # pooling. This is likely an error in this specific Inception
270
+ # implementation, as other Inception models use average pooling here
271
+ # (which matches the description in the paper).
272
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
273
+ branch_pool = self.branch_pool(branch_pool)
274
+
275
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276
+ return torch.cat(outputs, 1)
stylegan2/external_models/lpips.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from https://github.com/richzhang/PerceptualSimilarity
3
+
4
+ Original License:
5
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
6
+ All rights reserved.
7
+
8
+ Redistribution and use in source and binary forms, with or without
9
+ modification, are permitted provided that the following conditions are met:
10
+
11
+ * Redistributions of source code must retain the above copyright notice, this
12
+ list of conditions and the following disclaimer.
13
+
14
+ * Redistributions in binary form must reproduce the above copyright notice,
15
+ this list of conditions and the following disclaimer in the documentation
16
+ and/or other materials provided with the distribution.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+ """
29
+ import torch
30
+ from torch import nn
31
+ import torchvision
32
+
33
+
34
+ class LPIPS_VGG16(nn.Module):
35
+ _FEATURE_IDX = [0, 4, 9, 16, 23, 30]
36
+ _LINEAR_WEIGHTS_URL = 'https://github.com/richzhang/PerceptualSimilarity' + \
37
+ '/blob/master/lpips/weights/v0.1/vgg.pth?raw=true'
38
+
39
+ def __init__(self, pixel_min=-1, pixel_max=1):
40
+ super(LPIPS_VGG16, self).__init__()
41
+ features = torchvision.models.vgg16(pretrained=True).features
42
+ self.slices = nn.ModuleList()
43
+ linear_weights = torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL)
44
+ for i in range(1, len(self._FEATURE_IDX)):
45
+ idx_range = range(self._FEATURE_IDX[i - 1], self._FEATURE_IDX[i])
46
+ self.slices.append(nn.Sequential(*[features[j] for j in idx_range]))
47
+ self.linear_layers = nn.ModuleList()
48
+ for weight in torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL).values():
49
+ weight = weight.view(1, -1)
50
+ linear = nn.Linear(weight.size(1), 1, bias=False)
51
+ linear.weight.data.copy_(weight)
52
+ self.linear_layers.append(linear)
53
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188]).view(1, -1, 1, 1))
54
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450]).view(1, -1, 1, 1))
55
+ self.pixel_min = pixel_min
56
+ self.pixel_max = pixel_max
57
+ self.requires_grad_(False)
58
+ self.eval()
59
+
60
+ def _scale(self, x):
61
+ if self.pixel_min != -1 or self.pixel_max != 1:
62
+ x = (2*x - self.pixel_min - self.pixel_max) \
63
+ / (self.pixel_max - self.pixel_min)
64
+ return (x - self.shift) / self.scale
65
+
66
+ @staticmethod
67
+ def _normalize_tensor(feature_maps, eps=1e-8):
68
+ rnorm = torch.rsqrt(torch.sum(feature_maps ** 2, dim=1, keepdim=True) + eps)
69
+ return feature_maps * rnorm
70
+
71
+ def forward(self, x0, x1, eps=1e-8):
72
+ x0, x1 = self._scale(x0), self._scale(x1)
73
+ dist = 0
74
+ for slice, linear in zip(self.slices, self.linear_layers):
75
+ x0, x1 = slice(x0), slice(x1)
76
+ _x0, _x1 = self._normalize_tensor(x0, eps), self._normalize_tensor(x1, eps)
77
+ dist += linear(torch.mean((_x0 - _x1) ** 2, dim=[-1, -2]))
78
+ return dist.view(-1)
stylegan2/loss_fns.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ from . import utils
6
+
7
+
8
+ def _grad(input, output, retain_graph):
9
+ # https://discuss.pytorch.org/t/gradient-penalty-loss-with-modified-weights/64910
10
+ # Currently not possible to not
11
+ # retain graph for regularization losses.
12
+ # Ugly hack is to always set it to True.
13
+ retain_graph = True
14
+ grads = torch.autograd.grad(
15
+ output.sum(),
16
+ input,
17
+ only_inputs=True,
18
+ retain_graph=retain_graph,
19
+ create_graph=True
20
+ )
21
+ return grads[0]
22
+
23
+
24
+ def _grad_pen(input, output, gamma, constraint=1, onesided=False, retain_graph=True):
25
+ grad = _grad(input, output, retain_graph=retain_graph)
26
+ grad = grad.view(grad.size(0), -1)
27
+ grad_norm = grad.norm(2, dim=1)
28
+ if onesided:
29
+ gp = torch.max(0, grad_norm - constraint)
30
+ else:
31
+ gp = (grad_norm - constraint) ** 2
32
+ return gamma * gp.mean()
33
+
34
+
35
+ def _grad_reg(input, output, gamma, retain_graph=True):
36
+ grad = _grad(input, output, retain_graph=retain_graph)
37
+ grad = grad.view(grad.size(0), -1)
38
+ gr = (grad ** 2).sum(1)
39
+ return (0.5 * gamma) * gr.mean()
40
+
41
+
42
+ def _pathreg(dlatents, fakes, pl_avg, pl_decay, gamma, retain_graph=True):
43
+ retain_graph = True
44
+ pl_noise = torch.empty_like(fakes).normal_().div_(np.sqrt(np.prod(fakes.size()[2:])))
45
+ pl_grad = _grad(dlatents, torch.sum(pl_noise * fakes), retain_graph=retain_graph)
46
+ pl_length = torch.sqrt(torch.mean(torch.sum(pl_grad ** 2, dim=2), dim=1))
47
+ with torch.no_grad():
48
+ pl_avg.add_(pl_decay * (torch.mean(pl_length) - pl_avg))
49
+ return gamma * torch.mean((pl_length - pl_avg) ** 2)
50
+
51
+
52
+ #----------------------------------------------------------------------------
53
+ # Logistic loss from the paper
54
+ # "Generative Adversarial Nets", Goodfellow et al. 2014
55
+
56
+
57
+ def G_logistic(G,
58
+ D,
59
+ latents,
60
+ latent_labels=None,
61
+ *args,
62
+ **kwargs):
63
+ fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
64
+ loss = - F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
65
+ reg = None
66
+ return loss, reg
67
+
68
+
69
+ def G_logistic_ns(G,
70
+ D,
71
+ latents,
72
+ latent_labels=None,
73
+ *args,
74
+ **kwargs):
75
+ fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
76
+ loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
77
+ reg = None
78
+ return loss, reg
79
+
80
+
81
+ def D_logistic(G,
82
+ D,
83
+ latents,
84
+ reals,
85
+ latent_labels=None,
86
+ real_labels=None,
87
+ *args,
88
+ **kwargs):
89
+ assert (latent_labels is None) == (real_labels is None)
90
+ with torch.no_grad():
91
+ fakes = G(latents, labels=latent_labels)
92
+ real_scores = D(reals, labels=real_labels).float()
93
+ fake_scores = D(fakes, labels=latent_labels).float()
94
+ real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
95
+ fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
96
+ loss = real_loss + fake_loss
97
+ reg = None
98
+ return loss, reg
99
+
100
+
101
+ #----------------------------------------------------------------------------
102
+ # R1 and R2 regularizers from the paper
103
+ # "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
104
+
105
+
106
+ def D_r1(D,
107
+ reals,
108
+ real_labels=None,
109
+ gamma=10,
110
+ *args,
111
+ **kwargs):
112
+ loss = None
113
+ reg = None
114
+ if gamma:
115
+ reals.requires_grad_(True)
116
+ real_scores = D(reals, labels=real_labels)
117
+ reg = _grad_reg(
118
+ input=reals, output=real_scores, gamma=gamma, retain_graph=False).float()
119
+ return loss, reg
120
+
121
+
122
+ def D_r2(D,
123
+ G,
124
+ latents,
125
+ latent_labels=None,
126
+ gamma=10,
127
+ *args,
128
+ **kwargs):
129
+ loss = None
130
+ reg = None
131
+ if gamma:
132
+ with torch.no_grad():
133
+ fakes = G(latents, labels=latent_labels)
134
+ fakes.requires_grad_(True)
135
+ fake_scores = D(fakes, labels=latent_labels)
136
+ reg = _grad_reg(
137
+ input=fakes, output=fake_scores, gamma=gamma, retain_graph=False).float()
138
+ return loss, reg
139
+
140
+
141
+ def D_logistic_r1(G,
142
+ D,
143
+ latents,
144
+ reals,
145
+ latent_labels=None,
146
+ real_labels=None,
147
+ gamma=10,
148
+ *args,
149
+ **kwargs):
150
+ assert (latent_labels is None) == (real_labels is None)
151
+ with torch.no_grad():
152
+ fakes = G(latents, labels=latent_labels)
153
+ if gamma:
154
+ reals.requires_grad_(True)
155
+ real_scores = D(reals, labels=real_labels).float()
156
+ fake_scores = D(fakes, labels=latent_labels).float()
157
+ real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
158
+ fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
159
+ loss = real_loss + fake_loss
160
+ reg = None
161
+ if gamma:
162
+ reg = _grad_reg(
163
+ input=reals, output=real_scores, gamma=gamma, retain_graph=True).float()
164
+ return loss, reg
165
+
166
+
167
+ def D_logistic_r2(G,
168
+ D,
169
+ latents,
170
+ reals,
171
+ latent_labels=None,
172
+ real_labels=None,
173
+ gamma=10,
174
+ *args,
175
+ **kwargs):
176
+ assert (latent_labels is None) == (real_labels is None)
177
+ with torch.no_grad():
178
+ fakes = G(latents, labels=latent_labels)
179
+ if gamma:
180
+ fakes.requires_grad_(True)
181
+ real_scores = D(reals, labels=real_labels).float()
182
+ fake_scores = D(fakes, labels=latent_labels).float()
183
+ real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
184
+ fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
185
+ loss = real_loss + fake_loss
186
+ reg = None
187
+ if gamma:
188
+ reg = _grad_reg(
189
+ input=fakes, output=fake_scores, gamma=gamma, retain_graph=True).float()
190
+ return loss, reg
191
+
192
+
193
+ #----------------------------------------------------------------------------
194
+ # Non-saturating logistic loss with path length regularizer from the paper
195
+ # "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
196
+
197
+
198
+ def G_pathreg(G,
199
+ latents,
200
+ pl_avg,
201
+ latent_labels=None,
202
+ pl_decay=0.01,
203
+ gamma=2,
204
+ *args,
205
+ **kwargs):
206
+ loss = None
207
+ reg = None
208
+ if gamma:
209
+ fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True, mapping_grad=False)
210
+ reg = _pathreg(
211
+ dlatents=dlatents,
212
+ fakes=fakes,
213
+ pl_avg=pl_avg,
214
+ pl_decay=pl_decay,
215
+ gamma=gamma,
216
+ retain_graph=False
217
+ ).float()
218
+ return loss, reg
219
+
220
+
221
+ def G_logistic_ns_pathreg(G,
222
+ D,
223
+ latents,
224
+ pl_avg,
225
+ latent_labels=None,
226
+ pl_decay=0.01,
227
+ gamma=2,
228
+ *args,
229
+ **kwargs):
230
+ fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True)
231
+ fake_scores = D(fakes, labels=latent_labels).float()
232
+ loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
233
+ reg = None
234
+ if gamma:
235
+ reg = _pathreg(
236
+ dlatents=dlatents,
237
+ fakes=fakes,
238
+ pl_avg=pl_avg,
239
+ pl_decay=pl_decay,
240
+ gamma=gamma,
241
+ retain_graph=True
242
+ ).float()
243
+ return loss, reg
244
+
245
+
246
+ #----------------------------------------------------------------------------
247
+ # WGAN loss from the paper
248
+ # "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
249
+
250
+
251
+ def G_wgan(G,
252
+ D,
253
+ latents,
254
+ latent_labels=None,
255
+ *args,
256
+ **kwargs):
257
+ fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
258
+ loss = -fake_scores.mean()
259
+ reg = None
260
+ return loss, reg
261
+
262
+
263
+ def D_wgan(G,
264
+ D,
265
+ latents,
266
+ reals,
267
+ latent_labels=None,
268
+ real_labels=None,
269
+ drift_gamma=0.001,
270
+ *args,
271
+ **kwargs):
272
+ assert (latent_labels is None) == (real_labels is None)
273
+ with torch.no_grad():
274
+ fakes = G(latents, labels=latent_labels)
275
+ real_scores = D(reals, labels=real_labels).float()
276
+ fake_scores = D(fakes, labels=latent_labels).float()
277
+ loss = fake_scores.mean() - real_scores.mean()
278
+ if drift_gamma:
279
+ loss += drift_gamma * torch.mean(real_scores ** 2)
280
+ reg = None
281
+ return loss, reg
282
+
283
+
284
+ #----------------------------------------------------------------------------
285
+ # WGAN-GP loss from the paper
286
+ # "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
287
+
288
+
289
+ def D_gp(G,
290
+ D,
291
+ latents,
292
+ reals,
293
+ latent_labels=None,
294
+ real_labels=None,
295
+ gamma=0,
296
+ constraint=1,
297
+ *args,
298
+ **kwargs):
299
+ loss = None
300
+ reg = None
301
+ if gamma:
302
+ assert (latent_labels is None) == (real_labels is None)
303
+ with torch.no_grad():
304
+ fakes = G(latents, labels=latent_labels)
305
+ assert reals.size() == fakes.size()
306
+ if latent_labels:
307
+ assert latent_labels == real_labels
308
+ alpha = torch.empty(reals.size(0)).uniform_()
309
+ alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
310
+ interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
311
+ interp_scores = D(interp, labels=latent_labels)
312
+ reg = _grad_pen(
313
+ input=interp, output=interp_scores, gamma=gamma, retain_graph=False).float()
314
+ return loss, reg
315
+
316
+
317
+ def D_wgan_gp(G,
318
+ D,
319
+ latents,
320
+ reals,
321
+ latent_labels=None,
322
+ real_labels=None,
323
+ gamma=0,
324
+ drift_gamma=0.001,
325
+ constraint=1,
326
+ *args,
327
+ **kwargs):
328
+ assert (latent_labels is None) == (real_labels is None)
329
+ with torch.no_grad():
330
+ fakes = G(latents, labels=latent_labels)
331
+ real_scores = D(reals, labels=real_labels).float()
332
+ fake_scores = D(fakes, labels=latent_labels).float()
333
+ loss = fake_scores.mean() - real_scores.mean()
334
+ if drift_gamma:
335
+ loss += drift_gamma * torch.mean(real_scores ** 2)
336
+ reg = None
337
+ if gamma:
338
+ assert reals.size() == fakes.size()
339
+ if latent_labels:
340
+ assert latent_labels == real_labels
341
+ alpha = torch.empty(reals.size(0)).uniform_()
342
+ alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
343
+ interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
344
+ interp_scores = D(interp, labels=latent_labels)
345
+ reg = _grad_pen(
346
+ input=interp, output=interp_scores, gamma=gamma, retain_graph=True).float()
347
+ return loss, reg
stylegan2/metrics/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import fid
2
+ from . import ppl
stylegan2/metrics/fid.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numbers
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ from .. import models, utils
9
+ from ..external_models import inception
10
+
11
+
12
+ class _TruncatedDataset:
13
+ """
14
+ Truncates a dataset, making only part of it accessible
15
+ by `torch.utils.data.DataLoader`.
16
+ """
17
+
18
+ def __init__(self, dataset, max_len):
19
+ self.dataset = dataset
20
+ self.max_len = max_len
21
+
22
+ def __len__(self):
23
+ return min(len(self.dataset), self.max_len)
24
+
25
+ def __getitem__(self, index):
26
+ return self.dataset[index]
27
+
28
+
29
+ class FID:
30
+ """
31
+ This class evaluates the FID metric of a generator.
32
+ Arguments:
33
+ G (Generator)
34
+ prior_generator (PriorGenerator)
35
+ dataset (indexable)
36
+ device (int, str, torch.device, optional): The device
37
+ to use for calculations. By default, the same device
38
+ is chosen as the parameters in `generator` reside on.
39
+ num_samples (int): Number of samples of reals and fakes
40
+ to gather statistics for which are used for calculating
41
+ the metric. Default value is 50 000.
42
+ fid_model (nn.Module): A model that returns feature maps
43
+ of shape (batch_size, features, *). Default value
44
+ is InceptionV3.
45
+ fid_size (int, optional): Resize any data fed to `fid_model` by scaling
46
+ the data so that its smallest side is the same size as this
47
+ argument.
48
+ truncation_psi (float, optional): Truncation of the generator
49
+ when evaluating.
50
+ truncation_cutoff (int, optional): Cutoff for truncation when
51
+ evaluating.
52
+ reals_batch_size (int, optional): Batch size to use for real
53
+ samples statistics gathering.
54
+ reals_data_workers (int, optional): Number of workers fetching
55
+ the real data samples. Default value is 0.
56
+ verbose (bool): Write progress of gathering statistics for reals
57
+ to stdout. Default value is True.
58
+ """
59
+ def __init__(self,
60
+ G,
61
+ prior_generator,
62
+ dataset,
63
+ device=None,
64
+ num_samples=50000,
65
+ fid_model=None,
66
+ fid_size=None,
67
+ truncation_psi=None,
68
+ truncation_cutoff=None,
69
+ reals_batch_size=None,
70
+ reals_data_workers=0,
71
+ verbose=True):
72
+ device_ids = []
73
+ if isinstance(G, torch.nn.DataParallel):
74
+ device_ids = G.device_ids
75
+ G = utils.unwrap_module(G)
76
+ assert isinstance(G, models.Generator)
77
+ assert isinstance(prior_generator, utils.PriorGenerator)
78
+ if device is None:
79
+ device = next(G.parameters()).device
80
+ else:
81
+ device = torch.device(device)
82
+ assert torch.device(prior_generator.device) == device, \
83
+ 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
84
+ 'is not the same as the specified (or infered from the model)' + \
85
+ 'device ({}) for the PPL evaluation.'.format(device)
86
+ G.eval().to(device)
87
+ if device_ids:
88
+ G = torch.nn.DataParallel(G, device_ids=device_ids)
89
+ self.G = G
90
+ self.prior_generator = prior_generator
91
+ self.device = device
92
+ self.num_samples = num_samples
93
+ self.batch_size = self.prior_generator.batch_size
94
+ if fid_model is None:
95
+ warnings.warn(
96
+ 'Using default fid model metric based on Inception V3. ' + \
97
+ 'This metric will only work on image data where values are in ' + \
98
+ 'the range [-1, 1], please specify another module if you want ' + \
99
+ 'to use other kinds of data formats.'
100
+ )
101
+ fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
102
+ if device_ids:
103
+ fid_model = torch.nn.DataParallel(fid_model, device_ids)
104
+ self.fid_model = fid_model.eval().to(device)
105
+ self.fid_size = fid_size
106
+
107
+ dataset = _TruncatedDataset(dataset, self.num_samples)
108
+ dataloader = torch.utils.data.DataLoader(
109
+ dataset,
110
+ batch_size=reals_batch_size or self.batch_size,
111
+ num_workers=reals_data_workers
112
+ )
113
+ features = []
114
+ self.labels = []
115
+
116
+ if verbose:
117
+ progress = utils.ProgressWriter(
118
+ np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
119
+ progress.write('FID: Gathering statistics for reals...', step=False)
120
+
121
+ for batch in dataloader:
122
+ data = batch
123
+ if isinstance(batch, (tuple, list)):
124
+ data = batch[0]
125
+ if len(batch) > 1:
126
+ self.labels.append(batch[1])
127
+ data = self._scale_for_fid(data).to(self.device)
128
+ with torch.no_grad():
129
+ batch_features = self.fid_model(data)
130
+ batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
131
+ features.append(batch_features.cpu())
132
+ progress.step()
133
+
134
+ if verbose:
135
+ progress.write('FID: Statistics for reals gathered!', step=False)
136
+ progress.close()
137
+
138
+ features = torch.cat(features, dim=0).numpy()
139
+
140
+ self.mu_real = np.mean(features, axis=0)
141
+ self.sigma_real = np.cov(features, rowvar=False)
142
+ self.truncation_psi = truncation_psi
143
+ self.truncation_cutoff = truncation_cutoff
144
+
145
+ def _scale_for_fid(self, data):
146
+ if not self.fid_size:
147
+ return data
148
+ scale_factor = self.fid_size / min(data.size()[2:])
149
+ if scale_factor == 1:
150
+ return data
151
+ mode = 'nearest'
152
+ if scale_factor < 1:
153
+ mode = 'area'
154
+ return F.interpolate(data, scale_factor=scale_factor, mode=mode)
155
+
156
+ def __call__(self, *args, **kwargs):
157
+ return self.evaluate(*args, **kwargs)
158
+
159
+ def evaluate(self, verbose=True):
160
+ """
161
+ Evaluate the FID.
162
+ Arguments:
163
+ verbose (bool): Write progress to stdout.
164
+ Default value is True.
165
+ Returns:
166
+ fid (float): Metric value.
167
+ """
168
+ utils.unwrap_module(self.G).set_truncation(
169
+ truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
170
+ self.G.eval()
171
+ features = []
172
+
173
+ if verbose:
174
+ progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
175
+ progress.write('FID: Gathering statistics for fakes...', step=False)
176
+
177
+ remaining = self.num_samples
178
+ for i in range(0, self.num_samples, self.batch_size):
179
+
180
+ latents, latent_labels = self.prior_generator(
181
+ batch_size=min(self.batch_size, remaining))
182
+ if latent_labels is not None and self.labels:
183
+ latent_labels = self.labels[i].to(self.device)
184
+ length = min(len(latents), len(latent_labels))
185
+ latents, latent_labels = latents[:length], latent_labels[:length]
186
+
187
+ with torch.no_grad():
188
+ fakes = self.G(latents, labels=latent_labels)
189
+
190
+ with torch.no_grad():
191
+ batch_features = self.fid_model(fakes)
192
+ batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
193
+ features.append(batch_features.cpu())
194
+
195
+ remaining -= len(latents)
196
+ progress.step()
197
+
198
+ if verbose:
199
+ progress.write('FID: Statistics for fakes gathered!', step=False)
200
+ progress.close()
201
+
202
+ features = torch.cat(features, dim=0).numpy()
203
+
204
+ mu_fake = np.mean(features, axis=0)
205
+ sigma_fake = np.cov(features, rowvar=False)
206
+
207
+ m = np.square(mu_fake - self.mu_real).sum()
208
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
209
+ dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
210
+ return float(np.real(dist))
stylegan2/metrics/ppl.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numbers
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ from .. import models, utils
8
+ from ..external_models import lpips
9
+
10
+
11
+ class PPL:
12
+ """
13
+ This class evaluates the PPL metric of a generator.
14
+ Arguments:
15
+ G (Generator)
16
+ prior_generator (PriorGenerator)
17
+ device (int, str, torch.device, optional): The device
18
+ to use for calculations. By default, the same device
19
+ is chosen as the parameters in `generator` reside on.
20
+ num_samples (int): Number of samples of reals and fakes
21
+ to gather statistics for which are used for calculating
22
+ the metric. Default value is 50 000.
23
+ epsilon (float): Perturbation value. Default value is 1e-4.
24
+ use_dlatent (bool): Measure PPL against the dlatents instead
25
+ of the latents. Default value is True.
26
+ full_sampling (bool): Measure on a random interpolation between
27
+ two inputs. Default value is False.
28
+ crop (float, list, optional): Crop values that should be in the
29
+ range [0, 1] with 1 representing the entire data length.
30
+ If single value this will be the amount cropped from all
31
+ sides of the data. If a list of same length as number of
32
+ data dimensions, each crop is mirrored to both sides of
33
+ each respective dimension. If the length is 2 * number
34
+ of dimensions the crop values for the start and end of
35
+ a dimension may be different.
36
+ Example 1:
37
+ We have 1d data of length 10. We want to crop 1
38
+ from the start and end of the data. We then need
39
+ to use `crop=0.1` or `crop=[0.1]` or `crop=[0.1, 0.9]`.
40
+ Example 2:
41
+ We have 2d data (images) of size 10, 10 (height, width)
42
+ and we want to use only the top left quarter of the image
43
+ we would use `crop=[0, 0.5, 0, 0.5]`.
44
+ lpips_model (nn.Module): A model that returns feature the distance
45
+ between two inputs. Default value is the LPIPS VGG16 model.
46
+ lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
47
+ the data so that its smallest side is the same size as this
48
+ argument. Only has a default value of 256 if `lpips_model` is unspecified.
49
+ """
50
+ FFHQ_CROP = [1/8 * 3, 1/8 * 7, 1/8 * 2, 1/8 * 6]
51
+
52
+ def __init__(self,
53
+ G,
54
+ prior_generator,
55
+ device=None,
56
+ num_samples=50000,
57
+ epsilon=1e-4,
58
+ use_dlatent=True,
59
+ full_sampling=False,
60
+ crop=None,
61
+ lpips_model=None,
62
+ lpips_size=None):
63
+ device_ids = []
64
+ if isinstance(G, torch.nn.DataParallel):
65
+ device_ids = G.device_ids
66
+ G = utils.unwrap_module(G)
67
+ assert isinstance(G, models.Generator)
68
+ assert isinstance(prior_generator, utils.PriorGenerator)
69
+ if device is None:
70
+ device = next(G.parameters()).device
71
+ else:
72
+ device = torch.device(device)
73
+ assert torch.device(prior_generator.device) == device, \
74
+ 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
75
+ 'is not the same as the specified (or infered from the model)' + \
76
+ 'device ({}) for the PPL evaluation.'.format(device)
77
+ G.eval().to(device)
78
+ self.G_mapping = G.G_mapping
79
+ self.G_synthesis = G.G_synthesis
80
+ if device_ids:
81
+ self.G_mapping = torch.nn.DataParallel(self.G_mapping, device_ids=device_ids)
82
+ self.G_synthesis = torch.nn.DataParallel(self.G_synthesis, device_ids=device_ids)
83
+ self.prior_generator = prior_generator
84
+ self.device = device
85
+ self.num_samples = num_samples
86
+ self.epsilon = epsilon
87
+ self.use_dlatent = use_dlatent
88
+ self.full_sampling = full_sampling
89
+ self.crop = crop
90
+ self.batch_size = self.prior_generator.batch_size
91
+ if lpips_model is None:
92
+ warnings.warn(
93
+ 'Using default LPIPS distance metric based on VGG 16. ' + \
94
+ 'This metric will only work on image data where values are in ' + \
95
+ 'the range [-1, 1], please specify an lpips module if you want ' + \
96
+ 'to use other kinds of data formats.'
97
+ )
98
+ lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
99
+ if device_ids:
100
+ lpips_model = torch.nn.DataParallel(lpips_model, device_ids=device_ids)
101
+ lpips_size = lpips_size or 256
102
+ self.lpips_model = lpips_model.eval().to(device)
103
+ self.lpips_size = lpips_size
104
+
105
+ def _scale_for_lpips(self, data):
106
+ if not self.lpips_size:
107
+ return data
108
+ scale_factor = self.lpips_size / min(data.size()[2:])
109
+ if scale_factor == 1:
110
+ return data
111
+ mode = 'nearest'
112
+ if scale_factor < 1:
113
+ mode = 'area'
114
+ return F.interpolate(data, scale_factor=scale_factor, mode=mode)
115
+
116
+ def crop_data(self, data):
117
+ if not self.crop:
118
+ return data
119
+ dim = data.dim() - 2
120
+ if isinstance(self.crop, numbers.Number):
121
+ self.crop = [self.crop]
122
+ else:
123
+ self.crop = list(self.crop)
124
+ if len(self.crop) == 1:
125
+ self.crop = [self.crop[0], (1 if self.crop[0] < 1 else size) - self.crop[0]] * dim
126
+ if len(self.crop) == dim:
127
+ crop = self.crop
128
+ self.crop = []
129
+ for value in crop:
130
+ self.crop += [value, (1 if value < 1 else size) - value]
131
+ assert len(self.crop) == 2 * dim, 'Crop values has to be ' + \
132
+ 'a single value or a sequence of values of the same ' + \
133
+ 'size as number of dimensions of the data or twice of that.'
134
+ pre_index = [Ellipsis]
135
+ post_index = [slice(None, None, None) for _ in range(dim)]
136
+ for i in range(0, 2 * dim, 2):
137
+ j = i // 2
138
+ size = data.size(2 + j)
139
+ crop_min, crop_max = self.crop[i:i + 2]
140
+ if crop_max < 1:
141
+ crop_min, crop_max = crop_min * size, crop_max * size
142
+ crop_min, crop_max = max(0, int(crop_min)), min(size, int(crop_max))
143
+ dim_index = post_index.copy()
144
+ dim_index[j] = slice(crop_min, crop_max, None)
145
+ data = data[pre_index + dim_index]
146
+ return data
147
+
148
+ def prep_latents(self, latents):
149
+ if self.full_sampling:
150
+ lerp = utils.slerp
151
+ if self.use_dlatent:
152
+ lerp = utils.lerp
153
+ latents_a, latents_b = latents[:self.batch_size], latents[self.batch_size:]
154
+ latents = lerp(
155
+ latents_a,
156
+ latents_b,
157
+ torch.rand(
158
+ latents_a.size()[:-1],
159
+ dtype=latents_a.dtype,
160
+ device=latents_a.device
161
+ ).unsqueeze(-1)
162
+ )
163
+ return torch.cat([latents, latents + self.epsilon], dim=0)
164
+
165
+ def __call__(self, *args, **kwargs):
166
+ return self.evaluate(*args, **kwargs)
167
+
168
+ def evaluate(self, verbose=True):
169
+ """
170
+ Evaluate the PPL.
171
+ Arguments:
172
+ verbose (bool): Write progress to stdout.
173
+ Default value is True.
174
+ Returns:
175
+ ppl (float): Metric value.
176
+ """
177
+ distances = []
178
+ batch_size = self.batch_size
179
+ if self.full_sampling:
180
+ batch_size = 2 * batch_size
181
+
182
+ if verbose:
183
+ progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
184
+ progress.write('PPL: Evaluating metric...', step=False)
185
+
186
+ for _ in range(0, self.num_samples, self.batch_size):
187
+ utils.unwrap_module(self.G_synthesis).static_noise()
188
+
189
+ latents, latent_labels = self.prior_generator(batch_size=batch_size)
190
+ if latent_labels is not None and self.full_sampling:
191
+ # Labels should be the same for the first and second half of latents
192
+ latent_labels = latent_labels.view(2, -1)[0].repeat(2)
193
+
194
+ if self.use_dlatent:
195
+ with torch.no_grad():
196
+ dlatents = self.G_mapping(latents=latents, labels=latent_labels)
197
+ dlatents = self.prep_latents(dlatents)
198
+ else:
199
+ latents = self.prep_latents(latents)
200
+ with torch.no_grad():
201
+ dlatents = self.G_mapping(latents=latents, labels=latent_labels)
202
+
203
+ dlatents = dlatents.unsqueeze(1).repeat(1, len(utils.unwrap_module(self.G_synthesis)), 1)
204
+
205
+ with torch.no_grad():
206
+ output = self.G_synthesis(dlatents)
207
+
208
+ output = self.crop_data(output)
209
+ output = self._scale_for_lpips(output)
210
+
211
+ output_a, output_b = output[:self.batch_size], output[self.batch_size:]
212
+
213
+ with torch.no_grad():
214
+ dist = self.lpips_model(output_a, output_b)
215
+
216
+ distances.append(dist.cpu() * (1 / self.epsilon ** 2))
217
+
218
+ if verbose:
219
+ progress.step()
220
+
221
+ if verbose:
222
+ progress.write('PPL: Evaluated!', step=False)
223
+ progress.close()
224
+
225
+ distances = torch.cat(distances, dim=0).numpy()
226
+ lo = np.percentile(distances, 1, interpolation='lower')
227
+ hi = np.percentile(distances, 99, interpolation='higher')
228
+ filtered_distances = np.extract(np.logical_and(lo <= distances, distances <= hi), distances)
229
+ return float(np.mean(filtered_distances))
stylegan2/models.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections import OrderedDict
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+
7
+ from . import modules, utils
8
+
9
+
10
+ class _BaseModel(nn.Module):
11
+ """
12
+ Adds some base functionality to models that inherit this class.
13
+ """
14
+
15
+ def __init__(self):
16
+ super(_BaseModel, self).__setattr__('kwargs', {})
17
+ super(_BaseModel, self).__setattr__('_defaults', {})
18
+ super(_BaseModel, self).__init__()
19
+
20
+ def _update_kwargs(self, **kwargs):
21
+ """
22
+ Update the current keyword arguments. Overrides any
23
+ default values set.
24
+ Arguments:
25
+ **kwargs: Keyword arguments
26
+ """
27
+ self.kwargs.update(**kwargs)
28
+
29
+ def _update_default_kwargs(self, **defaults):
30
+ """
31
+ Update the default values for keyword arguments.
32
+ Arguments:
33
+ **defaults: Keyword arguments
34
+ """
35
+ self._defaults.update(**defaults)
36
+
37
+ def __getattr__(self, name):
38
+ """
39
+ Try to get the keyword argument for this attribute.
40
+ If no keyword argument of this name exists, try to
41
+ get the attribute directly from this object instead.
42
+ Arguments:
43
+ name (str): Name of keyword argument or attribute.
44
+ Returns:
45
+ value
46
+ """
47
+ try:
48
+ return self.__getattribute__('kwargs')[name]
49
+ except KeyError:
50
+ try:
51
+ return self.__getattribute__('_defaults')[name]
52
+ except KeyError:
53
+ return super(_BaseModel, self).__getattr__(name)
54
+
55
+ def __setattr__(self, name, value):
56
+ """
57
+ Try to set the keyword argument for this attribute.
58
+ If no keyword argument of this name exists, set
59
+ the attribute directly for this object instead.
60
+ Arguments:
61
+ name (str): Name of keyword argument or attribute.
62
+ value
63
+ """
64
+ if name != '__dict__' and (name in self.kwargs or name in self._defaults):
65
+ self.kwargs[name] = value
66
+ else:
67
+ super(_BaseModel, self).__setattr__(name, value)
68
+
69
+ def __delattr__(self, name):
70
+ """
71
+ Try to delete the keyword argument for this attribute.
72
+ If no keyword argument of this name exists, delete
73
+ the attribute of this object instead.
74
+ Arguments:
75
+ name (str): Name of keyword argument or attribute.
76
+ """
77
+ deleted = False
78
+ if name in self.kwargs:
79
+ del self.kwargs[name]
80
+ deleted = True
81
+ if name in self._defaults:
82
+ del self._defaults[name]
83
+ deleted = True
84
+ if not deleted:
85
+ super(_BaseModel, self).__delattr__(name)
86
+
87
+ def clone(self):
88
+ """
89
+ Create a copy of this model.
90
+ Returns:
91
+ model_copy (nn.Module)
92
+ """
93
+ return copy.deepcopy(self)
94
+
95
+ def _get_state_dict(self):
96
+ """
97
+ Delegate function for getting the state dict.
98
+ Should be overridden if state dict has to be
99
+ fetched in abnormal way.
100
+ """
101
+ return self.state_dict()
102
+
103
+ def _set_state_dict(self, state_dict):
104
+ """
105
+ Delegate function for loading the state dict.
106
+ Should be overridden if state dict has to be
107
+ loaded in abnormal way.
108
+ """
109
+ self.load_state_dict(state_dict)
110
+
111
+ def _serialize(self, half=False):
112
+ """
113
+ Turn model arguments and weights into
114
+ a dict that can safely be pickled and unpickled.
115
+ Arguments:
116
+ half (bool): Save weights in half precision.
117
+ Default value is False.
118
+ """
119
+ state_dict = self._get_state_dict()
120
+ for key in state_dict.keys():
121
+ values = state_dict[key].cpu()
122
+ if torch.is_floating_point(values):
123
+ if half:
124
+ values = values.half()
125
+ else:
126
+ values = values.float()
127
+ state_dict[key] = values
128
+ return {
129
+ 'name': self.__class__.__name__,
130
+ 'kwargs': self.kwargs,
131
+ 'state_dict': state_dict
132
+ }
133
+
134
+ @classmethod
135
+ def load(cls, fpath, map_location='cpu'):
136
+ """
137
+ Load a model of this class.
138
+ Arguments:
139
+ fpath (str): File path of saved model.
140
+ map_location (str, int, torch.device): Weights and
141
+ buffers will be loaded into this device.
142
+ Default value is 'cpu'.
143
+ """
144
+ model = load(fpath, map_location=map_location)
145
+ assert isinstance(model, cls), 'Trying to load a `{}` '.format(type(model)) + \
146
+ 'model from {}.load()'.format(cls.__name__)
147
+ return model
148
+
149
+ def save(self, fpath, half=False):
150
+ """
151
+ Save this model.
152
+ Arguments:
153
+ fpath (str): File path of save location.
154
+ half (bool): Save weights in half precision.
155
+ Default value is False.
156
+ """
157
+ torch.save(self._serialize(half=half), fpath)
158
+
159
+
160
+ def _deserialize(state):
161
+ """
162
+ Load a model from its serialized state.
163
+ Arguments:
164
+ state (dict)
165
+ Returns:
166
+ model (nn.Module): Model that inherits `_BaseModel`.
167
+ """
168
+ state = state.copy()
169
+ name = state.pop('name')
170
+ if name not in globals():
171
+ raise NameError('Class {} is not defined.'.format(state['name']))
172
+ kwargs = state.pop('kwargs')
173
+ state_dict = state.pop('state_dict')
174
+ # Assume every other entry in the state is a serialized
175
+ # keyword argument.
176
+ for key in list(state.keys()):
177
+ kwargs[key] = _deserialize(state.pop(key))
178
+ model = globals()[name](**kwargs)
179
+ model._set_state_dict(state_dict)
180
+ return model
181
+
182
+
183
+ def load(fpath, map_location='cpu'):
184
+ """
185
+ Load a model.
186
+ Arguments:
187
+ fpath (str): File path of saved model.
188
+ map_location (str, int, torch.device): Weights and
189
+ buffers will be loaded into this device.
190
+ Default value is 'cpu'.
191
+ Returns:
192
+ model (nn.Module): Model that inherits `_BaseModel`.
193
+ """
194
+ if map_location is not None:
195
+ map_location = torch.device(map_location)
196
+ return _deserialize(torch.load(fpath, map_location=map_location))
197
+
198
+
199
+ def save(model, fpath, half=False):
200
+ """
201
+ Save a model.
202
+ Arguments:
203
+ model (nn.Module): Wrapped or unwrapped module
204
+ that inherits `_BaseModel`.
205
+ fpath (str): File path of save location.
206
+ half (bool): Save weights in half precision.
207
+ Default value is False.
208
+ """
209
+ utils.unwrap_module(model).save(fpath, half=half)
210
+
211
+
212
+ class Generator(_BaseModel):
213
+ """
214
+ A wrapper class for the latent mapping model
215
+ and synthesis (generator) model.
216
+ Keyword Arguments:
217
+ G_mapping (GeneratorMapping)
218
+ G_synthesis (GeneratorSynthesis)
219
+ dlatent_avg_beta (float): The beta value
220
+ of the exponential moving average
221
+ of the dlatents. This statistic
222
+ is used for truncation of dlatents.
223
+ Default value is 0.995
224
+ """
225
+
226
+ def __init__(self, *, G_mapping, G_synthesis, **kwargs):
227
+ super(Generator, self).__init__()
228
+ self._update_default_kwargs(
229
+ dlatent_avg_beta=0.995
230
+ )
231
+ self._update_kwargs(**kwargs)
232
+
233
+ assert isinstance(G_mapping, GeneratorMapping), \
234
+ '`G_mapping` has to be an instance of `model.GeneratorMapping`'
235
+ assert isinstance(G_synthesis, GeneratorSynthesis), \
236
+ '`G_synthesis` has to be an instance of `model.GeneratorSynthesis`'
237
+ self.G_mapping = G_mapping
238
+ self.G_synthesis = G_synthesis
239
+ self.register_buffer('dlatent_avg', torch.zeros(self.G_mapping.latent_size))
240
+ self.set_truncation()
241
+
242
+ @property
243
+ def latent_size(self):
244
+ return self.G_mapping.latent_size
245
+
246
+ @property
247
+ def label_size(self):
248
+ return self.G_mapping.label_size
249
+
250
+ def _get_state_dict(self):
251
+ state_dict = OrderedDict()
252
+ self._save_to_state_dict(destination=state_dict, prefix='', keep_vars=False)
253
+ return state_dict
254
+
255
+ def _set_state_dict(self, state_dict):
256
+ self.load_state_dict(state_dict, strict=False)
257
+
258
+ def _serialize(self, half=False):
259
+ state = super(Generator, self)._serialize(half=half)
260
+ for name in ['G_mapping', 'G_synthesis']:
261
+ state[name] = getattr(self, name)._serialize(half=half)
262
+ return state
263
+
264
+ def set_truncation(self, truncation_psi=None, truncation_cutoff=None):
265
+ """
266
+ Set the truncation of dlatents before they are passed to the
267
+ synthesis model.
268
+ Arguments:
269
+ truncation_psi (float): Beta value of linear interpolation between
270
+ the average dlatent and the current dlatent. 0 -> 100% average,
271
+ 1 -> 0% average.
272
+ truncation_cutoff (int, optional): Truncation is only used up until
273
+ this affine layer index.
274
+ """
275
+ layer_psi = None
276
+ if truncation_psi is not None and truncation_psi != 1 and truncation_cutoff != 0:
277
+ layer_psi = torch.ones(len(self.G_synthesis))
278
+ if truncation_cutoff is None:
279
+ layer_psi *= truncation_psi
280
+ else:
281
+ layer_psi_mask = torch.arange(len(layer_psi)) < truncation_cutoff
282
+ layer_psi[layer_psi_mask] *= truncation_psi
283
+ layer_psi = layer_psi.view(1, -1, 1)
284
+ layer_psi = layer_psi.to(self.dlatent_avg)
285
+ self.register_buffer('layer_psi', layer_psi)
286
+
287
+ def random_noise(self):
288
+ """
289
+ Set noise of synthesis model to be random for every
290
+ input.
291
+ """
292
+ self.G_synthesis.random_noise()
293
+
294
+ def static_noise(self, trainable=False, noise_tensors=None):
295
+ """
296
+ Set up injected noise to be fixed (alternatively trainable).
297
+ Get the fixed noise tensors (or parameters).
298
+ Arguments:
299
+ trainable (bool): Make noise trainable and return
300
+ parameters instead of normal tensors.
301
+ noise_tensors (list, optional): List of tensors to use as static noise.
302
+ Has to be same length as number of noise injection layers.
303
+ Returns:
304
+ noise_tensors (list): List of the noise tensors (or parameters).
305
+ """
306
+ return self.G_synthesis.static_noise(trainable=trainable, noise_tensors=noise_tensors)
307
+
308
+ def __len__(self):
309
+ """
310
+ Get the number of affine (style) layers of the synthesis model.
311
+ """
312
+ return len(self.G_synthesis)
313
+
314
+ def truncate(self, dlatents):
315
+ """
316
+ Truncate the dlatents.
317
+ Arguments:
318
+ dlatents (torch.Tensor)
319
+ Returns:
320
+ truncated_dlatents (torch.Tensor)
321
+ """
322
+ if self.layer_psi is not None:
323
+ dlatents = utils.lerp(self.dlatent_avg, dlatents, self.layer_psi)
324
+ return dlatents
325
+
326
+ def forward(self,
327
+ latents=None,
328
+ labels=None,
329
+ dlatents=None,
330
+ return_dlatents=False,
331
+ mapping_grad=True,
332
+ latent_to_layer_idx=None):
333
+ """
334
+ Synthesize some data from latent inputs. The latents
335
+ can have an extra optional dimension, where latents
336
+ from this dimension will be distributed to the different
337
+ affine layers of the synthesis model. The distribution
338
+ is a index to index mapping if the amount of latents
339
+ is the same as the number of affine layers. Otherwise,
340
+ latents are distributed consecutively for a random
341
+ number of layers before the next latent is used for
342
+ some random amount of following layers. If the size
343
+ of this extra dimension is 1 or it does not exist,
344
+ the same latent is passed to every affine layer.
345
+
346
+ Latents are first mapped to disentangled latents (`dlatents`)
347
+ and are then optionally truncated (if model is in eval mode
348
+ and truncation options have been set.) Set up truncation by
349
+ calling `set_truncation()`.
350
+ Arguments:
351
+ latents (torch.Tensor): The latent values of shape
352
+ (batch_size, N, num_features) where N is an
353
+ optional dimension. This argument is not required
354
+ if `dlatents` is passed.
355
+ labels (optional): A sequence of labels, one for
356
+ each index in the batch dimension of the input.
357
+ dlatents (torch.Tensor, optional): Skip the latent
358
+ mapping model and feed these dlatents straight
359
+ to the synthesis model. The same type of distribution
360
+ to affine layers as is described in this function
361
+ description is also used for dlatents.
362
+ NOTE: Explicitly passing dlatents to this function
363
+ will stop them from being truncated. If required,
364
+ do this manually by calling the `truncate()` function
365
+ of this model.
366
+ return_dlatents (bool): Return not only the synthesized
367
+ data, but also the dlatents. The dlatents tensor
368
+ will also have its `requires_grad` set to True
369
+ before being passed to the synthesis model for
370
+ use with pathlength regularization during training.
371
+ This requires training to be enabled (`thismodel.train()`).
372
+ Default value is False.
373
+ mapping_grad (bool): Let gradients be calculated when passing
374
+ latents through the latent mapping model. Should be
375
+ set to False when only optimising the synthesiser parameters.
376
+ Default value is True.
377
+ latent_to_layer_idx (list, tuple, optional): A manual mapping
378
+ of the latent vectors to the affine layers of this network.
379
+ Each position in this sequence maps the affine layer of the
380
+ same index to an index of the latents. The latents should
381
+ have a shape of (batch_size, N, num_features) and this argument
382
+ should be a list of the same length as number of affine layers
383
+ in this model (can be found by calling len(thismodel)) with values
384
+ in the range [0, N - 1]. Without this argument, latents are distributed
385
+ according to this function description.
386
+ """
387
+ # Keep track of number of latents for each batch index.
388
+ num_latents = 1
389
+
390
+ # Keep track of if dlatent truncation is enabled or disabled.
391
+ truncate = False
392
+
393
+ if dlatents is None:
394
+ # Calculate dlatents
395
+
396
+ # dlatent truncation enabled as dlatents were not explicitly given
397
+ truncate = True
398
+
399
+ assert latents is not None, 'Either the `latents` ' + \
400
+ 'or the `dlatents` argument is required.'
401
+ if labels is not None:
402
+ if not torch.is_tensor(labels):
403
+ labels = torch.tensor(labels, dtype=torch.int64)
404
+
405
+ # If latents are passed with the layer dimension we need
406
+ # to flatten it to shape (N, latent_size) before passing
407
+ # it to the latent mapping model.
408
+ if latents.dim() == 3:
409
+ num_latents = latents.size(1)
410
+ latents = latents.view(-1, latents.size(-1))
411
+ # Labels need to repeated for the extra dimension of latents.
412
+ if labels is not None:
413
+ labels = labels.unsqueeze(1).repeat(1, num_latents).view(-1)
414
+
415
+ # Dont allow this operation to create a computation graph for
416
+ # backprop unless specified. This is useful for pathreg as it
417
+ # only regularizes the parameters of the synthesiser and not
418
+ # to latent mapping model.
419
+ with torch.set_grad_enabled(mapping_grad):
420
+ dlatents = self.G_mapping(latents=latents, labels=labels)
421
+ else:
422
+ if dlatents.dim() == 3:
423
+ num_latents = dlatents.size(1)
424
+
425
+ # Now we expand/repeat the number of latents per batch index until it is
426
+ # the same number as affine layers in our synthesis model.
427
+ dlatents = dlatents.view(-1, num_latents, dlatents.size(-1))
428
+ if num_latents == 1:
429
+ dlatents = dlatents.expand(
430
+ dlatents.size(0), len(self), dlatents.size(2))
431
+ elif num_latents != len(self):
432
+ assert dlatents.size(1) <= len(self), \
433
+ 'More latents ({}) than number '.format(dlatents.size(1)) + \
434
+ 'of generator layers ({}) received.'.format(len(self))
435
+ if not latent_to_layer_idx:
436
+ # Lets randomly distribute the latents to
437
+ # ranges of layers (each latent is assigned
438
+ # to a random number of consecutive layers).
439
+ cutoffs = np.random.choice(
440
+ np.arange(1, len(self)),
441
+ dlatents.size(1) - 1,
442
+ replace=False
443
+ )
444
+ cutoffs = [0] + sorted(cutoffs.tolist()) + [len(self)]
445
+ dlatents = [
446
+ dlatents[:, i].unsqueeze(1).expand(
447
+ -1, cutoffs[i + 1] - cutoffs[i], dlatents.size(2))
448
+ for i in range(dlatents.size(1))
449
+ ]
450
+ dlatents = torch.cat(dlatents, dim=1)
451
+ else:
452
+ # Assign latents as specified by argument
453
+ assert len(latent_to_layer_idx) == len(self), \
454
+ 'The latent index to layer index mapping does ' + \
455
+ 'not have the same number of elements ' + \
456
+ '({}) as the number of '.format(len(latent_to_layer_idx)) + \
457
+ 'generator layers ({})'.format(len(self))
458
+ dlatents = dlatents[:, latent_to_layer_idx]
459
+
460
+ # Update moving average of dlatents when training
461
+ if self.training and self.dlatent_avg_beta != 1:
462
+ with torch.no_grad():
463
+ batch_dlatent_avg = dlatents[:, 0].mean(dim=0)
464
+ self.dlatent_avg = utils.lerp(
465
+ batch_dlatent_avg, self.dlatent_avg, self.dlatent_avg_beta)
466
+
467
+ # Truncation is only applied when dlatents are not explicitly
468
+ # given and the model is in evaluation mode.
469
+ if truncate and not self.training:
470
+ dlatents = self.truncate(dlatents)
471
+
472
+ # One of the reasons we might want to return the dlatents is for
473
+ # pathreg, in which case the dlatents need to require gradients
474
+ # before being passed to the synthesiser. This should only be
475
+ # the case when the model is in training mode.
476
+ if return_dlatents and self.training:
477
+ dlatents.requires_grad_(True)
478
+
479
+ synth = self.G_synthesis(latents=dlatents)
480
+ if return_dlatents:
481
+ return synth, dlatents
482
+ return synth
483
+
484
+
485
+ # Base class for the parameterized models. This is used as parent
486
+ # class to reduce duplicate code and documentation for shared arguments.
487
+ class _BaseParameterizedModel(_BaseModel):
488
+ """
489
+ activation (str, callable, nn.Module): The non-linear
490
+ activation function to use.
491
+ Default value is leaky relu with a slope of 0.2.
492
+ lr_mul (float): The learning rate multiplier for this
493
+ model. When loading weights of previously trained
494
+ networks, this value has to be the same as when
495
+ the network was trained for the outputs to not
496
+ change (as this is used to scale the weights).
497
+ Default value depends on model type and can
498
+ be found in the original paper for StyleGAN.
499
+ weight_scale (bool): Use weight scaling for
500
+ equalized learning rate. Default value
501
+ is True.
502
+ eps (float): Epsilon value added for numerical stability.
503
+ Default value is 1e-8."""
504
+
505
+ def __init__(self, **kwargs):
506
+ super(_BaseParameterizedModel, self).__init__()
507
+ self._update_default_kwargs(
508
+ activation='lrelu:0.2',
509
+ lr_mul=1,
510
+ weight_scale=True,
511
+ eps=1e-8
512
+ )
513
+ self._update_kwargs(**kwargs)
514
+
515
+
516
+ class GeneratorMapping(_BaseParameterizedModel):
517
+ """
518
+ Latent mapping model, handles the
519
+ transformation of latents into disentangled
520
+ latents.
521
+ Keyword Arguments:
522
+ latent_size (int): The size of the latent vectors.
523
+ This will also be the size of the disentangled
524
+ latent vectors.
525
+ Default value is 512.
526
+ label_size (int, optional): The number of different
527
+ possible labels. Use for label conditioning of
528
+ the GAN. Unused by default.
529
+ out_size (int, optional): The size of the disentangled
530
+ latents output by this model. If not specified,
531
+ the outputs will have the same size as the input
532
+ latents.
533
+ num_layers (int): Number of dense layers in this
534
+ model. Default value is 8.
535
+ hidden (int, optional): Number of hidden features of layers.
536
+ If unspecified, this is the same size as the latents.
537
+ normalize_input (bool): Normalize the input of this
538
+ model. Default value is True."""
539
+ __doc__ += _BaseParameterizedModel.__doc__
540
+
541
+ def __init__(self, **kwargs):
542
+ super(GeneratorMapping, self).__init__()
543
+ self._update_default_kwargs(
544
+ latent_size=512,
545
+ label_size=0,
546
+ out_size=None,
547
+ num_layers=8,
548
+ hidden=None,
549
+ normalize_input=True,
550
+ lr_mul=0.01,
551
+ )
552
+ self._update_kwargs(**kwargs)
553
+
554
+ # Find in and out features of first dense layer
555
+ in_features = self.latent_size
556
+ out_features = self.hidden or self.latent_size
557
+
558
+ # Each class label has its own embedded vector representation.
559
+ self.embedding = None
560
+ if self.label_size:
561
+ self.embedding = nn.Embedding(self.label_size, self.latent_size)
562
+ # The input is now the latents concatenated with
563
+ # the label embeddings.
564
+ in_features += self.latent_size
565
+ dense_layers = []
566
+ for i in range(self.num_layers):
567
+ if i == self.num_layers - 1:
568
+ # Set out features for last dense layer
569
+ out_features = self.out_size or self.latent_size
570
+ dense_layers.append(
571
+ modules.BiasActivationWrapper(
572
+ layer=modules.DenseLayer(
573
+ in_features=in_features,
574
+ out_features=out_features,
575
+ lr_mul=self.lr_mul,
576
+ weight_scale=self.weight_scale,
577
+ gain=1
578
+ ),
579
+ features=out_features,
580
+ use_bias=True,
581
+ activation=self.activation,
582
+ bias_init=0,
583
+ lr_mul=self.lr_mul,
584
+ weight_scale=self.weight_scale
585
+ )
586
+ )
587
+ in_features = out_features
588
+ self.main = nn.Sequential(*dense_layers)
589
+
590
+ def forward(self, latents, labels=None):
591
+ """
592
+ Get the disentangled latents from the input latents
593
+ and optional labels.
594
+ Arguments:
595
+ latents (torch.Tensor): Tensor of shape (batch_size, latent_size).
596
+ labels (torch.Tensor, optional): Labels for conditioning of latents
597
+ if there are any.
598
+ Returns:
599
+ dlatents (torch.Tensor): Disentangled latents of same shape as
600
+ `latents` argument.
601
+ """
602
+ assert latents.dim() == 2 and latents.size(-1) == self.latent_size, \
603
+ 'Incorrect input shape. Should be ' + \
604
+ '(batch_size, {}) '.format(self.latent_size) + \
605
+ 'but received {}'.format(tuple(latents.size()))
606
+ x = latents
607
+ if labels is not None:
608
+ assert self.embedding is not None, \
609
+ 'No embedding layer found, please ' + \
610
+ 'specify the number of possible labels ' + \
611
+ 'in the constructor of this class if ' + \
612
+ 'using labels.'
613
+ assert len(labels) == len(latents), \
614
+ 'Received different number of labels ' + \
615
+ '({}) and latents ({}).'.format(len(labels), len(latents))
616
+ if not torch.is_tensor(labels):
617
+ labels = torch.tensor(labels, dtype=torch.int64)
618
+ assert labels.dtype == torch.int64, \
619
+ 'Labels should be integer values ' + \
620
+ 'of dtype torch.in64 (long)'
621
+ y = self.embedding(labels)
622
+ x = torch.cat([x, y], dim=-1)
623
+ else:
624
+ assert self.embedding is None, 'Missing input labels.'
625
+ if self.normalize_input:
626
+ x = x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
627
+ return self.main(x)
628
+
629
+
630
+ # Base class for the synthesising and discriminating models. This is used as parent
631
+ # class to reduce duplicate code and documentation for shared arguments.
632
+ class _BaseAdverserialModel(_BaseParameterizedModel):
633
+ """
634
+ data_channels (int): Number of channels of the data.
635
+ Default value is 3.
636
+ base_shape (list, tuple): This is the shape of the feature
637
+ activations when it is most compact and still has the
638
+ same number of dims as the data. This is one of the
639
+ arguments that controls what shape the data will be.
640
+ the value of each size in the shape is going to double
641
+ in size for number of `channels` - 1.
642
+ Example:
643
+ `data_channels=3`
644
+ `base_shape=(4, 2)`
645
+ and 9 `channels` in total will give us a shape of
646
+ (3, 4 * 2^(9 - 1), 2 * 2^(9 - 1)) which is the
647
+ same as (3, 1024, 512).
648
+ Default value is (4, 4).
649
+ channels (int, list, optional): The channels of each block
650
+ of layers. If int, this many channel values will be
651
+ created with sensible default values optimal for image
652
+ synthesis. If list, the number of blocks in this model
653
+ will be the same as the number of channels in the list.
654
+ Default value is the int value 9 which will create the
655
+ following channels: [32, 32, 64, 128, 256, 512, 512, 512, 512].
656
+ These are the channel values used in the stylegan2 paper for
657
+ their FFHQ-trained face generation network.
658
+ If channels is given as a list it should be in the order:
659
+ Generator: last layer -> first layer
660
+ Discriminator: first layer -> last layer
661
+ resnet (bool): Use resnet connections.
662
+ Defaults:
663
+ Generator: False
664
+ Discriminator: True
665
+ skip (bool): Use skip connections for data.
666
+ Defaults:
667
+ Generator: True
668
+ Discriminator: False
669
+ fused_resample (bool): Fuse any up- or downsampling that
670
+ is paired with a convolutional layer into a strided
671
+ convolution (transposed if upsampling was used).
672
+ Default value is True.
673
+ conv_resample_mode (str): The resample mode of up- or
674
+ downsampling layers. If `fused_resample=True` only
675
+ 'FIR' and 'none' can be used. Else, 'FIR' or anything
676
+ that can be passed to torch.nn.functional.interpolate
677
+ is a valid mode (and 'max' but only for downsampling
678
+ operations). Default value is 'FIR'.
679
+ conv_filter (int, list): The filter to use if
680
+ `conv_resample_mode='FIR'`. If int, a low
681
+ pass filter of this size will be used. If list,
682
+ the filter is explicitly specified. If the filter
683
+ is of a single dimension it will be expanded to
684
+ the number of dimensions of the data. Default
685
+ value is a low pass filter of [1, 3, 3, 1].
686
+ skip_resample_mode (str): If `skip=True`, this
687
+ mode is used for the resamplings of skip
688
+ connections of different sizes. Same possible
689
+ values as `conv_filter` (except 'none', which
690
+ can not be used). Default value is 'FIR'.
691
+ skip_filter (int, list): Same description as
692
+ `conv_filter` but for skip connections.
693
+ Only used if `skip_resample_mode='FIR'` and
694
+ `skip=True`. Default value is a low pass
695
+ filter of [1, 3, 3, 1].
696
+ kernel_size (int): The size of the convolutional kernels.
697
+ Default value is 3.
698
+ conv_pad_mode (str): The padding mode for convolutional
699
+ layers. Has to be one of 'constant', 'reflect',
700
+ 'replicate' or 'circular'. Default value is
701
+ 'constant'.
702
+ conv_pad_constant (float): The value to use for conv
703
+ padding if `conv_pad_mode='constant'`. Default
704
+ value is 0.
705
+ filter_pad_mode (str): The padding mode for FIR
706
+ filters. Same possible values as `conv_pad_mode`.
707
+ Default value is 'constant'.
708
+ filter_pad_constant (float): The value to use for FIR
709
+ padding if `filter_pad_mode='constant'`. Default
710
+ value is 0.
711
+ pad_once (bool): If FIR filter is used in conjunction with a
712
+ conv layer, do all the padding for both convolution and
713
+ FIR in the FIR layer instead of once per layer.
714
+ Default value is True.
715
+ conv_block_size (int): The number of conv layers in
716
+ each conv block. Default value is 2."""
717
+ __doc__ += _BaseParameterizedModel.__doc__
718
+
719
+ def __init__(self, **kwargs):
720
+ super(_BaseAdverserialModel, self).__init__()
721
+ self._update_default_kwargs(
722
+ data_channels=3,
723
+ base_shape=(4, 4),
724
+ channels=9,
725
+ resnet=False,
726
+ skip=False,
727
+ fused_resample=True,
728
+ conv_resample_mode='FIR',
729
+ conv_filter=[1, 3, 3, 1],
730
+ skip_resample_mode='FIR',
731
+ skip_filter=[1, 3, 3, 1],
732
+ kernel_size=3,
733
+ conv_pad_mode='constant',
734
+ conv_pad_constant=0,
735
+ filter_pad_mode='constant',
736
+ filter_pad_constant=0,
737
+ pad_once=True,
738
+ conv_block_size=2,
739
+ )
740
+ self._update_kwargs(**kwargs)
741
+
742
+ self.dim = len(self.base_shape)
743
+ assert 1 <= self.dim <= 3, '`base_shape` can only have 1, 2 or 3 dimensions.'
744
+ if isinstance(self.channels, int):
745
+ # Create the specified number of channel values with sensible
746
+ # sizes (these values do well for image synthesis).
747
+ num_channels = self.channels
748
+ self.channels = [min(32 * 2 ** i, 512) for i in range(min(8, num_channels))]
749
+ if len(self.channels) < num_channels:
750
+ self.channels = [32] * (num_channels - len(self.channels)) + self.channels
751
+
752
+
753
+ class GeneratorSynthesis(_BaseAdverserialModel):
754
+ """
755
+ The synthesis model that takes latents and synthesises
756
+ some data.
757
+ Keyword Arguments:
758
+ latent_size (int): The size of the latent vectors.
759
+ This will also be the size of the disentangled
760
+ latent vectors.
761
+ Default value is 512.
762
+ demodulate (bool): Normalize feature outputs from conv
763
+ layers. Default value is True.
764
+ modulate_data_out (bool): Apply style to the data output
765
+ layers. These layers are projections of the feature
766
+ maps into the space of the data. Default value is True.
767
+ noise (bool): Add noise after each conv style layer.
768
+ Default value is True."""
769
+ __doc__ += _BaseAdverserialModel.__doc__
770
+
771
+ def __init__(self, **kwargs):
772
+ super(GeneratorSynthesis, self).__init__()
773
+ self._update_default_kwargs(
774
+ latent_size=512,
775
+ demodulate=True,
776
+ modulate_data_out=True,
777
+ noise=True,
778
+ resnet=False,
779
+ skip=True
780
+ )
781
+ self._update_kwargs(**kwargs)
782
+
783
+ # The constant input of the model has no activations
784
+ # normalization, it is just passed straight to the first
785
+ # layer of the model.
786
+ self.const = torch.nn.Parameter(
787
+ torch.empty(self.channels[-1], *self.base_shape).normal_()
788
+ )
789
+ conv_block_kwargs = dict(
790
+ latent_size=self.latent_size,
791
+ demodulate=self.demodulate,
792
+ resnet=self.resnet,
793
+ up=True,
794
+ num_layers=self.conv_block_size,
795
+ filter=self.conv_filter,
796
+ activation=self.activation,
797
+ mode=self.conv_resample_mode,
798
+ fused=self.fused_resample,
799
+ kernel_size=self.kernel_size,
800
+ pad_mode=self.conv_pad_mode,
801
+ pad_constant=self.conv_pad_constant,
802
+ filter_pad_mode=self.filter_pad_mode,
803
+ filter_pad_constant=self.filter_pad_constant,
804
+ pad_once=self.pad_once,
805
+ noise=self.noise,
806
+ lr_mul=self.lr_mul,
807
+ weight_scale=self.weight_scale,
808
+ gain=1,
809
+ dim=self.dim,
810
+ eps=self.eps
811
+ )
812
+ self.conv_blocks = nn.ModuleList()
813
+
814
+ # The first convolutional layer is slightly different
815
+ # from the following convolutional blocks but can still
816
+ # be represented as a convolutional block if we change
817
+ # some of its arguments.
818
+ self.conv_blocks.append(
819
+ modules.GeneratorConvBlock(
820
+ **{
821
+ **conv_block_kwargs,
822
+ 'in_channels': self.channels[-1],
823
+ 'out_channels': self.channels[-1],
824
+ 'resnet': False,
825
+ 'up': False,
826
+ 'num_layers': 1
827
+ }
828
+ )
829
+ )
830
+
831
+ # The rest of the convolutional blocks all look the same
832
+ # except for number of input and output channels
833
+ for i in range(1, len(self.channels)):
834
+ self.conv_blocks.append(
835
+ modules.GeneratorConvBlock(
836
+ in_channels=self.channels[-i],
837
+ out_channels=self.channels[-i - 1],
838
+ **conv_block_kwargs
839
+ )
840
+ )
841
+
842
+ # If not using the skip architecture, only one
843
+ # layer will project the feature maps into
844
+ # the space of the data (from the activations of
845
+ # the last convolutional block). If using the skip
846
+ # architecture, every block will have its
847
+ # own projection layer instead.
848
+ self.to_data_layers = nn.ModuleList()
849
+ for i in range(1, len(self.channels) + 1):
850
+ to_data = None
851
+ if i == len(self.channels) or self.skip:
852
+ to_data = modules.BiasActivationWrapper(
853
+ layer=modules.ConvLayer(
854
+ **{
855
+ **conv_block_kwargs,
856
+ 'in_channels': self.channels[-i],
857
+ 'out_channels': self.data_channels,
858
+ 'modulate': self.modulate_data_out,
859
+ 'demodulate': False,
860
+ 'kernel_size': 1
861
+ }
862
+ ),
863
+ **{
864
+ **conv_block_kwargs,
865
+ 'features': self.data_channels,
866
+ 'use_bias': True,
867
+ 'activation': 'linear',
868
+ 'bias_init': 0
869
+ }
870
+ )
871
+ self.to_data_layers.append(to_data)
872
+
873
+ # When the skip architecture is used we need to
874
+ # upsample data outputs of previous convolutional
875
+ # blocks so that it can be added to the data output
876
+ # of the current convolutional block.
877
+ self.upsample = None
878
+ if self.skip:
879
+ self.upsample = modules.Upsample(
880
+ mode=self.skip_resample_mode,
881
+ filter=self.skip_filter,
882
+ filter_pad_mode=self.filter_pad_mode,
883
+ filter_pad_constant=self.filter_pad_constant,
884
+ gain=1,
885
+ dim=self.dim
886
+ )
887
+
888
+ # Calculate the number of latents required
889
+ # in the input.
890
+ self._num_latents = 1 + self.conv_block_size * (len(self.channels) - 1)
891
+ # Only the final data output layer uses
892
+ # its own latent input when being modulated.
893
+ # The other data output layers recycles latents
894
+ # from the next convolutional block.
895
+ if self.modulate_data_out:
896
+ self._num_latents += 1
897
+
898
+ def __len__(self):
899
+ """
900
+ Get the number of affine (style) layers of this model.
901
+ """
902
+ return self._num_latents
903
+
904
+ def random_noise(self):
905
+ """
906
+ Set injected noise to be random for each new input.
907
+ """
908
+ for module in self.modules():
909
+ if isinstance(module, modules.NoiseInjectionWrapper):
910
+ module.random_noise()
911
+
912
+ def static_noise(self, trainable=False, noise_tensors=None):
913
+ """
914
+ Set up injected noise to be fixed (alternatively trainable).
915
+ Get the fixed noise tensors (or parameters).
916
+ Arguments:
917
+ trainable (bool): Make noise trainable and return
918
+ parameters instead of normal tensors.
919
+ noise_tensors (list, optional): List of tensors to use as static noise.
920
+ Has to be same length as number of noise injection layers.
921
+ Returns:
922
+ noise_tensors (list): List of the noise tensors (or parameters).
923
+ """
924
+ rtn_tensors = []
925
+
926
+ if not self.noise:
927
+ return rtn_tensors
928
+
929
+ for module in self.modules():
930
+ if isinstance(module, modules.NoiseInjectionWrapper):
931
+ has_noise_shape = module.has_noise_shape()
932
+ device = module.weight.device
933
+ dtype = module.weight.dtype
934
+ break
935
+
936
+ # If noise layers dont have the shape that the noise should be
937
+ # we first need to pass some data through the network once for
938
+ # these layers to record the shape. To create noise tensors
939
+ # we need to know what size they should be.
940
+ if not has_noise_shape:
941
+ with torch.no_grad():
942
+ self(torch.zeros(
943
+ 1, len(self), self.latent_size, device=device, dtype=dtype))
944
+
945
+ i = 0
946
+ for block in self.conv_blocks:
947
+ for layer in block.conv_block:
948
+ for module in layer.modules():
949
+ if isinstance(module, modules.NoiseInjectionWrapper):
950
+ noise_tensor = None
951
+ if noise_tensors is not None:
952
+ if i < len(noise_tensors):
953
+ noise_tensor = noise_tensors[i]
954
+ i += 1
955
+ else:
956
+ rtn_tensors.append(None)
957
+ continue
958
+ rtn_tensors.append(
959
+ module.static_noise(trainable=trainable, noise_tensor=noise_tensor))
960
+
961
+ if noise_tensors is not None:
962
+ assert len(rtn_tensors) == len(noise_tensors), \
963
+ 'Got a list of {} '.format(len(noise_tensors)) + \
964
+ 'noise tensors but there are ' + \
965
+ '{} noise layers in this model'.format(len(rtn_tensors))
966
+
967
+ return rtn_tensors
968
+
969
+ def forward(self, latents):
970
+ """
971
+ Synthesise some data from input latents.
972
+ Arguments:
973
+ latents (torch.Tensor): Latent vectors of shape
974
+ (batch_size, num_affine_layers, latent_size)
975
+ where num_affine_layers is the value returned
976
+ by __len__() of this class.
977
+ Returns:
978
+ synthesised (torch.Tensor): Synthesised data.
979
+ """
980
+ assert latents.dim() == 3 and latents.size(1) == len(self), \
981
+ 'Input mismatch, expected latents of shape ' + \
982
+ '(batch_size, {}, latent_size) '.format(len(self)) + \
983
+ 'but got {}.'.format(tuple(latents.size()))
984
+ # Declare our feature activations variable
985
+ # and give it the value of our const parameter with
986
+ # an added batch dimension.
987
+ x = self.const.unsqueeze(0)
988
+ # Declare our data (output) variable
989
+ y = None
990
+ # Start counting style layers used. This is used for specifying
991
+ # which latents should be passed to the current block in the loop.
992
+ layer_idx = 0
993
+ for block, to_data in zip(self.conv_blocks, self.to_data_layers):
994
+ # Get the latents for the style layers in this block.
995
+ block_latents = latents[:, layer_idx:layer_idx + len(block)]
996
+
997
+ x = block(input=x, latents=block_latents)
998
+
999
+ layer_idx += len(block)
1000
+
1001
+ # Upsample the data output of the previous block to fit
1002
+ # the data output size of this block so that they can
1003
+ # be added together. Only performed for 'skip' architectures.
1004
+ if self.upsample is not None and layer_idx < len(self):
1005
+ if y is not None:
1006
+ y = self.upsample(y)
1007
+
1008
+ # Combine the data output of this block with any previous
1009
+ # blocks outputs if using 'skip' architecture, else only
1010
+ # perform this operation for the very last block outputs.
1011
+ if to_data is not None:
1012
+ t = to_data(input=x, latent=latents[:, layer_idx])
1013
+ y = t if y is None else y + t
1014
+ return y
1015
+
1016
+
1017
+ class Discriminator(_BaseAdverserialModel):
1018
+ """
1019
+ The discriminator scores data inputs.
1020
+ Keyword Arguments:
1021
+ label_size (int, optional): The number of different
1022
+ possible labels. Use for label conditioning of
1023
+ the GAN. The discriminator will calculate scores
1024
+ for each possible label and only returns the score
1025
+ from the label passed with the input data. If no
1026
+ labels are used, only one score is calculated.
1027
+ Disabled by default.
1028
+ mbstd_group_size (int): Group size for minibatch std
1029
+ before the final conv layer. A value of 0 indicates
1030
+ not to use minibatch std, and a value of -1 indicates
1031
+ that the group should be over the entire batch.
1032
+ This is used for increasing variety of the outputs of
1033
+ the generator. Default value is 4.
1034
+ NOTE: Scores for the same data may vary depending
1035
+ on batch size when using a value of -1.
1036
+ NOTE: If a value > 0 is given, every input batch
1037
+ must have a size evenly divisible by this value.
1038
+ dense_hidden (int, optional): The number of hidden features
1039
+ of the first dense layer. By default, this is the same as
1040
+ the number of channels in the final conv layer."""
1041
+ __doc__ += _BaseAdverserialModel.__doc__
1042
+
1043
+ def __init__(self, **kwargs):
1044
+ super(Discriminator, self).__init__()
1045
+ self._update_default_kwargs(
1046
+ label_size=0,
1047
+ mbstd_group_size=4,
1048
+ dense_hidden=None,
1049
+ resnet=True,
1050
+ skip=False
1051
+ )
1052
+ self._update_kwargs(**kwargs)
1053
+
1054
+ conv_block_kwargs = dict(
1055
+ resnet=self.resnet,
1056
+ down=True,
1057
+ num_layers=self.conv_block_size,
1058
+ filter=self.conv_filter,
1059
+ activation=self.activation,
1060
+ mode=self.conv_resample_mode,
1061
+ fused=self.fused_resample,
1062
+ kernel_size=self.kernel_size,
1063
+ pad_mode=self.conv_pad_mode,
1064
+ pad_constant=self.conv_pad_constant,
1065
+ filter_pad_mode=self.filter_pad_mode,
1066
+ filter_pad_constant=self.filter_pad_constant,
1067
+ pad_once=self.pad_once,
1068
+ noise=False,
1069
+ lr_mul=self.lr_mul,
1070
+ weight_scale=self.weight_scale,
1071
+ gain=1,
1072
+ dim=self.dim,
1073
+ eps=self.eps
1074
+ )
1075
+ self.conv_blocks = nn.ModuleList()
1076
+
1077
+ # All but the last of the convolutional blocks look the same
1078
+ # except for number of input and output channels
1079
+ for i in range(len(self.channels) - 1):
1080
+ self.conv_blocks.append(
1081
+ modules.DiscriminatorConvBlock(
1082
+ in_channels=self.channels[i],
1083
+ out_channels=self.channels[i + 1],
1084
+ **conv_block_kwargs
1085
+ )
1086
+ )
1087
+
1088
+ # The final convolutional layer is slightly different
1089
+ # from the previous convolutional blocks but can still
1090
+ # be represented as a convolutional block if we change
1091
+ # some of its arguments and optionally add a minibatch
1092
+ # std layer before it.
1093
+ final_conv_block = []
1094
+ if self.mbstd_group_size:
1095
+ final_conv_block.append(
1096
+ modules.MinibatchStd(
1097
+ group_size=self.mbstd_group_size,
1098
+ eps=self.eps
1099
+ )
1100
+ )
1101
+ final_conv_block.append(
1102
+ modules.DiscriminatorConvBlock(
1103
+ **{
1104
+ **conv_block_kwargs,
1105
+ 'in_channels': self.channels[-1] + (1 if self.mbstd_group_size else 0),
1106
+ 'out_channels': self.channels[-1],
1107
+ 'resnet': False,
1108
+ 'down': False,
1109
+ 'num_layers': 1
1110
+ },
1111
+ )
1112
+ )
1113
+ self.conv_blocks.append(nn.Sequential(*final_conv_block))
1114
+
1115
+ # If not using the skip architecture, only one
1116
+ # layer will project the data into feature maps.
1117
+ # This would be performed only for the input data at
1118
+ # the first block.
1119
+ # If using the skip architecture, every block will
1120
+ # have its own projection layer instead.
1121
+ self.from_data_layers = nn.ModuleList()
1122
+ for i in range(len(self.channels)):
1123
+ from_data = None
1124
+ if i == 0 or self.skip:
1125
+ from_data = modules.BiasActivationWrapper(
1126
+ layer=modules.ConvLayer(
1127
+ **{
1128
+ **conv_block_kwargs,
1129
+ 'in_channels': self.data_channels,
1130
+ 'out_channels': self.channels[i],
1131
+ 'modulate': False,
1132
+ 'demodulate': False,
1133
+ 'kernel_size': 1
1134
+ }
1135
+ ),
1136
+ **{
1137
+ **conv_block_kwargs,
1138
+ 'features': self.channels[i],
1139
+ 'use_bias': True,
1140
+ 'activation': self.activation,
1141
+ 'bias_init': 0
1142
+ }
1143
+ )
1144
+ self.from_data_layers.append(from_data)
1145
+
1146
+ # When the skip architecture is used we need to
1147
+ # downsample the data input so that it has the same
1148
+ # size as the feature maps of each block so that it
1149
+ # can be projected and added to these feature maps.
1150
+ self.downsample = None
1151
+ if self.skip:
1152
+ self.downsample = modules.Downsample(
1153
+ mode=self.skip_resample_mode,
1154
+ filter=self.skip_filter,
1155
+ filter_pad_mode=self.filter_pad_mode,
1156
+ filter_pad_constant=self.filter_pad_constant,
1157
+ gain=1,
1158
+ dim=self.dim
1159
+ )
1160
+
1161
+ # The final layers are two dense layers that maps
1162
+ # the features into score logits. If labels are
1163
+ # used, we instead output one score for each possible
1164
+ # class of the labels and then return the score for the
1165
+ # labeled class.
1166
+ dense_layers = []
1167
+ in_features = self.channels[-1] * np.prod(self.base_shape)
1168
+ out_features = self.dense_hidden or self.channels[-1]
1169
+ activation = self.activation
1170
+ for _ in range(2):
1171
+ dense_layers.append(
1172
+ modules.BiasActivationWrapper(
1173
+ layer=modules.DenseLayer(
1174
+ in_features=in_features,
1175
+ out_features=out_features,
1176
+ lr_mul=self.lr_mul,
1177
+ weight_scale=self.weight_scale,
1178
+ gain=1,
1179
+ ),
1180
+ features=out_features,
1181
+ activation=activation,
1182
+ use_bias=True,
1183
+ bias_init=0,
1184
+ lr_mul=self.lr_mul,
1185
+ weight_scale=self.weight_scale
1186
+ )
1187
+ )
1188
+ in_features = out_features
1189
+ out_features = max(1, self.label_size)
1190
+ activation = 'linear'
1191
+ self.dense = nn.Sequential(*dense_layers)
1192
+
1193
+ def forward(self, input, labels=None):
1194
+ """
1195
+ Takes some data and optionally its labels and
1196
+ produces one score logit per data input.
1197
+ Arguments:
1198
+ input (torch.Tensor)
1199
+ labels (torch.Tensor, list, optional)
1200
+ Returns:
1201
+ score_logits (torch.Tensor)
1202
+ """
1203
+ # Declare our feature activations variable.
1204
+ x = None
1205
+ # Declare our data (input) variable
1206
+ y = input
1207
+ for i, (block, from_data) in enumerate(zip(self.conv_blocks, self.from_data_layers)):
1208
+ # Combine the data input of this block with any previous
1209
+ # block output if using 'skip' architecture, else only
1210
+ # perform this operation as a way to create inputs for
1211
+ # the first block.
1212
+ if from_data is not None:
1213
+ t = from_data(y)
1214
+ x = t if x is None else x + t
1215
+
1216
+ x = block(input=x)
1217
+
1218
+ # Downsample the data input of this block to fit
1219
+ # the feature size of the output of this block so that they can
1220
+ # be added together. Only performed for 'skip' architectures.
1221
+ if self.downsample is not None and i != len(self.conv_blocks) - 1:
1222
+ y = self.downsample(y)
1223
+ # Calculate scores
1224
+ x = x.view(x.size(0), -1)
1225
+ x = self.dense(x)
1226
+ if labels is not None:
1227
+ # Use advanced indexing to fetch only the score of the
1228
+ # class labels.
1229
+ x = x[torch.arange(x.size(0)), labels].unsqueeze(-1)
1230
+ return x
stylegan2/modules.py ADDED
@@ -0,0 +1,1601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def get_activation(activation):
8
+ """
9
+ Get the module for a specific activation function and its gain if
10
+ it can be calculated.
11
+ Arguments:
12
+ activation (str, callable, nn.Module): String representing the activation.
13
+ Returns:
14
+ activation_module (torch.nn.Module): The module representing
15
+ the activation function.
16
+ gain (float): The gain value. Defaults to 1 if it can not be calculated.
17
+ """
18
+ if isinstance(activation, nn.Module) or callable(activation):
19
+ return activation, 1.
20
+ if isinstance(activation, str):
21
+ activation = activation.lower()
22
+ if activation in [None, 'linear']:
23
+ return nn.Identity(), 1.
24
+ lrelu_strings = ('leaky', 'leakyrely', 'leaky_relu', 'leaky relu', 'lrelu')
25
+ if activation.startswith(lrelu_strings):
26
+ for l_s in lrelu_strings:
27
+ activation = activation.replace(l_s, '')
28
+ slope = ''.join(
29
+ char for char in activation if char.isdigit() or char == '.')
30
+ slope = float(slope) if slope else 0.01
31
+ return nn.LeakyReLU(slope), np.sqrt(2) # close enough to true gain
32
+ elif activation.startswith('swish'):
33
+ return Swish(affine=activation != 'swish'), np.sqrt(2)
34
+ elif activation in ['relu']:
35
+ return nn.ReLU(), np.sqrt(2)
36
+ elif activation in ['elu']:
37
+ return nn.ELU(), 1.
38
+ elif activation in ['prelu']:
39
+ return nn.PReLU(), np.sqrt(2)
40
+ elif activation in ['rrelu', 'randomrelu']:
41
+ return nn.RReLU(), np.sqrt(2)
42
+ elif activation in ['selu']:
43
+ return nn.SELU(), 1.
44
+ elif activation in ['softplus']:
45
+ return nn.Softplus(), 1
46
+ elif activation in ['softsign']:
47
+ return nn.Softsign(), 1 # unsure about this gain
48
+ elif activation in ['sigmoid', 'logistic']:
49
+ return nn.Sigmoid(), 1.
50
+ elif activation in ['tanh']:
51
+ return nn.Tanh(), 1.
52
+ else:
53
+ raise ValueError(
54
+ 'Activation "{}" not available.'.format(activation)
55
+ )
56
+
57
+
58
+ class Swish(nn.Module):
59
+ """
60
+ Performs the 'Swish' non-linear activation function.
61
+ https://arxiv.org/pdf/1710.05941.pdf
62
+ Arguments:
63
+ affine (bool): Multiply the input to sigmoid
64
+ with a learnable scale. Default value is False.
65
+ """
66
+ def __init__(self, affine=False):
67
+ super(Swish, self).__init__()
68
+ if affine:
69
+ self.beta = nn.Parameter(torch.tensor([1.]))
70
+ self.affine = affine
71
+
72
+ def forward(self, input, *args, **kwargs):
73
+ """
74
+ Apply the swish non-linear activation function
75
+ and return the results.
76
+ Arguments:
77
+ input (torch.Tensor)
78
+ Returns:
79
+ output (torch.Tensor)
80
+ """
81
+ x = input
82
+ if self.affine:
83
+ x *= self.beta
84
+ return x * torch.sigmoid(x)
85
+
86
+
87
+ def _get_weight_and_coef(shape, lr_mul=1, weight_scale=True, gain=1, fill=None):
88
+ """
89
+ Get an intialized weight and its runtime coefficients as an nn.Parameter tensor.
90
+ Arguments:
91
+ shape (tuple, list): Shape of weight tensor.
92
+ lr_mul (float): The learning rate multiplier for
93
+ this weight. Default value is 1.
94
+ weight_scale (bool): Use weight scaling for equalized
95
+ learning rate. Default value is True.
96
+ gain (float): The gain of the weight. Default value is 1.
97
+ fill (float, optional): Instead of initializing the weight
98
+ with scaled normally distributed values, fill it with
99
+ this value. Useful for bias weights.
100
+ Returns:
101
+ weight (nn.Parameter)
102
+ """
103
+ fan_in = np.prod(shape[1:])
104
+ he_std = gain / np.sqrt(fan_in)
105
+
106
+ if weight_scale:
107
+ init_std = 1 / lr_mul
108
+ runtime_coef = he_std * lr_mul
109
+ else:
110
+ init_std = he_std / lr_mul
111
+ runtime_coef = lr_mul
112
+
113
+ weight = torch.empty(*shape)
114
+ if fill is None:
115
+ weight.normal_(0, init_std)
116
+ else:
117
+ weight.fill_(fill)
118
+ return nn.Parameter(weight), runtime_coef
119
+
120
+
121
+ def _apply_conv(input, *args, transpose=False, **kwargs):
122
+ """
123
+ Perform a 1d, 2d or 3d convolution with specified
124
+ positional and keyword arguments. Which type of
125
+ convolution that is used depends on shape of data.
126
+ Arguments:
127
+ input (torch.Tensor): The input data for the
128
+ convolution.
129
+ *args: Positional arguments for the convolution.
130
+ Keyword Arguments:
131
+ transpose (bool): Transpose the convolution.
132
+ Default value is False
133
+ **kwargs: Keyword arguments for the convolution.
134
+ """
135
+ dim = input.dim() - 2
136
+ conv_fn = getattr(
137
+ F, 'conv{}{}d'.format('_transpose' if transpose else '', dim))
138
+ return conv_fn(input=input, *args, **kwargs)
139
+
140
+
141
+ def _setup_mod_weight_for_t_conv(weight, in_channels, out_channels):
142
+ """
143
+ Reshape a modulated conv weight for use with a transposed convolution.
144
+ Arguments:
145
+ weight (torch.Tensor)
146
+ in_channels (int)
147
+ out_channels (int)
148
+ Returns:
149
+ reshaped_weight (torch.Tensor)
150
+ """
151
+ # [BO]I*k -> BOI*k
152
+ weight = weight.view(
153
+ -1,
154
+ out_channels,
155
+ in_channels,
156
+ *weight.size()[2:]
157
+ )
158
+ # BOI*k -> BIO*k
159
+ weight = weight.transpose(1, 2)
160
+ # BIO*k -> [BI]O*k
161
+ weight = weight.reshape(
162
+ -1,
163
+ out_channels,
164
+ *weight.size()[3:]
165
+ )
166
+ return weight
167
+
168
+
169
+ def _setup_filter_kernel(filter_kernel, gain=1, up_factor=1, dim=2):
170
+ """
171
+ Set up a filter kernel and return it as a tensor.
172
+ Arguments:
173
+ filter_kernel (int, list, torch.tensor, None): The filter kernel
174
+ values to use. If this value is an int, a binomial filter of
175
+ this size is created. If a sequence with a single axis is used,
176
+ it will be expanded to the number of `dims` specified. If value
177
+ is None, a filter of values [1, 1] is used.
178
+ gain (float): Gain of the filter kernel. Default value is 1.
179
+ up_factor (int): Scale factor. Should only be given for upscaling filters.
180
+ Default value is 1.
181
+ dim (int): Number of dimensions of data. Default value is 2.
182
+ Returns:
183
+ filter_kernel_tensor (torch.Tensor)
184
+ """
185
+ filter_kernel = filter_kernel or 2
186
+ if isinstance(filter_kernel, (int, float)):
187
+ def binomial(n, k):
188
+ if k in [1, n]:
189
+ return 1
190
+ return np.math.factorial(n) / (np.math.factorial(k) * np.math.factorial(n - k))
191
+ filter_kernel = [binomial(filter_kernel, k) for k in range(1, filter_kernel + 1)]
192
+ if not torch.is_tensor(filter_kernel):
193
+ filter_kernel = torch.tensor(filter_kernel)
194
+ filter_kernel = filter_kernel.float()
195
+ if filter_kernel.dim() == 1:
196
+ _filter_kernel = filter_kernel.unsqueeze(0)
197
+ while filter_kernel.dim() < dim:
198
+ filter_kernel = torch.matmul(
199
+ filter_kernel.unsqueeze(-1), _filter_kernel)
200
+ assert all(filter_kernel.size(0) == s for s in filter_kernel.size())
201
+ filter_kernel /= filter_kernel.sum()
202
+ filter_kernel *= gain * up_factor ** 2
203
+ return filter_kernel.float()
204
+
205
+
206
+ def _get_layer(layer_class, kwargs, wrap=False, noise=False):
207
+ """
208
+ Create a layer and wrap it in optional
209
+ noise and/or bias/activation layers.
210
+ Arguments:
211
+ layer_class: The class of the layer to construct.
212
+ kwargs (dict): The keyword arguments to use for constructing
213
+ the layer and optionally the bias/activaiton layer.
214
+ wrap (bool): Wrap the layer in an bias/activation layer and
215
+ optionally a noise injection layer. Default value is False.
216
+ noise (bool): Inject noise before the bias/activation wrapper.
217
+ This can only be done when `wrap=True`. Default value is False.
218
+ """
219
+ layer = layer_class(**kwargs)
220
+ if wrap:
221
+ if noise:
222
+ layer = NoiseInjectionWrapper(layer)
223
+ layer = BiasActivationWrapper(layer, **kwargs)
224
+ return layer
225
+
226
+
227
+ class BiasActivationWrapper(nn.Module):
228
+ """
229
+ Wrap a module to add bias and non-linear activation
230
+ to any outputs of that module.
231
+ Arguments:
232
+ layer (nn.Module): The module to wrap.
233
+ features (int, optional): The number of features
234
+ of the output of the `layer`. This argument
235
+ has to be specified if `use_bias=True`.
236
+ use_bias (bool): Add bias to the output.
237
+ Default value is True.
238
+ activation (str, nn.Module, callable, optional):
239
+ non-linear activation function to use.
240
+ Unused if notspecified.
241
+ bias_init (float): Value to initialize bias
242
+ weight with. Default value is 0.
243
+ lr_mul (float): Learning rate multiplier of
244
+ the bias weight. Weights are scaled by
245
+ this value. Default value is 1.
246
+ weight_scale (float): Scale weights for
247
+ equalized learning rate.
248
+ Default value is True.
249
+ """
250
+ def __init__(self,
251
+ layer,
252
+ features=None,
253
+ use_bias=True,
254
+ activation='linear',
255
+ bias_init=0,
256
+ lr_mul=1,
257
+ weight_scale=True,
258
+ *args,
259
+ **kwargs):
260
+ super(BiasActivationWrapper, self).__init__()
261
+ self.layer = layer
262
+ bias = None
263
+ bias_coef = None
264
+ if use_bias:
265
+ assert features, '`features` is required when using bias.'
266
+ bias, bias_coef = _get_weight_and_coef(
267
+ shape=[features],
268
+ lr_mul=lr_mul,
269
+ weight_scale=False,
270
+ fill=bias_init
271
+ )
272
+ self.register_parameter('bias', bias)
273
+ self.bias_coef = bias_coef
274
+ self.act, self.gain = get_activation(activation)
275
+
276
+ def forward(self, *args, **kwargs):
277
+ """
278
+ Forward all possitional and keyword arguments
279
+ to the layer wrapped by this module and add
280
+ bias (if set) and run through non-linear activation
281
+ function (if set).
282
+ Arguments:
283
+ *args (positional arguments)
284
+ **kwargs (keyword arguments)
285
+ Returns:
286
+ output (torch.Tensor)
287
+ """
288
+ x = self.layer(*args, **kwargs)
289
+ if self.bias is not None:
290
+ bias = self.bias.view(1, -1, *[1] * (x.dim() - 2))
291
+ if self.bias_coef != 1:
292
+ bias = self.bias_coef * bias
293
+ x += bias
294
+ x = self.act(x)
295
+ if self.gain != 1:
296
+ x *= self.gain
297
+ return x
298
+
299
+ def extra_repr(self):
300
+ return 'bias={}'.format(self.bias is not None)
301
+
302
+
303
+ class NoiseInjectionWrapper(nn.Module):
304
+ """
305
+ Wrap a module to add noise scaled by a
306
+ learnable parameter to any outputs of the
307
+ wrapped module.
308
+ Noise is randomized for each output but can
309
+ be set to static noise by calling `static_noise()`
310
+ of this object. This can only be done once data
311
+ has passed through this layer at least once so that
312
+ the shape of the static noise to create is known.
313
+ Check if the shape is known by calling `has_noise_shape()`.
314
+ Arguments:
315
+ layer (nn.Module): The module to wrap.
316
+ same_over_batch (bool): Repeat the same
317
+ noise values over the entire batch
318
+ instead of creating separate noise
319
+ values for each entry in the batch.
320
+ Default value is True.
321
+ """
322
+
323
+ def __init__(self, layer, same_over_batch=True):
324
+ super(NoiseInjectionWrapper, self).__init__()
325
+ self.layer = layer
326
+ self.weight = torch.nn.Parameter(torch.zeros(1))
327
+ self.register_buffer('noise_storage', None)
328
+ self.same_over_batch = same_over_batch
329
+ self.random_noise()
330
+
331
+ def has_noise_shape(self):
332
+ """
333
+ If this module has had data passed through it
334
+ the noise shape is known and this function returns
335
+ True. Else False.
336
+ Returns:
337
+ noise_shape_known (bool)
338
+ """
339
+ return self.noise_storage is not None
340
+
341
+ def random_noise(self):
342
+ """
343
+ Randomize noise for each
344
+ new output.
345
+ """
346
+ self._fixed_noise = False
347
+ if isinstance(self.noise_storage, nn.Parameter):
348
+ noise_storage = self.noise_storage
349
+ del self.noise_storage
350
+ self.register_buffer('noise_storage', noise_storage.data)
351
+
352
+ def static_noise(self, trainable=False, noise_tensor=None):
353
+ """
354
+ Set up static noise that can optionally be a trainable
355
+ parameter. Static noise does not change between inputs
356
+ unless the user has altered its values. Returns the tensor
357
+ object that stores the static noise.
358
+ Arguments:
359
+ trainable (bool): Wrap the static noise tensor in
360
+ nn.Parameter to make it trainable. The returned
361
+ tensor will be wrapped.
362
+ noise_tensor (torch.Tensor, optional): A predefined
363
+ static noise tensor. If not passed, one will be
364
+ created.
365
+ """
366
+ assert self.has_noise_shape(), \
367
+ 'Noise shape is unknown'
368
+ if noise_tensor is None:
369
+ noise_tensor = self.noise_storage
370
+ else:
371
+ noise_tensor = noise_tensor.to(self.weight)
372
+ if trainable and not isinstance(noise_tensor, nn.Parameter):
373
+ noise_tensor = nn.Parameter(noise_tensor)
374
+ if isinstance(self.noise_storage, nn.Parameter) and not trainable:
375
+ del self.noise_storage
376
+ self.register_buffer('noise_storage', noise_tensor)
377
+ else:
378
+ self.noise_storage = noise_tensor
379
+ self._fixed_noise = True
380
+ return noise_tensor
381
+
382
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
383
+ r"""Saves module state to `destination` dictionary, containing a state
384
+ submodule in :meth:`~torch.nn.Module.state_dict`.
385
+
386
+ Overridden to ignore the noise storage buffer.
387
+
388
+ Arguments:
389
+ destination (dict): a dict where state will be stored
390
+ prefix (str): the prefix for parameters and buffers used in this
391
+ module
392
+ """
393
+ for name, param in self._parameters.items():
394
+ if name != 'noise_storage' and param is not None:
395
+ destination[prefix + name] = param if keep_vars else param.data
396
+ for name, buf in self._buffers.items():
397
+ if name != 'noise_storage' and buf is not None:
398
+ destination[prefix + name] = buf if keep_vars else buf.data
399
+
400
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
401
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
402
+ this module, but not its descendants. This is called on every submodule
403
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
404
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
405
+ For state dicts without metadata, :attr:`local_metadata` is empty.
406
+ Overridden to ignore noise storage buffer.
407
+ """
408
+ key = prefix + 'noise_storage'
409
+ if key in state_dict:
410
+ del state_dict[key]
411
+ return super(NoiseInjectionWrapper, self)._load_from_state_dict(
412
+ state_dict, prefix, *args, **kwargs)
413
+
414
+ def forward(self, *args, **kwargs):
415
+ """
416
+ Forward all possitional and keyword arguments
417
+ to the layer wrapped by this module and add
418
+ noise to its outputs before returning them.
419
+ Arguments:
420
+ *args (positional arguments)
421
+ **kwargs (keyword arguments)
422
+ Returns:
423
+ output (torch.Tensor)
424
+ """
425
+ x = self.layer(*args, **kwargs)
426
+ noise_shape = list(x.size())
427
+ noise_shape[1] = 1
428
+ if self.same_over_batch:
429
+ noise_shape[0] = 1
430
+ if self.noise_storage is None or list(self.noise_storage.size()) != noise_shape:
431
+ if not self._fixed_noise:
432
+ self.noise_storage = torch.empty(
433
+ *noise_shape,
434
+ dtype=self.weight.dtype,
435
+ device=self.weight.device
436
+ )
437
+ else:
438
+ assert list(self.noise_storage.size()[2:]) == noise_shape[2:], \
439
+ 'A data size {} has been encountered, '.format(x.size()[2:]) + \
440
+ 'the static noise previously set up does ' + \
441
+ 'not match this size {}'.format(self.noise_storage.size()[2:])
442
+ assert self.noise_storage.size(0) == 1 or self.noise_storage.size(0) == x.size(0), \
443
+ 'Static noise batch size mismatch! ' + \
444
+ 'Noise batch size: {}, '.format(self.noise_storage.size(0)) + \
445
+ 'input batch size: {}'.format(x.size(0))
446
+ assert self.noise_storage.size(1) == 1 or self.noise_storage.size(1) == x.size(1), \
447
+ 'Static noise channel size mismatch! ' + \
448
+ 'Noise channel size: {}, '.format(self.noise_storage.size(1)) + \
449
+ 'input channel size: {}'.format(x.size(1))
450
+ if not self._fixed_noise:
451
+ self.noise_storage.normal_()
452
+ x += self.weight * self.noise_storage
453
+ return x
454
+
455
+ def extra_repr(self):
456
+ return 'static_noise={}'.format(self._fixed_noise)
457
+
458
+
459
+ class FilterLayer(nn.Module):
460
+ """
461
+ Apply a filter by using convolution.
462
+ Arguments:
463
+ filter_kernel (torch.Tensor): The filter kernel to use.
464
+ Should be of shape `dims * (k,)` where `k` is the
465
+ kernel size and `dims` is the number of data dimensions
466
+ (excluding batch and channel dimension).
467
+ stride (int): The stride of the convolution.
468
+ pad0 (int): Amount to pad start of each data dimension.
469
+ Default value is 0.
470
+ pad1 (int): Amount to pad end of each data dimension.
471
+ Default value is 0.
472
+ pad_mode (str): The padding mode. Default value is 'constant'.
473
+ pad_constant (float): The constant value to pad with if
474
+ `pad_mode='constant'`. Default value is 0.
475
+ """
476
+ def __init__(self,
477
+ filter_kernel,
478
+ stride=1,
479
+ pad0=0,
480
+ pad1=0,
481
+ pad_mode='constant',
482
+ pad_constant=0,
483
+ *args,
484
+ **kwargs):
485
+ super(FilterLayer, self).__init__()
486
+ dim = filter_kernel.dim()
487
+ filter_kernel = filter_kernel.view(1, 1, *filter_kernel.size())
488
+ self.register_buffer('filter_kernel', filter_kernel)
489
+ self.stride = stride
490
+ if pad0 == pad1 and (pad0 == 0 or pad_mode == 'constant' and pad_constant == 0):
491
+ self.fused_pad = True
492
+ self.padding = pad0
493
+ else:
494
+ self.fused_pad = False
495
+ self.padding = [pad0, pad1] * dim
496
+ self.pad_mode = pad_mode
497
+ self.pad_constant = pad_constant
498
+
499
+ def forward(self, input, **kwargs):
500
+ """
501
+ Pad the input and run the filter over it
502
+ before returning the new values.
503
+ Arguments:
504
+ input (torch.Tensor)
505
+ Returns:
506
+ output (torch.Tensor)
507
+ """
508
+ x = input
509
+ conv_kwargs = dict(
510
+ weight=self.filter_kernel.repeat(
511
+ input.size(1), *[1] * (self.filter_kernel.dim() - 1)),
512
+ stride=self.stride,
513
+ groups=input.size(1),
514
+ )
515
+ if self.fused_pad:
516
+ conv_kwargs.update(padding=self.padding)
517
+ else:
518
+ x = F.pad(x, self.padding, mode=self.pad_mode, value=self.pad_constant)
519
+ return _apply_conv(
520
+ input=x,
521
+ transpose=False,
522
+ **conv_kwargs
523
+ )
524
+
525
+ def extra_repr(self):
526
+ return 'filter_size={}, stride={}'.format(
527
+ tuple(self.filter_kernel.size()[2:]), self.stride)
528
+
529
+
530
+ class Upsample(nn.Module):
531
+ """
532
+ Performs upsampling without learnable parameters that doubles
533
+ the size of data.
534
+ Arguments:
535
+ mode (str): 'FIR' or one of the valid modes
536
+ that can be passed to torch.nn.functional.interpolate().
537
+ filter (int, list, tensor): Filter to use if `mode='FIR'`.
538
+ Default value is a lowpass filter of values [1, 3, 3, 1].
539
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
540
+ See `FilterLayer` docstring for more info.
541
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
542
+ See `FilterLayer` docstring for more info.
543
+ gain (float): If `mode='FIR'`, this is used with the filter.
544
+ See `FilterLayer` docstring for more info.
545
+ dim (int): Dims of data (excluding batch and channel dimensions).
546
+ Default value is 2.
547
+ """
548
+
549
+ def __init__(self,
550
+ mode='FIR',
551
+ filter=[1, 3, 3, 1],
552
+ filter_pad_mode='constant',
553
+ filter_pad_constant=0,
554
+ gain=1,
555
+ dim=2,
556
+ *args,
557
+ **kwargs):
558
+ super(Upsample, self).__init__()
559
+ assert mode != 'max', 'mode \'max\' can only be used for downsampling.'
560
+ if mode == 'FIR':
561
+ if filter is None:
562
+ filter = [1, 1]
563
+ filter_kernel = _setup_filter_kernel(
564
+ filter_kernel=filter,
565
+ gain=gain,
566
+ up_factor=2,
567
+ dim=dim
568
+ )
569
+ pad = filter_kernel.size(-1) - 1
570
+ self.filter = FilterLayer(
571
+ filter_kernel=filter_kernel,
572
+ pad0=(pad + 1) // 2 + 1,
573
+ pad1=pad // 2,
574
+ pad_mode=filter_pad_mode,
575
+ pad_constant=filter_pad_constant
576
+ )
577
+ self.register_buffer('weight', torch.ones(*[1 for _ in range(dim + 2)]))
578
+ self.mode = mode
579
+
580
+ def forward(self, input, **kwargs):
581
+ """
582
+ Upsample inputs.
583
+ Arguments:
584
+ input (torch.Tensor)
585
+ Returns:
586
+ output (torch.Tensor)
587
+ """
588
+ if self.mode == 'FIR':
589
+ x = _apply_conv(
590
+ input=input,
591
+ weight=self.weight.expand(input.size(1), *self.weight.size()[1:]),
592
+ groups=input.size(1),
593
+ stride=2,
594
+ transpose=True
595
+ )
596
+ x = self.filter(x)
597
+ else:
598
+ interp_kwargs = dict(scale_factor=2, mode=self.mode)
599
+ if 'linear' in self.mode or 'cubic' in self.mode:
600
+ interp_kwargs.update(align_corners=False)
601
+ x = F.interpolate(input, **interp_kwargs)
602
+ return x
603
+
604
+ def extra_repr(self):
605
+ return 'resample_mode={}'.format(self.mode)
606
+
607
+
608
+ class Downsample(nn.Module):
609
+ """
610
+ Performs downsampling without learnable parameters that
611
+ reduces size of data by half.
612
+ Arguments:
613
+ mode (str): 'FIR', 'max' or one of the valid modes
614
+ that can be passed to torch.nn.functional.interpolate().
615
+ filter (int, list, tensor): Filter to use if `mode='FIR'`.
616
+ Default value is a lowpass filter of values [1, 3, 3, 1].
617
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
618
+ See `FilterLayer` docstring for more info.
619
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
620
+ See `FilterLayer` docstring for more info.
621
+ gain (float): If `mode='FIR'`, this is used with the filter.
622
+ See `FilterLayer` docstring for more info.
623
+ dim (int): Dims of data (excluding batch and channel dimensions).
624
+ Default value is 2.
625
+ """
626
+
627
+ def __init__(self,
628
+ mode='FIR',
629
+ filter=[1, 3, 3, 1],
630
+ filter_pad_mode='constant',
631
+ filter_pad_constant=0,
632
+ gain=1,
633
+ dim=2,
634
+ *args,
635
+ **kwargs):
636
+ super(Downsample, self).__init__()
637
+ if mode == 'FIR':
638
+ if filter is None:
639
+ filter = [1, 1]
640
+ filter_kernel = _setup_filter_kernel(
641
+ filter_kernel=filter,
642
+ gain=gain,
643
+ up_factor=1,
644
+ dim=dim
645
+ )
646
+ pad = filter_kernel.size(-1) - 2
647
+ pad0 = pad // 2
648
+ pad1 = pad - pad0
649
+ self.filter = FilterLayer(
650
+ filter_kernel=filter_kernel,
651
+ stride=2,
652
+ pad0=pad0,
653
+ pad1=pad1,
654
+ pad_mode=filter_pad_mode,
655
+ pad_constant=filter_pad_constant
656
+ )
657
+ self.mode = mode
658
+
659
+ def forward(self, input, **kwargs):
660
+ """
661
+ Downsample inputs to half its size.
662
+ Arguments:
663
+ input (torch.Tensor)
664
+ Returns:
665
+ output (torch.Tensor)
666
+ """
667
+ if self.mode == 'FIR':
668
+ x = self.filter(input)
669
+ elif self.mode == 'max':
670
+ return getattr(F, 'max_pool{}d'.format(input.dim() - 2))(input)
671
+ else:
672
+ x = F.interpolate(input, scale_factor=0.5, mode=self.mode)
673
+ return x
674
+
675
+ def extra_repr(self):
676
+ return 'resample_mode={}'.format(self.mode)
677
+
678
+
679
+ class MinibatchStd(nn.Module):
680
+ """
681
+ Adds the aveage std of each data point over a
682
+ slice of the minibatch to that slice as a new
683
+ feature map. This gives an output with one extra
684
+ channel.
685
+ Arguments:
686
+ group_size (int): Number of entries in each slice
687
+ of the batch. If <= 0, the entire batch is used.
688
+ Default value is 4.
689
+ eps (float): Epsilon value added for numerical stability.
690
+ Default value is 1e-8.
691
+ """
692
+ def __init__(self, group_size=4, eps=1e-8, *args, **kwargs):
693
+ super(MinibatchStd, self).__init__()
694
+ if group_size is None or group_size <= 0:
695
+ # Entire batch as group size
696
+ group_size = 0
697
+ assert group_size != 1, 'Can not use 1 as minibatch std group size.'
698
+ self.group_size = group_size
699
+ self.eps = eps
700
+
701
+ def forward(self, input, **kwargs):
702
+ """
703
+ Add a new feature map to the input containing the average
704
+ standard deviation for each slice.
705
+ Arguments:
706
+ input (torch.Tensor)
707
+ Returns:
708
+ output (torch.Tensor)
709
+ """
710
+ group_size = self.group_size or input.size(0)
711
+ assert input.size(0) >= group_size, \
712
+ 'Can not use a smaller batch size ' + \
713
+ '({}) than the specified '.format(input.size(0)) + \
714
+ 'group size ({}) '.format(group_size) + \
715
+ 'of this minibatch std layer.'
716
+ assert input.size(0) % group_size == 0, \
717
+ 'Can not use a batch of a size ' + \
718
+ '({}) that is not '.format(input.size(0)) + \
719
+ 'evenly divisible by the group size ({})'.format(group_size)
720
+ x = input
721
+
722
+ # B = batch size, C = num channels
723
+ # *s = the size dimensions (height, width for images)
724
+
725
+ # BC*s -> G[B/G]C*s
726
+ y = input.view(group_size, -1, *input.size()[1:])
727
+ # For numerical stability when training with mixed precision
728
+ y = y.float()
729
+ # G[B/G]C*s
730
+ y -= y.mean(dim=0, keepdim=True)
731
+ # [B/G]C*s
732
+ y = torch.mean(y ** 2, dim=0)
733
+ # [B/G]C*s
734
+ y = torch.sqrt(y + self.eps)
735
+ # [B/G]
736
+ y = torch.mean(y.view(y.size(0), -1), dim=-1)
737
+ # [B/G]1*1
738
+ y = y.view(-1, *[1] * (input.dim() - 1))
739
+ # Cast back to input dtype
740
+ y = y.to(x)
741
+ # B1*1
742
+ y = y.repeat(group_size, *[1] * (y.dim() - 1))
743
+ # B1*s
744
+ y = y.expand(y.size(0), 1, *x.size()[2:])
745
+ # B[C+1]*s
746
+ x = torch.cat([x, y], dim=1)
747
+ return x
748
+
749
+ def extra_repr(self):
750
+ return 'group_size={}'.format(self.group_size or '-1')
751
+
752
+
753
+ class DenseLayer(nn.Module):
754
+ """
755
+ A fully connected layer.
756
+ NOTE: No bias is applied in this layer.
757
+ Arguments:
758
+ in_features (int): Number of input features.
759
+ out_features (int): Number of output features.
760
+ lr_mul (float): Learning rate multiplier of
761
+ the weight. Weights are scaled by
762
+ this value. Default value is 1.
763
+ weight_scale (float): Scale weights for
764
+ equalized learning rate.
765
+ Default value is True.
766
+ gain (float): The gain of this layer. Default value is 1.
767
+ """
768
+ def __init__(self,
769
+ in_features,
770
+ out_features,
771
+ lr_mul=1,
772
+ weight_scale=True,
773
+ gain=1,
774
+ *args,
775
+ **kwargs):
776
+ super(DenseLayer, self).__init__()
777
+ weight, weight_coef = _get_weight_and_coef(
778
+ shape=[out_features, in_features],
779
+ lr_mul=lr_mul,
780
+ weight_scale=weight_scale,
781
+ gain=gain
782
+ )
783
+ self.register_parameter('weight', weight)
784
+ self.weight_coef = weight_coef
785
+
786
+ def forward(self, input, **kwargs):
787
+ """
788
+ Perform a matrix multiplication of the weight
789
+ of this layer and the input.
790
+ Arguments:
791
+ input (torch.Tensor)
792
+ Returns:
793
+ output (torch.Tensor)
794
+ """
795
+ weight = self.weight
796
+ if self.weight_coef != 1:
797
+ weight = self.weight_coef * weight
798
+ return input.matmul(weight.t())
799
+
800
+ def extra_repr(self):
801
+ return 'in_features={}, out_features={}'.format(
802
+ self.weight.size(1), self.weight.size(0))
803
+
804
+
805
+ class ConvLayer(nn.Module):
806
+ """
807
+ A convolutional layer that can have its outputs
808
+ modulated (style mod). It can also normalize outputs.
809
+ These operations are done by modifying the convolutional
810
+ kernel weight and employing grouped convolutions for
811
+ efficiency.
812
+ NOTE: No bias is applied in this layer.
813
+ NOTE: Amount of padding used is the same as 'SAME'
814
+ argument in tensorflow for conv padding.
815
+ Arguments:
816
+ in_channels (int): Number of input channels.
817
+ out_channels (int): Number of output channels.
818
+ latent_size (int, optional): The size of the
819
+ latents to use for modulating this convolution.
820
+ Only required when `modulate=True`.
821
+ modulate (bool): Applies a "style" to the outputs
822
+ of the layer. The style is given by a latent
823
+ vector passed with the input to this layer.
824
+ A dense layer is added that projects the
825
+ values of the latent into scales for the
826
+ data channels.
827
+ Default value is False.
828
+ demodulate (bool): Normalize std of outputs.
829
+ Can only be set to True when `modulate=True`.
830
+ Default value is False.
831
+ kernel_size (int): The size of the kernel.
832
+ Default value is 3.
833
+ pad_mode (str): The padding mode. Default value is 'constant'.
834
+ pad_constant (float): The constant value to pad with if
835
+ `pad_mode='constant'`. Default value is 0.
836
+ lr_mul (float): Learning rate multiplier of
837
+ the weight. Weights are scaled by
838
+ this value. Default value is 1.
839
+ weight_scale (float): Scale weights for
840
+ equalized learning rate.
841
+ Default value is True.
842
+ gain (float): The gain of this layer. Default value is 1.
843
+ dim (int): Dims of data (excluding batch and channel dimensions).
844
+ Default value is 2.
845
+ eps (float): Epsilon value added for numerical stability.
846
+ Default value is 1e-8.
847
+ """
848
+ def __init__(self,
849
+ in_channels,
850
+ out_channels,
851
+ latent_size=None,
852
+ modulate=False,
853
+ demodulate=False,
854
+ kernel_size=3,
855
+ pad_mode='constant',
856
+ pad_constant=0,
857
+ lr_mul=1,
858
+ weight_scale=True,
859
+ gain=1,
860
+ dim=2,
861
+ eps=1e-8,
862
+ *args,
863
+ **kwargs):
864
+ super(ConvLayer, self).__init__()
865
+ assert modulate or not demodulate, '`demodulate=True` can ' + \
866
+ 'only be used when `modulate=True`'
867
+ if modulate:
868
+ assert latent_size is not None, 'When using `modulate=True`, ' + \
869
+ '`latent_size` has to be specified.'
870
+ kernel_shape = [out_channels, in_channels] + dim * [kernel_size]
871
+ weight, weight_coef = _get_weight_and_coef(
872
+ shape=kernel_shape,
873
+ lr_mul=lr_mul,
874
+ weight_scale=weight_scale,
875
+ gain=gain
876
+ )
877
+ self.register_parameter('weight', weight)
878
+ self.weight_coef = weight_coef
879
+ if modulate:
880
+ self.dense = BiasActivationWrapper(
881
+ layer=DenseLayer(
882
+ in_features=latent_size,
883
+ out_features=in_channels,
884
+ lr_mul=lr_mul,
885
+ weight_scale=weight_scale,
886
+ gain=1
887
+ ),
888
+ features=in_channels,
889
+ use_bias=True,
890
+ activation='linear',
891
+ bias_init=1,
892
+ lr_mul=lr_mul,
893
+ weight_scale=weight_scale,
894
+ )
895
+ self.dense_reshape = [-1, 1, in_channels] + dim * [1]
896
+ self.dmod_reshape = [-1, out_channels, 1] + dim * [1]
897
+ pad = (kernel_size - 1)
898
+ pad0 = pad - pad // 2
899
+ pad1 = pad - pad0
900
+ if pad0 == pad1 and (pad0 == 0 or pad_mode == 'constant' and pad_constant == 0):
901
+ self.fused_pad = True
902
+ self.padding = pad0
903
+ else:
904
+ self.fused_pad = False
905
+ self.padding = [pad0, pad1] * dim
906
+ self.pad_mode = pad_mode
907
+ self.pad_constant = pad_constant
908
+ self.in_channels = in_channels
909
+ self.out_channels = out_channels
910
+ self.latent_size = latent_size
911
+ self.modulate = modulate
912
+ self.demodulate = demodulate
913
+ self.kernel_size = kernel_size
914
+ self.lr_mul = lr_mul
915
+ self.weight_scale = weight_scale
916
+ self.gain = gain
917
+ self.dim = dim
918
+ self.eps = eps
919
+
920
+ def forward_mod(self, input, latent, weight, **kwargs):
921
+ """
922
+ Run the forward operation with modulation.
923
+ Automatically called from `forward()` if modulation
924
+ is enabled.
925
+ """
926
+ assert latent is not None, 'A latent vector is ' + \
927
+ 'required for the forwad pass of a modulated conv layer.'
928
+
929
+ # B = batch size, C = num channels
930
+ # *s = the size dimensions, example: (height, width) for images
931
+ # *k = sizes of the convolutional kernel excluding in and out channel dimensions.
932
+ # *1 = multiple dimensions of size 1, with number of dimensions depending on data format.
933
+ # O = num output channels, I = num input channels
934
+
935
+ # BI
936
+ style_mod = self.dense(input=latent)
937
+ # B1I*1
938
+ style_mod = style_mod.view(*self.dense_reshape)
939
+ # 1OI*k
940
+ weight = weight.unsqueeze(0)
941
+ # (1OI*k)x(B1I*1) -> BOI*k
942
+ weight = weight * style_mod
943
+ if self.demodulate:
944
+ # BO
945
+ dmod = torch.rsqrt(
946
+ torch.sum(
947
+ weight.view(
948
+ weight.size(0),
949
+ weight.size(1),
950
+ -1
951
+ ) ** 2,
952
+ dim=-1
953
+ ) + self.eps
954
+ )
955
+ # BO1*1
956
+ dmod = dmod.view(*self.dmod_reshape)
957
+ # (BOI*k)x(BO1*1) -> BOI*k
958
+ weight = weight * dmod
959
+ # BI*s -> 1[BI]*s
960
+ x = input.view(1, -1, *input.size()[2:])
961
+ # BOI*k -> [BO]I*k
962
+ weight = weight.view(-1, *weight.size()[2:])
963
+ # 1[BO]*s
964
+ x = self._process(input=x, weight=weight, groups=input.size(0))
965
+ # 1[BO]*s -> BO*s
966
+ x = x.view(-1, self.out_channels, *x.size()[2:])
967
+ return x
968
+
969
+ def forward(self, input, latent=None, **kwargs):
970
+ """
971
+ Convolve the input.
972
+ Arguments:
973
+ input (torch.Tensor)
974
+ latents (torch.Tensor, optional)
975
+ Returns:
976
+ output (torch.Tensor)
977
+ """
978
+ weight = self.weight
979
+ if self.weight_coef != 1:
980
+ weight = self.weight_coef * weight
981
+ if self.modulate:
982
+ return self.forward_mod(input=input, latent=latent, weight=weight)
983
+ return self._process(input=input, weight=weight)
984
+
985
+ def _process(self, input, weight, **kwargs):
986
+ """
987
+ Pad input and convolve it returning the result.
988
+ """
989
+ x = input
990
+ if self.fused_pad:
991
+ kwargs.update(padding=self.padding)
992
+ else:
993
+ x = F.pad(x, self.padding, mode=self.pad_mode, value=self.pad_constant)
994
+ return _apply_conv(input=x, weight=weight, transpose=False, **kwargs)
995
+
996
+ def extra_repr(self):
997
+ string = 'in_channels={}, out_channels={}'.format(
998
+ self.weight.size(1), self.weight.size(0))
999
+ string += ', modulate={}, demodulate={}'.format(
1000
+ self.modulate, self.demodulate)
1001
+ return string
1002
+
1003
+
1004
+ class ConvUpLayer(ConvLayer):
1005
+ """
1006
+ A convolutional upsampling layer that doubles the size of inputs.
1007
+ Extends the functionality of the `ConvLayer` class.
1008
+ Arguments:
1009
+ Same arguments as the `ConvLayer` class.
1010
+ Class Specific Keyword Arguments:
1011
+ fused (bool): Fuse the upsampling operation with the
1012
+ convolution, turning this layer into a strided transposed
1013
+ convolution. Default value is True.
1014
+ mode (str): Resample mode, can only be 'FIR' or 'none' if the operation
1015
+ is fused, otherwise it can also be one of the valid modes
1016
+ that can be passed to torch.nn.functional.interpolate().
1017
+ filter (int, list, tensor): Filter to use if `mode='FIR'`.
1018
+ Default value is a lowpass filter of values [1, 3, 3, 1].
1019
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
1020
+ See `FilterLayer` docstring for more info.
1021
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
1022
+ See `FilterLayer` docstring for more info.
1023
+ pad_once (bool): If FIR filter is used, do all the padding for
1024
+ both convolution and FIR in the FIR layer instead of once per layer.
1025
+ Default value is True.
1026
+ """
1027
+
1028
+ def __init__(self,
1029
+ *args,
1030
+ fused=True,
1031
+ mode='FIR',
1032
+ filter=[1, 3, 3, 1],
1033
+ filter_pad_mode='constant',
1034
+ filter_pad_constant=0,
1035
+ pad_once=True,
1036
+ **kwargs):
1037
+ super(ConvUpLayer, self).__init__(*args, **kwargs)
1038
+ if fused:
1039
+ assert mode in ['FIR', 'none'], \
1040
+ 'Fused conv upsample can only use ' + \
1041
+ '\'FIR\' or \'none\' for resampling ' + \
1042
+ '(`mode` argument).'
1043
+ self.padding = np.ceil(self.kernel_size / 2 - 1)
1044
+ self.output_padding = 2 * (self.padding + 1) - self.kernel_size
1045
+ if not self.modulate:
1046
+ # pre-prepare weights only once instead of every forward pass
1047
+ self.weight = nn.Parameter(self.weight.data.transpose(0, 1).contiguous())
1048
+ self.filter = None
1049
+ if mode == 'FIR':
1050
+ filter_kernel = _setup_filter_kernel(
1051
+ filter_kernel=filter,
1052
+ gain=self.gain,
1053
+ up_factor=2,
1054
+ dim=self.dim
1055
+ )
1056
+ if pad_once:
1057
+ self.padding = 0
1058
+ self.output_padding = 0
1059
+ pad = (filter_kernel.size(-1) - 2) - (self.kernel_size - 1)
1060
+ pad0 = (pad + 1) // 2 + 1,
1061
+ pad1 = pad // 2 + 1,
1062
+ else:
1063
+ pad = (filter_kernel.size(-1) - 1)
1064
+ pad0 = pad // 2
1065
+ pad1 = pad - pad0
1066
+ self.filter = FilterLayer(
1067
+ filter_kernel=filter_kernel,
1068
+ pad0=pad0,
1069
+ pad1=pad1,
1070
+ pad_mode=filter_pad_mode,
1071
+ pad_constant=filter_pad_constant
1072
+ )
1073
+ else:
1074
+ assert mode != 'none', '\'none\' can not be used as ' + \
1075
+ 'sampling `mode` when `fused=False` as upsampling ' + \
1076
+ 'has to be performed separately from the conv layer.'
1077
+ self.upsample = Upsample(
1078
+ mode=mode,
1079
+ filter=filter,
1080
+ filter_pad_mode=filter_pad_mode,
1081
+ filter_pad_constant=filter_pad_constant,
1082
+ channels=self.in_channels,
1083
+ gain=self.gain,
1084
+ dim=self.dim
1085
+ )
1086
+ self.fused = fused
1087
+ self.mode = mode
1088
+
1089
+ def _process(self, input, weight, **kwargs):
1090
+ """
1091
+ Apply resampling (if enabled) and convolution.
1092
+ """
1093
+ x = input
1094
+ if self.fused:
1095
+ if self.modulate:
1096
+ weight = _setup_mod_weight_for_t_conv(
1097
+ weight=weight,
1098
+ in_channels=self.in_channels,
1099
+ out_channels=self.out_channels
1100
+ )
1101
+ pad_out = False
1102
+ if self.pad_mode == 'constant' and self.pad_constant == 0:
1103
+ if self.filter is not None or not self.pad_once:
1104
+ kwargs.update(
1105
+ padding=self.padding,
1106
+ output_padding=self.output_padding,
1107
+ )
1108
+ elif self.filter is None:
1109
+ if self.padding:
1110
+ x = F.pad(
1111
+ x,
1112
+ [self.padding] * 2 * self.dim,
1113
+ mode=self.pad_mode,
1114
+ value=self.pad_constant
1115
+ )
1116
+ pad_out = self.output_padding != 0
1117
+ kwargs.update(stride=2)
1118
+ x = _apply_conv(
1119
+ input=x,
1120
+ weight=weight,
1121
+ transpose=True,
1122
+ **kwargs
1123
+ )
1124
+ if pad_out:
1125
+ x = F.pad(
1126
+ x,
1127
+ [self.output_padding, 0] * self.dim,
1128
+ mode=self.pad_mode,
1129
+ value=self.pad_constant
1130
+ )
1131
+ if self.filter is not None:
1132
+ x = self.filter(x)
1133
+ else:
1134
+ x = super(ConvUpLayer, self)._process(
1135
+ input=self.upsample(input=x),
1136
+ weight=weight,
1137
+ **kwargs
1138
+ )
1139
+ return x
1140
+
1141
+ def extra_repr(self):
1142
+ string = super(ConvUpLayer, self).extra_repr()
1143
+ string += ', fused={}, resample_mode={}'.format(
1144
+ self.fused, self.mode)
1145
+ return string
1146
+
1147
+
1148
+ class ConvDownLayer(ConvLayer):
1149
+ """
1150
+ A convolutional downsampling layer that halves the size of inputs.
1151
+ Extends the functionality of the `ConvLayer` class.
1152
+ Arguments:
1153
+ Same arguments as the `ConvLayer` class.
1154
+ Class Specific Keyword Arguments:
1155
+ fused (bool): Fuse the downsampling operation with the
1156
+ convolution, turning this layer into a strided convolution.
1157
+ Default value is True.
1158
+ mode (str): Resample mode, can only be 'FIR' or 'none' if the operation
1159
+ is fused, otherwise it can also be 'max' or one of the valid modes
1160
+ that can be passed to torch.nn.functional.interpolate().
1161
+ filter (int, list, tensor): Filter to use if `mode='FIR'`.
1162
+ Default value is a lowpass filter of values [1, 3, 3, 1].
1163
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
1164
+ See `FilterLayer` docstring for more info.
1165
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
1166
+ See `FilterLayer` docstring for more info.
1167
+ pad_once (bool): If FIR filter is used, do all the padding for
1168
+ both convolution and FIR in the FIR layer instead of once per layer.
1169
+ Default value is True.
1170
+ """
1171
+
1172
+ def __init__(self,
1173
+ *args,
1174
+ fused=True,
1175
+ mode='FIR',
1176
+ filter=[1, 3, 3, 1],
1177
+ filter_pad_mode='constant',
1178
+ filter_pad_constant=0,
1179
+ pad_once=True,
1180
+ **kwargs):
1181
+ super(ConvDownLayer, self).__init__(*args, **kwargs)
1182
+ if fused:
1183
+ assert mode in ['FIR', 'none'], \
1184
+ 'Fused conv downsample can only use ' + \
1185
+ '\'FIR\' or \'none\' for resampling ' + \
1186
+ '(`mode` argument).'
1187
+ pad = self.kernel_size - 2
1188
+ pad0 = pad // 2
1189
+ pad1 = pad - pad0
1190
+ if pad0 == pad1 and (pad0 == 0 or self.pad_mode == 'constant' and self.pad_constant == 0):
1191
+ self.fused_pad = True
1192
+ self.padding = pad0
1193
+ else:
1194
+ self.fused_pad = False
1195
+ self.padding = [pad0, pad1] * self.dim
1196
+ self.filter = None
1197
+ if mode == 'FIR':
1198
+ filter_kernel = _setup_filter_kernel(
1199
+ filter_kernel=filter,
1200
+ gain=self.gain,
1201
+ up_factor=1,
1202
+ dim=self.dim
1203
+ )
1204
+ if pad_once:
1205
+ self.fused_pad = True
1206
+ self.padding = 0
1207
+ pad = (filter_kernel.size(-1) - 2) + (self.kernel_size - 1)
1208
+ pad0 = (pad + 1) // 2,
1209
+ pad1 = pad // 2,
1210
+ else:
1211
+ pad = (filter_kernel.size(-1) - 1)
1212
+ pad0 = pad // 2
1213
+ pad1 = pad - pad0
1214
+ self.filter = FilterLayer(
1215
+ filter_kernel=filter_kernel,
1216
+ pad0=pad0,
1217
+ pad1=pad1,
1218
+ pad_mode=filter_pad_mode,
1219
+ pad_constant=filter_pad_constant
1220
+ )
1221
+ self.pad_once = pad_once
1222
+ else:
1223
+ assert mode != 'none', '\'none\' can not be used as ' + \
1224
+ 'sampling `mode` when `fused=False` as downsampling ' + \
1225
+ 'has to be performed separately from the conv layer.'
1226
+ self.downsample = Downsample(
1227
+ mode=mode,
1228
+ filter=filter,
1229
+ pad_mode=filter_pad_mode,
1230
+ pad_constant=filter_pad_constant,
1231
+ channels=self.in_channels,
1232
+ gain=self.gain,
1233
+ dim=self.dim
1234
+ )
1235
+ self.fused = fused
1236
+ self.mode = mode
1237
+
1238
+ def _process(self, input, weight, **kwargs):
1239
+ """
1240
+ Apply resampling (if enabled) and convolution.
1241
+ """
1242
+ x = input
1243
+ if self.fused:
1244
+ kwargs.update(stride=2)
1245
+ if self.filter is not None:
1246
+ x = self.filter(input=x)
1247
+ else:
1248
+ x = self.downsample(input=x)
1249
+ x = super(ConvDownLayer, self)._process(
1250
+ input=x,
1251
+ weight=weight,
1252
+ **kwargs
1253
+ )
1254
+ return x
1255
+
1256
+ def extra_repr(self):
1257
+ string = super(ConvDownLayer, self).extra_repr()
1258
+ string += ', fused={}, resample_mode={}'.format(
1259
+ self.fused, self.mode)
1260
+ return string
1261
+
1262
+
1263
+ class GeneratorConvBlock(nn.Module):
1264
+ """
1265
+ A convblock for the synthesiser model.
1266
+ Arguments:
1267
+ in_channels (int): Number of input channels.
1268
+ out_channels (int): Number of output channels.
1269
+ latent_size (int): The size of the latent vectors.
1270
+ demodulate (bool): Normalize feature outputs from conv
1271
+ layers. Default value is True.
1272
+ resnet (bool): Use residual connections. Default value is
1273
+ False.
1274
+ up (bool): Upsample the data to twice its size. This is
1275
+ performed in the first layer of the block. Default
1276
+ value is False.
1277
+ num_layers (int): Number of convolutional layers of this
1278
+ block. Default value is 2.
1279
+ filter (int, list): The filter to use if
1280
+ `up=True` and `mode='FIR'`. If int, a low
1281
+ pass filter of this size will be used. If list,
1282
+ the filter is explicitly specified. If the filter
1283
+ is of a single dimension it will be expanded to
1284
+ the number of dimensions of the data. Default
1285
+ value is a low pass filter of [1, 3, 3, 1].
1286
+ activation (str, callable, nn.Module): The non-linear
1287
+ activation function to use.
1288
+ Default value is leaky relu with a slope of 0.2.
1289
+ mode (str): The resample mode of upsampling layers.
1290
+ Only used when `up=True`. If fused=True` only 'FIR'
1291
+ and 'none' can be used. Else, anything that can
1292
+ be passed to torch.nn.functional.interpolate is
1293
+ a valid mode. Default value is 'FIR'.
1294
+ fused (bool): If `up=True`, fuse the upsample operation
1295
+ and the first convolutional layer into a transposed
1296
+ convolutional layer.
1297
+ kernel_size (int): Size of the convolutional kernel.
1298
+ Default value is 3.
1299
+ pad_mode (str): The padding mode for convolutional
1300
+ layers. Has to be one of 'constant', 'reflect',
1301
+ 'replicate' or 'circular'. Default value is
1302
+ 'constant'.
1303
+ pad_constant (float): The value to use for conv
1304
+ padding if `conv_pad_mode='constant'`. Default
1305
+ value is 0.
1306
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
1307
+ Otherwise works the same as `pad_mode`.
1308
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
1309
+ Otherwise works the same as `pad_constant`
1310
+ pad_once (bool): If FIR filter is used, do all the padding for
1311
+ both convolution and FIR in the FIR layer instead of once per layer.
1312
+ Default value is True.
1313
+ use_bias (bool): Add bias to layer outputs. Default value is True.
1314
+ noise (bool): Add noise to the output of each layer. Default value
1315
+ is True.
1316
+ lr_mul (float): The learning rate multiplier for this
1317
+ block. When loading weights of previously trained
1318
+ networks, this value has to be the same as when
1319
+ the network was trained for the outputs to not
1320
+ change (as this is used to scale the weights).
1321
+ Default value is 1.
1322
+ weight_scale (bool): Use weight scaling for
1323
+ equalized learning rate. Default value
1324
+ is True.
1325
+ eps (float): Epsilon value added for numerical stability.
1326
+ Default value is 1e-8.
1327
+ """
1328
+ def __init__(self,
1329
+ in_channels,
1330
+ out_channels,
1331
+ latent_size,
1332
+ demodulate=True,
1333
+ resnet=False,
1334
+ up=False,
1335
+ num_layers=2,
1336
+ filter=[1, 3, 3, 1],
1337
+ activation='leaky:0.2',
1338
+ mode='FIR',
1339
+ fused=True,
1340
+ kernel_size=3,
1341
+ pad_mode='constant',
1342
+ pad_constant=0,
1343
+ filter_pad_mode='constant',
1344
+ filter_pad_constant=0,
1345
+ pad_once=True,
1346
+ use_bias=True,
1347
+ noise=True,
1348
+ lr_mul=1,
1349
+ weight_scale=True,
1350
+ gain=1,
1351
+ dim=2,
1352
+ eps=1e-8,
1353
+ *args,
1354
+ **kwargs):
1355
+ super(GeneratorConvBlock, self).__init__()
1356
+ layer_kwargs = locals()
1357
+ layer_kwargs.pop('self')
1358
+ layer_kwargs.pop('__class__')
1359
+ layer_kwargs.update(
1360
+ features=out_channels,
1361
+ modulate=True,
1362
+ )
1363
+
1364
+ assert num_layers > 0
1365
+ assert 1 <= dim <= 3, '`dim` can only be 1, 2 or 3.'
1366
+ if up:
1367
+ available_sampling = ['FIR']
1368
+ if fused:
1369
+ available_sampling.append('none')
1370
+ else:
1371
+ available_sampling.append('nearest')
1372
+ if dim == 1:
1373
+ available_sampling.append('linear')
1374
+ elif dim == 2:
1375
+ available_sampling.append('bilinear')
1376
+ available_sampling.append('bicubic')
1377
+ else:
1378
+ available_sampling.append('trilinear')
1379
+ assert mode in available_sampling, \
1380
+ '`mode` {} '.format(mode) + \
1381
+ 'is not one of the available sample ' + \
1382
+ 'modes {}.'.format(available_sampling)
1383
+
1384
+ self.conv_block = nn.ModuleList()
1385
+
1386
+ while len(self.conv_block) < num_layers:
1387
+ use_up = up and not self.conv_block
1388
+ self.conv_block.append(_get_layer(
1389
+ ConvUpLayer if use_up else ConvLayer, layer_kwargs, wrap=True, noise=noise))
1390
+ layer_kwargs.update(in_channels=out_channels)
1391
+
1392
+ self.projection = None
1393
+ if resnet:
1394
+ projection_kwargs = {
1395
+ **layer_kwargs,
1396
+ 'in_channels': in_channels,
1397
+ 'kernel_size': 1,
1398
+ 'modulate': False,
1399
+ 'demodulate': False
1400
+ }
1401
+ self.projection = _get_layer(
1402
+ ConvUpLayer if up else ConvLayer, projection_kwargs, wrap=False)
1403
+
1404
+ self.res_scale = 1 / np.sqrt(2)
1405
+
1406
+ def __len__(self):
1407
+ """
1408
+ Get the number of conv layers in this block.
1409
+ """
1410
+ return len(self.conv_block)
1411
+
1412
+ def forward(self, input, latents=None, **kwargs):
1413
+ """
1414
+ Run some input through this block and return the output.
1415
+ Arguments:
1416
+ input (torch.Tensor)
1417
+ latents (torch.Tensor)
1418
+ Returns:
1419
+ output (torch.Tensor)
1420
+ """
1421
+ if latents.dim() == 2:
1422
+ latents.unsqueeze(1)
1423
+ if latents.size(1) == 1:
1424
+ latents = latents.repeat(1, len(self), 1)
1425
+ assert latents.size(1) == len(self), \
1426
+ 'Number of latent inputs ' + \
1427
+ '({}) does not match '.format(latents.size(1)) + \
1428
+ 'number of conv layers ' + \
1429
+ '({}) in block.'.format(len(self))
1430
+ x = input
1431
+ for i, layer in enumerate(self.conv_block):
1432
+ x = layer(input=x, latent=latents[:, i])
1433
+ if self.projection is not None:
1434
+ x += self.projection(input=input)
1435
+ x *= self.res_scale
1436
+ return x
1437
+
1438
+
1439
+ class DiscriminatorConvBlock(nn.Module):
1440
+ """
1441
+ A convblock for the discriminator model.
1442
+ Arguments:
1443
+ in_channels (int): Number of input channels.
1444
+ out_channels (int): Number of output channels.
1445
+ demodulate (bool): Normalize feature outputs from conv
1446
+ layers. Default value is True.
1447
+ resnet (bool): Use residual connections. Default value is
1448
+ False.
1449
+ down (bool): Downsample the data to twice its size. This is
1450
+ performed in the last layer of the block. Default
1451
+ value is False.
1452
+ num_layers (int): Number of convolutional layers of this
1453
+ block. Default value is 2.
1454
+ filter (int, list): The filter to use if
1455
+ `down=True` and `mode='FIR'`. If int, a low
1456
+ pass filter of this size will be used. If list,
1457
+ the filter is explicitly specified. If the filter
1458
+ is of a single dimension it will be expanded to
1459
+ the number of dimensions of the data. Default
1460
+ value is a low pass filter of [1, 3, 3, 1].
1461
+ activation (str, callable, nn.Module): The non-linear
1462
+ activation function to use.
1463
+ Default value is leaky relu with a slope of 0.2.
1464
+ mode (str): The resample mode of downsampling layers.
1465
+ Only used when `down=True`. If fused=True` only 'FIR'
1466
+ and 'none' can be used. Else, 'max' or anything that can
1467
+ be passed to torch.nn.functional.interpolate is
1468
+ a valid mode. Default value is 'FIR'.
1469
+ fused (bool): If `down=True`, fuse the downsample operation
1470
+ and the last convolutional layer into a strided
1471
+ convolutional layer.
1472
+ kernel_size (int): Size of the convolutional kernel.
1473
+ Default value is 3.
1474
+ pad_mode (str): The padding mode for convolutional
1475
+ layers. Has to be one of 'constant', 'reflect',
1476
+ 'replicate' or 'circular'. Default value is
1477
+ 'constant'.
1478
+ pad_constant (float): The value to use for conv
1479
+ padding if `conv_pad_mode='constant'`. Default
1480
+ value is 0.
1481
+ filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
1482
+ Otherwise works the same as `pad_mode`.
1483
+ filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
1484
+ Otherwise works the same as `pad_constant`
1485
+ pad_once (bool): If FIR filter is used, do all the padding for
1486
+ both convolution and FIR in the FIR layer instead of once per layer.
1487
+ Default value is True.
1488
+ use_bias (bool): Add bias to layer outputs. Default value is True.
1489
+ lr_mul (float): The learning rate multiplier for this
1490
+ block. When loading weights of previously trained
1491
+ networks, this value has to be the same as when
1492
+ the network was trained for the outputs to not
1493
+ change (as this is used to scale the weights).
1494
+ Default value is 1.
1495
+ weight_scale (bool): Use weight scaling for
1496
+ equalized learning rate. Default value
1497
+ is True.
1498
+ """
1499
+ def __init__(self,
1500
+ in_channels,
1501
+ out_channels,
1502
+ resnet=False,
1503
+ down=False,
1504
+ num_layers=2,
1505
+ filter=[1, 3, 3, 1],
1506
+ activation='leaky:0.2',
1507
+ mode='FIR',
1508
+ fused=True,
1509
+ kernel_size=3,
1510
+ pad_mode='constant',
1511
+ pad_constant=0,
1512
+ filter_pad_mode='constant',
1513
+ filter_pad_constant=0,
1514
+ pad_once=True,
1515
+ use_bias=True,
1516
+ lr_mul=1,
1517
+ weight_scale=True,
1518
+ gain=1,
1519
+ dim=2,
1520
+ *args,
1521
+ **kwargs):
1522
+ super(DiscriminatorConvBlock, self).__init__()
1523
+ layer_kwargs = locals()
1524
+ layer_kwargs.pop('self')
1525
+ layer_kwargs.pop('__class__')
1526
+ layer_kwargs.update(
1527
+ out_channels=in_channels,
1528
+ features=in_channels,
1529
+ modulate=False,
1530
+ demodulate=False
1531
+ )
1532
+
1533
+ assert num_layers > 0
1534
+ assert 1 <= dim <= 3, '`dim` can only be 1, 2 or 3.'
1535
+ if down:
1536
+ available_sampling = ['FIR']
1537
+ if fused:
1538
+ available_sampling.append('none')
1539
+ else:
1540
+ available_sampling.append('max')
1541
+ available_sampling.append('area')
1542
+ available_sampling.append('nearest')
1543
+ if dim == 1:
1544
+ available_sampling.append('linear')
1545
+ elif dim == 2:
1546
+ available_sampling.append('bilinear')
1547
+ available_sampling.append('bicubic')
1548
+ else:
1549
+ available_sampling.append('trilinear')
1550
+ assert mode in available_sampling, \
1551
+ '`mode` {} '.format(mode) + \
1552
+ 'is not one of the available sample ' + \
1553
+ 'modes {}'.format(available_sampling)
1554
+
1555
+ self.conv_block = nn.ModuleList()
1556
+
1557
+ while len(self.conv_block) < num_layers:
1558
+ if len(self.conv_block) == num_layers - 1:
1559
+ layer_kwargs.update(
1560
+ out_channels=out_channels,
1561
+ features=out_channels
1562
+ )
1563
+ use_down = down and len(self.conv_block) == num_layers - 1
1564
+ self.conv_block.append(_get_layer(
1565
+ ConvDownLayer if use_down else ConvLayer, layer_kwargs, wrap=True, noise=False))
1566
+
1567
+ self.projection = None
1568
+ if resnet:
1569
+ projection_kwargs = {
1570
+ **layer_kwargs,
1571
+ 'in_channels': in_channels,
1572
+ 'kernel_size': 1,
1573
+ 'modulate': False,
1574
+ 'demodulate': False
1575
+ }
1576
+ self.projection = _get_layer(
1577
+ ConvDownLayer if down else ConvLayer, projection_kwargs, wrap=False)
1578
+
1579
+ self.res_scale = 1 / np.sqrt(2)
1580
+
1581
+ def __len__(self):
1582
+ """
1583
+ Get the number of conv layers in this block.
1584
+ """
1585
+ return len(self.conv_block)
1586
+
1587
+ def forward(self, input, **kwargs):
1588
+ """
1589
+ Run some input through this block and return the output.
1590
+ Arguments:
1591
+ input (torch.Tensor)
1592
+ Returns:
1593
+ output (torch.Tensor)
1594
+ """
1595
+ x = input
1596
+ for layer in self.conv_block:
1597
+ x = layer(input=x)
1598
+ if self.projection is not None:
1599
+ x += self.projection(input=input)
1600
+ x *= self.res_scale
1601
+ return x
stylegan2/project.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from . import models, utils
8
+ from .external_models import lpips
9
+
10
+
11
+ class Projector(nn.Module):
12
+ """
13
+ Projects data to latent space and noise tensors.
14
+ Arguments:
15
+ G (Generator)
16
+ dlatent_avg_samples (int): Number of dlatent samples
17
+ to collect to find the mean and std.
18
+ Default value is 10 000.
19
+ dlatent_avg_label (int, torch.Tensor, optional): The label to
20
+ use when gathering dlatent statistics.
21
+ dlatent_device (int, str, torch.device, optional): Device to use
22
+ for gathering statistics of dlatents. By default uses
23
+ the same device as parameters of `G` reside on.
24
+ dlatent_batch_size (int): The batch size to sample
25
+ dlatents with. Default value is 1024.
26
+ lpips_model (nn.Module): A model that returns feature the distance
27
+ between two inputs. Default value is the LPIPS VGG16 model.
28
+ lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
29
+ the data so that its smallest side is the same size as this
30
+ argument. Only has a default value of 256 if `lpips_model` is unspecified.
31
+ verbose (bool): Write progress of dlatent statistics gathering to stdout.
32
+ Default value is True.
33
+ """
34
+ def __init__(self,
35
+ G,
36
+ dlatent_avg_samples=10000,
37
+ dlatent_avg_label=None,
38
+ dlatent_device=None,
39
+ dlatent_batch_size=1024,
40
+ lpips_model=None,
41
+ lpips_size=None,
42
+ verbose=True):
43
+ super(Projector, self).__init__()
44
+ assert isinstance(G, models.Generator)
45
+ G.eval().requires_grad_(False)
46
+
47
+ self.G_synthesis = G.G_synthesis
48
+
49
+ G_mapping = G.G_mapping
50
+
51
+ dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples)
52
+
53
+ if dlatent_device is None:
54
+ dlatent_device = next(G_mapping.parameters()).device()
55
+ else:
56
+ dlatent_device = torch.device(dlatent_device)
57
+
58
+ G_mapping.to(dlatent_device)
59
+
60
+ latents = torch.empty(
61
+ dlatent_avg_samples, G_mapping.latent_size).normal_()
62
+ dlatents = []
63
+
64
+ labels = None
65
+ if dlatent_avg_label is not None:
66
+ labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size)
67
+
68
+ if verbose:
69
+ progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size))
70
+ progress.write('Gathering dlatents...', step=False)
71
+
72
+ for i in range(0, dlatent_avg_samples, dlatent_batch_size):
73
+ batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device)
74
+ batch_labels = None
75
+ if labels is not None:
76
+ batch_labels = labels[:len(batch_latents)]
77
+ with torch.no_grad():
78
+ dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu())
79
+ if verbose:
80
+ progress.step()
81
+
82
+ if verbose:
83
+ progress.write('Done!', step=False)
84
+ progress.close()
85
+
86
+ dlatents = torch.cat(dlatents, dim=0)
87
+
88
+ self.register_buffer(
89
+ '_dlatent_avg',
90
+ dlatents.mean(dim=0).view(1, 1, -1)
91
+ )
92
+ self.register_buffer(
93
+ '_dlatent_std',
94
+ torch.sqrt(
95
+ torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8
96
+ ).view(1, 1, 1)
97
+ )
98
+
99
+ if lpips_model is None:
100
+ warnings.warn(
101
+ 'Using default LPIPS distance metric based on VGG 16. ' + \
102
+ 'This metric will only work on image data where values are in ' + \
103
+ 'the range [-1, 1], please specify an lpips module if you want ' + \
104
+ 'to use other kinds of data formats.'
105
+ )
106
+ lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
107
+ lpips_size = 256
108
+ self.lpips_model = lpips_model.eval().requires_grad_(False)
109
+ self.lpips_size = lpips_size
110
+
111
+ self.to(dlatent_device)
112
+
113
+ def _scale_for_lpips(self, data):
114
+ if not self.lpips_size:
115
+ return data
116
+ scale_factor = self.lpips_size / min(data.size()[2:])
117
+ if scale_factor == 1:
118
+ return data
119
+ mode = 'nearest'
120
+ if scale_factor < 1:
121
+ mode = 'area'
122
+ return F.interpolate(data, scale_factor=scale_factor, mode=mode)
123
+
124
+ def _check_job(self):
125
+ assert self._job is not None, 'Call `start()` first to set up target.'
126
+ # device of dlatent param will not change with the rest of the models
127
+ # and buffers of this class as it was never registered as a buffer or
128
+ # parameter. Same goes for optimizer. Make sure it is on the correct device.
129
+ if self._job.dlatent_param.device != self._dlatent_avg.device:
130
+ self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg)
131
+ self._job.opt.load_state_dict(
132
+ utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0])
133
+
134
+ def generate(self):
135
+ """
136
+ Generate an output with the current dlatent and noise values.
137
+ Returns:
138
+ output (torch.Tensor)
139
+ """
140
+ self._check_job()
141
+ with torch.no_grad():
142
+ return self.G_synthesis(self._job.dlatent_param)
143
+
144
+ def get_dlatent(self):
145
+ """
146
+ Get a copy of the current dlatent values.
147
+ Returns:
148
+ dlatents (torch.Tensor)
149
+ """
150
+ self._check_job()
151
+ return self._job.dlatent_param.data.clone()
152
+
153
+ def get_noise(self):
154
+ """
155
+ Get a copy of the current noise values.
156
+ Returns:
157
+ noise_tensors (list)
158
+ """
159
+ self._check_job()
160
+ return [noise.data.clone() for noise in self._job.noise_params]
161
+
162
+ def start(self,
163
+ target,
164
+ num_steps=1000,
165
+ initial_learning_rate=0.1,
166
+ initial_noise_factor=0.05,
167
+ lr_rampdown_length=0.25,
168
+ lr_rampup_length=0.05,
169
+ noise_ramp_length=0.75,
170
+ regularize_noise_weight=1e5,
171
+ verbose=True,
172
+ verbose_prefix=''):
173
+ """
174
+ Set up a target and its projection parameters.
175
+ Arguments:
176
+ target (torch.Tensor): The data target. This should
177
+ already be preprocessed (scaled to correct value range).
178
+ num_steps (int): Number of optimization steps. Default
179
+ value is 1000.
180
+ initial_learning_rate (float): Default value is 0.1.
181
+ initial_noise_factor (float): Default value is 0.05.
182
+ lr_rampdown_length (float): Default value is 0.25.
183
+ lr_rampup_length (float): Default value is 0.05.
184
+ noise_ramp_length (float): Default value is 0.75.
185
+ regularize_noise_weight (float): Default value is 1e5.
186
+ verbose (bool): Write progress to stdout every time
187
+ `step()` is called.
188
+ verbose_prefix (str, optional): This is written before
189
+ any other output to stdout.
190
+ """
191
+ if target.dim() == self.G_synthesis.dim + 1:
192
+ target = target.unsqueeze(0)
193
+ assert target.dim() == self.G_synthesis.dim + 2, \
194
+ 'Number of dimensions of target data is incorrect.'
195
+
196
+ target = target.to(self._dlatent_avg)
197
+ target_scaled = self._scale_for_lpips(target)
198
+
199
+ dlatent_param = nn.Parameter(
200
+ self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1))
201
+ noise_params = self.G_synthesis.static_noise(trainable=True)
202
+ params = [dlatent_param] + noise_params
203
+
204
+ opt = torch.optim.Adam(params)
205
+
206
+ noise_tensor = torch.empty_like(dlatent_param)
207
+
208
+ if verbose:
209
+ progress = utils.ProgressWriter(num_steps)
210
+ value_tracker = utils.ValueTracker()
211
+
212
+ self._job = utils.AttributeDict(**locals())
213
+ self._job.current_step = 0
214
+
215
+ def step(self, steps=1):
216
+ """
217
+ Take a projection step.
218
+ Arguments:
219
+ steps (int): Number of steps to take. If this
220
+ exceeds the remaining steps of the projection
221
+ that amount of steps is taken instead. Default
222
+ value is 1.
223
+ """
224
+ self._check_job()
225
+
226
+ remaining_steps = self._job.num_steps - self._job.current_step
227
+ if not remaining_steps > 0:
228
+ warnings.warn(
229
+ 'Trying to take a projection step after the ' + \
230
+ 'final projection iteration has been completed.'
231
+ )
232
+ if steps < 0:
233
+ steps = remaining_steps
234
+ steps = min(remaining_steps, steps)
235
+
236
+ if not steps > 0:
237
+ return
238
+
239
+ for _ in range(steps):
240
+
241
+ if self._job.current_step >= self._job.num_steps:
242
+ break
243
+
244
+ # Hyperparameters.
245
+ t = self._job.current_step / self._job.num_steps
246
+ noise_strength = self._dlatent_std * self._job.initial_noise_factor \
247
+ * max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2
248
+ lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length)
249
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
250
+ lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length)
251
+ learning_rate = self._job.initial_learning_rate * lr_ramp
252
+
253
+ for param_group in self._job.opt.param_groups:
254
+ param_group['lr'] = learning_rate
255
+
256
+ dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_()
257
+
258
+ output = self.G_synthesis(dlatents)
259
+ assert output.size() == self._job.target.size(), \
260
+ 'target size {} does not fit output size {} of generator'.format(
261
+ target.size(), output.size())
262
+
263
+ output_scaled = self._scale_for_lpips(output)
264
+
265
+ # Main loss: LPIPS distance of output and target
266
+ lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled))
267
+
268
+ # Calculate noise regularization loss
269
+ reg_loss = 0
270
+ for p in self._job.noise_params:
271
+ size = min(p.size()[2:])
272
+ dim = p.dim() - 2
273
+ while True:
274
+ reg_loss += torch.mean(
275
+ (p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2)
276
+ if size <= 8:
277
+ break
278
+ p = F.interpolate(p, scale_factor=0.5, mode='area')
279
+ size = size // 2
280
+
281
+ # Combine loss, backward and update params
282
+ loss = lpips_distance + self._job.regularize_noise_weight * reg_loss
283
+ self._job.opt.zero_grad()
284
+ loss.backward()
285
+ self._job.opt.step()
286
+
287
+ # Normalize noise values
288
+ for p in self._job.noise_params:
289
+ with torch.no_grad():
290
+ p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True)
291
+ p_rstd = torch.rsqrt(
292
+ torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8)
293
+ p.data = (p.data - p_mean) * p_rstd
294
+
295
+ self._job.current_step += 1
296
+
297
+ if self._job.verbose:
298
+ self._job.value_tracker.add('loss', float(loss))
299
+ self._job.value_tracker.add('lpips_distance', float(lpips_distance))
300
+ self._job.value_tracker.add('noise_reg', float(reg_loss))
301
+ self._job.value_tracker.add('lr', learning_rate, beta=0)
302
+ self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker))
303
+ if self._job.current_step >= self._job.num_steps:
304
+ self._job.progress.close()
stylegan2/train.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import functools
3
+ import os
4
+ import time
5
+ import sys
6
+ import json
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.tensorboard
10
+ from torch import nn
11
+ import torchvision
12
+ try:
13
+ import apex
14
+ from apex import amp
15
+ except ImportError:
16
+ pass
17
+
18
+ from . import models, utils, loss_fns
19
+
20
+
21
+ class Trainer:
22
+ """
23
+ Class that handles training and logging for stylegan2.
24
+ For distributed training, the arguments `rank`, `world_size`,
25
+ `master_addr`, `master_port` can all be given as environmnet variables
26
+ (only difference is that the keys should be capital cased).
27
+ Environment variables if available will override any python
28
+ value for the same argument.
29
+ Arguments:
30
+ G (Generator): The generator model.
31
+ D (Discriminator): The discriminator model.
32
+ latent_size (int): The size of the latent inputs.
33
+ dataset (indexable object): The dataset. Has to implement
34
+ '__getitem__' and '__len__'. If `label_size` > 0, this
35
+ dataset object has to return both a data entry and its
36
+ label when calling '__getitem__'.
37
+ device (str, int, list, torch.device): The device to run training on.
38
+ Can be a list of integers for parallel training in the same
39
+ process. Parallel training can also be achieved by spawning
40
+ seperate processes and using the `rank` argument for each
41
+ process. In that case, only one device should be specified
42
+ per process.
43
+ Gs (Generator, optional): A generator copy with the current
44
+ moving average of the training generator. If not specified,
45
+ a copy of the generator is made for the moving average of
46
+ weights.
47
+ Gs_beta (float): The beta value for the moving average weights.
48
+ Default value is 1 / (2 ^(32 / 10000)).
49
+ Gs_device (str, int, torch.device, optional): The device to store
50
+ the moving average weights on. If using a different device
51
+ than what is specified for the `device` argument, updating
52
+ the moving average weights will take longer as the data
53
+ will have to be transfered over different devices. If
54
+ this argument is not specified, the same device is used
55
+ as specified in the `device` argument.
56
+ batch_size (int): The total batch size to average gradients
57
+ over. This should be the combined batch size of all used
58
+ devices (it is later divided by world size for distributed
59
+ training).
60
+ Example: We want to average gradients over 32 data
61
+ entries. To do this we just set `batch_size=32`.
62
+ Even if we train on 8 GPUs we still use the same
63
+ batch size (each GPU will take 4 data entries per
64
+ batch).
65
+ Default value is 32.
66
+ device_batch_size (int): The number of data entries that can
67
+ fit on the specified device at a time.
68
+ Example: We want to average gradients over 32 data
69
+ entries. To do this we just set `batch_size=32`.
70
+ However, our device can only handle a batch of
71
+ 4 at a time before running out of memory. We
72
+ therefor set `device_batch_size=4`. With a
73
+ single device (no distributed training), each
74
+ batch is split into 32 / 4 parts and gradients
75
+ are averaged over all these parts.
76
+ Default value is 4.
77
+ label_size (int, optional): Number of possible class labels.
78
+ This is required for conditioning the GAN with labels.
79
+ If not specified it is assumed that no labels are used.
80
+ data_workers (int): The number of spawned processes that
81
+ handle data loading. Default value is 4.
82
+ G_loss (str, callable): The loss function to use
83
+ for the generator. If string, it can be one of the
84
+ following: 'logistic', 'logistic_ns' or 'wgan'.
85
+ If not a string, the callable has to follow
86
+ the format of functions found in `stylegan2.loss`.
87
+ Default value is 'logistic_ns' (non-saturating logistic).
88
+ D_loss (str, callable): The loss function to use
89
+ for the discriminator. If string, it can be one of the
90
+ following: 'logistic' or 'wgan'.
91
+ If not a string, same restriction follows as for `G_loss`.
92
+ Default value is 'logistic'.
93
+ G_reg (str, callable, None): The regularizer function to use
94
+ for the generator. If string, it can only be 'pathreg'
95
+ (pathlength regularization). A weight for the regularizer
96
+ can be passed after the string name like the following:
97
+ G_reg='pathreg:5'
98
+ This will assign a weight of 5 to the regularization loss.
99
+ If set to None, no geenerator regularization is performed.
100
+ Default value is 'pathreg:2'.
101
+ G_reg_interval (int): The interval at which to regularize the
102
+ generator. If set to 0, the regularization and loss gradients
103
+ are combined in a single optimization step every iteration.
104
+ If set to 1, the gradients for the regularization and loss
105
+ are used separately for two optimization steps. Any value
106
+ higher than 1 indicates that regularization should only
107
+ be performed at this interval (lazy regularization).
108
+ Default value is 4.
109
+ G_opt_class (str, class): The optimizer class for the generator.
110
+ Default value is 'Adam'.
111
+ G_opt_kwargs (dict): Keyword arguments for the generator optimizer
112
+ constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
113
+ G_reg_batch_size (int): Same as `batch_size` but only for
114
+ the regularization loss of the generator. Default value
115
+ is 16.
116
+ G_reg_device_batch_size (int): Same as `device_batch_size`
117
+ but only for the regularization loss of the generator.
118
+ Default value is 2.
119
+ D_reg (str, callable, None): The regularizer function to use
120
+ for the discriminator. If string, the following values
121
+ can be used: 'r1', 'r2', 'gp'. See doc for `G_reg` for
122
+ rest of info on regularizer format.
123
+ Default value is 'r1:10'.
124
+ D_reg_interval (int): Same as `D_reg_interval` but for the
125
+ discriminator. Default value is 16.
126
+ D_opt_class (str, class): The optimizer class for the discriminator.
127
+ Default value is 'Adam'.
128
+ D_opt_kwargs (dict): Keyword arguments for the discriminator optimizer
129
+ constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
130
+ style_mix_prob (float): The probability of passing 2 latents instead of 1
131
+ to the generator during training. Default value is 0.9.
132
+ G_iter (int): Number of generator iterations for every full training
133
+ iteration. Default value is 1.
134
+ D_iter (int): Number of discriminator iterations for every full training
135
+ iteration. Default value is 1.
136
+ pl_avg (float, torch.Tensor): The average pathlength starting value for
137
+ pathlength regularization of the generator. Default value is 0.
138
+ tensorboard_log_dir (str, optional): A path to a directory to log training values
139
+ in for tensorboard. Only used without distributed training or when
140
+ distributed training is enabled and the rank of this trainer is 0.
141
+ checkpoint_dir (str, optional): A path to a directory to save training
142
+ checkpoints to. If not specified, not checkpoints are automatically
143
+ saved during training.
144
+ checkpoint_interval (int): The interval at which to save training checkpoints.
145
+ Default value is 10000.
146
+ seen (int): The number of previously trained iterations. Used for logging.
147
+ Default value is 0.
148
+ half (bool): Use mixed precision training. Default value is False.
149
+ rank (int, optional): If set, use distributed training. Expects that
150
+ this object has been constructed with the same arguments except
151
+ for `rank` in different processes.
152
+ world_size (int, optional): If using distributed training, this specifies
153
+ the number of nodes in the training.
154
+ master_addr (str): The master address for distributed training.
155
+ Default value is '127.0.0.1'.
156
+ master_port (str): The master port for distributed training.
157
+ Default value is '23456'.
158
+ """
159
+
160
+ def __init__(self,
161
+ G,
162
+ D,
163
+ latent_size,
164
+ dataset,
165
+ device,
166
+ Gs=None,
167
+ Gs_beta=0.5 ** (32 / 10000),
168
+ Gs_device=None,
169
+ batch_size=32,
170
+ device_batch_size=4,
171
+ label_size=0,
172
+ data_workers=4,
173
+ G_loss='logistic_ns',
174
+ D_loss='logistic',
175
+ G_reg='pathreg:2',
176
+ G_reg_interval=4,
177
+ G_opt_class='Adam',
178
+ G_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
179
+ G_reg_batch_size=None,
180
+ G_reg_device_batch_size=None,
181
+ D_reg='r1:10',
182
+ D_reg_interval=16,
183
+ D_opt_class='Adam',
184
+ D_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
185
+ style_mix_prob=0.9,
186
+ G_iter=1,
187
+ D_iter=1,
188
+ pl_avg=0.,
189
+ tensorboard_log_dir=None,
190
+ checkpoint_dir=None,
191
+ checkpoint_interval=10000,
192
+ seen=0,
193
+ half=False,
194
+ rank=None,
195
+ world_size=None,
196
+ master_addr='127.0.0.1',
197
+ master_port='23456'):
198
+ assert not isinstance(G, nn.parallel.DistributedDataParallel) and \
199
+ not isinstance(D, nn.parallel.DistributedDataParallel), \
200
+ 'Encountered a model wrapped in `DistributedDataParallel`. ' + \
201
+ 'Distributed parallelism is handled by this class and can ' + \
202
+ 'not be initialized before.'
203
+ # We store the training settings in a dict that can be saved as a json file.
204
+ kwargs = locals()
205
+ # First we remove the arguments that can not be turned into json.
206
+ kwargs.pop('self')
207
+ kwargs.pop('G')
208
+ kwargs.pop('D')
209
+ kwargs.pop('Gs')
210
+ kwargs.pop('dataset')
211
+ # Some arguments may have to be turned into strings to be compatible with json.
212
+ kwargs.update(pl_avg=float(pl_avg))
213
+ if isinstance(device, torch.device):
214
+ kwargs.update(device=str(device))
215
+ if isinstance(Gs_device, torch.device):
216
+ kwargs.update(device=str(Gs_device))
217
+ self.kwargs = kwargs
218
+
219
+ if device or device == 0:
220
+ if isinstance(device, (tuple, list)):
221
+ self.device = torch.device(device[0])
222
+ else:
223
+ self.device = torch.device(device)
224
+ else:
225
+ self.device = torch.device('cpu')
226
+ if self.device.index is not None:
227
+ torch.cuda.set_device(self.device.index)
228
+ else:
229
+ assert not half, 'Mixed precision training only available ' + \
230
+ 'for CUDA devices.'
231
+ # Set up the models
232
+ self.G = G.train().to(self.device)
233
+ self.D = D.train().to(self.device)
234
+ if isinstance(device, (tuple, list)) and len(device) > 1:
235
+ assert all(isinstance(dev, int) for dev in device), \
236
+ 'Multiple devices have to be specified as a list ' + \
237
+ 'or tuple of integers corresponding to device indices.'
238
+ # TODO: Look into bug with torch.autograd.grad and nn.DataParallel
239
+ # In the meanwhile just prohibit its use together.
240
+ assert G_reg is None and D_reg is None, 'Regularization ' + \
241
+ 'currently not supported for multi-gpu training in single process. ' + \
242
+ 'Please use distributed training with one device per process instead.'
243
+
244
+ device_batch_size *= len(device)
245
+ def to_data_parallel(model):
246
+ if not isinstance(model, nn.DataParallel):
247
+ return nn.DataParallel(model, device_ids=device)
248
+ return model
249
+ self.G = to_data_parallel(self.G)
250
+ self.D = to_data_parallel(self.D)
251
+
252
+ # Default generator reg batch size is the global batch size
253
+ # unless it has been specified otherwise.
254
+ G_reg_batch_size = G_reg_batch_size or batch_size
255
+ G_reg_device_batch_size = G_reg_device_batch_size or device_batch_size
256
+
257
+ # Set up distributed training
258
+ rank = os.environ.get('RANK', rank)
259
+ if rank is not None:
260
+ rank = int(rank)
261
+ addr = os.environ.get('MASTER_ADDR', master_addr)
262
+ port = os.environ.get('MASTER_PORT', master_port)
263
+ world_size = os.environ.get('WORLD_SIZE', world_size)
264
+ assert world_size is not None, 'Distributed training ' + \
265
+ 'requires specifying world size.'
266
+ world_size = int(world_size)
267
+ assert self.device.index is not None, \
268
+ 'Distributed training is only supported for CUDA.'
269
+ assert batch_size % world_size == 0, 'Batch size has to be ' + \
270
+ 'evenly divisible by world size.'
271
+ assert G_reg_batch_size % world_size == 0, 'G reg batch size has to be ' + \
272
+ 'evenly divisible by world size.'
273
+ batch_size = batch_size // world_size
274
+ G_reg_batch_size = G_reg_batch_size // world_size
275
+ init_method = 'tcp://{}:{}'.format(addr, port)
276
+ torch.distributed.init_process_group(
277
+ backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
278
+ else:
279
+ world_size = 1
280
+ self.rank = rank
281
+ self.world_size = world_size
282
+
283
+ # Set up variable to keep track of moving average of path lengths
284
+ self.pl_avg = torch.tensor(
285
+ pl_avg, dtype=torch.float16 if half else torch.float32, device=self.device)
286
+
287
+ # Broadcast parameters from rank 0 if running distributed
288
+ self._sync_distributed(G=self.G, D=self.D, broadcast_weights=True)
289
+
290
+ # Set up moving average of generator
291
+ # Only for non-distributed training or
292
+ # if rank is 0
293
+ if not self.rank:
294
+ # Values for `rank`: None -> not distributed, 0 -> distributed and 'main' node
295
+ self.Gs = Gs
296
+ if not isinstance(Gs, utils.MovingAverageModule):
297
+ self.Gs = utils.MovingAverageModule(
298
+ from_module=self.G,
299
+ to_module=Gs,
300
+ param_beta=Gs_beta,
301
+ device=self.device if Gs_device is None else Gs_device
302
+ )
303
+ else:
304
+ self.Gs = None
305
+
306
+ # Set up loss and regularization functions
307
+ self.G_loss = get_loss_fn('G', G_loss)
308
+ self.D_loss = get_loss_fn('D', D_loss)
309
+ self.G_reg = get_reg_fn('G', G_reg, pl_avg=self.pl_avg)
310
+ self.D_reg = get_reg_fn('D', D_reg)
311
+ self.G_reg_interval = G_reg_interval
312
+ self.D_reg_interval = D_reg_interval
313
+ self.G_iter = G_iter
314
+ self.D_iter = D_iter
315
+
316
+ # Set up optimizers (adjust hyperparameters if lazy regularization is active)
317
+ self.G_opt = build_opt(self.G, G_opt_class, G_opt_kwargs, self.G_reg, self.G_reg_interval)
318
+ self.D_opt = build_opt(self.D, D_opt_class, D_opt_kwargs, self.D_reg, self.D_reg_interval)
319
+
320
+ # Set up mixed precision training
321
+ if half:
322
+ assert 'apex' in sys.modules, 'Can not run mixed precision ' + \
323
+ 'training (`half=True`) without the apex module.'
324
+ (self.G, self.D), (self.G_opt, self.D_opt) = amp.initialize(
325
+ [self.G, self.D], [self.G_opt, self.D_opt], opt_level='O1')
326
+ self.half = half
327
+
328
+ # Data
329
+ sampler = None
330
+ if self.rank is not None:
331
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
332
+ self.dataloader = torch.utils.data.DataLoader(
333
+ dataset,
334
+ batch_size=device_batch_size,
335
+ num_workers=data_workers,
336
+ shuffle=sampler is None,
337
+ pin_memory=self.device.index is not None,
338
+ drop_last=True,
339
+ sampler=sampler
340
+ )
341
+ self.dataloader_iter = None
342
+ self.prior_generator = utils.PriorGenerator(
343
+ latent_size=latent_size,
344
+ label_size=label_size,
345
+ batch_size=device_batch_size,
346
+ device=self.device
347
+ )
348
+ assert batch_size % device_batch_size == 0, \
349
+ 'Batch size has to be evenly divisible by the product of ' + \
350
+ 'device batch size and world size.'
351
+ self.subdivisions = batch_size // device_batch_size
352
+ assert G_reg_batch_size % G_reg_device_batch_size == 0, \
353
+ 'G reg batch size has to be evenly divisible by the product of ' + \
354
+ 'G reg device batch size and world size.'
355
+ self.G_reg_subdivisions = G_reg_batch_size // G_reg_device_batch_size
356
+ self.G_reg_device_batch_size = G_reg_device_batch_size
357
+
358
+ self.tb_writer = None
359
+ if tensorboard_log_dir and not self.rank:
360
+ self.tb_writer = torch.utils.tensorboard.SummaryWriter(tensorboard_log_dir)
361
+
362
+ self.label_size = label_size
363
+ self.style_mix_prob = style_mix_prob
364
+ self.checkpoint_dir = checkpoint_dir
365
+ self.checkpoint_interval = checkpoint_interval
366
+ self.seen = seen
367
+ self.metrics = {}
368
+ self.callbacks = []
369
+
370
+ def _get_batch(self):
371
+ """
372
+ Fetch a batch and its labels. If no labels are
373
+ available the returned labels will be `None`.
374
+ Returns:
375
+ data
376
+ labels
377
+ """
378
+ if self.dataloader_iter is None:
379
+ self.dataloader_iter = iter(self.dataloader)
380
+ try:
381
+ batch = next(self.dataloader_iter)
382
+ except StopIteration:
383
+ self.dataloader_iter = None
384
+ return self._get_batch()
385
+ if isinstance(batch, (tuple, list)):
386
+ if len(batch) > 1:
387
+ data, label = batch[:2]
388
+ else:
389
+ data, label = batch[0], None
390
+ else:
391
+ data, label = batch, None
392
+ if not self.label_size:
393
+ label = None
394
+ if torch.is_tensor(data):
395
+ data = data.to(self.device)
396
+ if torch.is_tensor(label):
397
+ label = label.to(self.device)
398
+ return data, label
399
+
400
+ def _sync_distributed(self, G=None, D=None, broadcast_weights=False):
401
+ """
402
+ Sync the gradients (and alternatively the weights) of
403
+ the specified networks over the distributed training
404
+ nodes. Varying buffers are broadcasted from rank 0.
405
+ If no distributed training is not enabled, no action
406
+ is taken and this is a no-op function.
407
+ Arguments:
408
+ G (Generator, optional)
409
+ D (Discriminator, optional)
410
+ broadcast_weights (bool): Broadcast the weights from
411
+ node of rank 0 to all other ranks. Default
412
+ value is False.
413
+ """
414
+ if self.rank is None:
415
+ return
416
+ for net in [G, D]:
417
+ if net is None:
418
+ continue
419
+ for p in net.parameters():
420
+ if p.grad is not None:
421
+ torch.distributed.all_reduce(p.grad, async_op=True)
422
+ if broadcast_weights:
423
+ torch.distributed.broadcast(p.data, src=0, async_op=True)
424
+ if G is not None:
425
+ if G.dlatent_avg is not None:
426
+ torch.distributed.broadcast(G.dlatent_avg, src=0, async_op=True)
427
+ if self.pl_avg is not None:
428
+ torch.distributed.broadcast(self.pl_avg, src=0, async_op=True)
429
+ if G is not None or D is not None:
430
+ torch.distributed.barrier(async_op=False)
431
+
432
+ def _backward(self, loss, opt, mul=1, subdivisions=None):
433
+ """
434
+ Reduce loss by world size and subdivisions before
435
+ calling backward for the loss. Loss scaling is
436
+ performed when mixed precision training is
437
+ enabled.
438
+ Arguments:
439
+ loss (torch.Tensor)
440
+ opt (torch.optim.Optimizer)
441
+ mul (float): Loss weight. Default value is 1.
442
+ subdivisions (int, optional): The number of
443
+ subdivisions to divide by. If this is
444
+ not specified, the subdvisions from
445
+ the specified batch and device size
446
+ at construction is used.
447
+ Returns:
448
+ loss (torch.Tensor): The loss scaled by mul
449
+ and subdivisions but not by world size.
450
+ """
451
+ if loss is None:
452
+ return 0
453
+ mul /= subdivisions or self.subdivisions
454
+ mul /= self.world_size or 1
455
+ if mul != 1:
456
+ loss *= mul
457
+ if self.half:
458
+ with amp.scale_loss(loss, opt) as scaled_loss:
459
+ scaled_loss.backward()
460
+ else:
461
+ loss.backward()
462
+ #get the scalar only
463
+ return loss.item() * (self.world_size or 1)
464
+
465
+ def train(self, iterations, callbacks=None, verbose=True):
466
+ """
467
+ Train the models for a specific number of iterations.
468
+ Arguments:
469
+ iterations (int): Number of iterations to train for.
470
+ callbacks (callable, list, optional): One
471
+ or more callbacks to call at the end of each
472
+ iteration. The function is given the total
473
+ number of batches that have been processed since
474
+ this trainer object was initialized (not reset
475
+ when loading a saved checkpoint).
476
+ Default value is None (unused).
477
+ verbose (bool): Write progress to stdout.
478
+ Default value is True.
479
+ """
480
+ evaluated_metrics = {}
481
+ if self.rank:
482
+ verbose=False
483
+ if verbose:
484
+ progress = utils.ProgressWriter(iterations)
485
+ value_tracker = utils.ValueTracker()
486
+ for _ in range(iterations):
487
+ # Figure out if G and/or D be
488
+ # regularized this iteration
489
+ G_reg = self.G_reg is not None
490
+ if self.G_reg_interval and G_reg:
491
+ G_reg = self.seen % self.G_reg_interval == 0
492
+ D_reg = self.D_reg is not None
493
+ if self.D_reg_interval and D_reg:
494
+ D_reg = self.seen % self.D_reg_interval == 0
495
+
496
+ # -----| Train G |----- #
497
+
498
+ # Disable gradients for D while training G
499
+ self.D.requires_grad_(False)
500
+
501
+ for _ in range(self.G_iter):
502
+ self.G_opt.zero_grad()
503
+
504
+ G_loss = 0
505
+ for i in range(self.subdivisions):
506
+ latents, latent_labels = self.prior_generator(
507
+ multi_latent_prob=self.style_mix_prob)
508
+ loss, _ = self.G_loss(
509
+ G=self.G,
510
+ D=self.D,
511
+ latents=latents,
512
+ latent_labels=latent_labels
513
+ )
514
+ G_loss += self._backward(loss, self.G_opt)
515
+
516
+ if G_reg:
517
+ if self.G_reg_interval:
518
+ # For lazy regularization, even if the interval
519
+ # is set to 1, the optimization step is taken
520
+ # before the gradients of the regularization is gathered.
521
+ self._sync_distributed(G=self.G)
522
+ self.G_opt.step()
523
+ self.G_opt.zero_grad()
524
+ G_reg_loss = 0
525
+ # Pathreg is expensive to compute which
526
+ # is why G regularization has its own settings
527
+ # for subdivisions and batch size.
528
+ for i in range(self.G_reg_subdivisions):
529
+ latents, latent_labels = self.prior_generator(
530
+ batch_size=self.G_reg_device_batch_size,
531
+ multi_latent_prob=self.style_mix_prob
532
+ )
533
+ _, reg_loss = self.G_reg(
534
+ G=self.G,
535
+ latents=latents,
536
+ latent_labels=latent_labels
537
+ )
538
+ G_reg_loss += self._backward(
539
+ reg_loss,
540
+ self.G_opt, mul=self.G_reg_interval or 1,
541
+ subdivisions=self.G_reg_subdivisions
542
+ )
543
+ self._sync_distributed(G=self.G)
544
+ self.G_opt.step()
545
+ # Update moving average of weights after
546
+ # each G training subiteration
547
+ if self.Gs is not None:
548
+ self.Gs.update()
549
+
550
+ # Re-enable gradients for D
551
+ self.D.requires_grad_(True)
552
+
553
+ # -----| Train D |----- #
554
+
555
+ # Disable gradients for G while training D
556
+ self.G.requires_grad_(False)
557
+
558
+ for _ in range(self.D_iter):
559
+ self.D_opt.zero_grad()
560
+
561
+ D_loss = 0
562
+ for i in range(self.subdivisions):
563
+ latents, latent_labels = self.prior_generator(
564
+ multi_latent_prob=self.style_mix_prob)
565
+ reals, real_labels = self._get_batch()
566
+ loss, _ = self.D_loss(
567
+ G=self.G,
568
+ D=self.D,
569
+ latents=latents,
570
+ latent_labels=latent_labels,
571
+ reals=reals,
572
+ real_labels=real_labels
573
+ )
574
+ D_loss += self._backward(loss, self.D_opt)
575
+
576
+ if D_reg:
577
+ if self.D_reg_interval:
578
+ # For lazy regularization, even if the interval
579
+ # is set to 1, the optimization step is taken
580
+ # before the gradients of the regularization is gathered.
581
+ self._sync_distributed(D=self.D)
582
+ self.D_opt.step()
583
+ self.D_opt.zero_grad()
584
+ D_reg_loss = 0
585
+ for i in range(self.subdivisions):
586
+ latents, latent_labels = self.prior_generator(
587
+ multi_latent_prob=self.style_mix_prob)
588
+ reals, real_labels = self._get_batch()
589
+ _, reg_loss = self.D_reg(
590
+ G=self.G,
591
+ D=self.D,
592
+ latents=latents,
593
+ latent_labels=latent_labels,
594
+ reals=reals,
595
+ real_labels=real_labels
596
+ )
597
+ D_reg_loss += self._backward(
598
+ reg_loss, self.D_opt, mul=self.D_reg_interval or 1)
599
+ self._sync_distributed(D=self.D)
600
+ self.D_opt.step()
601
+
602
+ # Re-enable grads for G
603
+ self.G.requires_grad_(True)
604
+
605
+ if self.tb_writer is not None or verbose:
606
+ # In case verbose is true and tensorboard logging enabled
607
+ # we calculate grad norm here to only do it once as well
608
+ # as making sure we do it before any metrics that may
609
+ # possibly zero the grads.
610
+ G_grad_norm = utils.get_grad_norm_from_optimizer(self.G_opt)
611
+ D_grad_norm = utils.get_grad_norm_from_optimizer(self.D_opt)
612
+
613
+ for name, metric in self.metrics.items():
614
+ if not metric['interval'] or self.seen % metric['interval'] == 0:
615
+ evaluated_metrics[name] = metric['eval_fn']()
616
+
617
+ # Printing and logging
618
+
619
+ # Tensorboard logging
620
+ if self.tb_writer is not None:
621
+ self.tb_writer.add_scalar('Loss/G_loss', G_loss, self.seen)
622
+ if G_reg:
623
+ self.tb_writer.add_scalar('Loss/G_reg', G_reg_loss, self.seen)
624
+ self.tb_writer.add_scalar('Grad_norm/G_reg', G_grad_norm, self.seen)
625
+ self.tb_writer.add_scalar('Params/pl_avg', self.pl_avg, self.seen)
626
+ else:
627
+ self.tb_writer.add_scalar('Grad_norm/G_loss', G_grad_norm, self.seen)
628
+ self.tb_writer.add_scalar('Loss/D_loss', D_loss, self.seen)
629
+ if D_reg:
630
+ self.tb_writer.add_scalar('Loss/D_reg', D_reg_loss, self.seen)
631
+ self.tb_writer.add_scalar('Grad_norm/D_reg', D_grad_norm, self.seen)
632
+ else:
633
+ self.tb_writer.add_scalar('Grad_norm/D_loss', D_grad_norm, self.seen)
634
+ for name, value in evaluated_metrics.items():
635
+ self.tb_writer.add_scalar('Metrics/{}'.format(name), value, self.seen)
636
+
637
+ # Printing
638
+ if verbose:
639
+ value_tracker.add('seen', self.seen + 1, beta=0)
640
+ value_tracker.add('G_lr', self.G_opt.param_groups[0]['lr'], beta=0)
641
+ value_tracker.add('G_loss', G_loss)
642
+ if G_reg:
643
+ value_tracker.add('G_reg', G_reg_loss)
644
+ value_tracker.add('G_reg_grad_norm', G_grad_norm)
645
+ value_tracker.add('pl_avg', self.pl_avg, beta=0)
646
+ else:
647
+ value_tracker.add('G_loss_grad_norm', G_grad_norm)
648
+ value_tracker.add('D_lr', self.D_opt.param_groups[0]['lr'], beta=0)
649
+ value_tracker.add('D_loss', D_loss)
650
+ if D_reg:
651
+ value_tracker.add('D_reg', D_reg_loss)
652
+ value_tracker.add('D_reg_grad_norm', D_grad_norm)
653
+ else:
654
+ value_tracker.add('D_loss_grad_norm', D_grad_norm)
655
+ for name, value in evaluated_metrics.items():
656
+ value_tracker.add(name, value, beta=0)
657
+ progress.write(str(value_tracker))
658
+
659
+ # Callback
660
+ for callback in utils.to_list(callbacks) + self.callbacks:
661
+ callback(self.seen)
662
+
663
+ self.seen += 1
664
+
665
+ # clear cache
666
+ torch.cuda.empty_cache()
667
+ # Handle checkpointing
668
+ if not self.rank and self.checkpoint_dir and self.checkpoint_interval:
669
+ if self.seen % self.checkpoint_interval == 0:
670
+ checkpoint_path = os.path.join(
671
+ self.checkpoint_dir,
672
+ '{}_{}'.format(self.seen, time.strftime('%Y-%m-%d_%H-%M-%S'))
673
+ )
674
+ self.save_checkpoint(checkpoint_path)
675
+
676
+ if verbose:
677
+ progress.close()
678
+
679
+ def register_metric(self, name, eval_fn, interval):
680
+ """
681
+ Add a metric. This will be evaluated every `interval`
682
+ training iteration. Used by tensorboard and progress
683
+ updates written to stdout while training.
684
+ Arguments:
685
+ name (str): A name for the metric. If a metric with
686
+ this name already exists it will be overwritten.
687
+ eval_fn (callable): A function that evaluates the metric
688
+ and returns a python number.
689
+ interval (int): The interval to evaluate at.
690
+ """
691
+ self.metrics[name] = {'eval_fn': eval_fn, 'interval': interval}
692
+
693
+ def remove_metric(self, name):
694
+ """
695
+ Remove a metric that was previously registered.
696
+ Arguments:
697
+ name (str): Name of the metric.
698
+ """
699
+ if name in self.metrics:
700
+ del self.metrics[name]
701
+ else:
702
+ warnings.warn(
703
+ 'Attempting to remove metric {} '.format(name) + \
704
+ 'which does not exist.'
705
+ )
706
+
707
+ def generate_images(self,
708
+ num_images,
709
+ seed=None,
710
+ truncation_psi=None,
711
+ truncation_cutoff=None,
712
+ label=None,
713
+ pixel_min=-1,
714
+ pixel_max=1):
715
+ """
716
+ Generate some images with the generator and transform them into PIL
717
+ images and return them as a list.
718
+ Arguments:
719
+ num_images (int): Number of images to generate.
720
+ seed (int, optional): The seed for the random generation
721
+ of input latent values.
722
+ truncation_psi (float): See stylegan2.model.Generator.set_truncation()
723
+ Default value is None.
724
+ truncation_cutoff (int): See stylegan2.model.Generator.set_truncation()
725
+ label (int, list, optional): Label to condition all generated images with
726
+ or multiple labels, one for each generated image.
727
+ pixel_min (float): The min value in the pixel range of the generator.
728
+ Default value is -1.
729
+ pixel_min (float): The max value in the pixel range of the generator.
730
+ Default value is 1.
731
+ Returns:
732
+ images (list): List of PIL images.
733
+ """
734
+ if seed is None:
735
+ seed = int(10000 * time.time())
736
+ latents, latent_labels = self.prior_generator(num_images, seed=seed)
737
+ if label:
738
+ assert latent_labels is not None, 'Can not specify label when no labels ' + \
739
+ 'are used by this model.'
740
+ label = utils.to_list(label)
741
+ assert all(isinstance(l, int) for l in label), '`label` can only consist of ' + \
742
+ 'one or more python integers.'
743
+ assert len(label) == 1 or len(label) == num_images, '`label` can either ' + \
744
+ 'specify one label to use for all images or a list of labels of the ' + \
745
+ 'same length as number of images. Received {} labels '.format(len(label)) + \
746
+ 'but {} images are to be generated.'.format(num_images)
747
+ if len(label) == 1:
748
+ latent_labels.fill_(label[0])
749
+ else:
750
+ latent_labels = torch.tensor(label).to(latent_labels)
751
+ self.Gs.set_truncation(
752
+ truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
753
+ with torch.no_grad():
754
+ generated = self.Gs(latents=latents, labels=latent_labels)
755
+ assert generated.dim() - 2 == 2, 'Can only generate images when using a ' + \
756
+ 'network built for 2-dimensional data.'
757
+ assert generated.dim() == 4, 'Only generators that produce 2d data ' + \
758
+ 'can be used to generate images.'
759
+ return utils.tensor_to_PIL(generated, pixel_min=pixel_min, pixel_max=pixel_max)
760
+
761
+ def log_images_tensorboard(self, images, name, resize=256):
762
+ """
763
+ Log a list of images to tensorboard by first turning
764
+ them into a grid. Can not be performed if rank > 0
765
+ or tensorboard_log_dir was not given at construction.
766
+ Arguments:
767
+ images (list): List of PIL images.
768
+ name (str): The name to log images for.
769
+ resize (int, tuple): The height and width to use for
770
+ each image in the grid. Default value is 256.
771
+ """
772
+ assert self.tb_writer is not None, \
773
+ 'No tensorboard log dir was specified ' + \
774
+ 'when constructing this object.'
775
+ image = utils.stack_images_PIL(images, individual_img_size=resize)
776
+ image = torchvision.transforms.ToTensor()(image)
777
+ self.tb_writer.add_image(name, image, self.seen)
778
+
779
+ def add_tensorboard_image_logging(self,
780
+ name,
781
+ interval,
782
+ num_images,
783
+ resize=256,
784
+ seed=None,
785
+ truncation_psi=None,
786
+ truncation_cutoff=None,
787
+ label=None,
788
+ pixel_min=-1,
789
+ pixel_max=1):
790
+ """
791
+ Set up tensorboard logging of generated images to be performed
792
+ at a certain training interval. If distributed training is set up
793
+ and this object does not have the rank 0, no logging will be performed
794
+ by this object.
795
+ All arguments except the ones mentioned below have their description
796
+ in the docstring of `generate_images()` and `log_images_tensorboard()`.
797
+ Arguments:
798
+ interval (int): The interval at which to log generated images.
799
+ """
800
+ if self.rank:
801
+ return
802
+ def callback(seen):
803
+ if seen % interval == 0:
804
+ images = self.generate_images(
805
+ num_images=num_images,
806
+ seed=seed,
807
+ truncation_psi=truncation_psi,
808
+ truncation_cutoff=truncation_cutoff,
809
+ label=label,
810
+ pixel_min=pixel_min,
811
+ pixel_max=pixel_max
812
+ )
813
+ self.log_images_tensorboard(
814
+ images=images,
815
+ name=name,
816
+ resize=resize
817
+ )
818
+ self.callbacks.append(callback)
819
+
820
+ def save_checkpoint(self, dir_path):
821
+ """
822
+ Save the current state of this trainer as a checkpoint.
823
+ NOTE: The dataset can not be serialized and saved so this
824
+ has to be reconstructed and given when loading this checkpoint.
825
+ Arguments:
826
+ dir_path (str): The checkpoint path.
827
+ """
828
+ if not os.path.exists(dir_path):
829
+ os.makedirs(dir_path)
830
+ else:
831
+ assert os.path.isdir(dir_path), '`dir_path` points to a file.'
832
+ kwargs = self.kwargs.copy()
833
+ # Update arguments that may have changed since construction
834
+ kwargs.update(
835
+ seen=self.seen,
836
+ pl_avg=float(self.pl_avg)
837
+ )
838
+ with open(os.path.join(dir_path, 'kwargs.json'), 'w') as fp:
839
+ json.dump(kwargs, fp)
840
+ torch.save(self.G_opt.state_dict(), os.path.join(dir_path, 'G_opt.pth'))
841
+ torch.save(self.D_opt.state_dict(), os.path.join(dir_path, 'D_opt.pth'))
842
+ models.save(self.G, os.path.join(dir_path, 'G.pth'))
843
+ models.save(self.D, os.path.join(dir_path, 'D.pth'))
844
+ if self.Gs is not None:
845
+ models.save(self.Gs, os.path.join(dir_path, 'Gs.pth'))
846
+
847
+ @classmethod
848
+ def load_checkpoint(cls, checkpoint_path, dataset, **kwargs):
849
+ """
850
+ Load a checkpoint into a new Trainer object and return that
851
+ object. If the path specified points at a folder containing
852
+ multiple checkpoints, the latest one will be used.
853
+ The dataset can not be serialized and saved so it is required
854
+ to be explicitly given when loading a checkpoint.
855
+ Arguments:
856
+ checkpoint_path (str): Path to a checkpoint or to a folder
857
+ containing one or more checkpoints.
858
+ dataset (indexable): The dataset to use.
859
+ **kwargs (keyword arguments): Any other arguments to override
860
+ the ones saved in the checkpoint. Useful for when training
861
+ is continued on a different device or when distributed training
862
+ is changed.
863
+ """
864
+ checkpoint_path = _find_checkpoint(checkpoint_path)
865
+ _is_checkpoint(checkpoint_path, enforce=True)
866
+ with open(os.path.join(checkpoint_path, 'kwargs.json'), 'r') as fp:
867
+ loaded_kwargs = json.load(fp)
868
+ loaded_kwargs.update(**kwargs)
869
+ device = torch.device('cpu')
870
+ if isinstance(loaded_kwargs['device'], (list, tuple)):
871
+ device = torch.device(loaded_kwargs['device'][0])
872
+ for name in ['G', 'D']:
873
+ fpath = os.path.join(checkpoint_path, name + '.pth')
874
+ loaded_kwargs[name] = models.load(fpath, map_location=device)
875
+ if os.path.exists(os.path.join(checkpoint_path, 'Gs.pth')):
876
+ loaded_kwargs['Gs'] = models.load(
877
+ os.path.join(checkpoint_path, 'Gs.pth'),
878
+ map_location=device if loaded_kwargs['Gs_device'] is None \
879
+ else torch.device(loaded_kwargs['Gs_device'])
880
+ )
881
+ obj = cls(dataset=dataset, **loaded_kwargs)
882
+ for name in ['G_opt', 'D_opt']:
883
+ fpath = os.path.join(checkpoint_path, name + '.pth')
884
+ state_dict = torch.load(fpath, map_location=device)
885
+ getattr(obj, name).load_state_dict(state_dict)
886
+ return obj
887
+
888
+
889
+ #----------------------------------------------------------------------------
890
+ # Checkpoint helper functions
891
+
892
+
893
+ def _is_checkpoint(dir_path, enforce=False):
894
+ if not dir_path:
895
+ if enforce:
896
+ raise ValueError('Not a checkpoint.')
897
+ return False
898
+ if not os.path.exists(dir_path):
899
+ if enforce:
900
+ raise FileNotFoundError('{} could not be found.'.format(dir_path))
901
+ return False
902
+ if not os.path.isdir(dir_path):
903
+ if enforce:
904
+ raise NotADirectoryError('{} is not a directory.'.format(dir_path))
905
+ return False
906
+ fnames = os.listdir(dir_path)
907
+ for fname in ['G.pth', 'D.pth', 'G_opt.pth', 'D_opt.pth', 'kwargs.json']:
908
+ if fname not in fnames:
909
+ if enforce:
910
+ raise FileNotFoundError(
911
+ 'Could not find {} in {}.'.format(fname, dir_path))
912
+ return False
913
+ return True
914
+
915
+
916
+ def _find_checkpoint(dir_path):
917
+ if not dir_path:
918
+ return None
919
+ if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
920
+ return None
921
+ if _is_checkpoint(dir_path):
922
+ return dir_path
923
+ checkpoint_names = []
924
+ for name in os.listdir(dir_path):
925
+ if _is_checkpoint(os.path.join(dir_path, name)):
926
+ checkpoint_names.append(name)
927
+ if not checkpoint_names:
928
+ return None
929
+ def get_iteration(name):
930
+ return int(name.split('_')[0])
931
+ def get_timestamp(name):
932
+ return '_'.join(name.split('_')[1:])
933
+ # Python sort is stable, meaning that this sort operation
934
+ # will guarantee that the order of values after the first
935
+ # sort will stay for a set of values that have the same
936
+ # key value.
937
+ checkpoint_names = sorted(
938
+ sorted(checkpoint_names, key=get_iteration), key=get_timestamp)
939
+ return os.path.join(dir_path, checkpoint_names[-1])
940
+
941
+
942
+ #----------------------------------------------------------------------------
943
+ # Reg and loss function fetchers
944
+
945
+
946
+ def build_opt(net, opt_class, opt_kwargs, reg, reg_interval):
947
+ opt_kwargs['lr'] = opt_kwargs.get('lr', 1e-3)
948
+ if reg not in [None, False] and reg_interval:
949
+ mb_ratio = reg_interval / (reg_interval + 1.)
950
+ opt_kwargs['lr'] *= mb_ratio
951
+ if 'momentum' in opt_kwargs:
952
+ opt_kwargs['momentum'] = opt_kwargs['momentum'] ** mb_ratio
953
+ if 'betas' in opt_kwargs:
954
+ betas = opt_kwargs['betas']
955
+ opt_kwargs['betas'] = (betas[0] ** mb_ratio, betas[1] ** mb_ratio)
956
+ if isinstance(opt_class, str):
957
+ opt_class = getattr(torch.optim, opt_class.title())
958
+ return opt_class(net.parameters(), **opt_kwargs)
959
+
960
+
961
+ #----------------------------------------------------------------------------
962
+ # Reg and loss function fetchers
963
+
964
+
965
+ _LOSS_FNS = {
966
+ 'G': {
967
+ 'logistic': loss_fns.G_logistic,
968
+ 'logistic_ns': loss_fns.G_logistic_ns,
969
+ 'wgan': loss_fns.G_wgan
970
+ },
971
+ 'D': {
972
+ 'logistic': loss_fns.D_logistic,
973
+ 'wgan': loss_fns.D_wgan
974
+ }
975
+ }
976
+ def get_loss_fn(net, loss):
977
+ if callable(loss):
978
+ return loss
979
+ net = net.upper()
980
+ assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
981
+ loss = loss.lower()
982
+ for name in _LOSS_FNS[net].keys():
983
+ if loss == name:
984
+ return _LOSS_FNS[net][name]
985
+ raise ValueError('Unknow {} loss {}'.format(net, loss))
986
+
987
+
988
+ _REG_FNS = {
989
+ 'G': {
990
+ 'pathreg': loss_fns.G_pathreg
991
+ },
992
+ 'D': {
993
+ 'r1': loss_fns.D_r1,
994
+ 'r2': loss_fns.D_r2,
995
+ 'gp': loss_fns.D_gp,
996
+ }
997
+ }
998
+ def get_reg_fn(net, reg, **kwargs):
999
+ if reg is None:
1000
+ return None
1001
+ if callable(reg):
1002
+ functools.partial(reg, **kwargs)
1003
+ net = net.upper()
1004
+ assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
1005
+ reg = reg.lower()
1006
+ gamma = None
1007
+ for name in _REG_FNS[net].keys():
1008
+ if reg.startswith(name):
1009
+ gamma_chars = [c for c in reg.replace(name, '') if c.isdigit() or c == '.']
1010
+ if gamma_chars:
1011
+ kwargs.update(gamma=float(''.join(gamma_chars)))
1012
+ return functools.partial(_REG_FNS[net][name], **kwargs)
1013
+ raise ValueError('Unknow regularizer {}'.format(reg))
stylegan2/utils.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numbers
3
+ import re
4
+ import sys
5
+ import collections
6
+ import argparse
7
+ import yaml
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ import torchvision
14
+ try:
15
+ import tqdm
16
+ except ImportError:
17
+ pass
18
+ try:
19
+ from IPython.display import display as notebook_display
20
+ from IPython.display import clear_output as notebook_clear
21
+ except ImportError:
22
+ pass
23
+
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Miscellaneous utils
27
+
28
+
29
+ class AttributeDict(dict):
30
+ """
31
+ Dict where values can be accessed using attribute syntax.
32
+ Same as "EasyDict" in the NVIDIA stylegan git repository.
33
+ """
34
+
35
+ def __getattr__(self, name):
36
+ try:
37
+ return self[name]
38
+ except KeyError:
39
+ raise AttributeError(name)
40
+
41
+ def __setattr__(self, name, value):
42
+ self[name] = value
43
+
44
+ def __delattr__(self, name):
45
+ del self[name]
46
+
47
+ def __getstate__(self):
48
+ return dict(**self)
49
+
50
+ def __setstate__(self, state):
51
+ self.update(**state)
52
+
53
+ def __repr__(self):
54
+ return '{}({})'.format(
55
+ self.__class__.__name__,
56
+ ', '.join('{}={}'.format(key, value) for key, value in self.items())
57
+ )
58
+
59
+ @classmethod
60
+ def convert_dict_recursive(cls, obj):
61
+ if isinstance(obj, dict):
62
+ for key in list(obj.keys()):
63
+ obj[key] = cls.convert_dict_recursive(obj[key])
64
+ if not isinstance(obj, cls):
65
+ return cls(**obj)
66
+ return obj
67
+
68
+
69
+ class Timer:
70
+
71
+ def __init__(self):
72
+ self.reset()
73
+
74
+ def __enter__(self):
75
+ self._t0 = time.time()
76
+
77
+ def __exit__(self, *args):
78
+ self._t += time.time() - self._t0
79
+
80
+ def value(self):
81
+ return self._t
82
+
83
+ def reset(self):
84
+ self._t = 0
85
+
86
+ def __str__(self):
87
+ """
88
+ Get a string representation of the recorded time.
89
+ Returns:
90
+ time_as_string (str)
91
+ """
92
+ value = self.value()
93
+ if not value or value >= 100:
94
+ return '{} s'.format(int(value))
95
+ elif value >= 1:
96
+ return '{:.3g} s'.format(value)
97
+ elif value >= 1e-3:
98
+ return '{:.3g} ms'.format(value * 1e+3)
99
+ elif value >= 1e-6:
100
+ return '{:.3g} us'.format(value * 1e+6)
101
+ elif value >= 1e-9:
102
+ return '{:.3g} ns'.format(value * 1e+9)
103
+ else:
104
+ return '{:.2E} s'.format(value)
105
+
106
+
107
+ def to_list(values):
108
+ if values is None:
109
+ return []
110
+ if isinstance(values, tuple):
111
+ return list(values)
112
+ if not isinstance(values, list):
113
+ return [values]
114
+ return values
115
+
116
+
117
+ def lerp(a, b, beta):
118
+ if isinstance(beta, numbers.Number):
119
+ if beta == 1:
120
+ return b
121
+ elif beta == 0:
122
+ return a
123
+ if torch.is_tensor(a) and a.dtype == torch.float32:
124
+ # torch lerp only available for fp32
125
+ return torch.lerp(a, b, beta)
126
+ # More numerically stable than a + beta * (b - a)
127
+ return (1 - beta) * a + beta * b
128
+
129
+
130
+ def _normalize(v):
131
+ return v * torch.rsqrt(torch.sum(v ** 2, dim=-1, keepdim=True))
132
+
133
+
134
+ def slerp(a, b, beta):
135
+ assert a.size() == b.size(), 'Size mismatch between ' + \
136
+ 'slerp arguments, received {} and {}'.format(a.size(), b.size())
137
+ if not torch.is_tensor(beta):
138
+ beta = torch.tensor(beta).to(a)
139
+ a = _normalize(a)
140
+ b = _normalize(b)
141
+ d = torch.sum(a * b, axis=-1, keepdim=True)
142
+ p = beta * torch.acos(beta)
143
+ c = _normalize(b - d * a)
144
+ d = a * torch.cos(p) + c * torch.sin(p)
145
+ return _normalize(d)
146
+
147
+
148
+ #----------------------------------------------------------------------------
149
+ # Command line utils
150
+
151
+
152
+ def _parse_configs(configs):
153
+ kwargs = {}
154
+ for config in configs:
155
+ with open(config, 'r') as fp:
156
+ kwargs.update(yaml.safe_load(fp))
157
+ return kwargs
158
+
159
+
160
+ class ConfigArgumentParser(argparse.ArgumentParser):
161
+
162
+ _CONFIG_ARG_KEY = '_configs'
163
+
164
+ def __init__(self, *args, **kwargs):
165
+ super(ConfigArgumentParser, self).__init__(*args, **kwargs)
166
+ self.add_argument(
167
+ self._CONFIG_ARG_KEY,
168
+ nargs='*',
169
+ help='Any yaml-style config file whos values will override the defaults of this argument parser.',
170
+ type=str
171
+ )
172
+
173
+ def parse_args(self, args=None):
174
+ config_args = _parse_configs(
175
+ getattr(
176
+ super(ConfigArgumentParser, self).parse_args(args),
177
+ self._CONFIG_ARG_KEY
178
+ )
179
+ )
180
+ self.set_defaults(**config_args)
181
+ return super(ConfigArgumentParser, self).parse_args(args)
182
+
183
+
184
+ def bool_type(v):
185
+ if isinstance(v, bool):
186
+ return v
187
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
188
+ return True
189
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
190
+ return False
191
+ else:
192
+ raise argparse.ArgumentTypeError('Boolean value expected.')
193
+
194
+
195
+ def range_type(s):
196
+ """
197
+ Accept either a comma separated list of numbers
198
+ 'a,b,c' or a range 'a-c' and return as a list of ints.
199
+ """
200
+ range_re = re.compile(r'^(\d+)-(\d+)$')
201
+ m = range_re.match(s)
202
+ if m:
203
+ return range(int(m.group(1)), int(m.group(2))+1)
204
+ vals = s.split(',')
205
+ return [int(x) for x in vals]
206
+
207
+
208
+ #----------------------------------------------------------------------------
209
+ # Dataset and generation of latents
210
+
211
+
212
+ class ResizeTransform:
213
+
214
+ def __init__(self, height, width, resize=True, mode='bicubic'):
215
+ if resize:
216
+ assert height and width, 'Height and width have to be given ' + \
217
+ 'when resizing data.'
218
+ self.height = height
219
+ self.width = width
220
+ self.resize = resize
221
+ self.mode = mode
222
+
223
+ def __call__(self, tensor):
224
+ if self.height and self.width:
225
+ if tensor.size(1) != self.height or tensor.size(2) != self.width:
226
+ if self.resize:
227
+ kwargs = {}
228
+ if 'cubic' in self.mode or 'linear' in self.mode:
229
+ kwargs.update(align_corners=False)
230
+ tensor = F.interpolate(
231
+ tensor.unsqueeze(0),
232
+ size=(self.height, self.width),
233
+ mode=self.mode,
234
+ **kwargs
235
+ ).squeeze(0)
236
+ else:
237
+ raise ValueError(
238
+ 'Data shape incorrect, expected ({},{}) '.format(self.width, self.height) + \
239
+ 'but got ({},{}) (width, height)'.format(tensor.size(2), tensor.size(1))
240
+ )
241
+ return tensor
242
+
243
+
244
+ def _PIL_RGB_loader(path):
245
+ return Image.open(path).convert('RGB')
246
+
247
+
248
+ def _PIL_grayscale_loader(path):
249
+ return Image.open(path).convert('L')
250
+
251
+
252
+ class ImageFolder(torchvision.datasets.ImageFolder):
253
+
254
+ def __init__(self,
255
+ *args,
256
+ mirror=False,
257
+ pixel_min=-1,
258
+ pixel_max=1,
259
+ height=None,
260
+ width=None,
261
+ resize=False,
262
+ resize_mode='bicubic',
263
+ grayscale=False,
264
+ **kwargs):
265
+ super(ImageFolder, self).__init__(
266
+ *args,
267
+ loader=_PIL_grayscale_loader if grayscale else _PIL_RGB_loader,
268
+ **kwargs
269
+ )
270
+ transforms = []
271
+ if mirror:
272
+ transforms.append(torchvision.transforms.RandomHorizontalFlip())
273
+ transforms.append(torchvision.transforms.ToTensor())
274
+ transforms.append(
275
+ torchvision.transforms.Normalize(
276
+ mean=[-(pixel_min / (pixel_max - pixel_min))],
277
+ std=[1. / (pixel_max - pixel_min)]
278
+ )
279
+ )
280
+ transforms.append(ResizeTransform(
281
+ height=height, width=width, resize=resize, mode=resize_mode))
282
+ self.transform = torchvision.transforms.Compose(transforms)
283
+
284
+ def _find_classes(self, *args, **kwargs):
285
+ classes, class_to_idx = super(ImageFolder, self)._find_classes(*args, **kwargs)
286
+ if not classes:
287
+ classes = ['']
288
+ class_to_idx = {'': 0}
289
+ return classes, class_to_idx
290
+
291
+
292
+ class PriorGenerator:
293
+
294
+ def __init__(self, latent_size, label_size, batch_size, device):
295
+ self.latent_size = latent_size
296
+ self.label_size = label_size
297
+ self.batch_size = batch_size
298
+ self.device = device
299
+
300
+ def __iter__(self):
301
+ return self
302
+
303
+ def __next__(self):
304
+ return self()
305
+
306
+ def __call__(self, batch_size=None, multi_latent_prob=0, seed=None):
307
+ if batch_size is None:
308
+ batch_size = self.batch_size
309
+ shape = [batch_size, self.latent_size]
310
+ if multi_latent_prob:
311
+ if seed is not None:
312
+ np.random.seed(seed)
313
+ if np.random.uniform() < multi_latent_prob:
314
+ shape = [batch_size, 2, self.latent_size]
315
+ if seed is not None:
316
+ torch.manual_seed(seed)
317
+ latents = torch.empty(*shape, device=self.device).normal_()
318
+ labels = None
319
+ if self.label_size:
320
+ label_shape = [batch_size]
321
+ labels = torch.randint(0, self.label_size, label_shape, device=self.device)
322
+ return latents, labels
323
+
324
+
325
+ #----------------------------------------------------------------------------
326
+ # Training utils
327
+
328
+
329
+ class MovingAverageModule:
330
+
331
+ def __init__(self,
332
+ from_module,
333
+ to_module=None,
334
+ param_beta=0.995,
335
+ buffer_beta=0,
336
+ device=None):
337
+ from_module = unwrap_module(from_module)
338
+ to_module = unwrap_module(to_module)
339
+ if device is None:
340
+ module = from_module
341
+ if to_module is not None:
342
+ module = to_module
343
+ device = next(module.parameters()).device
344
+ else:
345
+ device = torch.device(device)
346
+ self.from_module = from_module
347
+ if to_module is None:
348
+ self.module = from_module.clone().to(device)
349
+ else:
350
+ assert type(to_module) == type(from_module), \
351
+ 'Mismatch between type of source and target module.'
352
+ assert set(self._get_named_parameters(to_module).keys()) \
353
+ == set(self._get_named_parameters(from_module).keys()), \
354
+ 'Mismatch between parameters of source and target module.'
355
+ assert set(self._get_named_buffers(to_module).keys()) \
356
+ == set(self._get_named_buffers(from_module).keys()), \
357
+ 'Mismatch between buffers of source and target module.'
358
+ self.module = to_module.to(device)
359
+ self.module.eval().requires_grad_(False)
360
+ self.param_beta = param_beta
361
+ self.buffer_beta = buffer_beta
362
+ self.device = device
363
+
364
+ def __getattr__(self, name):
365
+ try:
366
+ return super(object, self).__getattr__(name)
367
+ except AttributeError:
368
+ return getattr(self.module, name)
369
+
370
+ def update(self):
371
+ self._update_data(
372
+ from_data=self._get_named_parameters(self.from_module),
373
+ to_data=self._get_named_parameters(self.module),
374
+ beta=self.param_beta
375
+ )
376
+ self._update_data(
377
+ from_data=self._get_named_buffers(self.from_module),
378
+ to_data=self._get_named_buffers(self.module),
379
+ beta=self.buffer_beta
380
+ )
381
+
382
+ @staticmethod
383
+ def _update_data(from_data, to_data, beta):
384
+ for name in from_data.keys():
385
+ if name not in to_data:
386
+ continue
387
+ fr, to = from_data[name], to_data[name]
388
+ with torch.no_grad():
389
+ if beta == 0:
390
+ to.data.copy_(fr.data.to(to.data))
391
+ elif beta < 1:
392
+ to.data.copy_(lerp(fr.data.to(to.data), to.data, beta))
393
+
394
+ @staticmethod
395
+ def _get_named_parameters(module):
396
+ return {name: value for name, value in module.named_parameters()}
397
+
398
+ @staticmethod
399
+ def _get_named_buffers(module):
400
+ return {name: value for name, value in module.named_buffers()}
401
+
402
+ def __call__(self, *args, **kwargs):
403
+ return self.forward(*args, **kwargs)
404
+
405
+ def forward(self, *args, **kwargs):
406
+ self.module.eval()
407
+ args, args_in_device = move_to_device(args, self.device)
408
+ kwargs, kwargs_in_device = move_to_device(kwargs, self.device)
409
+ in_device = None
410
+ if args_in_device is not None:
411
+ in_device = args_in_device
412
+ if kwargs_in_device is not None:
413
+ in_device = kwargs_in_device
414
+ out = self.module(*args, **kwargs)
415
+ if in_device is not None:
416
+ out, _ = move_to_device(out, in_device)
417
+ return out
418
+
419
+
420
+ def move_to_device(value, device):
421
+ if torch.is_tensor(value):
422
+ value.to(device), value.device
423
+ orig_device = None
424
+ if isinstance(value, (tuple, list)):
425
+ values = []
426
+ for val in value:
427
+ _val, orig_device = move_to_device(val, device)
428
+ values.append(_val)
429
+ return type(value)(values), orig_device
430
+ if isinstance(value, dict):
431
+ if isinstance(value, collections.OrderedDict):
432
+ values = collections.OrderedDict()
433
+ else:
434
+ values = {}
435
+ for key, val in value.items():
436
+ _val, orig_device = move_to_device(val, device)
437
+ values[key] = val
438
+ return values, orig_device
439
+ return value, orig_device
440
+
441
+
442
+ _WRAPPER_CLASSES = (MovingAverageModule, nn.DataParallel, nn.parallel.DistributedDataParallel)
443
+ def unwrap_module(module):
444
+ if isinstance(module, _WRAPPER_CLASSES):
445
+ return module.module
446
+ return module
447
+
448
+
449
+ def get_grad_norm_from_optimizer(optimizer, norm_type=2):
450
+ """
451
+ Get the gradient norm for some parameters contained in an optimizer.
452
+ Arguments:
453
+ optimizer (torch.optim.Optimizer)
454
+ norm_type (int): Type of norm. Default value is 2.
455
+ Returns:
456
+ norm (float)
457
+ """
458
+ total_norm = 0
459
+ if optimizer is not None:
460
+ for param_group in optimizer.param_groups:
461
+ for p in param_group['params']:
462
+ if p.grad is not None:
463
+ with torch.no_grad():
464
+ param_norm = p.grad.data.norm(norm_type)
465
+ total_norm += param_norm ** norm_type
466
+ total_norm = total_norm ** (1. / norm_type)
467
+ return total_norm.item()
468
+
469
+
470
+ #----------------------------------------------------------------------------
471
+ # printing and logging utils
472
+
473
+
474
+ class ValueTracker:
475
+
476
+ def __init__(self, beta=0.95):
477
+ self.beta = beta
478
+ self.values = {}
479
+
480
+ def add(self, name, value, beta=None):
481
+ if torch.is_tensor(value):
482
+ value = value.item()
483
+ if beta is None:
484
+ beta = self.beta
485
+ if name not in self.values:
486
+ self.values[name] = value
487
+ else:
488
+ self.values[name] = lerp(value, self.values[name], beta)
489
+
490
+ def __getitem__(self, key):
491
+ return self.values[key]
492
+
493
+ def __str__(self):
494
+ string = ''
495
+ for i, name in enumerate(sorted(self.values.keys())):
496
+ if i and i % 3 == 0:
497
+ string += '\n'
498
+ elif string:
499
+ string += ', '
500
+ format_string = '{}: {}'
501
+ if isinstance(self.values[name], float):
502
+ format_string = '{}: {:.4g}'
503
+ string += format_string.format(name, self.values[name])
504
+ return string
505
+
506
+
507
+ def is_notebook():
508
+ """
509
+ Check if code is running from jupyter notebook.
510
+ Returns:
511
+ notebook (bool): True if running from jupyter notebook,
512
+ else False.
513
+ """
514
+ try:
515
+ __IPYTHON__
516
+ return True
517
+ except NameError:
518
+ return False
519
+
520
+
521
+ def _progress_bar(count, total):
522
+ """
523
+ Get a simple one-line string representing a progress bar.
524
+ Arguments:
525
+ count (int): Current count. Starts at 0.
526
+ total (int): Total count.
527
+ Returns:
528
+ pbar_string (str): The string progress bar.
529
+ """
530
+ bar_len = 60
531
+ filled_len = int(round(bar_len * (count + 1) / float(total)))
532
+ bar = '=' * filled_len + '-' * (bar_len - filled_len)
533
+ return '[{}] {}/{}'.format(bar, count + 1, total)
534
+
535
+
536
+ class ProgressWriter:
537
+ """
538
+ Handles writing output and displaying a progress bar. Automatically
539
+ adjust for notebooks. Supports outputting text
540
+ that is compatible with the progressbar (in notebooks the text is
541
+ refreshed instead of printed).
542
+ Arguments:
543
+ length (int, optional): Total length of the progressbar.
544
+ Default value is None.
545
+ progress_bar (bool, optional): Display a progressbar.
546
+ Default value is True.
547
+ clear (bool, optional): If running from a notebook, clear
548
+ the current cell's output. Default value is False.
549
+ """
550
+ def __init__(self, length=None, progress_bar=True, clear=False):
551
+ if is_notebook() and clear:
552
+ notebook_clear()
553
+
554
+ if length is not None:
555
+ length = int(length)
556
+ self.length = length
557
+ self.count = 0
558
+
559
+ self._simple_pbar = False
560
+ if progress_bar and 'tqdm' not in sys.modules:
561
+ self._simple_pbar = True
562
+
563
+ progress_bar = progress_bar and 'tqdm' in sys.modules
564
+
565
+ self._progress_bar = None
566
+ if progress_bar:
567
+ pbar = tqdm.tqdm
568
+ if is_notebook():
569
+ pbar = tqdm.tqdm_notebook
570
+ if length is not None:
571
+ self._progress_bar = pbar(total=length, file=sys.stdout)
572
+ else:
573
+ self._progress_bar = pbar(file=sys.stdout)
574
+
575
+ if is_notebook():
576
+ self._writer = notebook_display(
577
+ _StrRepr(''),
578
+ display_id=time.asctime()
579
+ )
580
+ else:
581
+ if progress_bar:
582
+ self._writer = self._progress_bar
583
+ else:
584
+ self._writer = sys.stdout
585
+
586
+ def write(self, *lines, step=True):
587
+ """
588
+ Output values to stdout (or a display object if called from notebook).
589
+ Arguments:
590
+ *lines: The lines to write (positional arguments).
591
+ step (bool): Update the progressbar if present.
592
+ Default value is True.
593
+ """
594
+ string = '\n'.join(str(line) for line in lines if line and line.strip())
595
+ if self._simple_pbar:
596
+ string = _progress_bar(self.count, self.length) + '\n' + string
597
+ if is_notebook():
598
+ self._writer.update(_StrRepr(string))
599
+ else:
600
+ self._writer.write('\n\n' + string)
601
+ if hasattr(self._writer, 'flush'):
602
+ self._writer.flush()
603
+ if step:
604
+ self.step()
605
+
606
+ def step(self):
607
+ """
608
+ Update the progressbar if present.
609
+ """
610
+ self.count += 1
611
+ if self._progress_bar is not None:
612
+ self._progress_bar.update()
613
+
614
+ def __iter__(self):
615
+ return self
616
+
617
+ def __next__(self):
618
+ return next(self.rnge)
619
+
620
+ def close(self):
621
+ if hasattr(self._writer, 'close'):
622
+ can_close = True
623
+ try:
624
+ can_close = self._writer != sys.stdout and self._writer != sys.stderr
625
+ except AttributeError:
626
+ pass
627
+ if can_close:
628
+ self._writer.close()
629
+ if hasattr(self._progress_bar, 'close'):
630
+ self._progress_bar.close()
631
+
632
+ def __del__(self):
633
+ self.close()
634
+
635
+
636
+ class _StrRepr:
637
+ """
638
+ A wrapper for strings that returns the string
639
+ on repr() calls. Used by notebooks.
640
+ """
641
+ def __init__(self, string):
642
+ self.string = string
643
+
644
+ def __repr__(self):
645
+ return self.string
646
+
647
+
648
+ #----------------------------------------------------------------------------
649
+ # image utils
650
+
651
+
652
+ def tensor_to_PIL(image_tensor, pixel_min=-1, pixel_max=1):
653
+ image_tensor = image_tensor.cpu()
654
+ if pixel_min != 0 or pixel_max != 1:
655
+ image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)
656
+ image_tensor.clamp_(min=0, max=1)
657
+ to_pil = torchvision.transforms.functional.to_pil_image
658
+ if image_tensor.dim() == 4:
659
+ return [to_pil(img) for img in image_tensor]
660
+ return to_pil(image_tensor)
661
+
662
+
663
+ def PIL_to_tensor(image, pixel_min=-1, pixel_max=1):
664
+ to_tensor = torchvision.transforms.functional.to_tensor
665
+ if isinstance(image, (list, tuple)):
666
+ image_tensor = torch.stack([to_tensor(img) for img in image])
667
+ else:
668
+ image_tensor = to_tensor(image)
669
+ if pixel_min != 0 or pixel_max != 1:
670
+ image_tensor = image_tensor * (pixel_max - pixel_min) + pixel_min
671
+ return image_tensor
672
+
673
+
674
+ def stack_images_PIL(imgs, shape=None, individual_img_size=None):
675
+ """
676
+ Concatenate multiple images into a grid within a single image.
677
+ Arguments:
678
+ imgs (Sequence of PIL.Image): Input images.
679
+ shape (list, tuple, int, optional): Shape of the grid. Should consist
680
+ of two values, (width, height). If an integer value is passed it
681
+ is used for both width and height. If no value is passed the shape
682
+ is infered from the number of images. Default value is None.
683
+ individual_img_size (list, tuple, int, optional): The size of the
684
+ images being concatenated. Default value is None.
685
+ Returns:
686
+ canvas (PIL.Image): Image containing input images in a grid.
687
+ """
688
+ assert len(imgs) > 0, 'No images received.'
689
+ if shape is None:
690
+ size = int(np.ceil(np.sqrt(len(imgs))))
691
+ shape = [int(np.ceil(len(imgs) / size)), size]
692
+ else:
693
+ if isinstance(shape, numbers.Number):
694
+ shape = 2 * [shape]
695
+ assert len(shape) == 2, 'Shape should specify (width, height).'
696
+
697
+ if individual_img_size is None:
698
+ for i in range(len(imgs) - 1):
699
+ assert imgs[i].size == imgs[i + 1].size, \
700
+ 'Images are of different sizes, please specify a ' + \
701
+ 'size (width, height). Found sizes:\n' + \
702
+ ', '.join(str(img.size) for img in imgs)
703
+ individual_img_size = imgs[0].size
704
+ else:
705
+ if not isinstance(individual_img_size, (tuple, list)):
706
+ individual_img_size = 2 * (individual_img_size,)
707
+ individual_img_size = tuple(individual_img_size)
708
+ for i in range(len(imgs)):
709
+ if imgs[i].size != individual_img_size:
710
+ imgs[i] = imgs[i].resize(individual_img_size)
711
+
712
+ width, height = individual_img_size
713
+ width, height = int(width), int(height)
714
+ canvas = Image.new(
715
+ 'RGB',
716
+ (shape[0] * width, shape[1] * height),
717
+ (0, 0, 0, 0)
718
+ )
719
+ imgs = imgs.copy()
720
+ for h_i in range(shape[1]):
721
+ for w_i in range(shape[0]):
722
+ if len(imgs) > 0:
723
+ img = imgs.pop(0).convert('RGB')
724
+ offset = (w_i * width, h_i * height)
725
+ canvas.paste(img, offset)
726
+ return canvas