Spaces:
Runtime error
Runtime error
Fork adriansahlman's stylegan2_pytorch
Browse files- .gitignore +8 -0
- LICENSE.txt +7 -0
- requirements.txt +9 -0
- run_convert_from_tf.py +373 -0
- run_generator.py +361 -0
- run_metrics.py +338 -0
- run_projector.py +401 -0
- run_training.py +1002 -0
- stylegan2/__init__.py +5 -0
- stylegan2/external_models/__init__.py +2 -0
- stylegan2/external_models/inception.py +276 -0
- stylegan2/external_models/lpips.py +78 -0
- stylegan2/loss_fns.py +347 -0
- stylegan2/metrics/__init__.py +2 -0
- stylegan2/metrics/fid.py +210 -0
- stylegan2/metrics/ppl.py +229 -0
- stylegan2/models.py +1230 -0
- stylegan2/modules.py +1601 -0
- stylegan2/project.py +304 -0
- stylegan2/train.py +1013 -0
- stylegan2/utils.py +726 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
|
5 |
+
# Jupyter Notebook
|
6 |
+
.ipynb_checkpoints
|
7 |
+
|
8 |
+
results/
|
LICENSE.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2020 Adrian Sahlman
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pillow
|
3 |
+
pyyaml
|
4 |
+
requests
|
5 |
+
scipy
|
6 |
+
tensorboard
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
tqdm
|
run_convert_from_tf.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import pickle
|
4 |
+
import argparse
|
5 |
+
import io
|
6 |
+
import requests
|
7 |
+
import torch
|
8 |
+
import stylegan2
|
9 |
+
from stylegan2 import utils
|
10 |
+
|
11 |
+
|
12 |
+
pretrained_model_urls = {
|
13 |
+
'car-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl',
|
14 |
+
'car-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl',
|
15 |
+
'cat-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl',
|
16 |
+
'church-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl',
|
17 |
+
'ffhq-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
|
18 |
+
'ffhq-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
|
19 |
+
'horse-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl',
|
20 |
+
'car-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl',
|
21 |
+
'car-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl',
|
22 |
+
'car-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl',
|
23 |
+
'car-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl',
|
24 |
+
'car-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl',
|
25 |
+
'car-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl',
|
26 |
+
'car-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl',
|
27 |
+
'car-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl',
|
28 |
+
'car-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl',
|
29 |
+
'ffhq-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl',
|
30 |
+
'ffhq-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl',
|
31 |
+
'ffhq-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl',
|
32 |
+
'ffhq-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl',
|
33 |
+
'ffhq-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl',
|
34 |
+
'ffhq-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl',
|
35 |
+
'ffhq-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl',
|
36 |
+
'ffhq-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl',
|
37 |
+
'ffhq-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class Unpickler(pickle.Unpickler):
|
42 |
+
def find_class(self, module, name):
|
43 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
44 |
+
return utils.AttributeDict
|
45 |
+
return super(Unpickler, self).find_class(module, name)
|
46 |
+
|
47 |
+
|
48 |
+
def load_tf_models_file(fpath):
|
49 |
+
with open(fpath, 'rb') as fp:
|
50 |
+
return Unpickler(fp).load()
|
51 |
+
|
52 |
+
|
53 |
+
def load_tf_models_url(url):
|
54 |
+
print('Downloading file {}...'.format(url))
|
55 |
+
with requests.Session() as session:
|
56 |
+
with session.get(url) as ret:
|
57 |
+
fp = io.BytesIO(ret.content)
|
58 |
+
return Unpickler(fp).load()
|
59 |
+
|
60 |
+
|
61 |
+
def convert_kwargs(static_kwargs, kwargs_mapping):
|
62 |
+
kwargs = utils.AttributeDict()
|
63 |
+
for key, value in static_kwargs.items():
|
64 |
+
if key in kwargs_mapping:
|
65 |
+
if value == 'lrelu':
|
66 |
+
value = 'leaky:0.2'
|
67 |
+
for k in utils.to_list(kwargs_mapping[key]):
|
68 |
+
kwargs[k] = value
|
69 |
+
return kwargs
|
70 |
+
|
71 |
+
|
72 |
+
_PERMITTED_MODELS = ['G_main', 'G_mapping', 'G_synthesis_stylegan2', 'D_stylegan2', 'D_main', 'G_synthesis']
|
73 |
+
def convert_from_tf(tf_state):
|
74 |
+
tf_state = utils.AttributeDict.convert_dict_recursive(tf_state)
|
75 |
+
model_type = tf_state.build_func_name
|
76 |
+
assert model_type in _PERMITTED_MODELS, \
|
77 |
+
'Found model type {}. '.format(model_type) + \
|
78 |
+
'Allowed model types are: {}'.format(_PERMITTED_MODELS)
|
79 |
+
|
80 |
+
if model_type == 'G_main':
|
81 |
+
kwargs = convert_kwargs(
|
82 |
+
static_kwargs=tf_state.static_kwargs,
|
83 |
+
kwargs_mapping={
|
84 |
+
'dlatent_avg_beta': 'dlatent_avg_beta'
|
85 |
+
}
|
86 |
+
)
|
87 |
+
kwargs.G_mapping = convert_from_tf(tf_state.components.mapping)
|
88 |
+
kwargs.G_synthesis = convert_from_tf(tf_state.components.synthesis)
|
89 |
+
G = stylegan2.models.Generator(**kwargs)
|
90 |
+
for name, var in tf_state.variables:
|
91 |
+
if name == 'dlatent_avg':
|
92 |
+
G.dlatent_avg.data.copy_(torch.from_numpy(var))
|
93 |
+
kwargs = convert_kwargs(
|
94 |
+
static_kwargs=tf_state.static_kwargs,
|
95 |
+
kwargs_mapping={
|
96 |
+
'truncation_psi': 'truncation_psi',
|
97 |
+
'truncation_cutoff': 'truncation_cutoff',
|
98 |
+
'truncation_psi_val': 'truncation_psi',
|
99 |
+
'truncation_cutoff_val': 'truncation_cutoff'
|
100 |
+
}
|
101 |
+
)
|
102 |
+
G.set_truncation(**kwargs)
|
103 |
+
return G
|
104 |
+
|
105 |
+
if model_type == 'G_mapping':
|
106 |
+
kwargs = convert_kwargs(
|
107 |
+
static_kwargs=tf_state.static_kwargs,
|
108 |
+
kwargs_mapping={
|
109 |
+
'mapping_nonlinearity': 'activation',
|
110 |
+
'normalize_latents': 'normalize_input',
|
111 |
+
'mapping_lr_mul': 'lr_mul'
|
112 |
+
}
|
113 |
+
)
|
114 |
+
kwargs.num_layers = sum(
|
115 |
+
1 for var_name, _ in tf_state.variables
|
116 |
+
if re.match('Dense[0-9]+/weight', var_name)
|
117 |
+
)
|
118 |
+
for var_name, var in tf_state.variables:
|
119 |
+
if var_name == 'LabelConcat/weight':
|
120 |
+
kwargs.label_size = var.shape[0]
|
121 |
+
if var_name == 'Dense0/weight':
|
122 |
+
kwargs.latent_size = var.shape[0]
|
123 |
+
kwargs.hidden = var.shape[1]
|
124 |
+
if var_name == 'Dense{}/bias'.format(kwargs.num_layers - 1):
|
125 |
+
kwargs.out_size = var.shape[0]
|
126 |
+
G_mapping = stylegan2.models.GeneratorMapping(**kwargs)
|
127 |
+
for var_name, var in tf_state.variables:
|
128 |
+
if re.match('Dense[0-9]+/[a-zA-Z]*', var_name):
|
129 |
+
layer_idx = int(re.search('Dense(\d+)/[a-zA-Z]*', var_name).groups()[0])
|
130 |
+
if var_name.endswith('weight'):
|
131 |
+
G_mapping.main[layer_idx].layer.weight.data.copy_(
|
132 |
+
torch.from_numpy(var.T).contiguous())
|
133 |
+
elif var_name.endswith('bias'):
|
134 |
+
G_mapping.main[layer_idx].bias.data.copy_(torch.from_numpy(var))
|
135 |
+
if var_name == 'LabelConcat/weight':
|
136 |
+
G_mapping.embedding.weight.data.copy_(torch.from_numpy(var))
|
137 |
+
return G_mapping
|
138 |
+
|
139 |
+
if model_type == 'G_synthesis_stylegan2' or model_type == 'G_synthesis':
|
140 |
+
assert tf_state.static_kwargs.get('fused_modconv', True), \
|
141 |
+
'Can not load TF networks that use `fused_modconv=False`'
|
142 |
+
noise_tensors = []
|
143 |
+
conv_vars = {}
|
144 |
+
for var_name, var in tf_state.variables:
|
145 |
+
if var_name.startswith('noise'):
|
146 |
+
noise_tensors.append(torch.from_numpy(var))
|
147 |
+
else:
|
148 |
+
layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
|
149 |
+
if layer_size not in conv_vars:
|
150 |
+
conv_vars[layer_size] = {}
|
151 |
+
var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
|
152 |
+
conv_vars[layer_size][var_name] = var
|
153 |
+
noise_tensors = sorted(noise_tensors, key=lambda x:x.size(-1))
|
154 |
+
kwargs = convert_kwargs(
|
155 |
+
static_kwargs=tf_state.static_kwargs,
|
156 |
+
kwargs_mapping={
|
157 |
+
'nonlinearity': 'activation',
|
158 |
+
'resample_filter': ['conv_filter', 'skip_filter']
|
159 |
+
}
|
160 |
+
)
|
161 |
+
kwargs.skip = False
|
162 |
+
kwargs.resnet = True
|
163 |
+
kwargs.channels = []
|
164 |
+
for size in sorted(conv_vars.keys(), reverse=True):
|
165 |
+
if size == 4:
|
166 |
+
if 'ToRGB/weight' in conv_vars[size]:
|
167 |
+
kwargs.skip = True
|
168 |
+
kwargs.resnet = False
|
169 |
+
kwargs.latent_size = conv_vars[size]['Conv/mod_weight'].shape[0]
|
170 |
+
kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
|
171 |
+
else:
|
172 |
+
kwargs.channels.append(conv_vars[size]['Conv1/bias'].shape[0])
|
173 |
+
if 'ToRGB/bias' in conv_vars[size]:
|
174 |
+
kwargs.data_channels = conv_vars[size]['ToRGB/bias'].shape[0]
|
175 |
+
G_synthesis = stylegan2.models.GeneratorSynthesis(**kwargs)
|
176 |
+
G_synthesis.const.data.copy_(torch.from_numpy(conv_vars[4]['Const/const']).squeeze(0))
|
177 |
+
def assign_weights(layer, weight, bias, mod_weight, mod_bias, noise_strength, transposed=False):
|
178 |
+
layer.bias.data.copy_(torch.from_numpy(bias))
|
179 |
+
layer.layer.weight.data.copy_(torch.tensor(noise_strength))
|
180 |
+
layer.layer.layer.dense.layer.weight.data.copy_(
|
181 |
+
torch.from_numpy(mod_weight.T).contiguous())
|
182 |
+
layer.layer.layer.dense.bias.data.copy_(torch.from_numpy(mod_bias + 1))
|
183 |
+
weight = torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous()
|
184 |
+
if transposed:
|
185 |
+
weight = weight.flip(dims=[2,3])
|
186 |
+
layer.layer.layer.weight.data.copy_(weight)
|
187 |
+
conv_blocks = G_synthesis.conv_blocks
|
188 |
+
for i, size in enumerate(sorted(conv_vars.keys())):
|
189 |
+
block = conv_blocks[i]
|
190 |
+
if size == 4:
|
191 |
+
assign_weights(
|
192 |
+
layer=block.conv_block[0],
|
193 |
+
weight=conv_vars[size]['Conv/weight'],
|
194 |
+
bias=conv_vars[size]['Conv/bias'],
|
195 |
+
mod_weight=conv_vars[size]['Conv/mod_weight'],
|
196 |
+
mod_bias=conv_vars[size]['Conv/mod_bias'],
|
197 |
+
noise_strength=conv_vars[size]['Conv/noise_strength'],
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
assign_weights(
|
201 |
+
layer=block.conv_block[0],
|
202 |
+
weight=conv_vars[size]['Conv0_up/weight'],
|
203 |
+
bias=conv_vars[size]['Conv0_up/bias'],
|
204 |
+
mod_weight=conv_vars[size]['Conv0_up/mod_weight'],
|
205 |
+
mod_bias=conv_vars[size]['Conv0_up/mod_bias'],
|
206 |
+
noise_strength=conv_vars[size]['Conv0_up/noise_strength'],
|
207 |
+
transposed=True
|
208 |
+
)
|
209 |
+
assign_weights(
|
210 |
+
layer=block.conv_block[1],
|
211 |
+
weight=conv_vars[size]['Conv1/weight'],
|
212 |
+
bias=conv_vars[size]['Conv1/bias'],
|
213 |
+
mod_weight=conv_vars[size]['Conv1/mod_weight'],
|
214 |
+
mod_bias=conv_vars[size]['Conv1/mod_bias'],
|
215 |
+
noise_strength=conv_vars[size]['Conv1/noise_strength'],
|
216 |
+
)
|
217 |
+
if 'Skip/weight' in conv_vars[size]:
|
218 |
+
block.projection.weight.data.copy_(torch.from_numpy(
|
219 |
+
conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
|
220 |
+
to_RGB = G_synthesis.to_data_layers[i]
|
221 |
+
if to_RGB is not None:
|
222 |
+
to_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['ToRGB/bias']))
|
223 |
+
to_RGB.layer.weight.data.copy_(torch.from_numpy(
|
224 |
+
conv_vars[size]['ToRGB/weight']).permute((3, 2, 0, 1)).contiguous())
|
225 |
+
to_RGB.layer.dense.bias.data.copy_(
|
226 |
+
torch.from_numpy(conv_vars[size]['ToRGB/mod_bias'] + 1))
|
227 |
+
to_RGB.layer.dense.layer.weight.data.copy_(
|
228 |
+
torch.from_numpy(conv_vars[size]['ToRGB/mod_weight'].T).contiguous())
|
229 |
+
if not tf_state.static_kwargs.get('randomize_noise', True):
|
230 |
+
G_synthesis.static_noise(noise_tensors=noise_tensors)
|
231 |
+
return G_synthesis
|
232 |
+
|
233 |
+
if model_type == 'D_stylegan2' or model_type == 'D_main':
|
234 |
+
output_vars = {}
|
235 |
+
conv_vars = {}
|
236 |
+
for var_name, var in tf_state.variables:
|
237 |
+
if var_name.startswith('Output'):
|
238 |
+
output_vars[var_name.replace('Output/', '')] = var
|
239 |
+
else:
|
240 |
+
layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
|
241 |
+
if layer_size not in conv_vars:
|
242 |
+
conv_vars[layer_size] = {}
|
243 |
+
var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
|
244 |
+
conv_vars[layer_size][var_name] = var
|
245 |
+
kwargs = convert_kwargs(
|
246 |
+
static_kwargs=tf_state.static_kwargs,
|
247 |
+
kwargs_mapping={
|
248 |
+
'nonlinearity': 'activation',
|
249 |
+
'resample_filter': ['conv_filter', 'skip_filter'],
|
250 |
+
'mbstd_group_size': 'mbstd_group_size'
|
251 |
+
}
|
252 |
+
)
|
253 |
+
kwargs.skip = False
|
254 |
+
kwargs.resnet = True
|
255 |
+
kwargs.channels = []
|
256 |
+
for size in sorted(conv_vars.keys(), reverse=True):
|
257 |
+
if size == 4:
|
258 |
+
if 'FromRGB/weight' in conv_vars[size]:
|
259 |
+
kwargs.skip = True
|
260 |
+
kwargs.resnet = False
|
261 |
+
kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
|
262 |
+
kwargs.dense_hidden = conv_vars[size]['Dense0/bias'].shape[0]
|
263 |
+
else:
|
264 |
+
kwargs.channels.append(conv_vars[size]['Conv0/bias'].shape[0])
|
265 |
+
if 'FromRGB/weight' in conv_vars[size]:
|
266 |
+
kwargs.data_channels = conv_vars[size]['FromRGB/weight'].shape[-2]
|
267 |
+
output_size = output_vars['bias'].shape[0]
|
268 |
+
if output_size > 1:
|
269 |
+
kwargs.label_size = output_size
|
270 |
+
D = stylegan2.models.Discriminator(**kwargs)
|
271 |
+
def assign_weights(layer, weight, bias):
|
272 |
+
layer.bias.data.copy_(torch.from_numpy(bias))
|
273 |
+
layer.layer.weight.data.copy_(
|
274 |
+
torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous())
|
275 |
+
conv_blocks = D.conv_blocks
|
276 |
+
for i, size in enumerate(sorted(conv_vars.keys())):
|
277 |
+
block = conv_blocks[-i - 1]
|
278 |
+
if size == 4:
|
279 |
+
assign_weights(
|
280 |
+
layer=block[-1].conv_block[0],
|
281 |
+
weight=conv_vars[size]['Conv/weight'],
|
282 |
+
bias=conv_vars[size]['Conv/bias'],
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
assign_weights(
|
286 |
+
layer=block.conv_block[0],
|
287 |
+
weight=conv_vars[size]['Conv0/weight'],
|
288 |
+
bias=conv_vars[size]['Conv0/bias'],
|
289 |
+
)
|
290 |
+
assign_weights(
|
291 |
+
layer=block.conv_block[1],
|
292 |
+
weight=conv_vars[size]['Conv1_down/weight'],
|
293 |
+
bias=conv_vars[size]['Conv1_down/bias'],
|
294 |
+
)
|
295 |
+
if 'Skip/weight' in conv_vars[size]:
|
296 |
+
block.projection.weight.data.copy_(torch.from_numpy(
|
297 |
+
conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
|
298 |
+
from_RGB = D.from_data_layers[-i - 1]
|
299 |
+
if from_RGB is not None:
|
300 |
+
from_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['FromRGB/bias']))
|
301 |
+
from_RGB.layer.weight.data.copy_(torch.from_numpy(
|
302 |
+
conv_vars[size]['FromRGB/weight']).permute((3, 2, 0, 1)).contiguous())
|
303 |
+
return D
|
304 |
+
|
305 |
+
|
306 |
+
def get_arg_parser():
|
307 |
+
parser = argparse.ArgumentParser(
|
308 |
+
description='Convert tensorflow stylegan2 model to pytorch.',
|
309 |
+
epilog='Pretrained models that can be downloaded:\n{}'.format(
|
310 |
+
'\n'.join(pretrained_model_urls.keys()))
|
311 |
+
)
|
312 |
+
|
313 |
+
parser.add_argument(
|
314 |
+
'-i',
|
315 |
+
'--input',
|
316 |
+
help='File path to pickled tensorflow models.',
|
317 |
+
type=str,
|
318 |
+
default=None,
|
319 |
+
)
|
320 |
+
|
321 |
+
parser.add_argument(
|
322 |
+
'-d',
|
323 |
+
'--download',
|
324 |
+
help='Download the specified pretrained model. Use --help for info on available models.',
|
325 |
+
type=str,
|
326 |
+
default=None,
|
327 |
+
)
|
328 |
+
|
329 |
+
parser.add_argument(
|
330 |
+
'-o',
|
331 |
+
'--output',
|
332 |
+
help='One or more output file paths. Alternatively a directory path ' + \
|
333 |
+
'where all models will be saved. Default: current directory',
|
334 |
+
type=str,
|
335 |
+
nargs='*',
|
336 |
+
default=['.'],
|
337 |
+
)
|
338 |
+
|
339 |
+
return parser
|
340 |
+
|
341 |
+
|
342 |
+
def main():
|
343 |
+
args = get_arg_parser().parse_args()
|
344 |
+
assert bool(args.input) != bool(args.download), \
|
345 |
+
'Incorrect input format. Can only take either one ' + \
|
346 |
+
'input filepath to a pickled tensorflow model or ' + \
|
347 |
+
'a model name to download, but not both at the same ' + \
|
348 |
+
'time or none at all.'
|
349 |
+
if args.input:
|
350 |
+
unpickled = load_tf_models_file(args.input)
|
351 |
+
else:
|
352 |
+
assert args.download in pretrained_model_urls.keys(), \
|
353 |
+
'Unknown model {}. Use --help for list of models.'.format(args.download)
|
354 |
+
unpickled = load_tf_models_url(pretrained_model_urls[args.download])
|
355 |
+
if not isinstance(unpickled, (tuple, list)):
|
356 |
+
unpickled = [unpickled]
|
357 |
+
print('Converting tensorflow models and saving them...')
|
358 |
+
converted = [convert_from_tf(tf_state) for tf_state in unpickled]
|
359 |
+
if len(args.output) == 1 and (os.path.isdir(args.output[0]) or not os.path.splitext(args.output[0])[-1]):
|
360 |
+
if not os.path.exists(args.output[0]):
|
361 |
+
os.makedirs(args.output[0])
|
362 |
+
for tf_state, torch_model in zip(unpickled, converted):
|
363 |
+
torch_model.save(os.path.join(args.output[0], tf_state['name'] + '.pth'))
|
364 |
+
else:
|
365 |
+
assert len(args.output) == len(converted), 'Found {} models '.format(len(converted)) + \
|
366 |
+
'in pickled file but only {} output paths were given.'.format(len(args.output))
|
367 |
+
for out_path, torch_model in zip(args.output, converted):
|
368 |
+
torch_model.save(out_path)
|
369 |
+
print('Done!')
|
370 |
+
|
371 |
+
|
372 |
+
if __name__ == '__main__':
|
373 |
+
main()
|
run_generator.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import stylegan2
|
9 |
+
from stylegan2 import utils
|
10 |
+
|
11 |
+
#----------------------------------------------------------------------------
|
12 |
+
|
13 |
+
_description = """StyleGAN2 generator.
|
14 |
+
Run 'python %(prog)s <subcommand> --help' for subcommand help."""
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
_examples = """examples:
|
19 |
+
# Train a network or convert a pretrained one.
|
20 |
+
# Example of converting pretrained ffhq model:
|
21 |
+
python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
|
22 |
+
|
23 |
+
# Generate ffhq uncurated images (matches paper Figure 12)
|
24 |
+
python %(prog)s generate_images --network=Gs.pth --seeds=6600-6625 --truncation_psi=0.5
|
25 |
+
|
26 |
+
# Generate ffhq curated images (matches paper Figure 11)
|
27 |
+
python %(prog)s generate_images --network=Gs.pth --seeds=66,230,389,1518 --truncation_psi=1.0
|
28 |
+
|
29 |
+
# Example of converting pretrained car model:
|
30 |
+
python run_convert_from_tf --download car-config-f --output G_car.pth D_car.pth Gs_car.pth
|
31 |
+
|
32 |
+
# Generate uncurated car images (matches paper Figure 12)
|
33 |
+
python %(prog)s generate_images --network=Gs_car.pth --seeds=6000-6025 --truncation_psi=0.5
|
34 |
+
|
35 |
+
# Generate style mixing example (matches style mixing video clip)
|
36 |
+
python %(prog)s style_mixing_example --network=Gs.pth --row_seeds=85,100,75,458,1500 --col_seeds=55,821,1789,293 --truncation_psi=1.0
|
37 |
+
"""
|
38 |
+
|
39 |
+
#----------------------------------------------------------------------------
|
40 |
+
|
41 |
+
def _add_shared_arguments(parser):
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
'--network',
|
45 |
+
help='Network file path',
|
46 |
+
required=True,
|
47 |
+
metavar='FILE'
|
48 |
+
)
|
49 |
+
|
50 |
+
parser.add_argument(
|
51 |
+
'--output',
|
52 |
+
help='Root directory for run results. Default: %(default)s',
|
53 |
+
type=str,
|
54 |
+
default='./results',
|
55 |
+
metavar='DIR'
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
'--pixel_min',
|
60 |
+
help='Minumum of the value range of pixels in generated images. ' + \
|
61 |
+
'Default: %(default)s',
|
62 |
+
default=-1,
|
63 |
+
type=float,
|
64 |
+
metavar='VALUE'
|
65 |
+
)
|
66 |
+
|
67 |
+
parser.add_argument(
|
68 |
+
'--pixel_max',
|
69 |
+
help='Maximum of the value range of pixels in generated images. ' + \
|
70 |
+
'Default: %(default)s',
|
71 |
+
default=1,
|
72 |
+
type=float,
|
73 |
+
metavar='VALUE'
|
74 |
+
)
|
75 |
+
|
76 |
+
parser.add_argument(
|
77 |
+
'--gpu',
|
78 |
+
help='CUDA device indices (given as separate ' + \
|
79 |
+
'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
|
80 |
+
type=int,
|
81 |
+
default=[],
|
82 |
+
nargs='*',
|
83 |
+
metavar='INDEX'
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
'--truncation_psi',
|
88 |
+
help='Truncation psi. Default: %(default)s',
|
89 |
+
type=float,
|
90 |
+
default=0.5,
|
91 |
+
metavar='VALUE'
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def get_arg_parser():
|
96 |
+
parser = argparse.ArgumentParser(
|
97 |
+
description=_description,
|
98 |
+
epilog=_examples,
|
99 |
+
formatter_class=argparse.RawDescriptionHelpFormatter
|
100 |
+
)
|
101 |
+
|
102 |
+
range_desc = 'NOTE: This is a single argument, where list ' + \
|
103 |
+
'elements are separated by "," and ranges are defined as "a-b". Only integers are allowed.'
|
104 |
+
|
105 |
+
subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
|
106 |
+
|
107 |
+
generate_images_parser = subparsers.add_parser(
|
108 |
+
'generate_images', help='Generate images')
|
109 |
+
|
110 |
+
generate_images_parser.add_argument(
|
111 |
+
'--batch_size',
|
112 |
+
help='Batch size for generator. Default: %(default)s',
|
113 |
+
type=int,
|
114 |
+
default=1,
|
115 |
+
metavar='VALUE'
|
116 |
+
)
|
117 |
+
|
118 |
+
generate_images_parser.add_argument(
|
119 |
+
'--seeds',
|
120 |
+
help='List of random seeds for generating images. ' + range_desc,
|
121 |
+
type=utils.range_type,
|
122 |
+
required=True,
|
123 |
+
metavar='RANGE'
|
124 |
+
)
|
125 |
+
|
126 |
+
_add_shared_arguments(generate_images_parser)
|
127 |
+
|
128 |
+
style_mixing_example_parser = subparsers.add_parser(
|
129 |
+
'style_mixing_example', help='Generate style mixing video')
|
130 |
+
|
131 |
+
style_mixing_example_parser.add_argument(
|
132 |
+
'--row_seeds',
|
133 |
+
help='List of random seeds for image rows. ' + range_desc,
|
134 |
+
type=utils.range_type,
|
135 |
+
required=True,
|
136 |
+
metavar='RANGE'
|
137 |
+
)
|
138 |
+
|
139 |
+
style_mixing_example_parser.add_argument(
|
140 |
+
'--col_seeds',
|
141 |
+
help='List of random seeds for image columns. ' + range_desc,
|
142 |
+
type=utils.range_type,
|
143 |
+
required=True,
|
144 |
+
metavar='RANGE'
|
145 |
+
)
|
146 |
+
|
147 |
+
style_mixing_example_parser.add_argument(
|
148 |
+
'--style_layers',
|
149 |
+
help='Indices of layers to mix style for. ' + \
|
150 |
+
'Default: %(default)s. ' + range_desc,
|
151 |
+
type=utils.range_type,
|
152 |
+
default='0-6',
|
153 |
+
metavar='RANGE'
|
154 |
+
)
|
155 |
+
|
156 |
+
style_mixing_example_parser.add_argument(
|
157 |
+
'--grid',
|
158 |
+
help='Save a grid as well of the style mix. Default: %(default)s',
|
159 |
+
type=utils.bool_type,
|
160 |
+
default=True,
|
161 |
+
const=True,
|
162 |
+
nargs='?',
|
163 |
+
metavar='BOOL'
|
164 |
+
)
|
165 |
+
|
166 |
+
_add_shared_arguments(style_mixing_example_parser)
|
167 |
+
|
168 |
+
return parser
|
169 |
+
|
170 |
+
#----------------------------------------------------------------------------
|
171 |
+
|
172 |
+
def style_mixing_example(G, args):
|
173 |
+
assert max(args.style_layers) < len(G), \
|
174 |
+
'Style layer indices can not be larger than ' + \
|
175 |
+
'number of style layers ({}) of the generator.'.format(len(G))
|
176 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
177 |
+
if device.index is not None:
|
178 |
+
torch.cuda.set_device(device.index)
|
179 |
+
if len(args.gpu) > 1:
|
180 |
+
warnings.warn('Multi GPU is not available for style mixing example. Using device {}'.format(device))
|
181 |
+
G.to(device)
|
182 |
+
G.static_noise()
|
183 |
+
latent_size, label_size = G.latent_size, G.label_size
|
184 |
+
G_mapping, G_synthesis = G.G_mapping, G.G_synthesis
|
185 |
+
|
186 |
+
all_seeds = list(set(args.row_seeds + args.col_seeds))
|
187 |
+
all_z = torch.stack([torch.from_numpy(np.random.RandomState(seed).randn(latent_size)) for seed in all_seeds])
|
188 |
+
all_z = all_z.to(device=device, dtype=torch.float32)
|
189 |
+
if label_size:
|
190 |
+
labels = torch.zeros(len(all_z), dtype=torch.int64, device=device)
|
191 |
+
else:
|
192 |
+
labels = None
|
193 |
+
|
194 |
+
print('Generating disentangled latents...')
|
195 |
+
with torch.no_grad():
|
196 |
+
all_w = G_mapping(latents=all_z, labels=labels)
|
197 |
+
|
198 |
+
all_w = all_w.unsqueeze(1).repeat(1, len(G_synthesis), 1)
|
199 |
+
|
200 |
+
w_avg = G.dlatent_avg
|
201 |
+
|
202 |
+
if args.truncation_psi != 1:
|
203 |
+
all_w = w_avg + args.truncation_psi * (all_w - w_avg)
|
204 |
+
|
205 |
+
w_dict = {seed: w for seed, w in zip(all_seeds, all_w)}
|
206 |
+
|
207 |
+
all_images = []
|
208 |
+
|
209 |
+
progress = utils.ProgressWriter(len(all_w))
|
210 |
+
progress.write('Generating images...', step=False)
|
211 |
+
|
212 |
+
with torch.no_grad():
|
213 |
+
for w in all_w:
|
214 |
+
all_images.append(G_synthesis(w.unsqueeze(0)))
|
215 |
+
progress.step()
|
216 |
+
|
217 |
+
progress.write('Done!', step=False)
|
218 |
+
progress.close()
|
219 |
+
|
220 |
+
all_images = torch.cat(all_images, dim=0)
|
221 |
+
|
222 |
+
image_dict = {(seed, seed): image for seed, image in zip(all_seeds, all_images)}
|
223 |
+
|
224 |
+
progress = utils.ProgressWriter(len(args.row_seeds) * len(args.col_seeds))
|
225 |
+
progress.write('Generating style-mixed images...', step=False)
|
226 |
+
|
227 |
+
for row_seed in args.row_seeds:
|
228 |
+
for col_seed in args.col_seeds:
|
229 |
+
w = w_dict[row_seed].clone()
|
230 |
+
w[args.style_layers] = w_dict[col_seed][args.style_layers]
|
231 |
+
with torch.no_grad():
|
232 |
+
image_dict[(row_seed, col_seed)] = G_synthesis(w.unsqueeze(0)).squeeze(0)
|
233 |
+
progress.step()
|
234 |
+
|
235 |
+
progress.write('Done!', step=False)
|
236 |
+
progress.close()
|
237 |
+
|
238 |
+
progress = utils.ProgressWriter(len(image_dict))
|
239 |
+
progress.write('Saving images...', step=False)
|
240 |
+
|
241 |
+
for (row_seed, col_seed), image in list(image_dict.items()):
|
242 |
+
image = utils.tensor_to_PIL(
|
243 |
+
image, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
244 |
+
image_dict[(row_seed, col_seed)] = image
|
245 |
+
image.save(os.path.join(args.output, '%d-%d.png' % (row_seed, col_seed)))
|
246 |
+
progress.step()
|
247 |
+
|
248 |
+
progress.write('Done!', step=False)
|
249 |
+
progress.close()
|
250 |
+
|
251 |
+
if args.grid:
|
252 |
+
print('\n\nSaving style-mixed grid...')
|
253 |
+
H, W = all_images.size()[2:]
|
254 |
+
canvas = Image.new(
|
255 |
+
'RGB', (W * (len(args.col_seeds) + 1), H * (len(args.row_seeds) + 1)), 'black')
|
256 |
+
for row_idx, row_seed in enumerate([None] + args.row_seeds):
|
257 |
+
for col_idx, col_seed in enumerate([None] + args.col_seeds):
|
258 |
+
if row_seed is None and col_seed is None:
|
259 |
+
continue
|
260 |
+
key = (row_seed, col_seed)
|
261 |
+
if row_seed is None:
|
262 |
+
key = (col_seed, col_seed)
|
263 |
+
if col_seed is None:
|
264 |
+
key = (row_seed, row_seed)
|
265 |
+
canvas.paste(image_dict[key], (W * col_idx, H * row_idx))
|
266 |
+
canvas.save(os.path.join(args.output, 'grid.png'))
|
267 |
+
print('Done!')
|
268 |
+
|
269 |
+
#----------------------------------------------------------------------------
|
270 |
+
|
271 |
+
def generate_images(G, args):
|
272 |
+
latent_size, label_size = G.latent_size, G.label_size
|
273 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
274 |
+
if device.index is not None:
|
275 |
+
torch.cuda.set_device(device.index)
|
276 |
+
G.to(device)
|
277 |
+
if args.truncation_psi != 1:
|
278 |
+
G.set_truncation(truncation_psi=args.truncation_psi)
|
279 |
+
if len(args.gpu) > 1:
|
280 |
+
warnings.warn(
|
281 |
+
'Noise can not be randomized based on the seed ' + \
|
282 |
+
'when using more than 1 GPU device. Noise will ' + \
|
283 |
+
'now be randomized from default random state.'
|
284 |
+
)
|
285 |
+
G.random_noise()
|
286 |
+
G = torch.nn.DataParallel(G, device_ids=args.gpu)
|
287 |
+
else:
|
288 |
+
noise_reference = G.static_noise()
|
289 |
+
|
290 |
+
def get_batch(seeds):
|
291 |
+
latents = []
|
292 |
+
labels = []
|
293 |
+
if len(args.gpu) <= 1:
|
294 |
+
noise_tensors = [[] for _ in noise_reference]
|
295 |
+
for seed in seeds:
|
296 |
+
rnd = np.random.RandomState(seed)
|
297 |
+
latents.append(torch.from_numpy(rnd.randn(latent_size)))
|
298 |
+
if len(args.gpu) <= 1:
|
299 |
+
for i, ref in enumerate(noise_reference):
|
300 |
+
noise_tensors[i].append(torch.from_numpy(rnd.randn(*ref.size()[1:])))
|
301 |
+
if label_size:
|
302 |
+
labels.append(torch.tensor([rnd.randint(0, label_size)]))
|
303 |
+
latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32)
|
304 |
+
if labels:
|
305 |
+
labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64)
|
306 |
+
else:
|
307 |
+
labels = None
|
308 |
+
if len(args.gpu) <= 1:
|
309 |
+
noise_tensors = [
|
310 |
+
torch.stack(noise, dim=0).to(device=device, dtype=torch.float32)
|
311 |
+
for noise in noise_tensors
|
312 |
+
]
|
313 |
+
else:
|
314 |
+
noise_tensors = None
|
315 |
+
return latents, labels, noise_tensors
|
316 |
+
|
317 |
+
progress = utils.ProgressWriter(len(args.seeds))
|
318 |
+
progress.write('Generating images...', step=False)
|
319 |
+
|
320 |
+
for i in range(0, len(args.seeds), args.batch_size):
|
321 |
+
latents, labels, noise_tensors = get_batch(args.seeds[i: i + args.batch_size])
|
322 |
+
if noise_tensors is not None:
|
323 |
+
G.static_noise(noise_tensors=noise_tensors)
|
324 |
+
with torch.no_grad():
|
325 |
+
generated = G(latents, labels=labels)
|
326 |
+
images = utils.tensor_to_PIL(
|
327 |
+
generated, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
328 |
+
for seed, img in zip(args.seeds[i: i + args.batch_size], images):
|
329 |
+
img.save(os.path.join(args.output, 'seed%04d.png' % seed))
|
330 |
+
progress.step()
|
331 |
+
|
332 |
+
progress.write('Done!', step=False)
|
333 |
+
progress.close()
|
334 |
+
|
335 |
+
#----------------------------------------------------------------------------
|
336 |
+
|
337 |
+
def main():
|
338 |
+
args = get_arg_parser().parse_args()
|
339 |
+
assert args.command, 'Missing subcommand.'
|
340 |
+
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
|
341 |
+
'--output argument should specify a directory, not a file.'
|
342 |
+
if not os.path.exists(args.output):
|
343 |
+
os.makedirs(args.output)
|
344 |
+
|
345 |
+
G = stylegan2.models.load(args.network)
|
346 |
+
G.eval()
|
347 |
+
|
348 |
+
assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
|
349 |
+
'stylegan2.models.Generator. Found {}.'.format(type(G))
|
350 |
+
|
351 |
+
if args.command == 'generate_images':
|
352 |
+
generate_images(G, args)
|
353 |
+
elif args.command == 'style_mixing_example':
|
354 |
+
style_mixing_example(G, args)
|
355 |
+
else:
|
356 |
+
raise TypeError('Unkown command {}'.format(args.command))
|
357 |
+
|
358 |
+
#----------------------------------------------------------------------------
|
359 |
+
|
360 |
+
if __name__ == '__main__':
|
361 |
+
main()
|
run_metrics.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import stylegan2
|
8 |
+
from stylegan2 import utils
|
9 |
+
|
10 |
+
#----------------------------------------------------------------------------
|
11 |
+
|
12 |
+
_description = """Metrics evaluation.
|
13 |
+
Run 'python %(prog)s <subcommand> --help' for subcommand help."""
|
14 |
+
|
15 |
+
#----------------------------------------------------------------------------
|
16 |
+
|
17 |
+
_examples = """examples:
|
18 |
+
# Train a network or convert a pretrained one. In this example we first convert a pretrained one.
|
19 |
+
python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
|
20 |
+
|
21 |
+
# Project generated images
|
22 |
+
python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5
|
23 |
+
|
24 |
+
# Project real images
|
25 |
+
python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder
|
26 |
+
"""
|
27 |
+
|
28 |
+
#----------------------------------------------------------------------------
|
29 |
+
|
30 |
+
def _add_shared_arguments(parser):
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
'--network',
|
34 |
+
help='Network file path',
|
35 |
+
required=True,
|
36 |
+
metavar='FILE'
|
37 |
+
)
|
38 |
+
|
39 |
+
parser.add_argument(
|
40 |
+
'--num_samples',
|
41 |
+
type=int,
|
42 |
+
help='Number of samples to gather for evaluating ' + \
|
43 |
+
'this metric. Default: %(default)s',
|
44 |
+
default=50000,
|
45 |
+
metavar='VALUE'
|
46 |
+
)
|
47 |
+
|
48 |
+
parser.add_argument(
|
49 |
+
'--size',
|
50 |
+
type=int,
|
51 |
+
help='Rescale images so that this is the size of their ' + \
|
52 |
+
'smallest side in pixels. Default: Unscaled',
|
53 |
+
default=None,
|
54 |
+
metavar='VALUE'
|
55 |
+
)
|
56 |
+
|
57 |
+
parser.add_argument(
|
58 |
+
'--batch_size',
|
59 |
+
help='Batch size for generator. Default: %(default)s',
|
60 |
+
type=int,
|
61 |
+
default=1,
|
62 |
+
metavar='VALUE'
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
'--output',
|
67 |
+
help='Root directory for run results. Default: %(default)s',
|
68 |
+
type=str,
|
69 |
+
default='./results',
|
70 |
+
metavar='DIR'
|
71 |
+
)
|
72 |
+
|
73 |
+
parser.add_argument(
|
74 |
+
'--pixel_min',
|
75 |
+
help='Minumum of the value range of pixels in generated images. ' + \
|
76 |
+
'Default: %(default)s',
|
77 |
+
default=-1,
|
78 |
+
type=float,
|
79 |
+
metavar='VALUE'
|
80 |
+
)
|
81 |
+
|
82 |
+
parser.add_argument(
|
83 |
+
'--pixel_max',
|
84 |
+
help='Maximum of the value range of pixels in generated images. ' + \
|
85 |
+
'Default: %(default)s',
|
86 |
+
default=1,
|
87 |
+
type=float,
|
88 |
+
metavar='VALUE'
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
'--gpu',
|
93 |
+
help='CUDA device indices (given as separate ' + \
|
94 |
+
'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
|
95 |
+
type=int,
|
96 |
+
default=[],
|
97 |
+
nargs='*',
|
98 |
+
metavar='INDEX'
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def get_arg_parser():
|
103 |
+
parser = argparse.ArgumentParser(
|
104 |
+
description=_description,
|
105 |
+
epilog=_examples,
|
106 |
+
formatter_class=argparse.RawDescriptionHelpFormatter
|
107 |
+
)
|
108 |
+
|
109 |
+
subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
|
110 |
+
|
111 |
+
fid_parser = subparsers.add_parser('fid', help='Calculate FID')
|
112 |
+
|
113 |
+
fid_parser.add_argument(
|
114 |
+
'--data_dir',
|
115 |
+
help='Dataset root directory',
|
116 |
+
required=True,
|
117 |
+
metavar='DIR'
|
118 |
+
)
|
119 |
+
|
120 |
+
fid_parser.add_argument(
|
121 |
+
'--reals_batch_size',
|
122 |
+
help='Batch size for gathering statistics of reals. Default: %(default)s',
|
123 |
+
type=int,
|
124 |
+
default=1,
|
125 |
+
metavar='VALUE'
|
126 |
+
)
|
127 |
+
|
128 |
+
fid_parser.add_argument(
|
129 |
+
'--reals_data_workers',
|
130 |
+
help='Data workers for fetching real data samples. Default: %(default)s',
|
131 |
+
type=int,
|
132 |
+
default=4,
|
133 |
+
metavar='VALUE'
|
134 |
+
)
|
135 |
+
|
136 |
+
fid_parser.add_argument(
|
137 |
+
'--truncation_psi',
|
138 |
+
help='Truncation psi. Default: %(default)s',
|
139 |
+
type=float,
|
140 |
+
default=1.0,
|
141 |
+
metavar='VALUE'
|
142 |
+
)
|
143 |
+
|
144 |
+
_add_shared_arguments(fid_parser)
|
145 |
+
|
146 |
+
ppl_parser = subparsers.add_parser('ppl', help='Calculate PPL')
|
147 |
+
|
148 |
+
ppl_parser.add_argument(
|
149 |
+
'--epsilon',
|
150 |
+
type=float,
|
151 |
+
help='Perturbation value. Default: %(default)s',
|
152 |
+
default=1e-4,
|
153 |
+
metavar='VALUE'
|
154 |
+
)
|
155 |
+
|
156 |
+
ppl_parser.add_argument(
|
157 |
+
'--use_dlatent',
|
158 |
+
type=utils.bool_type,
|
159 |
+
help='Measure on perturbations of disentangled latents ' + \
|
160 |
+
'instead of raw latents. Default: %(default)s',
|
161 |
+
default=True,
|
162 |
+
const=True,
|
163 |
+
nargs='?',
|
164 |
+
metavar='BOOL'
|
165 |
+
)
|
166 |
+
|
167 |
+
ppl_parser.add_argument(
|
168 |
+
'--full_sampling',
|
169 |
+
type=utils.bool_type,
|
170 |
+
help='Measure on random interpolation between two inputs ' + \
|
171 |
+
'instead of directly on one input. Default: %(default)s',
|
172 |
+
default=False,
|
173 |
+
const=True,
|
174 |
+
nargs='?',
|
175 |
+
metavar='BOOL'
|
176 |
+
)
|
177 |
+
|
178 |
+
parser.add_argument(
|
179 |
+
'--ppl_ffhq_crop',
|
180 |
+
help='Crop images evaluated for PPL with crop values ' + \
|
181 |
+
'for FFHQ. Default: False',
|
182 |
+
type=utils.bool_type,
|
183 |
+
const=True,
|
184 |
+
nargs='?',
|
185 |
+
default=False,
|
186 |
+
metavar='BOOL'
|
187 |
+
)
|
188 |
+
|
189 |
+
_add_shared_arguments(ppl_parser)
|
190 |
+
|
191 |
+
return parser
|
192 |
+
|
193 |
+
#----------------------------------------------------------------------------
|
194 |
+
|
195 |
+
def _report_metric(value, name, args):
|
196 |
+
fpath = os.path.join(args.output, 'metrics.json')
|
197 |
+
metrics = {}
|
198 |
+
if os.path.exists(fpath):
|
199 |
+
with open(fpath, 'r') as fp:
|
200 |
+
try:
|
201 |
+
metrics = json.load(fp)
|
202 |
+
except Exception:
|
203 |
+
pass
|
204 |
+
metrics[name] = value
|
205 |
+
with open(fpath, 'w') as fp:
|
206 |
+
json.dump(metrics, fp)
|
207 |
+
print('\n\nMetric evaluated!:')
|
208 |
+
print('{}: {}'.format(name, value))
|
209 |
+
|
210 |
+
#----------------------------------------------------------------------------
|
211 |
+
|
212 |
+
def eval_fid(G, prior_generator, args):
|
213 |
+
assert args.data_dir, '--data_dir has to be specified.'
|
214 |
+
dataset = utils.ImageFolder(
|
215 |
+
args.data_dir,
|
216 |
+
pixel_min=args.pixel_min,
|
217 |
+
pixel_max=args.pixel_max
|
218 |
+
)
|
219 |
+
assert len(dataset), 'No images found at {}'.format(args.data_dir)
|
220 |
+
|
221 |
+
inception = stylegan2.external_models.inception.InceptionV3FeatureExtractor(
|
222 |
+
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
223 |
+
|
224 |
+
if len(args.gpu) > 1:
|
225 |
+
inception = torch.nn.DataParallel(inception, device_ids=args.gpu)
|
226 |
+
|
227 |
+
args.reals_batch_size = max(args.reals_batch_size, len(args.gpu))
|
228 |
+
|
229 |
+
fid = stylegan2.metrics.fid.FID(
|
230 |
+
G=G,
|
231 |
+
prior_generator=prior_generator,
|
232 |
+
dataset=dataset,
|
233 |
+
num_samples=args.num_samples,
|
234 |
+
fid_model=inception,
|
235 |
+
fid_size=args.size,
|
236 |
+
truncation_psi=args.truncation_psi,
|
237 |
+
reals_batch_size=args.reals_batch_size,
|
238 |
+
reals_data_workers=args.reals_data_workers
|
239 |
+
)
|
240 |
+
|
241 |
+
value = fid.evaluate()
|
242 |
+
|
243 |
+
name = 'FID'
|
244 |
+
if args.size:
|
245 |
+
name += '({})'.format(args.size)
|
246 |
+
if args.truncation_psi != 1:
|
247 |
+
name +='trunc{}'.format(args.truncation_psi)
|
248 |
+
name += ':{}k'.format(args.num_samples // 1000)
|
249 |
+
|
250 |
+
_report_metric(value, name, args)
|
251 |
+
|
252 |
+
#----------------------------------------------------------------------------
|
253 |
+
|
254 |
+
def eval_ppl(G, prior_generator, args):
|
255 |
+
|
256 |
+
lpips = stylegan2.external_models.lpips.LPIPS_VGG16(
|
257 |
+
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
258 |
+
|
259 |
+
if len(args.gpu) > 1:
|
260 |
+
lpips = torch.nn.DataParallel(lpips, device_ids=args.gpu)
|
261 |
+
|
262 |
+
crop = None
|
263 |
+
if args.ppl_ffhq_crop:
|
264 |
+
crop = stylegan2.metrics.ppl.PPL.FFHQ_CROP
|
265 |
+
|
266 |
+
ppl = stylegan2.metrics.ppl.PPL(
|
267 |
+
G=G,
|
268 |
+
prior_generator=prior_generator,
|
269 |
+
num_samples=args.num_samples,
|
270 |
+
epsilon=args.epsilon,
|
271 |
+
use_dlatent=args.use_dlatent,
|
272 |
+
full_sampling=args.full_sampling,
|
273 |
+
crop=crop,
|
274 |
+
lpips_model=lpips,
|
275 |
+
lpips_size=args.size,
|
276 |
+
)
|
277 |
+
|
278 |
+
value = ppl.evaluate()
|
279 |
+
|
280 |
+
name = 'PPL'
|
281 |
+
if args.size:
|
282 |
+
name += '({})'.format(args.size)
|
283 |
+
if args.use_dlatent:
|
284 |
+
name += 'W'
|
285 |
+
else:
|
286 |
+
name += 'Z'
|
287 |
+
if args.full_sampling:
|
288 |
+
name += '-full'
|
289 |
+
else:
|
290 |
+
name += '-end'
|
291 |
+
name += ':{}k'.format(args.num_samples // 1000)
|
292 |
+
|
293 |
+
_report_metric(value, name, args)
|
294 |
+
|
295 |
+
#----------------------------------------------------------------------------
|
296 |
+
|
297 |
+
def main():
|
298 |
+
args = get_arg_parser().parse_args()
|
299 |
+
assert args.command, 'Missing subcommand.'
|
300 |
+
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
|
301 |
+
'--output argument should specify a directory, not a file.'
|
302 |
+
if not os.path.exists(args.output):
|
303 |
+
os.makedirs(args.output)
|
304 |
+
|
305 |
+
G = stylegan2.models.load(args.network)
|
306 |
+
assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
|
307 |
+
'stylegan2.models.Generator. Found {}.'.format(type(G))
|
308 |
+
|
309 |
+
latent_size, label_size = G.latent_size, G.label_size
|
310 |
+
|
311 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
312 |
+
if device.index is not None:
|
313 |
+
torch.cuda.set_device(device.index)
|
314 |
+
|
315 |
+
G.to(device).eval().requires_grad_(False)
|
316 |
+
|
317 |
+
if len(args.gpu) > 1:
|
318 |
+
G = torch.nn.DataParallel(G, device_ids=args.gpu)
|
319 |
+
|
320 |
+
args.batch_size = max(args.batch_size, len(args.gpu))
|
321 |
+
|
322 |
+
prior_generator = utils.PriorGenerator(
|
323 |
+
latent_size=latent_size,
|
324 |
+
label_size=label_size,
|
325 |
+
batch_size=args.batch_size,
|
326 |
+
device=device
|
327 |
+
)
|
328 |
+
|
329 |
+
if args.command == 'fid':
|
330 |
+
eval_fid(G, prior_generator, args)
|
331 |
+
elif args.command == 'ppl':
|
332 |
+
eval_ppl(G, prior_generator, args)
|
333 |
+
else:
|
334 |
+
raise TypeError('Unkown command {}'.format(args.command))
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == '__main__':
|
338 |
+
main()
|
run_projector.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import stylegan2
|
7 |
+
from stylegan2 import utils
|
8 |
+
|
9 |
+
#----------------------------------------------------------------------------
|
10 |
+
|
11 |
+
_description = """StyleGAN2 projector.
|
12 |
+
Run 'python %(prog)s <subcommand> --help' for subcommand help."""
|
13 |
+
|
14 |
+
#----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
_examples = """examples:
|
17 |
+
# Train a network or convert a pretrained one.
|
18 |
+
# Example of converting pretrained ffhq model:
|
19 |
+
python run_convert_from_tf --download ffhq-config-f --output G.pth D.pth Gs.pth
|
20 |
+
|
21 |
+
# Project generated images
|
22 |
+
python %(prog)s project_generated_images --network=Gs.pth --seeds=0,1,5
|
23 |
+
|
24 |
+
# Project real images
|
25 |
+
python %(prog)s project_real_images --network=Gs.pth --data-dir=path/to/image_folder
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
#----------------------------------------------------------------------------
|
30 |
+
|
31 |
+
def _add_shared_arguments(parser):
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
'--network',
|
35 |
+
help='Network file path',
|
36 |
+
required=True,
|
37 |
+
metavar='FILE'
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
'--num_steps',
|
42 |
+
type=int,
|
43 |
+
help='Number of steps to use for projection. ' + \
|
44 |
+
'Default: %(default)s',
|
45 |
+
default=1000,
|
46 |
+
metavar='VALUE'
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
'--batch_size',
|
51 |
+
help='Batch size. Default: %(default)s',
|
52 |
+
type=int,
|
53 |
+
default=1,
|
54 |
+
metavar='VALUE'
|
55 |
+
)
|
56 |
+
|
57 |
+
parser.add_argument(
|
58 |
+
'--label',
|
59 |
+
help='Label to use for dlatent statistics gathering ' + \
|
60 |
+
'(should be integer index of class). Default: no label.',
|
61 |
+
type=int,
|
62 |
+
default=None,
|
63 |
+
metavar='CLASS_INDEX'
|
64 |
+
)
|
65 |
+
|
66 |
+
parser.add_argument(
|
67 |
+
'--initial_learning_rate',
|
68 |
+
help='Initial learning rate of projection. Default: %(default)s',
|
69 |
+
default=0.1,
|
70 |
+
type=float,
|
71 |
+
metavar='VALUE'
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
'--initial_noise_factor',
|
76 |
+
help='Initial noise factor of projection. Default: %(default)s',
|
77 |
+
default=0.05,
|
78 |
+
type=float,
|
79 |
+
metavar='VALUE'
|
80 |
+
)
|
81 |
+
|
82 |
+
parser.add_argument(
|
83 |
+
'--lr_rampdown_length',
|
84 |
+
help='Learning rate rampdown length for projection. ' + \
|
85 |
+
'Should be in range [0, 1]. Default: %(default)s',
|
86 |
+
default=0.25,
|
87 |
+
type=float,
|
88 |
+
metavar='VALUE'
|
89 |
+
)
|
90 |
+
|
91 |
+
parser.add_argument(
|
92 |
+
'--lr_rampup_length',
|
93 |
+
help='Learning rate rampup length for projection. ' + \
|
94 |
+
'Should be in range [0, 1]. Default: %(default)s',
|
95 |
+
default=0.05,
|
96 |
+
type=float,
|
97 |
+
metavar='VALUE'
|
98 |
+
)
|
99 |
+
|
100 |
+
parser.add_argument(
|
101 |
+
'--noise_ramp_length',
|
102 |
+
help='Learning rate rampdown length for projection. ' + \
|
103 |
+
'Should be in range [0, 1]. Default: %(default)s',
|
104 |
+
default=0.75,
|
105 |
+
type=float,
|
106 |
+
metavar='VALUE'
|
107 |
+
)
|
108 |
+
|
109 |
+
parser.add_argument(
|
110 |
+
'--regularize_noise_weight',
|
111 |
+
help='The weight for noise regularization. Default: %(default)s',
|
112 |
+
default=1e5,
|
113 |
+
type=float,
|
114 |
+
metavar='VALUE'
|
115 |
+
)
|
116 |
+
|
117 |
+
parser.add_argument(
|
118 |
+
'--output',
|
119 |
+
help='Root directory for run results. Default: %(default)s',
|
120 |
+
type=str,
|
121 |
+
default='./results',
|
122 |
+
metavar='DIR'
|
123 |
+
)
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
'--num_snapshots',
|
127 |
+
help='Number of snapshots. Default: %(default)s',
|
128 |
+
type=int,
|
129 |
+
default=5,
|
130 |
+
metavar='VALUE'
|
131 |
+
)
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
'--pixel_min',
|
135 |
+
help='Minumum of the value range of pixels in generated images. ' + \
|
136 |
+
'Default: %(default)s',
|
137 |
+
default=-1,
|
138 |
+
type=float,
|
139 |
+
metavar='VALUE'
|
140 |
+
)
|
141 |
+
|
142 |
+
parser.add_argument(
|
143 |
+
'--pixel_max',
|
144 |
+
help='Maximum of the value range of pixels in generated images. ' + \
|
145 |
+
'Default: %(default)s',
|
146 |
+
default=1,
|
147 |
+
type=float,
|
148 |
+
metavar='VALUE'
|
149 |
+
)
|
150 |
+
|
151 |
+
parser.add_argument(
|
152 |
+
'--gpu',
|
153 |
+
help='CUDA device indices (given as separate ' + \
|
154 |
+
'values if multiple, i.e. "--gpu 0 1"). Default: Use CPU',
|
155 |
+
type=int,
|
156 |
+
default=[],
|
157 |
+
nargs='*',
|
158 |
+
metavar='INDEX'
|
159 |
+
)
|
160 |
+
|
161 |
+
#----------------------------------------------------------------------------
|
162 |
+
|
163 |
+
def get_arg_parser():
|
164 |
+
parser = argparse.ArgumentParser(
|
165 |
+
description=_description,
|
166 |
+
epilog=_examples,
|
167 |
+
formatter_class=argparse.RawDescriptionHelpFormatter
|
168 |
+
)
|
169 |
+
|
170 |
+
range_desc = 'NOTE: This is a single argument, where list ' + \
|
171 |
+
'elements are separated by "," and ranges are defined as "a-b". ' + \
|
172 |
+
'Only integers are allowed.'
|
173 |
+
|
174 |
+
subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
|
175 |
+
|
176 |
+
project_generated_images_parser = subparsers.add_parser(
|
177 |
+
'project_generated_images', help='Project generated images')
|
178 |
+
|
179 |
+
project_generated_images_parser.add_argument(
|
180 |
+
'--seeds',
|
181 |
+
help='List of random seeds for generating images. ' + \
|
182 |
+
'Default: 66,230,389,1518. ' + range_desc,
|
183 |
+
type=utils.range_type,
|
184 |
+
default=[66, 230, 389, 1518],
|
185 |
+
metavar='RANGE'
|
186 |
+
)
|
187 |
+
|
188 |
+
project_generated_images_parser.add_argument(
|
189 |
+
'--truncation_psi',
|
190 |
+
help='Truncation psi. Default: %(default)s',
|
191 |
+
type=float,
|
192 |
+
default=1.0,
|
193 |
+
metavar='VALUE'
|
194 |
+
)
|
195 |
+
|
196 |
+
_add_shared_arguments(project_generated_images_parser)
|
197 |
+
|
198 |
+
project_real_images_parser = subparsers.add_parser(
|
199 |
+
'project_real_images', help='Project real images')
|
200 |
+
|
201 |
+
project_real_images_parser.add_argument(
|
202 |
+
'--data_dir',
|
203 |
+
help='Dataset root directory',
|
204 |
+
type=str,
|
205 |
+
required=True,
|
206 |
+
metavar='DIR'
|
207 |
+
)
|
208 |
+
|
209 |
+
project_real_images_parser.add_argument(
|
210 |
+
'--seed',
|
211 |
+
help='When there are more images available than ' + \
|
212 |
+
'the number that is going to be projected this ' + \
|
213 |
+
'seed is used for picking samples. Default: %(default)s',
|
214 |
+
type=int,
|
215 |
+
default=1234,
|
216 |
+
metavar='VALUE'
|
217 |
+
)
|
218 |
+
|
219 |
+
project_real_images_parser.add_argument(
|
220 |
+
'--num_images',
|
221 |
+
type=int,
|
222 |
+
help='Number of images to project. Default: %(default)s',
|
223 |
+
default=3,
|
224 |
+
metavar='VALUE'
|
225 |
+
)
|
226 |
+
|
227 |
+
_add_shared_arguments(project_real_images_parser)
|
228 |
+
|
229 |
+
return parser
|
230 |
+
|
231 |
+
#----------------------------------------------------------------------------
|
232 |
+
|
233 |
+
def project_images(G, images, name_prefix, args):
|
234 |
+
|
235 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
236 |
+
if device.index is not None:
|
237 |
+
torch.cuda.set_device(device.index)
|
238 |
+
if len(args.gpu) > 1:
|
239 |
+
warnings.warn(
|
240 |
+
'Multi GPU is not available for projection. ' + \
|
241 |
+
'Using device {}'.format(device)
|
242 |
+
)
|
243 |
+
G = utils.unwrap_module(G).to(device)
|
244 |
+
|
245 |
+
lpips_model = stylegan2.external_models.lpips.LPIPS_VGG16(
|
246 |
+
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
247 |
+
|
248 |
+
proj = stylegan2.project.Projector(
|
249 |
+
G=G,
|
250 |
+
dlatent_avg_samples=10000,
|
251 |
+
dlatent_avg_label=args.label,
|
252 |
+
dlatent_device=device,
|
253 |
+
dlatent_batch_size=1024,
|
254 |
+
lpips_model=lpips_model,
|
255 |
+
lpips_size=256
|
256 |
+
)
|
257 |
+
|
258 |
+
for i in range(0, len(images), args.batch_size):
|
259 |
+
target = images[i: i + args.batch_size]
|
260 |
+
proj.start(
|
261 |
+
target=target,
|
262 |
+
num_steps=args.num_steps,
|
263 |
+
initial_learning_rate=args.initial_learning_rate,
|
264 |
+
initial_noise_factor=args.initial_noise_factor,
|
265 |
+
lr_rampdown_length=args.lr_rampdown_length,
|
266 |
+
lr_rampup_length=args.lr_rampup_length,
|
267 |
+
noise_ramp_length=args.noise_ramp_length,
|
268 |
+
regularize_noise_weight=args.regularize_noise_weight,
|
269 |
+
verbose=True,
|
270 |
+
verbose_prefix='Projecting image(s) {}/{}'.format(
|
271 |
+
i * args.batch_size + len(target), len(images))
|
272 |
+
)
|
273 |
+
snapshot_steps = set(
|
274 |
+
args.num_steps - np.linspace(
|
275 |
+
0, args.num_steps, args.num_snapshots, endpoint=False, dtype=int))
|
276 |
+
for k, image in enumerate(
|
277 |
+
utils.tensor_to_PIL(target, pixel_min=args.pixel_min, pixel_max=args.pixel_max)):
|
278 |
+
image.save(os.path.join(args.output, name_prefix[i + k] + 'target.png'))
|
279 |
+
for j in range(args.num_steps):
|
280 |
+
proj.step()
|
281 |
+
if j in snapshot_steps:
|
282 |
+
generated = utils.tensor_to_PIL(
|
283 |
+
proj.generate(), pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
284 |
+
for k, image in enumerate(generated):
|
285 |
+
image.save(os.path.join(
|
286 |
+
args.output, name_prefix[i + k] + 'step%04d.png' % (j + 1)))
|
287 |
+
|
288 |
+
#----------------------------------------------------------------------------
|
289 |
+
|
290 |
+
def project_generated_images(G, args):
|
291 |
+
latent_size, label_size = G.latent_size, G.label_size
|
292 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
293 |
+
if device.index is not None:
|
294 |
+
torch.cuda.set_device(device.index)
|
295 |
+
G.to(device)
|
296 |
+
if len(args.gpu) > 1:
|
297 |
+
warnings.warn(
|
298 |
+
'Noise can not be randomized based on the seed ' + \
|
299 |
+
'when using more than 1 GPU device. Noise will ' + \
|
300 |
+
'now be randomized from default random state.'
|
301 |
+
)
|
302 |
+
G.random_noise()
|
303 |
+
G = torch.nn.DataParallel(G, device_ids=args.gpu)
|
304 |
+
else:
|
305 |
+
noise_reference = G.static_noise()
|
306 |
+
|
307 |
+
def get_batch(seeds):
|
308 |
+
latents = []
|
309 |
+
labels = []
|
310 |
+
if len(args.gpu) <= 1:
|
311 |
+
noise_tensors = [[] for _ in noise_reference]
|
312 |
+
for seed in seeds:
|
313 |
+
rnd = np.random.RandomState(seed)
|
314 |
+
latents.append(torch.from_numpy(rnd.randn(latent_size)))
|
315 |
+
if len(args.gpu) <= 1:
|
316 |
+
for i, ref in enumerate(noise_reference):
|
317 |
+
noise_tensors[i].append(
|
318 |
+
torch.from_numpy(rnd.randn(*ref.size()[1:])))
|
319 |
+
if label_size:
|
320 |
+
labels.append(torch.tensor([rnd.randint(0, label_size)]))
|
321 |
+
latents = torch.stack(latents, dim=0).to(device=device, dtype=torch.float32)
|
322 |
+
if labels:
|
323 |
+
labels = torch.cat(labels, dim=0).to(device=device, dtype=torch.int64)
|
324 |
+
else:
|
325 |
+
labels = None
|
326 |
+
if len(args.gpu) <= 1:
|
327 |
+
noise_tensors = [
|
328 |
+
torch.stack(noise, dim=0).to(device=device, dtype=torch.float32)
|
329 |
+
for noise in noise_tensors
|
330 |
+
]
|
331 |
+
else:
|
332 |
+
noise_tensors = None
|
333 |
+
return latents, labels, noise_tensors
|
334 |
+
|
335 |
+
images = []
|
336 |
+
|
337 |
+
progress = utils.ProgressWriter(len(args.seeds))
|
338 |
+
progress.write('Generating images...', step=False)
|
339 |
+
|
340 |
+
for i in range(0, len(args.seeds), args.batch_size):
|
341 |
+
latents, labels, noise_tensors = get_batch(args.seeds[i: i + args.batch_size])
|
342 |
+
if noise_tensors is not None:
|
343 |
+
G.static_noise(noise_tensors=noise_tensors)
|
344 |
+
with torch.no_grad():
|
345 |
+
images.append(G(latents, labels=labels))
|
346 |
+
progress.step()
|
347 |
+
|
348 |
+
images = torch.cat(images, dim=0)
|
349 |
+
|
350 |
+
progress.write('Done!', step=False)
|
351 |
+
progress.close()
|
352 |
+
|
353 |
+
name_prefix = ['seed%04d-' % seed for seed in args.seeds]
|
354 |
+
project_images(G, images, name_prefix, args)
|
355 |
+
|
356 |
+
#----------------------------------------------------------------------------
|
357 |
+
|
358 |
+
def project_real_images(G, args):
|
359 |
+
device = torch.device(args.gpu[0] if args.gpu else 'cpu')
|
360 |
+
print('Loading images from "%s"...' % args.data_dir)
|
361 |
+
dataset = utils.ImageFolder(
|
362 |
+
args.data_dir, pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
363 |
+
|
364 |
+
rnd = np.random.RandomState(args.seed)
|
365 |
+
indices = rnd.choice(
|
366 |
+
len(dataset), size=min(args.num_images, len(dataset)), replace=False)
|
367 |
+
images = []
|
368 |
+
for i in indices:
|
369 |
+
data = dataset[i]
|
370 |
+
if isinstance(data, (tuple, list)):
|
371 |
+
data = data[0]
|
372 |
+
images.append(data)
|
373 |
+
images = torch.stack(images).to(device)
|
374 |
+
name_prefix = ['image%04d-' % i for i in indices]
|
375 |
+
print('Done!')
|
376 |
+
project_images(G, images, name_prefix, args)
|
377 |
+
|
378 |
+
#----------------------------------------------------------------------------
|
379 |
+
|
380 |
+
def main():
|
381 |
+
args = get_arg_parser().parse_args()
|
382 |
+
assert args.command, 'Missing subcommand.'
|
383 |
+
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
|
384 |
+
'--output argument should specify a directory, not a file.'
|
385 |
+
if not os.path.exists(args.output):
|
386 |
+
os.makedirs(args.output)
|
387 |
+
|
388 |
+
G = stylegan2.models.load(args.network)
|
389 |
+
assert isinstance(G, stylegan2.models.Generator), 'Model type has to be ' + \
|
390 |
+
'stylegan2.models.Generator. Found {}.'.format(type(G))
|
391 |
+
|
392 |
+
if args.command == 'project_generated_images':
|
393 |
+
project_generated_images(G, args)
|
394 |
+
elif args.command == 'project_real_images':
|
395 |
+
project_real_images(G, args)
|
396 |
+
else:
|
397 |
+
raise TypeError('Unkown command {}'.format(args.command))
|
398 |
+
|
399 |
+
|
400 |
+
if __name__ == '__main__':
|
401 |
+
main()
|
run_training.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch import multiprocessing as mp
|
5 |
+
import stylegan2
|
6 |
+
from stylegan2 import utils
|
7 |
+
from stylegan2.external_models import inception, lpips
|
8 |
+
from stylegan2.metrics import fid, ppl
|
9 |
+
|
10 |
+
#----------------------------------------------------------------------------
|
11 |
+
|
12 |
+
def get_arg_parser():
|
13 |
+
parser = utils.ConfigArgumentParser()
|
14 |
+
|
15 |
+
parser.add_argument(
|
16 |
+
'--output',
|
17 |
+
help='Output directory for model weights.',
|
18 |
+
type=str,
|
19 |
+
default=None,
|
20 |
+
metavar='DIR'
|
21 |
+
)
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
# Model options
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
'--channels',
|
28 |
+
help='Specify the channels for each layer (can be overriden for individual ' + \
|
29 |
+
'networks with "--g_channels" and "--d_channels". ' + \
|
30 |
+
'Default: %(default)s',
|
31 |
+
nargs='*',
|
32 |
+
type=int,
|
33 |
+
default=[32, 32, 64, 128, 256, 512, 512, 512, 512],
|
34 |
+
metavar='CHANNELS'
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
'--latent',
|
39 |
+
help='Size of the prior (noise vector). Default: %(default)s',
|
40 |
+
type=int,
|
41 |
+
default=512,
|
42 |
+
metavar='VALUE'
|
43 |
+
)
|
44 |
+
|
45 |
+
parser.add_argument(
|
46 |
+
'--label',
|
47 |
+
help='Number of unique labels. Unused if not specified.',
|
48 |
+
type=int,
|
49 |
+
default=0,
|
50 |
+
metavar='VALUE'
|
51 |
+
)
|
52 |
+
|
53 |
+
parser.add_argument(
|
54 |
+
'--base_shape',
|
55 |
+
help='Data shape of first layer in generator or ' + \
|
56 |
+
'last layer in discriminator. Default: %(default)s',
|
57 |
+
nargs=2,
|
58 |
+
type=int,
|
59 |
+
default=(4, 4),
|
60 |
+
metavar='SIZE'
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
'--kernel_size',
|
65 |
+
help='Size of conv kernel. Default: %(default)s',
|
66 |
+
type=int,
|
67 |
+
default=3,
|
68 |
+
metavar='SIZE'
|
69 |
+
)
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
'--pad_once',
|
73 |
+
help='Pad filtered convs only once before filter instead ' + \
|
74 |
+
'of twice. Default: %(default)s',
|
75 |
+
type=utils.bool_type,
|
76 |
+
const=True,
|
77 |
+
nargs='?',
|
78 |
+
default=True,
|
79 |
+
metavar='BOOL'
|
80 |
+
)
|
81 |
+
|
82 |
+
parser.add_argument(
|
83 |
+
'--pad_mode',
|
84 |
+
help='Padding mode for conv layers. Default: %(default)s',
|
85 |
+
type=str,
|
86 |
+
default='constant',
|
87 |
+
metavar='MODE'
|
88 |
+
)
|
89 |
+
|
90 |
+
parser.add_argument(
|
91 |
+
'--pad_constant',
|
92 |
+
help='Padding constant for conv layers when `pad_mode` is ' + \
|
93 |
+
'\'constant\'. Default: %(default)s',
|
94 |
+
type=float,
|
95 |
+
default=0,
|
96 |
+
metavar='VALUE'
|
97 |
+
)
|
98 |
+
|
99 |
+
parser.add_argument(
|
100 |
+
'--filter_pad_mode',
|
101 |
+
help='Padding mode for filter layers. Default: %(default)s',
|
102 |
+
type=str,
|
103 |
+
default='constant',
|
104 |
+
metavar='MODE'
|
105 |
+
)
|
106 |
+
|
107 |
+
parser.add_argument(
|
108 |
+
'--filter_pad_constant',
|
109 |
+
help='Padding constant for filter layers when `filter_pad_mode` ' + \
|
110 |
+
'is \'constant\'. Default: %(default)s',
|
111 |
+
type=float,
|
112 |
+
default=0,
|
113 |
+
metavar='VALUE'
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
'--filter',
|
118 |
+
help='Filter to use whenever FIR is applied. Default: %(default)s',
|
119 |
+
nargs='*',
|
120 |
+
type=float,
|
121 |
+
default=[1, 3, 3, 1],
|
122 |
+
metavar='VALUE'
|
123 |
+
)
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
'--weight_scale',
|
127 |
+
help='Use weight scaling for equalized learning rate. Default: %(default)s',
|
128 |
+
type=utils.bool_type,
|
129 |
+
const=True,
|
130 |
+
nargs='?',
|
131 |
+
default=True,
|
132 |
+
metavar='BOOL'
|
133 |
+
)
|
134 |
+
|
135 |
+
#----------------------------------------------------------------------------
|
136 |
+
# Generator options
|
137 |
+
|
138 |
+
parser.add_argument(
|
139 |
+
'--g_file',
|
140 |
+
help='Load a generator model from a file instead of constructing a new one. Disabled unless a file is specified.',
|
141 |
+
type=str,
|
142 |
+
default=None,
|
143 |
+
metavar='FILE'
|
144 |
+
)
|
145 |
+
|
146 |
+
parser.add_argument(
|
147 |
+
'--g_channels',
|
148 |
+
help='Instead of the values of "--channels", ' + \
|
149 |
+
'use these for the generator instead.',
|
150 |
+
nargs='*',
|
151 |
+
type=int,
|
152 |
+
default=[],
|
153 |
+
metavar='CHANNELS'
|
154 |
+
)
|
155 |
+
|
156 |
+
parser.add_argument(
|
157 |
+
'--g_skip',
|
158 |
+
help='Use skip connections for the generator. Default: %(default)s',
|
159 |
+
type=utils.bool_type,
|
160 |
+
const=True,
|
161 |
+
nargs='?',
|
162 |
+
default=True,
|
163 |
+
metavar='BOOL'
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument(
|
167 |
+
'--g_resnet',
|
168 |
+
help='Use resnet connections for the generator. Default: %(default)s',
|
169 |
+
type=utils.bool_type,
|
170 |
+
const=True,
|
171 |
+
nargs='?',
|
172 |
+
default=False,
|
173 |
+
metavar='BOOL'
|
174 |
+
)
|
175 |
+
|
176 |
+
parser.add_argument(
|
177 |
+
'--g_conv_block_size',
|
178 |
+
help='Number of layers in a conv block in the generator. Default: %(default)s',
|
179 |
+
type=int,
|
180 |
+
default=2,
|
181 |
+
metavar='VALUE'
|
182 |
+
)
|
183 |
+
|
184 |
+
parser.add_argument(
|
185 |
+
'--g_normalize',
|
186 |
+
help='Normalize conv features for generator. Default: %(default)s',
|
187 |
+
type=utils.bool_type,
|
188 |
+
const=True,
|
189 |
+
nargs='?',
|
190 |
+
default=True,
|
191 |
+
metavar='BOOL'
|
192 |
+
)
|
193 |
+
|
194 |
+
parser.add_argument(
|
195 |
+
'--g_fused_conv',
|
196 |
+
help='Fuse conv & upsample into a transposed ' + \
|
197 |
+
'conv for the generator. Default: %(default)s',
|
198 |
+
type=utils.bool_type,
|
199 |
+
const=True,
|
200 |
+
nargs='?',
|
201 |
+
default=True,
|
202 |
+
metavar='BOOL'
|
203 |
+
)
|
204 |
+
|
205 |
+
parser.add_argument(
|
206 |
+
'--g_activation',
|
207 |
+
help='The non-linear activaiton function for ' + \
|
208 |
+
'the generator. Default: %(default)s',
|
209 |
+
default='leaky:0.2',
|
210 |
+
type=str,
|
211 |
+
metavar='ACTIVATION'
|
212 |
+
)
|
213 |
+
|
214 |
+
parser.add_argument(
|
215 |
+
'--g_conv_resample_mode',
|
216 |
+
help='Resample mode for upsampling conv ' + \
|
217 |
+
'layers for generator. Default: %(default)s',
|
218 |
+
type=str,
|
219 |
+
default='FIR',
|
220 |
+
metavar='MODE'
|
221 |
+
)
|
222 |
+
|
223 |
+
parser.add_argument(
|
224 |
+
'--g_skip_resample_mode',
|
225 |
+
help='Resample mode for skip connection ' + \
|
226 |
+
'upsamples for the generator. Default: %(default)s',
|
227 |
+
type=str,
|
228 |
+
default='FIR',
|
229 |
+
metavar='MODE'
|
230 |
+
)
|
231 |
+
|
232 |
+
parser.add_argument(
|
233 |
+
'--g_lr',
|
234 |
+
help='The learning rate for the generator. Default: %(default)s',
|
235 |
+
default=2e-3,
|
236 |
+
type=float,
|
237 |
+
metavar='VALUE'
|
238 |
+
)
|
239 |
+
|
240 |
+
parser.add_argument(
|
241 |
+
'--g_betas',
|
242 |
+
help='Beta values for the generator Adam optimizer. Default: %(default)s',
|
243 |
+
type=float,
|
244 |
+
nargs=2,
|
245 |
+
default=(0, 0.99),
|
246 |
+
metavar='VALUE'
|
247 |
+
)
|
248 |
+
|
249 |
+
parser.add_argument(
|
250 |
+
'--g_loss',
|
251 |
+
help='Loss function for the generator. Default: %(default)s',
|
252 |
+
default='logistic_ns',
|
253 |
+
type=str,
|
254 |
+
metavar='LOSS'
|
255 |
+
)
|
256 |
+
|
257 |
+
parser.add_argument(
|
258 |
+
'--g_reg',
|
259 |
+
help='Regularization function for the generator with an optional weight (:?). Default: %(default)s',
|
260 |
+
default='pathreg:2',
|
261 |
+
type=str,
|
262 |
+
metavar='REG'
|
263 |
+
)
|
264 |
+
|
265 |
+
parser.add_argument(
|
266 |
+
'--g_reg_interval',
|
267 |
+
help='Interval at which to regularize the generator. Default: %(default)s',
|
268 |
+
default=4,
|
269 |
+
type=int,
|
270 |
+
metavar='INTERVAL'
|
271 |
+
)
|
272 |
+
|
273 |
+
parser.add_argument(
|
274 |
+
'--g_iter',
|
275 |
+
help='Number of generator iterations per training iteration. Default: %(default)s',
|
276 |
+
default=1,
|
277 |
+
type=int,
|
278 |
+
metavar='ITER'
|
279 |
+
)
|
280 |
+
|
281 |
+
parser.add_argument(
|
282 |
+
'--style_mix',
|
283 |
+
help='The probability of passing more than one ' + \
|
284 |
+
'latent to the generator. Default: %(default)s',
|
285 |
+
type=float,
|
286 |
+
default=0.9,
|
287 |
+
metavar='PROBABILITY'
|
288 |
+
)
|
289 |
+
|
290 |
+
parser.add_argument(
|
291 |
+
'--latent_mapping_layers',
|
292 |
+
help='The number of layers of the latent mapping network. Default: %(default)s',
|
293 |
+
type=int,
|
294 |
+
default=8,
|
295 |
+
metavar='LAYERS'
|
296 |
+
)
|
297 |
+
|
298 |
+
parser.add_argument(
|
299 |
+
'--latent_mapping_lr_mul',
|
300 |
+
help='The learning rate multiplier for the latent ' + \
|
301 |
+
'mapping network. Default: %(default)s',
|
302 |
+
type=float,
|
303 |
+
default=0.01,
|
304 |
+
metavar='LR_MUL'
|
305 |
+
)
|
306 |
+
|
307 |
+
parser.add_argument(
|
308 |
+
'--normalize_latent',
|
309 |
+
help='Normalize latent inputs. Default: %(default)s',
|
310 |
+
type=utils.bool_type,
|
311 |
+
const=True,
|
312 |
+
nargs='?',
|
313 |
+
default=True,
|
314 |
+
metavar='BOOL'
|
315 |
+
)
|
316 |
+
|
317 |
+
parser.add_argument(
|
318 |
+
'--modulate_rgb',
|
319 |
+
help='Modulate RGB layers (use style for output ' + \
|
320 |
+
'layers of generator). Default: %(default)s',
|
321 |
+
type=utils.bool_type,
|
322 |
+
const=True,
|
323 |
+
nargs='?',
|
324 |
+
default=True,
|
325 |
+
metavar='BOOL'
|
326 |
+
)
|
327 |
+
|
328 |
+
#----------------------------------------------------------------------------
|
329 |
+
# Discriminator options
|
330 |
+
|
331 |
+
parser.add_argument(
|
332 |
+
'--d_file',
|
333 |
+
help='Load a discriminator model from a file instead of constructing a new one. Disabled unless a file is specified.',
|
334 |
+
type=str,
|
335 |
+
default=None,
|
336 |
+
metavar='FILE'
|
337 |
+
)
|
338 |
+
|
339 |
+
parser.add_argument(
|
340 |
+
'--d_channels',
|
341 |
+
help='Instead of the values of "--channels", ' + \
|
342 |
+
'use these for the discriminator instead.',
|
343 |
+
nargs='*',
|
344 |
+
type=int,
|
345 |
+
default=[],
|
346 |
+
metavar='CHANNELS'
|
347 |
+
)
|
348 |
+
|
349 |
+
parser.add_argument(
|
350 |
+
'--d_skip',
|
351 |
+
help='Use skip connections for the discriminator. Default: %(default)s',
|
352 |
+
type=utils.bool_type,
|
353 |
+
const=True,
|
354 |
+
nargs='?',
|
355 |
+
default=False,
|
356 |
+
metavar='BOOL'
|
357 |
+
)
|
358 |
+
|
359 |
+
parser.add_argument(
|
360 |
+
'--d_resnet',
|
361 |
+
help='Use resnet connections for the discriminator. Default: %(default)s',
|
362 |
+
type=utils.bool_type,
|
363 |
+
const=True,
|
364 |
+
nargs='?',
|
365 |
+
default=True,
|
366 |
+
metavar='BOOL'
|
367 |
+
)
|
368 |
+
|
369 |
+
parser.add_argument(
|
370 |
+
'--d_conv_block_size',
|
371 |
+
help='Number of layers in a conv block in the discriminator. Default: %(default)s',
|
372 |
+
type=int,
|
373 |
+
default=2,
|
374 |
+
metavar='VALUE'
|
375 |
+
)
|
376 |
+
|
377 |
+
parser.add_argument(
|
378 |
+
'--d_fused_conv',
|
379 |
+
help='Fuse conv & downsample into a strided ' + \
|
380 |
+
'conv for the discriminator. Default: %(default)s',
|
381 |
+
type=utils.bool_type,
|
382 |
+
const=True,
|
383 |
+
nargs='?',
|
384 |
+
default=True,
|
385 |
+
metavar='BOOL'
|
386 |
+
)
|
387 |
+
|
388 |
+
parser.add_argument(
|
389 |
+
'--group_size',
|
390 |
+
help='Size of the groups in batch std layer. Default: %(default)s',
|
391 |
+
type=int,
|
392 |
+
default=4,
|
393 |
+
metavar='VALUE'
|
394 |
+
)
|
395 |
+
|
396 |
+
parser.add_argument(
|
397 |
+
'--d_activation',
|
398 |
+
help='The non-linear activaiton function for the discriminator. Default: %(default)s',
|
399 |
+
default='leaky:0.2',
|
400 |
+
type=str,
|
401 |
+
metavar='ACTIVATION'
|
402 |
+
)
|
403 |
+
|
404 |
+
parser.add_argument(
|
405 |
+
'--d_conv_resample_mode',
|
406 |
+
help='Resample mode for downsampling conv ' + \
|
407 |
+
'layers for discriminator. Default: %(default)s',
|
408 |
+
type=str,
|
409 |
+
default='FIR',
|
410 |
+
metavar='MODE'
|
411 |
+
)
|
412 |
+
|
413 |
+
parser.add_argument(
|
414 |
+
'--d_skip_resample_mode',
|
415 |
+
help='Resample mode for skip connection ' + \
|
416 |
+
'downsamples for the discriminator. Default: %(default)s',
|
417 |
+
type=str,
|
418 |
+
default='FIR',
|
419 |
+
metavar='MODE'
|
420 |
+
)
|
421 |
+
|
422 |
+
parser.add_argument(
|
423 |
+
'--d_loss',
|
424 |
+
help='Loss function for the disriminator. Default: %(default)s',
|
425 |
+
default='logistic',
|
426 |
+
type=str,
|
427 |
+
metavar='LOSS'
|
428 |
+
)
|
429 |
+
|
430 |
+
parser.add_argument(
|
431 |
+
'--d_reg',
|
432 |
+
help='Regularization function for the discriminator ' + \
|
433 |
+
'with an optional weight (:?). Default: %(default)s',
|
434 |
+
default='r1:10',
|
435 |
+
type=str,
|
436 |
+
metavar='REG'
|
437 |
+
)
|
438 |
+
|
439 |
+
parser.add_argument(
|
440 |
+
'--d_reg_interval',
|
441 |
+
help='Interval at which to regularize the discriminator. Default: %(default)s',
|
442 |
+
default=16,
|
443 |
+
type=int,
|
444 |
+
metavar='INTERVAL'
|
445 |
+
)
|
446 |
+
|
447 |
+
parser.add_argument(
|
448 |
+
'--d_iter',
|
449 |
+
help='Number of discriminator iterations per training iteration. Default: %(default)s',
|
450 |
+
default=1,
|
451 |
+
type=int,
|
452 |
+
metavar='ITER'
|
453 |
+
)
|
454 |
+
|
455 |
+
parser.add_argument(
|
456 |
+
'--d_lr',
|
457 |
+
help='The learning rate for the discriminator. Default: %(default)s',
|
458 |
+
default=2e-3,
|
459 |
+
type=float,
|
460 |
+
metavar='VALUE'
|
461 |
+
)
|
462 |
+
|
463 |
+
parser.add_argument(
|
464 |
+
'--d_betas',
|
465 |
+
help='Beta values for the discriminator Adam optimizer. Default: %(default)s',
|
466 |
+
type=float,
|
467 |
+
nargs=2,
|
468 |
+
default=(0, 0.99),
|
469 |
+
metavar='VALUE'
|
470 |
+
)
|
471 |
+
|
472 |
+
#----------------------------------------------------------------------------
|
473 |
+
# Training options
|
474 |
+
|
475 |
+
parser.add_argument(
|
476 |
+
'--iterations',
|
477 |
+
help='Number of iterations to train for. Default: %(default)s',
|
478 |
+
type=int,
|
479 |
+
default=1000000,
|
480 |
+
metavar='ITERATIONS'
|
481 |
+
)
|
482 |
+
|
483 |
+
parser.add_argument(
|
484 |
+
'--gpu',
|
485 |
+
help='The cuda device(s) to use. Example: ""--gpu 0 1" will train ' + \
|
486 |
+
'on GPU 0 and GPU 1. Default: Only use CPU',
|
487 |
+
type=int,
|
488 |
+
default=[],
|
489 |
+
nargs='*',
|
490 |
+
metavar='DEVICE_ID'
|
491 |
+
)
|
492 |
+
|
493 |
+
parser.add_argument(
|
494 |
+
'--distributed',
|
495 |
+
help='When more than one gpu device is passed, automatically ' + \
|
496 |
+
'start one process for each device and give it the correct ' + \
|
497 |
+
'distributed args (rank, world_size etc). Disable this if ' + \
|
498 |
+
'you want training to be performed with only one process ' + \
|
499 |
+
'using the DataParallel module. Default: %(default)s',
|
500 |
+
type=utils.bool_type,
|
501 |
+
const=True,
|
502 |
+
nargs='?',
|
503 |
+
default=True,
|
504 |
+
metavar='BOOL'
|
505 |
+
)
|
506 |
+
|
507 |
+
parser.add_argument(
|
508 |
+
'--rank',
|
509 |
+
help='Rank for distributed training.',
|
510 |
+
type=int,
|
511 |
+
default=None,
|
512 |
+
)
|
513 |
+
|
514 |
+
parser.add_argument(
|
515 |
+
'--world_size',
|
516 |
+
help='World size for distributed training.',
|
517 |
+
type=int,
|
518 |
+
default=None,
|
519 |
+
)
|
520 |
+
|
521 |
+
parser.add_argument(
|
522 |
+
'--master_addr',
|
523 |
+
help='Address for distributed training.',
|
524 |
+
type=str,
|
525 |
+
default=None,
|
526 |
+
)
|
527 |
+
|
528 |
+
parser.add_argument(
|
529 |
+
'--master_port',
|
530 |
+
help='Port for distributed training.',
|
531 |
+
type=str,
|
532 |
+
default=None,
|
533 |
+
)
|
534 |
+
|
535 |
+
parser.add_argument(
|
536 |
+
'--batch_size',
|
537 |
+
help='Size of each batch. Default: %(default)s',
|
538 |
+
default=32,
|
539 |
+
type=int,
|
540 |
+
metavar='VALUE'
|
541 |
+
)
|
542 |
+
|
543 |
+
parser.add_argument(
|
544 |
+
'--device_batch_size',
|
545 |
+
help='Maximum number of items to fit on single device at a time. Default: %(default)s',
|
546 |
+
default=4,
|
547 |
+
type=int,
|
548 |
+
metavar='VALUE'
|
549 |
+
)
|
550 |
+
|
551 |
+
parser.add_argument(
|
552 |
+
'--g_reg_batch_size',
|
553 |
+
help='Size of each batch used to regularize the generator. Default: %(default)s',
|
554 |
+
default=16,
|
555 |
+
type=int,
|
556 |
+
metavar='VALUE'
|
557 |
+
)
|
558 |
+
|
559 |
+
parser.add_argument(
|
560 |
+
'--g_reg_device_batch_size',
|
561 |
+
help='Maximum number of items to fit on single device when ' + \
|
562 |
+
'regularizing the generator. Default: %(default)s',
|
563 |
+
default=2,
|
564 |
+
type=int,
|
565 |
+
metavar='VALUE'
|
566 |
+
)
|
567 |
+
|
568 |
+
parser.add_argument(
|
569 |
+
'--half',
|
570 |
+
help='Use mixed precision training. Default: %(default)s',
|
571 |
+
type=utils.bool_type,
|
572 |
+
const=True,
|
573 |
+
nargs='?',
|
574 |
+
default=False,
|
575 |
+
metavar='BOOL'
|
576 |
+
)
|
577 |
+
|
578 |
+
parser.add_argument(
|
579 |
+
'--resume',
|
580 |
+
help='Resume from the latest saved checkpoint in the checkpoint_dir. ' + \
|
581 |
+
'This loads all previous training settings except for the dataset options, ' + \
|
582 |
+
'device args (--gpu ...) and distributed training args (--rank, --world_size e.t.c) ' + \
|
583 |
+
'as well as metrics and logging.',
|
584 |
+
type=utils.bool_type,
|
585 |
+
const=True,
|
586 |
+
nargs='?',
|
587 |
+
default=False,
|
588 |
+
metavar='BOOL'
|
589 |
+
)
|
590 |
+
|
591 |
+
#----------------------------------------------------------------------------
|
592 |
+
# Extra metric options
|
593 |
+
|
594 |
+
parser.add_argument(
|
595 |
+
'--fid_interval',
|
596 |
+
help='If specified, evaluate the FID metric with this interval.',
|
597 |
+
default=None,
|
598 |
+
type=int,
|
599 |
+
metavar='INTERVAL'
|
600 |
+
)
|
601 |
+
|
602 |
+
parser.add_argument(
|
603 |
+
'--ppl_interval',
|
604 |
+
help='If specified, evaluate the PPL metric with this interval.',
|
605 |
+
default=None,
|
606 |
+
type=int,
|
607 |
+
metavar='INTERVAL'
|
608 |
+
)
|
609 |
+
|
610 |
+
parser.add_argument(
|
611 |
+
'--ppl_ffhq_crop',
|
612 |
+
help='Crop images evaluated for PPL with crop values for FFHQ. Default: %(default)s',
|
613 |
+
type=utils.bool_type,
|
614 |
+
const=True,
|
615 |
+
nargs='?',
|
616 |
+
default=False,
|
617 |
+
metavar='BOOL'
|
618 |
+
)
|
619 |
+
|
620 |
+
#----------------------------------------------------------------------------
|
621 |
+
# Data options
|
622 |
+
|
623 |
+
parser.add_argument(
|
624 |
+
'--pixel_min',
|
625 |
+
help='Minimum of the value range of pixels in generated images. Default: %(default)s',
|
626 |
+
default=-1,
|
627 |
+
type=float,
|
628 |
+
metavar='VALUE'
|
629 |
+
)
|
630 |
+
|
631 |
+
parser.add_argument(
|
632 |
+
'--pixel_max',
|
633 |
+
help='Maximum of the value range of pixels in generated images. Default: %(default)s',
|
634 |
+
default=1,
|
635 |
+
type=float,
|
636 |
+
metavar='VALUE'
|
637 |
+
)
|
638 |
+
|
639 |
+
parser.add_argument(
|
640 |
+
'--data_channels',
|
641 |
+
help='Number of channels in the data. Default: 3 (RGB)',
|
642 |
+
default=3,
|
643 |
+
type=int,
|
644 |
+
choices=[1, 3],
|
645 |
+
metavar='CHANNELS'
|
646 |
+
)
|
647 |
+
|
648 |
+
parser.add_argument(
|
649 |
+
'--data_dir',
|
650 |
+
help='The root directory of the dataset. This argument is required!',
|
651 |
+
type=str,
|
652 |
+
default=None
|
653 |
+
)
|
654 |
+
|
655 |
+
parser.add_argument(
|
656 |
+
'--data_resize',
|
657 |
+
help='Resize data to fit input size of discriminator. Default: %(default)s',
|
658 |
+
type=utils.bool_type,
|
659 |
+
const=True,
|
660 |
+
nargs='?',
|
661 |
+
default=False,
|
662 |
+
metavar='BOOL'
|
663 |
+
)
|
664 |
+
|
665 |
+
parser.add_argument(
|
666 |
+
'--mirror_augment',
|
667 |
+
help='Use random horizontal flipping for data images. Default: %(default)s',
|
668 |
+
type=utils.bool_type,
|
669 |
+
const=True,
|
670 |
+
nargs='?',
|
671 |
+
default=False,
|
672 |
+
metavar='BOOL'
|
673 |
+
)
|
674 |
+
|
675 |
+
parser.add_argument(
|
676 |
+
'--data_workers',
|
677 |
+
help='Number of worker processes that handles dataloading. Default: %(default)s',
|
678 |
+
default=4,
|
679 |
+
type=int,
|
680 |
+
metavar='WORKERS'
|
681 |
+
)
|
682 |
+
|
683 |
+
#----------------------------------------------------------------------------
|
684 |
+
# Logging options
|
685 |
+
|
686 |
+
parser.add_argument(
|
687 |
+
'--checkpoint_dir',
|
688 |
+
help='If specified, save checkpoints to this directory.',
|
689 |
+
default=None,
|
690 |
+
type=str,
|
691 |
+
metavar='DIR'
|
692 |
+
)
|
693 |
+
|
694 |
+
parser.add_argument(
|
695 |
+
'--checkpoint_interval',
|
696 |
+
help='Save checkpoints with this interval. Default: %(default)s',
|
697 |
+
default=10000,
|
698 |
+
type=int,
|
699 |
+
metavar='INTERVAL'
|
700 |
+
)
|
701 |
+
|
702 |
+
parser.add_argument(
|
703 |
+
'--tensorboard_log_dir',
|
704 |
+
help='Log to this tensorboard directory if specified.',
|
705 |
+
default=None,
|
706 |
+
type=str,
|
707 |
+
metavar='DIR'
|
708 |
+
)
|
709 |
+
|
710 |
+
parser.add_argument(
|
711 |
+
'--tensorboard_image_interval',
|
712 |
+
help='Log images to tensorboard with this interval if specified.',
|
713 |
+
default=None,
|
714 |
+
type=int,
|
715 |
+
metavar='INTERVAL'
|
716 |
+
)
|
717 |
+
|
718 |
+
parser.add_argument(
|
719 |
+
'--tensorboard_image_size',
|
720 |
+
help='Size of images logged to tensorboard. Default: %(default)s',
|
721 |
+
default=256,
|
722 |
+
type=int,
|
723 |
+
metavar='VALUE'
|
724 |
+
)
|
725 |
+
|
726 |
+
return parser
|
727 |
+
|
728 |
+
#----------------------------------------------------------------------------
|
729 |
+
|
730 |
+
def get_dataset(args):
|
731 |
+
assert args.data_dir, '--data_dir has to be specified.'
|
732 |
+
height, width = [
|
733 |
+
shape * 2 ** (len(args.d_channels or args.channels) - 1)
|
734 |
+
for shape in args.base_shape
|
735 |
+
]
|
736 |
+
dataset = utils.ImageFolder(
|
737 |
+
args.data_dir,
|
738 |
+
mirror=args.mirror_augment,
|
739 |
+
pixel_min=args.pixel_min,
|
740 |
+
pixel_max=args.pixel_max,
|
741 |
+
height=height,
|
742 |
+
width=width,
|
743 |
+
resize=args.data_resize,
|
744 |
+
grayscale=args.data_channels == 1
|
745 |
+
)
|
746 |
+
assert len(dataset), 'No images found at {}'.format(args.data_dir)
|
747 |
+
return dataset
|
748 |
+
|
749 |
+
#----------------------------------------------------------------------------
|
750 |
+
|
751 |
+
def get_models(args):
|
752 |
+
common_kwargs = dict(
|
753 |
+
data_channels=args.data_channels,
|
754 |
+
base_shape=args.base_shape,
|
755 |
+
conv_filter=args.filter,
|
756 |
+
skip_filter=args.filter,
|
757 |
+
kernel_size=args.kernel_size,
|
758 |
+
conv_pad_mode=args.pad_mode,
|
759 |
+
conv_pad_constant=args.pad_constant,
|
760 |
+
filter_pad_mode=args.filter_pad_mode,
|
761 |
+
filter_pad_constant=args.filter_pad_constant,
|
762 |
+
pad_once=args.pad_once,
|
763 |
+
weight_scale=args.weight_scale
|
764 |
+
)
|
765 |
+
|
766 |
+
if args.g_file:
|
767 |
+
G = stylegan2.models.load(args.g_file)
|
768 |
+
assert isinstance(G, stylegan2.models.Generator), \
|
769 |
+
'`--g_file` should specify a generator model, found {}'.format(type(G))
|
770 |
+
else:
|
771 |
+
|
772 |
+
G_M = stylegan2.models.GeneratorMapping(
|
773 |
+
latent_size=args.latent,
|
774 |
+
label_size=args.label,
|
775 |
+
num_layers=args.latent_mapping_layers,
|
776 |
+
hidden=args.latent,
|
777 |
+
activation=args.g_activation,
|
778 |
+
normalize_input=args.normalize_latent,
|
779 |
+
lr_mul=args.latent_mapping_lr_mul,
|
780 |
+
weight_scale=args.weight_scale
|
781 |
+
)
|
782 |
+
|
783 |
+
G_S = stylegan2.models.GeneratorSynthesis(
|
784 |
+
channels=args.g_channels or args.channels,
|
785 |
+
latent_size=args.latent,
|
786 |
+
demodulate=args.g_normalize,
|
787 |
+
modulate_data_out=args.modulate_rgb,
|
788 |
+
conv_block_size=args.g_conv_block_size,
|
789 |
+
activation=args.g_activation,
|
790 |
+
conv_resample_mode=args.g_conv_resample_mode,
|
791 |
+
skip_resample_mode=args.g_skip_resample_mode,
|
792 |
+
resnet=args.g_resnet,
|
793 |
+
skip=args.g_skip,
|
794 |
+
fused_resample=args.g_fused_conv,
|
795 |
+
**common_kwargs
|
796 |
+
)
|
797 |
+
|
798 |
+
G = stylegan2.models.Generator(G_mapping=G_M, G_synthesis=G_S)
|
799 |
+
|
800 |
+
if args.d_file:
|
801 |
+
D = stylegan2.models.load(args.d_file)
|
802 |
+
assert isinstance(D, stylegan2.models.Discriminator), \
|
803 |
+
'`--d_file` should specify a discriminator model, found {}'.format(type(D))
|
804 |
+
else:
|
805 |
+
D = stylegan2.models.Discriminator(
|
806 |
+
channels=args.d_channels or args.channels,
|
807 |
+
label_size=args.label,
|
808 |
+
conv_block_size=args.d_conv_block_size,
|
809 |
+
activation=args.d_activation,
|
810 |
+
conv_resample_mode=args.d_conv_resample_mode,
|
811 |
+
skip_resample_mode=args.d_skip_resample_mode,
|
812 |
+
mbstd_group_size=args.group_size,
|
813 |
+
resnet=args.d_resnet,
|
814 |
+
skip=args.d_skip,
|
815 |
+
fused_resample=args.d_fused_conv,
|
816 |
+
**common_kwargs
|
817 |
+
)
|
818 |
+
assert len(G.G_synthesis.channels) == len(D.channels), \
|
819 |
+
'While the number of channels for each layer can ' + \
|
820 |
+
'differ between generator and discriminator, the ' + \
|
821 |
+
'number of layers have to be the same. Received ' + \
|
822 |
+
'{} generator layers and {} discriminator layers.'.format(
|
823 |
+
len(G.G_synthesis.channels), len(D.channels))
|
824 |
+
|
825 |
+
return G, D
|
826 |
+
|
827 |
+
#----------------------------------------------------------------------------
|
828 |
+
|
829 |
+
def get_trainer(args):
|
830 |
+
dataset = get_dataset(args)
|
831 |
+
if args.resume and stylegan2.train._find_checkpoint(args.checkpoint_dir):
|
832 |
+
trainer = stylegan2.train.Trainer.load_checkpoint(
|
833 |
+
args.checkpoint_dir,
|
834 |
+
dataset,
|
835 |
+
device=args.gpu,
|
836 |
+
rank=args.rank,
|
837 |
+
world_size=args.world_size,
|
838 |
+
master_addr=args.master_addr,
|
839 |
+
master_port=args.master_port,
|
840 |
+
tensorboard_log_dir=args.tensorboard_log_dir
|
841 |
+
)
|
842 |
+
else:
|
843 |
+
G, D = get_models(args)
|
844 |
+
trainer = stylegan2.train.Trainer(
|
845 |
+
G=G,
|
846 |
+
D=D,
|
847 |
+
latent_size=args.latent,
|
848 |
+
dataset=dataset,
|
849 |
+
device=args.gpu,
|
850 |
+
batch_size=args.batch_size,
|
851 |
+
device_batch_size=args.device_batch_size,
|
852 |
+
label_size=args.label,
|
853 |
+
data_workers=args.data_workers,
|
854 |
+
G_loss=args.g_loss,
|
855 |
+
D_loss=args.d_loss,
|
856 |
+
G_reg=args.g_reg,
|
857 |
+
G_reg_interval=args.g_reg_interval,
|
858 |
+
G_opt_kwargs={'lr': args.g_lr, 'betas': args.g_betas},
|
859 |
+
G_reg_batch_size=args.g_reg_batch_size,
|
860 |
+
G_reg_device_batch_size=args.g_reg_device_batch_size,
|
861 |
+
D_reg=args.d_reg,
|
862 |
+
D_reg_interval=args.d_reg_interval,
|
863 |
+
D_opt_kwargs={'lr': args.d_lr, 'betas': args.d_betas},
|
864 |
+
style_mix_prob=args.style_mix,
|
865 |
+
G_iter=args.g_iter,
|
866 |
+
D_iter=args.d_iter,
|
867 |
+
tensorboard_log_dir=args.tensorboard_log_dir,
|
868 |
+
checkpoint_dir=args.checkpoint_dir,
|
869 |
+
checkpoint_interval=args.checkpoint_interval,
|
870 |
+
half=args.half,
|
871 |
+
rank=args.rank,
|
872 |
+
world_size=args.world_size,
|
873 |
+
master_addr=args.master_addr,
|
874 |
+
master_port=args.master_port
|
875 |
+
)
|
876 |
+
if args.fid_interval and not args.rank:
|
877 |
+
fid_model = inception.InceptionV3FeatureExtractor(
|
878 |
+
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
879 |
+
trainer.register_metric(
|
880 |
+
name='FID (299x299)',
|
881 |
+
eval_fn=fid.FID(
|
882 |
+
trainer.Gs,
|
883 |
+
trainer.prior_generator,
|
884 |
+
dataset=dataset,
|
885 |
+
fid_model=fid_model,
|
886 |
+
fid_size=299,
|
887 |
+
reals_batch_size=64
|
888 |
+
),
|
889 |
+
interval=args.fid_interval
|
890 |
+
)
|
891 |
+
trainer.register_metric(
|
892 |
+
name='FID',
|
893 |
+
eval_fn=fid.FID(
|
894 |
+
trainer.Gs,
|
895 |
+
trainer.prior_generator,
|
896 |
+
dataset=dataset,
|
897 |
+
fid_model=fid_model,
|
898 |
+
fid_size=None
|
899 |
+
),
|
900 |
+
interval=args.fid_interval
|
901 |
+
)
|
902 |
+
if args.ppl_interval and not args.rank:
|
903 |
+
lpips_model = lpips.LPIPS_VGG16(
|
904 |
+
pixel_min=args.pixel_min, pixel_max=args.pixel_max)
|
905 |
+
crop = None
|
906 |
+
if args.ppl_ffhq_crop:
|
907 |
+
crop = ppl.PPL.FFHQ_CROP
|
908 |
+
trainer.register_metric(
|
909 |
+
name='PPL_end',
|
910 |
+
eval_fn=ppl.PPL(
|
911 |
+
trainer.Gs,
|
912 |
+
trainer.prior_generator,
|
913 |
+
full_sampling=False,
|
914 |
+
crop=crop,
|
915 |
+
lpips_model=lpips_model,
|
916 |
+
lpips_size=256
|
917 |
+
),
|
918 |
+
interval=args.ppl_interval
|
919 |
+
)
|
920 |
+
trainer.register_metric(
|
921 |
+
name='PPL_full',
|
922 |
+
eval_fn=ppl.PPL(
|
923 |
+
trainer.Gs,
|
924 |
+
trainer.prior_generator,
|
925 |
+
full_sampling=True,
|
926 |
+
crop=crop,
|
927 |
+
lpips_model=lpips_model,
|
928 |
+
lpips_size=256
|
929 |
+
),
|
930 |
+
interval=args.ppl_interval
|
931 |
+
)
|
932 |
+
if args.tensorboard_image_interval:
|
933 |
+
for static in [True, False]:
|
934 |
+
for trunc in [0.5, 0.7, 1.0]:
|
935 |
+
if static:
|
936 |
+
name = 'static'
|
937 |
+
else:
|
938 |
+
name = 'random'
|
939 |
+
name += '/trunc_{:.1f}'.format(trunc)
|
940 |
+
trainer.add_tensorboard_image_logging(
|
941 |
+
name=name,
|
942 |
+
num_images=4,
|
943 |
+
interval=args.tensorboard_image_interval,
|
944 |
+
resize=args.tensorboard_image_size,
|
945 |
+
seed=1234567890 if static else None,
|
946 |
+
truncation_psi=trunc,
|
947 |
+
pixel_min=args.pixel_min,
|
948 |
+
pixel_max=args.pixel_max
|
949 |
+
)
|
950 |
+
return trainer
|
951 |
+
|
952 |
+
#----------------------------------------------------------------------------
|
953 |
+
|
954 |
+
def run(args):
|
955 |
+
if not args.rank:
|
956 |
+
if not (args.checkpoint_dir or args.output):
|
957 |
+
warnings.warn(
|
958 |
+
'Neither an output path or checkpoint dir has been ' + \
|
959 |
+
'given. Weights from this training run will never ' + \
|
960 |
+
'be saved.'
|
961 |
+
)
|
962 |
+
if args.output:
|
963 |
+
assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \
|
964 |
+
'--output argument should specify a directory, not a file.'
|
965 |
+
trainer = get_trainer(args)
|
966 |
+
trainer.train(iterations=args.iterations)
|
967 |
+
if not args.rank and args.output:
|
968 |
+
print('Saving models to {}'.format(args.output))
|
969 |
+
if not os.path.exists(args.output):
|
970 |
+
os.makedirs(args.output)
|
971 |
+
for model_name in ['G', 'D', 'Gs']:
|
972 |
+
getattr(trainer, model_name).save(
|
973 |
+
os.path.join(args.output_dir, model_name + '.pth'))
|
974 |
+
|
975 |
+
#----------------------------------------------------------------------------
|
976 |
+
|
977 |
+
def run_distributed(rank, args):
|
978 |
+
args.rank = rank
|
979 |
+
args.world_size = len(args.gpu)
|
980 |
+
args.gpu = args.gpu[rank]
|
981 |
+
args.master_addr = args.master_addr or '127.0.0.1'
|
982 |
+
args.master_port = args.master_port or '23456'
|
983 |
+
run(args)
|
984 |
+
|
985 |
+
#----------------------------------------------------------------------------
|
986 |
+
|
987 |
+
def main():
|
988 |
+
parser = get_arg_parser()
|
989 |
+
args = parser.parse_args()
|
990 |
+
if len(args.gpu) > 1 and args.distributed:
|
991 |
+
assert args.rank is None and args.world_size is None, \
|
992 |
+
'When --distributed is enabled (default) the rank and ' + \
|
993 |
+
'world size can not be given as this is set up automatically. ' + \
|
994 |
+
'Use --distributed 0 to disable automatic setup of distributed training.'
|
995 |
+
mp.spawn(run_distributed, nprocs=len(args.gpu), args=(args,))
|
996 |
+
else:
|
997 |
+
run(args)
|
998 |
+
|
999 |
+
#----------------------------------------------------------------------------
|
1000 |
+
|
1001 |
+
if __name__ == '__main__':
|
1002 |
+
main()
|
stylegan2/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import external_models
|
2 |
+
from . import metrics
|
3 |
+
from . import models
|
4 |
+
from . import project
|
5 |
+
from . import train
|
stylegan2/external_models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import inception
|
2 |
+
from . import lpips
|
stylegan2/external_models/inception.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from https://github.com/mseitzer/pytorch-fid/
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
Unless required by applicable law or agreed to in writing, software
|
9 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
See the License for the specific language governing permissions and
|
12 |
+
limitations under the License.
|
13 |
+
"""
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torchvision import models
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torchvision.models.utils import load_state_dict_from_url
|
21 |
+
except ImportError:
|
22 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
23 |
+
|
24 |
+
# Inception weights ported to Pytorch from
|
25 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
26 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
27 |
+
|
28 |
+
|
29 |
+
class InceptionV3FeatureExtractor(nn.Module):
|
30 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
31 |
+
|
32 |
+
# Index of default block of inception to return,
|
33 |
+
# corresponds to output of final average pooling
|
34 |
+
DEFAULT_BLOCK_INDEX = 3
|
35 |
+
|
36 |
+
# Maps feature dimensionality to their output blocks indices
|
37 |
+
BLOCK_INDEX_BY_DIM = {
|
38 |
+
64: 0, # First max pooling features
|
39 |
+
192: 1, # Second max pooling featurs
|
40 |
+
768: 2, # Pre-aux classifier features
|
41 |
+
2048: 3 # Final average pooling features
|
42 |
+
}
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
output_block=DEFAULT_BLOCK_INDEX,
|
46 |
+
pixel_min=-1,
|
47 |
+
pixel_max=1):
|
48 |
+
"""
|
49 |
+
Build pretrained InceptionV3
|
50 |
+
Arguments:
|
51 |
+
output_block (int): Index of block to return features of.
|
52 |
+
Possible values are:
|
53 |
+
- 0: corresponds to output of first max pooling
|
54 |
+
- 1: corresponds to output of second max pooling
|
55 |
+
- 2: corresponds to output which is fed to aux classifier
|
56 |
+
- 3: corresponds to output of final average pooling
|
57 |
+
pixel_min (float): Min value for inputs. Default value is -1.
|
58 |
+
pixel_max (float): Max value for inputs. Default value is 1.
|
59 |
+
"""
|
60 |
+
super(InceptionV3FeatureExtractor, self).__init__()
|
61 |
+
|
62 |
+
assert 0 <= output_block <= 3, '`output_block` can only be ' + \
|
63 |
+
'0 <= `output_block` <= 3.'
|
64 |
+
|
65 |
+
inception = fid_inception_v3()
|
66 |
+
|
67 |
+
blocks = []
|
68 |
+
|
69 |
+
# Block 0: input to maxpool1
|
70 |
+
block0 = [
|
71 |
+
inception.Conv2d_1a_3x3,
|
72 |
+
inception.Conv2d_2a_3x3,
|
73 |
+
inception.Conv2d_2b_3x3,
|
74 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
75 |
+
]
|
76 |
+
blocks.append(nn.Sequential(*block0))
|
77 |
+
|
78 |
+
# Block 1: maxpool1 to maxpool2
|
79 |
+
if output_block >= 1:
|
80 |
+
block1 = [
|
81 |
+
inception.Conv2d_3b_1x1,
|
82 |
+
inception.Conv2d_4a_3x3,
|
83 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
84 |
+
]
|
85 |
+
blocks.append(nn.Sequential(*block1))
|
86 |
+
|
87 |
+
# Block 2: maxpool2 to aux classifier
|
88 |
+
if output_block >= 2:
|
89 |
+
block2 = [
|
90 |
+
inception.Mixed_5b,
|
91 |
+
inception.Mixed_5c,
|
92 |
+
inception.Mixed_5d,
|
93 |
+
inception.Mixed_6a,
|
94 |
+
inception.Mixed_6b,
|
95 |
+
inception.Mixed_6c,
|
96 |
+
inception.Mixed_6d,
|
97 |
+
inception.Mixed_6e,
|
98 |
+
]
|
99 |
+
blocks.append(nn.Sequential(*block2))
|
100 |
+
|
101 |
+
# Block 3: aux classifier to final avgpool
|
102 |
+
if output_block >= 3:
|
103 |
+
block3 = [
|
104 |
+
inception.Mixed_7a,
|
105 |
+
inception.Mixed_7b,
|
106 |
+
inception.Mixed_7c,
|
107 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
108 |
+
]
|
109 |
+
blocks.append(nn.Sequential(*block3))
|
110 |
+
|
111 |
+
self.main = nn.Sequential(*blocks)
|
112 |
+
self.pixel_nin = pixel_min
|
113 |
+
self.pixel_max = pixel_max
|
114 |
+
self.requires_grad_(False)
|
115 |
+
self.eval()
|
116 |
+
|
117 |
+
def _scale(self, x):
|
118 |
+
if self.pixel_min != -1 or self.pixel_max != 1:
|
119 |
+
x = (2*x - self.pixel_min - self.pixel_max) \
|
120 |
+
/ (self.pixel_max - self.pixel_min)
|
121 |
+
return x
|
122 |
+
|
123 |
+
def forward(self, input):
|
124 |
+
"""
|
125 |
+
Get Inception feature maps.
|
126 |
+
Arguments:
|
127 |
+
input (torch.Tensor)
|
128 |
+
Returns:
|
129 |
+
feature_maps (torch.Tensor)
|
130 |
+
"""
|
131 |
+
return self.main(input)
|
132 |
+
|
133 |
+
|
134 |
+
def fid_inception_v3():
|
135 |
+
"""Build pretrained Inception model for FID computation
|
136 |
+
The Inception model for FID computation uses a different set of weights
|
137 |
+
and has a slightly different structure than torchvision's Inception.
|
138 |
+
This method first constructs torchvision's Inception and then patches the
|
139 |
+
necessary parts that are different in the FID Inception model.
|
140 |
+
"""
|
141 |
+
inception = models.inception_v3(num_classes=1008,
|
142 |
+
aux_logits=False,
|
143 |
+
pretrained=False)
|
144 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
145 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
146 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
147 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
148 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
149 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
150 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
151 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
152 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
153 |
+
|
154 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
155 |
+
inception.load_state_dict(state_dict)
|
156 |
+
return inception
|
157 |
+
|
158 |
+
|
159 |
+
class FIDInceptionA(models.inception.InceptionA):
|
160 |
+
"""InceptionA block patched for FID computation"""
|
161 |
+
def __init__(self, in_channels, pool_features):
|
162 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
branch1x1 = self.branch1x1(x)
|
166 |
+
|
167 |
+
branch5x5 = self.branch5x5_1(x)
|
168 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
169 |
+
|
170 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
171 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
172 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
173 |
+
|
174 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
175 |
+
# its average calculation
|
176 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
177 |
+
count_include_pad=False)
|
178 |
+
branch_pool = self.branch_pool(branch_pool)
|
179 |
+
|
180 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
181 |
+
return torch.cat(outputs, 1)
|
182 |
+
|
183 |
+
|
184 |
+
class FIDInceptionC(models.inception.InceptionC):
|
185 |
+
"""InceptionC block patched for FID computation"""
|
186 |
+
def __init__(self, in_channels, channels_7x7):
|
187 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
branch1x1 = self.branch1x1(x)
|
191 |
+
|
192 |
+
branch7x7 = self.branch7x7_1(x)
|
193 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
194 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
195 |
+
|
196 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
197 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
198 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
199 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
200 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
201 |
+
|
202 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
203 |
+
# its average calculation
|
204 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
205 |
+
count_include_pad=False)
|
206 |
+
branch_pool = self.branch_pool(branch_pool)
|
207 |
+
|
208 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
209 |
+
return torch.cat(outputs, 1)
|
210 |
+
|
211 |
+
|
212 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
213 |
+
"""First InceptionE block patched for FID computation"""
|
214 |
+
def __init__(self, in_channels):
|
215 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
branch1x1 = self.branch1x1(x)
|
219 |
+
|
220 |
+
branch3x3 = self.branch3x3_1(x)
|
221 |
+
branch3x3 = [
|
222 |
+
self.branch3x3_2a(branch3x3),
|
223 |
+
self.branch3x3_2b(branch3x3),
|
224 |
+
]
|
225 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
226 |
+
|
227 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
228 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
229 |
+
branch3x3dbl = [
|
230 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
231 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
232 |
+
]
|
233 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
234 |
+
|
235 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
236 |
+
# its average calculation
|
237 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
238 |
+
count_include_pad=False)
|
239 |
+
branch_pool = self.branch_pool(branch_pool)
|
240 |
+
|
241 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
242 |
+
return torch.cat(outputs, 1)
|
243 |
+
|
244 |
+
|
245 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
246 |
+
"""Second InceptionE block patched for FID computation"""
|
247 |
+
def __init__(self, in_channels):
|
248 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
branch1x1 = self.branch1x1(x)
|
252 |
+
|
253 |
+
branch3x3 = self.branch3x3_1(x)
|
254 |
+
branch3x3 = [
|
255 |
+
self.branch3x3_2a(branch3x3),
|
256 |
+
self.branch3x3_2b(branch3x3),
|
257 |
+
]
|
258 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
259 |
+
|
260 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
261 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
262 |
+
branch3x3dbl = [
|
263 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
264 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
265 |
+
]
|
266 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
267 |
+
|
268 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
269 |
+
# pooling. This is likely an error in this specific Inception
|
270 |
+
# implementation, as other Inception models use average pooling here
|
271 |
+
# (which matches the description in the paper).
|
272 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
273 |
+
branch_pool = self.branch_pool(branch_pool)
|
274 |
+
|
275 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
276 |
+
return torch.cat(outputs, 1)
|
stylegan2/external_models/lpips.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from https://github.com/richzhang/PerceptualSimilarity
|
3 |
+
|
4 |
+
Original License:
|
5 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
6 |
+
All rights reserved.
|
7 |
+
|
8 |
+
Redistribution and use in source and binary forms, with or without
|
9 |
+
modification, are permitted provided that the following conditions are met:
|
10 |
+
|
11 |
+
* Redistributions of source code must retain the above copyright notice, this
|
12 |
+
list of conditions and the following disclaimer.
|
13 |
+
|
14 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
15 |
+
this list of conditions and the following disclaimer in the documentation
|
16 |
+
and/or other materials provided with the distribution.
|
17 |
+
|
18 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
19 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
20 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
27 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 |
+
"""
|
29 |
+
import torch
|
30 |
+
from torch import nn
|
31 |
+
import torchvision
|
32 |
+
|
33 |
+
|
34 |
+
class LPIPS_VGG16(nn.Module):
|
35 |
+
_FEATURE_IDX = [0, 4, 9, 16, 23, 30]
|
36 |
+
_LINEAR_WEIGHTS_URL = 'https://github.com/richzhang/PerceptualSimilarity' + \
|
37 |
+
'/blob/master/lpips/weights/v0.1/vgg.pth?raw=true'
|
38 |
+
|
39 |
+
def __init__(self, pixel_min=-1, pixel_max=1):
|
40 |
+
super(LPIPS_VGG16, self).__init__()
|
41 |
+
features = torchvision.models.vgg16(pretrained=True).features
|
42 |
+
self.slices = nn.ModuleList()
|
43 |
+
linear_weights = torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL)
|
44 |
+
for i in range(1, len(self._FEATURE_IDX)):
|
45 |
+
idx_range = range(self._FEATURE_IDX[i - 1], self._FEATURE_IDX[i])
|
46 |
+
self.slices.append(nn.Sequential(*[features[j] for j in idx_range]))
|
47 |
+
self.linear_layers = nn.ModuleList()
|
48 |
+
for weight in torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL).values():
|
49 |
+
weight = weight.view(1, -1)
|
50 |
+
linear = nn.Linear(weight.size(1), 1, bias=False)
|
51 |
+
linear.weight.data.copy_(weight)
|
52 |
+
self.linear_layers.append(linear)
|
53 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188]).view(1, -1, 1, 1))
|
54 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450]).view(1, -1, 1, 1))
|
55 |
+
self.pixel_min = pixel_min
|
56 |
+
self.pixel_max = pixel_max
|
57 |
+
self.requires_grad_(False)
|
58 |
+
self.eval()
|
59 |
+
|
60 |
+
def _scale(self, x):
|
61 |
+
if self.pixel_min != -1 or self.pixel_max != 1:
|
62 |
+
x = (2*x - self.pixel_min - self.pixel_max) \
|
63 |
+
/ (self.pixel_max - self.pixel_min)
|
64 |
+
return (x - self.shift) / self.scale
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def _normalize_tensor(feature_maps, eps=1e-8):
|
68 |
+
rnorm = torch.rsqrt(torch.sum(feature_maps ** 2, dim=1, keepdim=True) + eps)
|
69 |
+
return feature_maps * rnorm
|
70 |
+
|
71 |
+
def forward(self, x0, x1, eps=1e-8):
|
72 |
+
x0, x1 = self._scale(x0), self._scale(x1)
|
73 |
+
dist = 0
|
74 |
+
for slice, linear in zip(self.slices, self.linear_layers):
|
75 |
+
x0, x1 = slice(x0), slice(x1)
|
76 |
+
_x0, _x1 = self._normalize_tensor(x0, eps), self._normalize_tensor(x1, eps)
|
77 |
+
dist += linear(torch.mean((_x0 - _x1) ** 2, dim=[-1, -2]))
|
78 |
+
return dist.view(-1)
|
stylegan2/loss_fns.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from . import utils
|
6 |
+
|
7 |
+
|
8 |
+
def _grad(input, output, retain_graph):
|
9 |
+
# https://discuss.pytorch.org/t/gradient-penalty-loss-with-modified-weights/64910
|
10 |
+
# Currently not possible to not
|
11 |
+
# retain graph for regularization losses.
|
12 |
+
# Ugly hack is to always set it to True.
|
13 |
+
retain_graph = True
|
14 |
+
grads = torch.autograd.grad(
|
15 |
+
output.sum(),
|
16 |
+
input,
|
17 |
+
only_inputs=True,
|
18 |
+
retain_graph=retain_graph,
|
19 |
+
create_graph=True
|
20 |
+
)
|
21 |
+
return grads[0]
|
22 |
+
|
23 |
+
|
24 |
+
def _grad_pen(input, output, gamma, constraint=1, onesided=False, retain_graph=True):
|
25 |
+
grad = _grad(input, output, retain_graph=retain_graph)
|
26 |
+
grad = grad.view(grad.size(0), -1)
|
27 |
+
grad_norm = grad.norm(2, dim=1)
|
28 |
+
if onesided:
|
29 |
+
gp = torch.max(0, grad_norm - constraint)
|
30 |
+
else:
|
31 |
+
gp = (grad_norm - constraint) ** 2
|
32 |
+
return gamma * gp.mean()
|
33 |
+
|
34 |
+
|
35 |
+
def _grad_reg(input, output, gamma, retain_graph=True):
|
36 |
+
grad = _grad(input, output, retain_graph=retain_graph)
|
37 |
+
grad = grad.view(grad.size(0), -1)
|
38 |
+
gr = (grad ** 2).sum(1)
|
39 |
+
return (0.5 * gamma) * gr.mean()
|
40 |
+
|
41 |
+
|
42 |
+
def _pathreg(dlatents, fakes, pl_avg, pl_decay, gamma, retain_graph=True):
|
43 |
+
retain_graph = True
|
44 |
+
pl_noise = torch.empty_like(fakes).normal_().div_(np.sqrt(np.prod(fakes.size()[2:])))
|
45 |
+
pl_grad = _grad(dlatents, torch.sum(pl_noise * fakes), retain_graph=retain_graph)
|
46 |
+
pl_length = torch.sqrt(torch.mean(torch.sum(pl_grad ** 2, dim=2), dim=1))
|
47 |
+
with torch.no_grad():
|
48 |
+
pl_avg.add_(pl_decay * (torch.mean(pl_length) - pl_avg))
|
49 |
+
return gamma * torch.mean((pl_length - pl_avg) ** 2)
|
50 |
+
|
51 |
+
|
52 |
+
#----------------------------------------------------------------------------
|
53 |
+
# Logistic loss from the paper
|
54 |
+
# "Generative Adversarial Nets", Goodfellow et al. 2014
|
55 |
+
|
56 |
+
|
57 |
+
def G_logistic(G,
|
58 |
+
D,
|
59 |
+
latents,
|
60 |
+
latent_labels=None,
|
61 |
+
*args,
|
62 |
+
**kwargs):
|
63 |
+
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
|
64 |
+
loss = - F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
|
65 |
+
reg = None
|
66 |
+
return loss, reg
|
67 |
+
|
68 |
+
|
69 |
+
def G_logistic_ns(G,
|
70 |
+
D,
|
71 |
+
latents,
|
72 |
+
latent_labels=None,
|
73 |
+
*args,
|
74 |
+
**kwargs):
|
75 |
+
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
|
76 |
+
loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
|
77 |
+
reg = None
|
78 |
+
return loss, reg
|
79 |
+
|
80 |
+
|
81 |
+
def D_logistic(G,
|
82 |
+
D,
|
83 |
+
latents,
|
84 |
+
reals,
|
85 |
+
latent_labels=None,
|
86 |
+
real_labels=None,
|
87 |
+
*args,
|
88 |
+
**kwargs):
|
89 |
+
assert (latent_labels is None) == (real_labels is None)
|
90 |
+
with torch.no_grad():
|
91 |
+
fakes = G(latents, labels=latent_labels)
|
92 |
+
real_scores = D(reals, labels=real_labels).float()
|
93 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
94 |
+
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
|
95 |
+
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
|
96 |
+
loss = real_loss + fake_loss
|
97 |
+
reg = None
|
98 |
+
return loss, reg
|
99 |
+
|
100 |
+
|
101 |
+
#----------------------------------------------------------------------------
|
102 |
+
# R1 and R2 regularizers from the paper
|
103 |
+
# "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
|
104 |
+
|
105 |
+
|
106 |
+
def D_r1(D,
|
107 |
+
reals,
|
108 |
+
real_labels=None,
|
109 |
+
gamma=10,
|
110 |
+
*args,
|
111 |
+
**kwargs):
|
112 |
+
loss = None
|
113 |
+
reg = None
|
114 |
+
if gamma:
|
115 |
+
reals.requires_grad_(True)
|
116 |
+
real_scores = D(reals, labels=real_labels)
|
117 |
+
reg = _grad_reg(
|
118 |
+
input=reals, output=real_scores, gamma=gamma, retain_graph=False).float()
|
119 |
+
return loss, reg
|
120 |
+
|
121 |
+
|
122 |
+
def D_r2(D,
|
123 |
+
G,
|
124 |
+
latents,
|
125 |
+
latent_labels=None,
|
126 |
+
gamma=10,
|
127 |
+
*args,
|
128 |
+
**kwargs):
|
129 |
+
loss = None
|
130 |
+
reg = None
|
131 |
+
if gamma:
|
132 |
+
with torch.no_grad():
|
133 |
+
fakes = G(latents, labels=latent_labels)
|
134 |
+
fakes.requires_grad_(True)
|
135 |
+
fake_scores = D(fakes, labels=latent_labels)
|
136 |
+
reg = _grad_reg(
|
137 |
+
input=fakes, output=fake_scores, gamma=gamma, retain_graph=False).float()
|
138 |
+
return loss, reg
|
139 |
+
|
140 |
+
|
141 |
+
def D_logistic_r1(G,
|
142 |
+
D,
|
143 |
+
latents,
|
144 |
+
reals,
|
145 |
+
latent_labels=None,
|
146 |
+
real_labels=None,
|
147 |
+
gamma=10,
|
148 |
+
*args,
|
149 |
+
**kwargs):
|
150 |
+
assert (latent_labels is None) == (real_labels is None)
|
151 |
+
with torch.no_grad():
|
152 |
+
fakes = G(latents, labels=latent_labels)
|
153 |
+
if gamma:
|
154 |
+
reals.requires_grad_(True)
|
155 |
+
real_scores = D(reals, labels=real_labels).float()
|
156 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
157 |
+
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
|
158 |
+
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
|
159 |
+
loss = real_loss + fake_loss
|
160 |
+
reg = None
|
161 |
+
if gamma:
|
162 |
+
reg = _grad_reg(
|
163 |
+
input=reals, output=real_scores, gamma=gamma, retain_graph=True).float()
|
164 |
+
return loss, reg
|
165 |
+
|
166 |
+
|
167 |
+
def D_logistic_r2(G,
|
168 |
+
D,
|
169 |
+
latents,
|
170 |
+
reals,
|
171 |
+
latent_labels=None,
|
172 |
+
real_labels=None,
|
173 |
+
gamma=10,
|
174 |
+
*args,
|
175 |
+
**kwargs):
|
176 |
+
assert (latent_labels is None) == (real_labels is None)
|
177 |
+
with torch.no_grad():
|
178 |
+
fakes = G(latents, labels=latent_labels)
|
179 |
+
if gamma:
|
180 |
+
fakes.requires_grad_(True)
|
181 |
+
real_scores = D(reals, labels=real_labels).float()
|
182 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
183 |
+
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
|
184 |
+
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
|
185 |
+
loss = real_loss + fake_loss
|
186 |
+
reg = None
|
187 |
+
if gamma:
|
188 |
+
reg = _grad_reg(
|
189 |
+
input=fakes, output=fake_scores, gamma=gamma, retain_graph=True).float()
|
190 |
+
return loss, reg
|
191 |
+
|
192 |
+
|
193 |
+
#----------------------------------------------------------------------------
|
194 |
+
# Non-saturating logistic loss with path length regularizer from the paper
|
195 |
+
# "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
|
196 |
+
|
197 |
+
|
198 |
+
def G_pathreg(G,
|
199 |
+
latents,
|
200 |
+
pl_avg,
|
201 |
+
latent_labels=None,
|
202 |
+
pl_decay=0.01,
|
203 |
+
gamma=2,
|
204 |
+
*args,
|
205 |
+
**kwargs):
|
206 |
+
loss = None
|
207 |
+
reg = None
|
208 |
+
if gamma:
|
209 |
+
fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True, mapping_grad=False)
|
210 |
+
reg = _pathreg(
|
211 |
+
dlatents=dlatents,
|
212 |
+
fakes=fakes,
|
213 |
+
pl_avg=pl_avg,
|
214 |
+
pl_decay=pl_decay,
|
215 |
+
gamma=gamma,
|
216 |
+
retain_graph=False
|
217 |
+
).float()
|
218 |
+
return loss, reg
|
219 |
+
|
220 |
+
|
221 |
+
def G_logistic_ns_pathreg(G,
|
222 |
+
D,
|
223 |
+
latents,
|
224 |
+
pl_avg,
|
225 |
+
latent_labels=None,
|
226 |
+
pl_decay=0.01,
|
227 |
+
gamma=2,
|
228 |
+
*args,
|
229 |
+
**kwargs):
|
230 |
+
fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True)
|
231 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
232 |
+
loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
|
233 |
+
reg = None
|
234 |
+
if gamma:
|
235 |
+
reg = _pathreg(
|
236 |
+
dlatents=dlatents,
|
237 |
+
fakes=fakes,
|
238 |
+
pl_avg=pl_avg,
|
239 |
+
pl_decay=pl_decay,
|
240 |
+
gamma=gamma,
|
241 |
+
retain_graph=True
|
242 |
+
).float()
|
243 |
+
return loss, reg
|
244 |
+
|
245 |
+
|
246 |
+
#----------------------------------------------------------------------------
|
247 |
+
# WGAN loss from the paper
|
248 |
+
# "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
|
249 |
+
|
250 |
+
|
251 |
+
def G_wgan(G,
|
252 |
+
D,
|
253 |
+
latents,
|
254 |
+
latent_labels=None,
|
255 |
+
*args,
|
256 |
+
**kwargs):
|
257 |
+
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
|
258 |
+
loss = -fake_scores.mean()
|
259 |
+
reg = None
|
260 |
+
return loss, reg
|
261 |
+
|
262 |
+
|
263 |
+
def D_wgan(G,
|
264 |
+
D,
|
265 |
+
latents,
|
266 |
+
reals,
|
267 |
+
latent_labels=None,
|
268 |
+
real_labels=None,
|
269 |
+
drift_gamma=0.001,
|
270 |
+
*args,
|
271 |
+
**kwargs):
|
272 |
+
assert (latent_labels is None) == (real_labels is None)
|
273 |
+
with torch.no_grad():
|
274 |
+
fakes = G(latents, labels=latent_labels)
|
275 |
+
real_scores = D(reals, labels=real_labels).float()
|
276 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
277 |
+
loss = fake_scores.mean() - real_scores.mean()
|
278 |
+
if drift_gamma:
|
279 |
+
loss += drift_gamma * torch.mean(real_scores ** 2)
|
280 |
+
reg = None
|
281 |
+
return loss, reg
|
282 |
+
|
283 |
+
|
284 |
+
#----------------------------------------------------------------------------
|
285 |
+
# WGAN-GP loss from the paper
|
286 |
+
# "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
|
287 |
+
|
288 |
+
|
289 |
+
def D_gp(G,
|
290 |
+
D,
|
291 |
+
latents,
|
292 |
+
reals,
|
293 |
+
latent_labels=None,
|
294 |
+
real_labels=None,
|
295 |
+
gamma=0,
|
296 |
+
constraint=1,
|
297 |
+
*args,
|
298 |
+
**kwargs):
|
299 |
+
loss = None
|
300 |
+
reg = None
|
301 |
+
if gamma:
|
302 |
+
assert (latent_labels is None) == (real_labels is None)
|
303 |
+
with torch.no_grad():
|
304 |
+
fakes = G(latents, labels=latent_labels)
|
305 |
+
assert reals.size() == fakes.size()
|
306 |
+
if latent_labels:
|
307 |
+
assert latent_labels == real_labels
|
308 |
+
alpha = torch.empty(reals.size(0)).uniform_()
|
309 |
+
alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
|
310 |
+
interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
|
311 |
+
interp_scores = D(interp, labels=latent_labels)
|
312 |
+
reg = _grad_pen(
|
313 |
+
input=interp, output=interp_scores, gamma=gamma, retain_graph=False).float()
|
314 |
+
return loss, reg
|
315 |
+
|
316 |
+
|
317 |
+
def D_wgan_gp(G,
|
318 |
+
D,
|
319 |
+
latents,
|
320 |
+
reals,
|
321 |
+
latent_labels=None,
|
322 |
+
real_labels=None,
|
323 |
+
gamma=0,
|
324 |
+
drift_gamma=0.001,
|
325 |
+
constraint=1,
|
326 |
+
*args,
|
327 |
+
**kwargs):
|
328 |
+
assert (latent_labels is None) == (real_labels is None)
|
329 |
+
with torch.no_grad():
|
330 |
+
fakes = G(latents, labels=latent_labels)
|
331 |
+
real_scores = D(reals, labels=real_labels).float()
|
332 |
+
fake_scores = D(fakes, labels=latent_labels).float()
|
333 |
+
loss = fake_scores.mean() - real_scores.mean()
|
334 |
+
if drift_gamma:
|
335 |
+
loss += drift_gamma * torch.mean(real_scores ** 2)
|
336 |
+
reg = None
|
337 |
+
if gamma:
|
338 |
+
assert reals.size() == fakes.size()
|
339 |
+
if latent_labels:
|
340 |
+
assert latent_labels == real_labels
|
341 |
+
alpha = torch.empty(reals.size(0)).uniform_()
|
342 |
+
alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
|
343 |
+
interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
|
344 |
+
interp_scores = D(interp, labels=latent_labels)
|
345 |
+
reg = _grad_pen(
|
346 |
+
input=interp, output=interp_scores, gamma=gamma, retain_graph=True).float()
|
347 |
+
return loss, reg
|
stylegan2/metrics/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import fid
|
2 |
+
from . import ppl
|
stylegan2/metrics/fid.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import numbers
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from .. import models, utils
|
9 |
+
from ..external_models import inception
|
10 |
+
|
11 |
+
|
12 |
+
class _TruncatedDataset:
|
13 |
+
"""
|
14 |
+
Truncates a dataset, making only part of it accessible
|
15 |
+
by `torch.utils.data.DataLoader`.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, dataset, max_len):
|
19 |
+
self.dataset = dataset
|
20 |
+
self.max_len = max_len
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return min(len(self.dataset), self.max_len)
|
24 |
+
|
25 |
+
def __getitem__(self, index):
|
26 |
+
return self.dataset[index]
|
27 |
+
|
28 |
+
|
29 |
+
class FID:
|
30 |
+
"""
|
31 |
+
This class evaluates the FID metric of a generator.
|
32 |
+
Arguments:
|
33 |
+
G (Generator)
|
34 |
+
prior_generator (PriorGenerator)
|
35 |
+
dataset (indexable)
|
36 |
+
device (int, str, torch.device, optional): The device
|
37 |
+
to use for calculations. By default, the same device
|
38 |
+
is chosen as the parameters in `generator` reside on.
|
39 |
+
num_samples (int): Number of samples of reals and fakes
|
40 |
+
to gather statistics for which are used for calculating
|
41 |
+
the metric. Default value is 50 000.
|
42 |
+
fid_model (nn.Module): A model that returns feature maps
|
43 |
+
of shape (batch_size, features, *). Default value
|
44 |
+
is InceptionV3.
|
45 |
+
fid_size (int, optional): Resize any data fed to `fid_model` by scaling
|
46 |
+
the data so that its smallest side is the same size as this
|
47 |
+
argument.
|
48 |
+
truncation_psi (float, optional): Truncation of the generator
|
49 |
+
when evaluating.
|
50 |
+
truncation_cutoff (int, optional): Cutoff for truncation when
|
51 |
+
evaluating.
|
52 |
+
reals_batch_size (int, optional): Batch size to use for real
|
53 |
+
samples statistics gathering.
|
54 |
+
reals_data_workers (int, optional): Number of workers fetching
|
55 |
+
the real data samples. Default value is 0.
|
56 |
+
verbose (bool): Write progress of gathering statistics for reals
|
57 |
+
to stdout. Default value is True.
|
58 |
+
"""
|
59 |
+
def __init__(self,
|
60 |
+
G,
|
61 |
+
prior_generator,
|
62 |
+
dataset,
|
63 |
+
device=None,
|
64 |
+
num_samples=50000,
|
65 |
+
fid_model=None,
|
66 |
+
fid_size=None,
|
67 |
+
truncation_psi=None,
|
68 |
+
truncation_cutoff=None,
|
69 |
+
reals_batch_size=None,
|
70 |
+
reals_data_workers=0,
|
71 |
+
verbose=True):
|
72 |
+
device_ids = []
|
73 |
+
if isinstance(G, torch.nn.DataParallel):
|
74 |
+
device_ids = G.device_ids
|
75 |
+
G = utils.unwrap_module(G)
|
76 |
+
assert isinstance(G, models.Generator)
|
77 |
+
assert isinstance(prior_generator, utils.PriorGenerator)
|
78 |
+
if device is None:
|
79 |
+
device = next(G.parameters()).device
|
80 |
+
else:
|
81 |
+
device = torch.device(device)
|
82 |
+
assert torch.device(prior_generator.device) == device, \
|
83 |
+
'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
|
84 |
+
'is not the same as the specified (or infered from the model)' + \
|
85 |
+
'device ({}) for the PPL evaluation.'.format(device)
|
86 |
+
G.eval().to(device)
|
87 |
+
if device_ids:
|
88 |
+
G = torch.nn.DataParallel(G, device_ids=device_ids)
|
89 |
+
self.G = G
|
90 |
+
self.prior_generator = prior_generator
|
91 |
+
self.device = device
|
92 |
+
self.num_samples = num_samples
|
93 |
+
self.batch_size = self.prior_generator.batch_size
|
94 |
+
if fid_model is None:
|
95 |
+
warnings.warn(
|
96 |
+
'Using default fid model metric based on Inception V3. ' + \
|
97 |
+
'This metric will only work on image data where values are in ' + \
|
98 |
+
'the range [-1, 1], please specify another module if you want ' + \
|
99 |
+
'to use other kinds of data formats.'
|
100 |
+
)
|
101 |
+
fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
|
102 |
+
if device_ids:
|
103 |
+
fid_model = torch.nn.DataParallel(fid_model, device_ids)
|
104 |
+
self.fid_model = fid_model.eval().to(device)
|
105 |
+
self.fid_size = fid_size
|
106 |
+
|
107 |
+
dataset = _TruncatedDataset(dataset, self.num_samples)
|
108 |
+
dataloader = torch.utils.data.DataLoader(
|
109 |
+
dataset,
|
110 |
+
batch_size=reals_batch_size or self.batch_size,
|
111 |
+
num_workers=reals_data_workers
|
112 |
+
)
|
113 |
+
features = []
|
114 |
+
self.labels = []
|
115 |
+
|
116 |
+
if verbose:
|
117 |
+
progress = utils.ProgressWriter(
|
118 |
+
np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
|
119 |
+
progress.write('FID: Gathering statistics for reals...', step=False)
|
120 |
+
|
121 |
+
for batch in dataloader:
|
122 |
+
data = batch
|
123 |
+
if isinstance(batch, (tuple, list)):
|
124 |
+
data = batch[0]
|
125 |
+
if len(batch) > 1:
|
126 |
+
self.labels.append(batch[1])
|
127 |
+
data = self._scale_for_fid(data).to(self.device)
|
128 |
+
with torch.no_grad():
|
129 |
+
batch_features = self.fid_model(data)
|
130 |
+
batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
|
131 |
+
features.append(batch_features.cpu())
|
132 |
+
progress.step()
|
133 |
+
|
134 |
+
if verbose:
|
135 |
+
progress.write('FID: Statistics for reals gathered!', step=False)
|
136 |
+
progress.close()
|
137 |
+
|
138 |
+
features = torch.cat(features, dim=0).numpy()
|
139 |
+
|
140 |
+
self.mu_real = np.mean(features, axis=0)
|
141 |
+
self.sigma_real = np.cov(features, rowvar=False)
|
142 |
+
self.truncation_psi = truncation_psi
|
143 |
+
self.truncation_cutoff = truncation_cutoff
|
144 |
+
|
145 |
+
def _scale_for_fid(self, data):
|
146 |
+
if not self.fid_size:
|
147 |
+
return data
|
148 |
+
scale_factor = self.fid_size / min(data.size()[2:])
|
149 |
+
if scale_factor == 1:
|
150 |
+
return data
|
151 |
+
mode = 'nearest'
|
152 |
+
if scale_factor < 1:
|
153 |
+
mode = 'area'
|
154 |
+
return F.interpolate(data, scale_factor=scale_factor, mode=mode)
|
155 |
+
|
156 |
+
def __call__(self, *args, **kwargs):
|
157 |
+
return self.evaluate(*args, **kwargs)
|
158 |
+
|
159 |
+
def evaluate(self, verbose=True):
|
160 |
+
"""
|
161 |
+
Evaluate the FID.
|
162 |
+
Arguments:
|
163 |
+
verbose (bool): Write progress to stdout.
|
164 |
+
Default value is True.
|
165 |
+
Returns:
|
166 |
+
fid (float): Metric value.
|
167 |
+
"""
|
168 |
+
utils.unwrap_module(self.G).set_truncation(
|
169 |
+
truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
|
170 |
+
self.G.eval()
|
171 |
+
features = []
|
172 |
+
|
173 |
+
if verbose:
|
174 |
+
progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
|
175 |
+
progress.write('FID: Gathering statistics for fakes...', step=False)
|
176 |
+
|
177 |
+
remaining = self.num_samples
|
178 |
+
for i in range(0, self.num_samples, self.batch_size):
|
179 |
+
|
180 |
+
latents, latent_labels = self.prior_generator(
|
181 |
+
batch_size=min(self.batch_size, remaining))
|
182 |
+
if latent_labels is not None and self.labels:
|
183 |
+
latent_labels = self.labels[i].to(self.device)
|
184 |
+
length = min(len(latents), len(latent_labels))
|
185 |
+
latents, latent_labels = latents[:length], latent_labels[:length]
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
fakes = self.G(latents, labels=latent_labels)
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
batch_features = self.fid_model(fakes)
|
192 |
+
batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
|
193 |
+
features.append(batch_features.cpu())
|
194 |
+
|
195 |
+
remaining -= len(latents)
|
196 |
+
progress.step()
|
197 |
+
|
198 |
+
if verbose:
|
199 |
+
progress.write('FID: Statistics for fakes gathered!', step=False)
|
200 |
+
progress.close()
|
201 |
+
|
202 |
+
features = torch.cat(features, dim=0).numpy()
|
203 |
+
|
204 |
+
mu_fake = np.mean(features, axis=0)
|
205 |
+
sigma_fake = np.cov(features, rowvar=False)
|
206 |
+
|
207 |
+
m = np.square(mu_fake - self.mu_real).sum()
|
208 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
|
209 |
+
dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
|
210 |
+
return float(np.real(dist))
|
stylegan2/metrics/ppl.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import numbers
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from .. import models, utils
|
8 |
+
from ..external_models import lpips
|
9 |
+
|
10 |
+
|
11 |
+
class PPL:
|
12 |
+
"""
|
13 |
+
This class evaluates the PPL metric of a generator.
|
14 |
+
Arguments:
|
15 |
+
G (Generator)
|
16 |
+
prior_generator (PriorGenerator)
|
17 |
+
device (int, str, torch.device, optional): The device
|
18 |
+
to use for calculations. By default, the same device
|
19 |
+
is chosen as the parameters in `generator` reside on.
|
20 |
+
num_samples (int): Number of samples of reals and fakes
|
21 |
+
to gather statistics for which are used for calculating
|
22 |
+
the metric. Default value is 50 000.
|
23 |
+
epsilon (float): Perturbation value. Default value is 1e-4.
|
24 |
+
use_dlatent (bool): Measure PPL against the dlatents instead
|
25 |
+
of the latents. Default value is True.
|
26 |
+
full_sampling (bool): Measure on a random interpolation between
|
27 |
+
two inputs. Default value is False.
|
28 |
+
crop (float, list, optional): Crop values that should be in the
|
29 |
+
range [0, 1] with 1 representing the entire data length.
|
30 |
+
If single value this will be the amount cropped from all
|
31 |
+
sides of the data. If a list of same length as number of
|
32 |
+
data dimensions, each crop is mirrored to both sides of
|
33 |
+
each respective dimension. If the length is 2 * number
|
34 |
+
of dimensions the crop values for the start and end of
|
35 |
+
a dimension may be different.
|
36 |
+
Example 1:
|
37 |
+
We have 1d data of length 10. We want to crop 1
|
38 |
+
from the start and end of the data. We then need
|
39 |
+
to use `crop=0.1` or `crop=[0.1]` or `crop=[0.1, 0.9]`.
|
40 |
+
Example 2:
|
41 |
+
We have 2d data (images) of size 10, 10 (height, width)
|
42 |
+
and we want to use only the top left quarter of the image
|
43 |
+
we would use `crop=[0, 0.5, 0, 0.5]`.
|
44 |
+
lpips_model (nn.Module): A model that returns feature the distance
|
45 |
+
between two inputs. Default value is the LPIPS VGG16 model.
|
46 |
+
lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
|
47 |
+
the data so that its smallest side is the same size as this
|
48 |
+
argument. Only has a default value of 256 if `lpips_model` is unspecified.
|
49 |
+
"""
|
50 |
+
FFHQ_CROP = [1/8 * 3, 1/8 * 7, 1/8 * 2, 1/8 * 6]
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
G,
|
54 |
+
prior_generator,
|
55 |
+
device=None,
|
56 |
+
num_samples=50000,
|
57 |
+
epsilon=1e-4,
|
58 |
+
use_dlatent=True,
|
59 |
+
full_sampling=False,
|
60 |
+
crop=None,
|
61 |
+
lpips_model=None,
|
62 |
+
lpips_size=None):
|
63 |
+
device_ids = []
|
64 |
+
if isinstance(G, torch.nn.DataParallel):
|
65 |
+
device_ids = G.device_ids
|
66 |
+
G = utils.unwrap_module(G)
|
67 |
+
assert isinstance(G, models.Generator)
|
68 |
+
assert isinstance(prior_generator, utils.PriorGenerator)
|
69 |
+
if device is None:
|
70 |
+
device = next(G.parameters()).device
|
71 |
+
else:
|
72 |
+
device = torch.device(device)
|
73 |
+
assert torch.device(prior_generator.device) == device, \
|
74 |
+
'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
|
75 |
+
'is not the same as the specified (or infered from the model)' + \
|
76 |
+
'device ({}) for the PPL evaluation.'.format(device)
|
77 |
+
G.eval().to(device)
|
78 |
+
self.G_mapping = G.G_mapping
|
79 |
+
self.G_synthesis = G.G_synthesis
|
80 |
+
if device_ids:
|
81 |
+
self.G_mapping = torch.nn.DataParallel(self.G_mapping, device_ids=device_ids)
|
82 |
+
self.G_synthesis = torch.nn.DataParallel(self.G_synthesis, device_ids=device_ids)
|
83 |
+
self.prior_generator = prior_generator
|
84 |
+
self.device = device
|
85 |
+
self.num_samples = num_samples
|
86 |
+
self.epsilon = epsilon
|
87 |
+
self.use_dlatent = use_dlatent
|
88 |
+
self.full_sampling = full_sampling
|
89 |
+
self.crop = crop
|
90 |
+
self.batch_size = self.prior_generator.batch_size
|
91 |
+
if lpips_model is None:
|
92 |
+
warnings.warn(
|
93 |
+
'Using default LPIPS distance metric based on VGG 16. ' + \
|
94 |
+
'This metric will only work on image data where values are in ' + \
|
95 |
+
'the range [-1, 1], please specify an lpips module if you want ' + \
|
96 |
+
'to use other kinds of data formats.'
|
97 |
+
)
|
98 |
+
lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
|
99 |
+
if device_ids:
|
100 |
+
lpips_model = torch.nn.DataParallel(lpips_model, device_ids=device_ids)
|
101 |
+
lpips_size = lpips_size or 256
|
102 |
+
self.lpips_model = lpips_model.eval().to(device)
|
103 |
+
self.lpips_size = lpips_size
|
104 |
+
|
105 |
+
def _scale_for_lpips(self, data):
|
106 |
+
if not self.lpips_size:
|
107 |
+
return data
|
108 |
+
scale_factor = self.lpips_size / min(data.size()[2:])
|
109 |
+
if scale_factor == 1:
|
110 |
+
return data
|
111 |
+
mode = 'nearest'
|
112 |
+
if scale_factor < 1:
|
113 |
+
mode = 'area'
|
114 |
+
return F.interpolate(data, scale_factor=scale_factor, mode=mode)
|
115 |
+
|
116 |
+
def crop_data(self, data):
|
117 |
+
if not self.crop:
|
118 |
+
return data
|
119 |
+
dim = data.dim() - 2
|
120 |
+
if isinstance(self.crop, numbers.Number):
|
121 |
+
self.crop = [self.crop]
|
122 |
+
else:
|
123 |
+
self.crop = list(self.crop)
|
124 |
+
if len(self.crop) == 1:
|
125 |
+
self.crop = [self.crop[0], (1 if self.crop[0] < 1 else size) - self.crop[0]] * dim
|
126 |
+
if len(self.crop) == dim:
|
127 |
+
crop = self.crop
|
128 |
+
self.crop = []
|
129 |
+
for value in crop:
|
130 |
+
self.crop += [value, (1 if value < 1 else size) - value]
|
131 |
+
assert len(self.crop) == 2 * dim, 'Crop values has to be ' + \
|
132 |
+
'a single value or a sequence of values of the same ' + \
|
133 |
+
'size as number of dimensions of the data or twice of that.'
|
134 |
+
pre_index = [Ellipsis]
|
135 |
+
post_index = [slice(None, None, None) for _ in range(dim)]
|
136 |
+
for i in range(0, 2 * dim, 2):
|
137 |
+
j = i // 2
|
138 |
+
size = data.size(2 + j)
|
139 |
+
crop_min, crop_max = self.crop[i:i + 2]
|
140 |
+
if crop_max < 1:
|
141 |
+
crop_min, crop_max = crop_min * size, crop_max * size
|
142 |
+
crop_min, crop_max = max(0, int(crop_min)), min(size, int(crop_max))
|
143 |
+
dim_index = post_index.copy()
|
144 |
+
dim_index[j] = slice(crop_min, crop_max, None)
|
145 |
+
data = data[pre_index + dim_index]
|
146 |
+
return data
|
147 |
+
|
148 |
+
def prep_latents(self, latents):
|
149 |
+
if self.full_sampling:
|
150 |
+
lerp = utils.slerp
|
151 |
+
if self.use_dlatent:
|
152 |
+
lerp = utils.lerp
|
153 |
+
latents_a, latents_b = latents[:self.batch_size], latents[self.batch_size:]
|
154 |
+
latents = lerp(
|
155 |
+
latents_a,
|
156 |
+
latents_b,
|
157 |
+
torch.rand(
|
158 |
+
latents_a.size()[:-1],
|
159 |
+
dtype=latents_a.dtype,
|
160 |
+
device=latents_a.device
|
161 |
+
).unsqueeze(-1)
|
162 |
+
)
|
163 |
+
return torch.cat([latents, latents + self.epsilon], dim=0)
|
164 |
+
|
165 |
+
def __call__(self, *args, **kwargs):
|
166 |
+
return self.evaluate(*args, **kwargs)
|
167 |
+
|
168 |
+
def evaluate(self, verbose=True):
|
169 |
+
"""
|
170 |
+
Evaluate the PPL.
|
171 |
+
Arguments:
|
172 |
+
verbose (bool): Write progress to stdout.
|
173 |
+
Default value is True.
|
174 |
+
Returns:
|
175 |
+
ppl (float): Metric value.
|
176 |
+
"""
|
177 |
+
distances = []
|
178 |
+
batch_size = self.batch_size
|
179 |
+
if self.full_sampling:
|
180 |
+
batch_size = 2 * batch_size
|
181 |
+
|
182 |
+
if verbose:
|
183 |
+
progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
|
184 |
+
progress.write('PPL: Evaluating metric...', step=False)
|
185 |
+
|
186 |
+
for _ in range(0, self.num_samples, self.batch_size):
|
187 |
+
utils.unwrap_module(self.G_synthesis).static_noise()
|
188 |
+
|
189 |
+
latents, latent_labels = self.prior_generator(batch_size=batch_size)
|
190 |
+
if latent_labels is not None and self.full_sampling:
|
191 |
+
# Labels should be the same for the first and second half of latents
|
192 |
+
latent_labels = latent_labels.view(2, -1)[0].repeat(2)
|
193 |
+
|
194 |
+
if self.use_dlatent:
|
195 |
+
with torch.no_grad():
|
196 |
+
dlatents = self.G_mapping(latents=latents, labels=latent_labels)
|
197 |
+
dlatents = self.prep_latents(dlatents)
|
198 |
+
else:
|
199 |
+
latents = self.prep_latents(latents)
|
200 |
+
with torch.no_grad():
|
201 |
+
dlatents = self.G_mapping(latents=latents, labels=latent_labels)
|
202 |
+
|
203 |
+
dlatents = dlatents.unsqueeze(1).repeat(1, len(utils.unwrap_module(self.G_synthesis)), 1)
|
204 |
+
|
205 |
+
with torch.no_grad():
|
206 |
+
output = self.G_synthesis(dlatents)
|
207 |
+
|
208 |
+
output = self.crop_data(output)
|
209 |
+
output = self._scale_for_lpips(output)
|
210 |
+
|
211 |
+
output_a, output_b = output[:self.batch_size], output[self.batch_size:]
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
dist = self.lpips_model(output_a, output_b)
|
215 |
+
|
216 |
+
distances.append(dist.cpu() * (1 / self.epsilon ** 2))
|
217 |
+
|
218 |
+
if verbose:
|
219 |
+
progress.step()
|
220 |
+
|
221 |
+
if verbose:
|
222 |
+
progress.write('PPL: Evaluated!', step=False)
|
223 |
+
progress.close()
|
224 |
+
|
225 |
+
distances = torch.cat(distances, dim=0).numpy()
|
226 |
+
lo = np.percentile(distances, 1, interpolation='lower')
|
227 |
+
hi = np.percentile(distances, 99, interpolation='higher')
|
228 |
+
filtered_distances = np.extract(np.logical_and(lo <= distances, distances <= hi), distances)
|
229 |
+
return float(np.mean(filtered_distances))
|
stylegan2/models.py
ADDED
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from collections import OrderedDict
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from . import modules, utils
|
8 |
+
|
9 |
+
|
10 |
+
class _BaseModel(nn.Module):
|
11 |
+
"""
|
12 |
+
Adds some base functionality to models that inherit this class.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(_BaseModel, self).__setattr__('kwargs', {})
|
17 |
+
super(_BaseModel, self).__setattr__('_defaults', {})
|
18 |
+
super(_BaseModel, self).__init__()
|
19 |
+
|
20 |
+
def _update_kwargs(self, **kwargs):
|
21 |
+
"""
|
22 |
+
Update the current keyword arguments. Overrides any
|
23 |
+
default values set.
|
24 |
+
Arguments:
|
25 |
+
**kwargs: Keyword arguments
|
26 |
+
"""
|
27 |
+
self.kwargs.update(**kwargs)
|
28 |
+
|
29 |
+
def _update_default_kwargs(self, **defaults):
|
30 |
+
"""
|
31 |
+
Update the default values for keyword arguments.
|
32 |
+
Arguments:
|
33 |
+
**defaults: Keyword arguments
|
34 |
+
"""
|
35 |
+
self._defaults.update(**defaults)
|
36 |
+
|
37 |
+
def __getattr__(self, name):
|
38 |
+
"""
|
39 |
+
Try to get the keyword argument for this attribute.
|
40 |
+
If no keyword argument of this name exists, try to
|
41 |
+
get the attribute directly from this object instead.
|
42 |
+
Arguments:
|
43 |
+
name (str): Name of keyword argument or attribute.
|
44 |
+
Returns:
|
45 |
+
value
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
return self.__getattribute__('kwargs')[name]
|
49 |
+
except KeyError:
|
50 |
+
try:
|
51 |
+
return self.__getattribute__('_defaults')[name]
|
52 |
+
except KeyError:
|
53 |
+
return super(_BaseModel, self).__getattr__(name)
|
54 |
+
|
55 |
+
def __setattr__(self, name, value):
|
56 |
+
"""
|
57 |
+
Try to set the keyword argument for this attribute.
|
58 |
+
If no keyword argument of this name exists, set
|
59 |
+
the attribute directly for this object instead.
|
60 |
+
Arguments:
|
61 |
+
name (str): Name of keyword argument or attribute.
|
62 |
+
value
|
63 |
+
"""
|
64 |
+
if name != '__dict__' and (name in self.kwargs or name in self._defaults):
|
65 |
+
self.kwargs[name] = value
|
66 |
+
else:
|
67 |
+
super(_BaseModel, self).__setattr__(name, value)
|
68 |
+
|
69 |
+
def __delattr__(self, name):
|
70 |
+
"""
|
71 |
+
Try to delete the keyword argument for this attribute.
|
72 |
+
If no keyword argument of this name exists, delete
|
73 |
+
the attribute of this object instead.
|
74 |
+
Arguments:
|
75 |
+
name (str): Name of keyword argument or attribute.
|
76 |
+
"""
|
77 |
+
deleted = False
|
78 |
+
if name in self.kwargs:
|
79 |
+
del self.kwargs[name]
|
80 |
+
deleted = True
|
81 |
+
if name in self._defaults:
|
82 |
+
del self._defaults[name]
|
83 |
+
deleted = True
|
84 |
+
if not deleted:
|
85 |
+
super(_BaseModel, self).__delattr__(name)
|
86 |
+
|
87 |
+
def clone(self):
|
88 |
+
"""
|
89 |
+
Create a copy of this model.
|
90 |
+
Returns:
|
91 |
+
model_copy (nn.Module)
|
92 |
+
"""
|
93 |
+
return copy.deepcopy(self)
|
94 |
+
|
95 |
+
def _get_state_dict(self):
|
96 |
+
"""
|
97 |
+
Delegate function for getting the state dict.
|
98 |
+
Should be overridden if state dict has to be
|
99 |
+
fetched in abnormal way.
|
100 |
+
"""
|
101 |
+
return self.state_dict()
|
102 |
+
|
103 |
+
def _set_state_dict(self, state_dict):
|
104 |
+
"""
|
105 |
+
Delegate function for loading the state dict.
|
106 |
+
Should be overridden if state dict has to be
|
107 |
+
loaded in abnormal way.
|
108 |
+
"""
|
109 |
+
self.load_state_dict(state_dict)
|
110 |
+
|
111 |
+
def _serialize(self, half=False):
|
112 |
+
"""
|
113 |
+
Turn model arguments and weights into
|
114 |
+
a dict that can safely be pickled and unpickled.
|
115 |
+
Arguments:
|
116 |
+
half (bool): Save weights in half precision.
|
117 |
+
Default value is False.
|
118 |
+
"""
|
119 |
+
state_dict = self._get_state_dict()
|
120 |
+
for key in state_dict.keys():
|
121 |
+
values = state_dict[key].cpu()
|
122 |
+
if torch.is_floating_point(values):
|
123 |
+
if half:
|
124 |
+
values = values.half()
|
125 |
+
else:
|
126 |
+
values = values.float()
|
127 |
+
state_dict[key] = values
|
128 |
+
return {
|
129 |
+
'name': self.__class__.__name__,
|
130 |
+
'kwargs': self.kwargs,
|
131 |
+
'state_dict': state_dict
|
132 |
+
}
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def load(cls, fpath, map_location='cpu'):
|
136 |
+
"""
|
137 |
+
Load a model of this class.
|
138 |
+
Arguments:
|
139 |
+
fpath (str): File path of saved model.
|
140 |
+
map_location (str, int, torch.device): Weights and
|
141 |
+
buffers will be loaded into this device.
|
142 |
+
Default value is 'cpu'.
|
143 |
+
"""
|
144 |
+
model = load(fpath, map_location=map_location)
|
145 |
+
assert isinstance(model, cls), 'Trying to load a `{}` '.format(type(model)) + \
|
146 |
+
'model from {}.load()'.format(cls.__name__)
|
147 |
+
return model
|
148 |
+
|
149 |
+
def save(self, fpath, half=False):
|
150 |
+
"""
|
151 |
+
Save this model.
|
152 |
+
Arguments:
|
153 |
+
fpath (str): File path of save location.
|
154 |
+
half (bool): Save weights in half precision.
|
155 |
+
Default value is False.
|
156 |
+
"""
|
157 |
+
torch.save(self._serialize(half=half), fpath)
|
158 |
+
|
159 |
+
|
160 |
+
def _deserialize(state):
|
161 |
+
"""
|
162 |
+
Load a model from its serialized state.
|
163 |
+
Arguments:
|
164 |
+
state (dict)
|
165 |
+
Returns:
|
166 |
+
model (nn.Module): Model that inherits `_BaseModel`.
|
167 |
+
"""
|
168 |
+
state = state.copy()
|
169 |
+
name = state.pop('name')
|
170 |
+
if name not in globals():
|
171 |
+
raise NameError('Class {} is not defined.'.format(state['name']))
|
172 |
+
kwargs = state.pop('kwargs')
|
173 |
+
state_dict = state.pop('state_dict')
|
174 |
+
# Assume every other entry in the state is a serialized
|
175 |
+
# keyword argument.
|
176 |
+
for key in list(state.keys()):
|
177 |
+
kwargs[key] = _deserialize(state.pop(key))
|
178 |
+
model = globals()[name](**kwargs)
|
179 |
+
model._set_state_dict(state_dict)
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def load(fpath, map_location='cpu'):
|
184 |
+
"""
|
185 |
+
Load a model.
|
186 |
+
Arguments:
|
187 |
+
fpath (str): File path of saved model.
|
188 |
+
map_location (str, int, torch.device): Weights and
|
189 |
+
buffers will be loaded into this device.
|
190 |
+
Default value is 'cpu'.
|
191 |
+
Returns:
|
192 |
+
model (nn.Module): Model that inherits `_BaseModel`.
|
193 |
+
"""
|
194 |
+
if map_location is not None:
|
195 |
+
map_location = torch.device(map_location)
|
196 |
+
return _deserialize(torch.load(fpath, map_location=map_location))
|
197 |
+
|
198 |
+
|
199 |
+
def save(model, fpath, half=False):
|
200 |
+
"""
|
201 |
+
Save a model.
|
202 |
+
Arguments:
|
203 |
+
model (nn.Module): Wrapped or unwrapped module
|
204 |
+
that inherits `_BaseModel`.
|
205 |
+
fpath (str): File path of save location.
|
206 |
+
half (bool): Save weights in half precision.
|
207 |
+
Default value is False.
|
208 |
+
"""
|
209 |
+
utils.unwrap_module(model).save(fpath, half=half)
|
210 |
+
|
211 |
+
|
212 |
+
class Generator(_BaseModel):
|
213 |
+
"""
|
214 |
+
A wrapper class for the latent mapping model
|
215 |
+
and synthesis (generator) model.
|
216 |
+
Keyword Arguments:
|
217 |
+
G_mapping (GeneratorMapping)
|
218 |
+
G_synthesis (GeneratorSynthesis)
|
219 |
+
dlatent_avg_beta (float): The beta value
|
220 |
+
of the exponential moving average
|
221 |
+
of the dlatents. This statistic
|
222 |
+
is used for truncation of dlatents.
|
223 |
+
Default value is 0.995
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(self, *, G_mapping, G_synthesis, **kwargs):
|
227 |
+
super(Generator, self).__init__()
|
228 |
+
self._update_default_kwargs(
|
229 |
+
dlatent_avg_beta=0.995
|
230 |
+
)
|
231 |
+
self._update_kwargs(**kwargs)
|
232 |
+
|
233 |
+
assert isinstance(G_mapping, GeneratorMapping), \
|
234 |
+
'`G_mapping` has to be an instance of `model.GeneratorMapping`'
|
235 |
+
assert isinstance(G_synthesis, GeneratorSynthesis), \
|
236 |
+
'`G_synthesis` has to be an instance of `model.GeneratorSynthesis`'
|
237 |
+
self.G_mapping = G_mapping
|
238 |
+
self.G_synthesis = G_synthesis
|
239 |
+
self.register_buffer('dlatent_avg', torch.zeros(self.G_mapping.latent_size))
|
240 |
+
self.set_truncation()
|
241 |
+
|
242 |
+
@property
|
243 |
+
def latent_size(self):
|
244 |
+
return self.G_mapping.latent_size
|
245 |
+
|
246 |
+
@property
|
247 |
+
def label_size(self):
|
248 |
+
return self.G_mapping.label_size
|
249 |
+
|
250 |
+
def _get_state_dict(self):
|
251 |
+
state_dict = OrderedDict()
|
252 |
+
self._save_to_state_dict(destination=state_dict, prefix='', keep_vars=False)
|
253 |
+
return state_dict
|
254 |
+
|
255 |
+
def _set_state_dict(self, state_dict):
|
256 |
+
self.load_state_dict(state_dict, strict=False)
|
257 |
+
|
258 |
+
def _serialize(self, half=False):
|
259 |
+
state = super(Generator, self)._serialize(half=half)
|
260 |
+
for name in ['G_mapping', 'G_synthesis']:
|
261 |
+
state[name] = getattr(self, name)._serialize(half=half)
|
262 |
+
return state
|
263 |
+
|
264 |
+
def set_truncation(self, truncation_psi=None, truncation_cutoff=None):
|
265 |
+
"""
|
266 |
+
Set the truncation of dlatents before they are passed to the
|
267 |
+
synthesis model.
|
268 |
+
Arguments:
|
269 |
+
truncation_psi (float): Beta value of linear interpolation between
|
270 |
+
the average dlatent and the current dlatent. 0 -> 100% average,
|
271 |
+
1 -> 0% average.
|
272 |
+
truncation_cutoff (int, optional): Truncation is only used up until
|
273 |
+
this affine layer index.
|
274 |
+
"""
|
275 |
+
layer_psi = None
|
276 |
+
if truncation_psi is not None and truncation_psi != 1 and truncation_cutoff != 0:
|
277 |
+
layer_psi = torch.ones(len(self.G_synthesis))
|
278 |
+
if truncation_cutoff is None:
|
279 |
+
layer_psi *= truncation_psi
|
280 |
+
else:
|
281 |
+
layer_psi_mask = torch.arange(len(layer_psi)) < truncation_cutoff
|
282 |
+
layer_psi[layer_psi_mask] *= truncation_psi
|
283 |
+
layer_psi = layer_psi.view(1, -1, 1)
|
284 |
+
layer_psi = layer_psi.to(self.dlatent_avg)
|
285 |
+
self.register_buffer('layer_psi', layer_psi)
|
286 |
+
|
287 |
+
def random_noise(self):
|
288 |
+
"""
|
289 |
+
Set noise of synthesis model to be random for every
|
290 |
+
input.
|
291 |
+
"""
|
292 |
+
self.G_synthesis.random_noise()
|
293 |
+
|
294 |
+
def static_noise(self, trainable=False, noise_tensors=None):
|
295 |
+
"""
|
296 |
+
Set up injected noise to be fixed (alternatively trainable).
|
297 |
+
Get the fixed noise tensors (or parameters).
|
298 |
+
Arguments:
|
299 |
+
trainable (bool): Make noise trainable and return
|
300 |
+
parameters instead of normal tensors.
|
301 |
+
noise_tensors (list, optional): List of tensors to use as static noise.
|
302 |
+
Has to be same length as number of noise injection layers.
|
303 |
+
Returns:
|
304 |
+
noise_tensors (list): List of the noise tensors (or parameters).
|
305 |
+
"""
|
306 |
+
return self.G_synthesis.static_noise(trainable=trainable, noise_tensors=noise_tensors)
|
307 |
+
|
308 |
+
def __len__(self):
|
309 |
+
"""
|
310 |
+
Get the number of affine (style) layers of the synthesis model.
|
311 |
+
"""
|
312 |
+
return len(self.G_synthesis)
|
313 |
+
|
314 |
+
def truncate(self, dlatents):
|
315 |
+
"""
|
316 |
+
Truncate the dlatents.
|
317 |
+
Arguments:
|
318 |
+
dlatents (torch.Tensor)
|
319 |
+
Returns:
|
320 |
+
truncated_dlatents (torch.Tensor)
|
321 |
+
"""
|
322 |
+
if self.layer_psi is not None:
|
323 |
+
dlatents = utils.lerp(self.dlatent_avg, dlatents, self.layer_psi)
|
324 |
+
return dlatents
|
325 |
+
|
326 |
+
def forward(self,
|
327 |
+
latents=None,
|
328 |
+
labels=None,
|
329 |
+
dlatents=None,
|
330 |
+
return_dlatents=False,
|
331 |
+
mapping_grad=True,
|
332 |
+
latent_to_layer_idx=None):
|
333 |
+
"""
|
334 |
+
Synthesize some data from latent inputs. The latents
|
335 |
+
can have an extra optional dimension, where latents
|
336 |
+
from this dimension will be distributed to the different
|
337 |
+
affine layers of the synthesis model. The distribution
|
338 |
+
is a index to index mapping if the amount of latents
|
339 |
+
is the same as the number of affine layers. Otherwise,
|
340 |
+
latents are distributed consecutively for a random
|
341 |
+
number of layers before the next latent is used for
|
342 |
+
some random amount of following layers. If the size
|
343 |
+
of this extra dimension is 1 or it does not exist,
|
344 |
+
the same latent is passed to every affine layer.
|
345 |
+
|
346 |
+
Latents are first mapped to disentangled latents (`dlatents`)
|
347 |
+
and are then optionally truncated (if model is in eval mode
|
348 |
+
and truncation options have been set.) Set up truncation by
|
349 |
+
calling `set_truncation()`.
|
350 |
+
Arguments:
|
351 |
+
latents (torch.Tensor): The latent values of shape
|
352 |
+
(batch_size, N, num_features) where N is an
|
353 |
+
optional dimension. This argument is not required
|
354 |
+
if `dlatents` is passed.
|
355 |
+
labels (optional): A sequence of labels, one for
|
356 |
+
each index in the batch dimension of the input.
|
357 |
+
dlatents (torch.Tensor, optional): Skip the latent
|
358 |
+
mapping model and feed these dlatents straight
|
359 |
+
to the synthesis model. The same type of distribution
|
360 |
+
to affine layers as is described in this function
|
361 |
+
description is also used for dlatents.
|
362 |
+
NOTE: Explicitly passing dlatents to this function
|
363 |
+
will stop them from being truncated. If required,
|
364 |
+
do this manually by calling the `truncate()` function
|
365 |
+
of this model.
|
366 |
+
return_dlatents (bool): Return not only the synthesized
|
367 |
+
data, but also the dlatents. The dlatents tensor
|
368 |
+
will also have its `requires_grad` set to True
|
369 |
+
before being passed to the synthesis model for
|
370 |
+
use with pathlength regularization during training.
|
371 |
+
This requires training to be enabled (`thismodel.train()`).
|
372 |
+
Default value is False.
|
373 |
+
mapping_grad (bool): Let gradients be calculated when passing
|
374 |
+
latents through the latent mapping model. Should be
|
375 |
+
set to False when only optimising the synthesiser parameters.
|
376 |
+
Default value is True.
|
377 |
+
latent_to_layer_idx (list, tuple, optional): A manual mapping
|
378 |
+
of the latent vectors to the affine layers of this network.
|
379 |
+
Each position in this sequence maps the affine layer of the
|
380 |
+
same index to an index of the latents. The latents should
|
381 |
+
have a shape of (batch_size, N, num_features) and this argument
|
382 |
+
should be a list of the same length as number of affine layers
|
383 |
+
in this model (can be found by calling len(thismodel)) with values
|
384 |
+
in the range [0, N - 1]. Without this argument, latents are distributed
|
385 |
+
according to this function description.
|
386 |
+
"""
|
387 |
+
# Keep track of number of latents for each batch index.
|
388 |
+
num_latents = 1
|
389 |
+
|
390 |
+
# Keep track of if dlatent truncation is enabled or disabled.
|
391 |
+
truncate = False
|
392 |
+
|
393 |
+
if dlatents is None:
|
394 |
+
# Calculate dlatents
|
395 |
+
|
396 |
+
# dlatent truncation enabled as dlatents were not explicitly given
|
397 |
+
truncate = True
|
398 |
+
|
399 |
+
assert latents is not None, 'Either the `latents` ' + \
|
400 |
+
'or the `dlatents` argument is required.'
|
401 |
+
if labels is not None:
|
402 |
+
if not torch.is_tensor(labels):
|
403 |
+
labels = torch.tensor(labels, dtype=torch.int64)
|
404 |
+
|
405 |
+
# If latents are passed with the layer dimension we need
|
406 |
+
# to flatten it to shape (N, latent_size) before passing
|
407 |
+
# it to the latent mapping model.
|
408 |
+
if latents.dim() == 3:
|
409 |
+
num_latents = latents.size(1)
|
410 |
+
latents = latents.view(-1, latents.size(-1))
|
411 |
+
# Labels need to repeated for the extra dimension of latents.
|
412 |
+
if labels is not None:
|
413 |
+
labels = labels.unsqueeze(1).repeat(1, num_latents).view(-1)
|
414 |
+
|
415 |
+
# Dont allow this operation to create a computation graph for
|
416 |
+
# backprop unless specified. This is useful for pathreg as it
|
417 |
+
# only regularizes the parameters of the synthesiser and not
|
418 |
+
# to latent mapping model.
|
419 |
+
with torch.set_grad_enabled(mapping_grad):
|
420 |
+
dlatents = self.G_mapping(latents=latents, labels=labels)
|
421 |
+
else:
|
422 |
+
if dlatents.dim() == 3:
|
423 |
+
num_latents = dlatents.size(1)
|
424 |
+
|
425 |
+
# Now we expand/repeat the number of latents per batch index until it is
|
426 |
+
# the same number as affine layers in our synthesis model.
|
427 |
+
dlatents = dlatents.view(-1, num_latents, dlatents.size(-1))
|
428 |
+
if num_latents == 1:
|
429 |
+
dlatents = dlatents.expand(
|
430 |
+
dlatents.size(0), len(self), dlatents.size(2))
|
431 |
+
elif num_latents != len(self):
|
432 |
+
assert dlatents.size(1) <= len(self), \
|
433 |
+
'More latents ({}) than number '.format(dlatents.size(1)) + \
|
434 |
+
'of generator layers ({}) received.'.format(len(self))
|
435 |
+
if not latent_to_layer_idx:
|
436 |
+
# Lets randomly distribute the latents to
|
437 |
+
# ranges of layers (each latent is assigned
|
438 |
+
# to a random number of consecutive layers).
|
439 |
+
cutoffs = np.random.choice(
|
440 |
+
np.arange(1, len(self)),
|
441 |
+
dlatents.size(1) - 1,
|
442 |
+
replace=False
|
443 |
+
)
|
444 |
+
cutoffs = [0] + sorted(cutoffs.tolist()) + [len(self)]
|
445 |
+
dlatents = [
|
446 |
+
dlatents[:, i].unsqueeze(1).expand(
|
447 |
+
-1, cutoffs[i + 1] - cutoffs[i], dlatents.size(2))
|
448 |
+
for i in range(dlatents.size(1))
|
449 |
+
]
|
450 |
+
dlatents = torch.cat(dlatents, dim=1)
|
451 |
+
else:
|
452 |
+
# Assign latents as specified by argument
|
453 |
+
assert len(latent_to_layer_idx) == len(self), \
|
454 |
+
'The latent index to layer index mapping does ' + \
|
455 |
+
'not have the same number of elements ' + \
|
456 |
+
'({}) as the number of '.format(len(latent_to_layer_idx)) + \
|
457 |
+
'generator layers ({})'.format(len(self))
|
458 |
+
dlatents = dlatents[:, latent_to_layer_idx]
|
459 |
+
|
460 |
+
# Update moving average of dlatents when training
|
461 |
+
if self.training and self.dlatent_avg_beta != 1:
|
462 |
+
with torch.no_grad():
|
463 |
+
batch_dlatent_avg = dlatents[:, 0].mean(dim=0)
|
464 |
+
self.dlatent_avg = utils.lerp(
|
465 |
+
batch_dlatent_avg, self.dlatent_avg, self.dlatent_avg_beta)
|
466 |
+
|
467 |
+
# Truncation is only applied when dlatents are not explicitly
|
468 |
+
# given and the model is in evaluation mode.
|
469 |
+
if truncate and not self.training:
|
470 |
+
dlatents = self.truncate(dlatents)
|
471 |
+
|
472 |
+
# One of the reasons we might want to return the dlatents is for
|
473 |
+
# pathreg, in which case the dlatents need to require gradients
|
474 |
+
# before being passed to the synthesiser. This should only be
|
475 |
+
# the case when the model is in training mode.
|
476 |
+
if return_dlatents and self.training:
|
477 |
+
dlatents.requires_grad_(True)
|
478 |
+
|
479 |
+
synth = self.G_synthesis(latents=dlatents)
|
480 |
+
if return_dlatents:
|
481 |
+
return synth, dlatents
|
482 |
+
return synth
|
483 |
+
|
484 |
+
|
485 |
+
# Base class for the parameterized models. This is used as parent
|
486 |
+
# class to reduce duplicate code and documentation for shared arguments.
|
487 |
+
class _BaseParameterizedModel(_BaseModel):
|
488 |
+
"""
|
489 |
+
activation (str, callable, nn.Module): The non-linear
|
490 |
+
activation function to use.
|
491 |
+
Default value is leaky relu with a slope of 0.2.
|
492 |
+
lr_mul (float): The learning rate multiplier for this
|
493 |
+
model. When loading weights of previously trained
|
494 |
+
networks, this value has to be the same as when
|
495 |
+
the network was trained for the outputs to not
|
496 |
+
change (as this is used to scale the weights).
|
497 |
+
Default value depends on model type and can
|
498 |
+
be found in the original paper for StyleGAN.
|
499 |
+
weight_scale (bool): Use weight scaling for
|
500 |
+
equalized learning rate. Default value
|
501 |
+
is True.
|
502 |
+
eps (float): Epsilon value added for numerical stability.
|
503 |
+
Default value is 1e-8."""
|
504 |
+
|
505 |
+
def __init__(self, **kwargs):
|
506 |
+
super(_BaseParameterizedModel, self).__init__()
|
507 |
+
self._update_default_kwargs(
|
508 |
+
activation='lrelu:0.2',
|
509 |
+
lr_mul=1,
|
510 |
+
weight_scale=True,
|
511 |
+
eps=1e-8
|
512 |
+
)
|
513 |
+
self._update_kwargs(**kwargs)
|
514 |
+
|
515 |
+
|
516 |
+
class GeneratorMapping(_BaseParameterizedModel):
|
517 |
+
"""
|
518 |
+
Latent mapping model, handles the
|
519 |
+
transformation of latents into disentangled
|
520 |
+
latents.
|
521 |
+
Keyword Arguments:
|
522 |
+
latent_size (int): The size of the latent vectors.
|
523 |
+
This will also be the size of the disentangled
|
524 |
+
latent vectors.
|
525 |
+
Default value is 512.
|
526 |
+
label_size (int, optional): The number of different
|
527 |
+
possible labels. Use for label conditioning of
|
528 |
+
the GAN. Unused by default.
|
529 |
+
out_size (int, optional): The size of the disentangled
|
530 |
+
latents output by this model. If not specified,
|
531 |
+
the outputs will have the same size as the input
|
532 |
+
latents.
|
533 |
+
num_layers (int): Number of dense layers in this
|
534 |
+
model. Default value is 8.
|
535 |
+
hidden (int, optional): Number of hidden features of layers.
|
536 |
+
If unspecified, this is the same size as the latents.
|
537 |
+
normalize_input (bool): Normalize the input of this
|
538 |
+
model. Default value is True."""
|
539 |
+
__doc__ += _BaseParameterizedModel.__doc__
|
540 |
+
|
541 |
+
def __init__(self, **kwargs):
|
542 |
+
super(GeneratorMapping, self).__init__()
|
543 |
+
self._update_default_kwargs(
|
544 |
+
latent_size=512,
|
545 |
+
label_size=0,
|
546 |
+
out_size=None,
|
547 |
+
num_layers=8,
|
548 |
+
hidden=None,
|
549 |
+
normalize_input=True,
|
550 |
+
lr_mul=0.01,
|
551 |
+
)
|
552 |
+
self._update_kwargs(**kwargs)
|
553 |
+
|
554 |
+
# Find in and out features of first dense layer
|
555 |
+
in_features = self.latent_size
|
556 |
+
out_features = self.hidden or self.latent_size
|
557 |
+
|
558 |
+
# Each class label has its own embedded vector representation.
|
559 |
+
self.embedding = None
|
560 |
+
if self.label_size:
|
561 |
+
self.embedding = nn.Embedding(self.label_size, self.latent_size)
|
562 |
+
# The input is now the latents concatenated with
|
563 |
+
# the label embeddings.
|
564 |
+
in_features += self.latent_size
|
565 |
+
dense_layers = []
|
566 |
+
for i in range(self.num_layers):
|
567 |
+
if i == self.num_layers - 1:
|
568 |
+
# Set out features for last dense layer
|
569 |
+
out_features = self.out_size or self.latent_size
|
570 |
+
dense_layers.append(
|
571 |
+
modules.BiasActivationWrapper(
|
572 |
+
layer=modules.DenseLayer(
|
573 |
+
in_features=in_features,
|
574 |
+
out_features=out_features,
|
575 |
+
lr_mul=self.lr_mul,
|
576 |
+
weight_scale=self.weight_scale,
|
577 |
+
gain=1
|
578 |
+
),
|
579 |
+
features=out_features,
|
580 |
+
use_bias=True,
|
581 |
+
activation=self.activation,
|
582 |
+
bias_init=0,
|
583 |
+
lr_mul=self.lr_mul,
|
584 |
+
weight_scale=self.weight_scale
|
585 |
+
)
|
586 |
+
)
|
587 |
+
in_features = out_features
|
588 |
+
self.main = nn.Sequential(*dense_layers)
|
589 |
+
|
590 |
+
def forward(self, latents, labels=None):
|
591 |
+
"""
|
592 |
+
Get the disentangled latents from the input latents
|
593 |
+
and optional labels.
|
594 |
+
Arguments:
|
595 |
+
latents (torch.Tensor): Tensor of shape (batch_size, latent_size).
|
596 |
+
labels (torch.Tensor, optional): Labels for conditioning of latents
|
597 |
+
if there are any.
|
598 |
+
Returns:
|
599 |
+
dlatents (torch.Tensor): Disentangled latents of same shape as
|
600 |
+
`latents` argument.
|
601 |
+
"""
|
602 |
+
assert latents.dim() == 2 and latents.size(-1) == self.latent_size, \
|
603 |
+
'Incorrect input shape. Should be ' + \
|
604 |
+
'(batch_size, {}) '.format(self.latent_size) + \
|
605 |
+
'but received {}'.format(tuple(latents.size()))
|
606 |
+
x = latents
|
607 |
+
if labels is not None:
|
608 |
+
assert self.embedding is not None, \
|
609 |
+
'No embedding layer found, please ' + \
|
610 |
+
'specify the number of possible labels ' + \
|
611 |
+
'in the constructor of this class if ' + \
|
612 |
+
'using labels.'
|
613 |
+
assert len(labels) == len(latents), \
|
614 |
+
'Received different number of labels ' + \
|
615 |
+
'({}) and latents ({}).'.format(len(labels), len(latents))
|
616 |
+
if not torch.is_tensor(labels):
|
617 |
+
labels = torch.tensor(labels, dtype=torch.int64)
|
618 |
+
assert labels.dtype == torch.int64, \
|
619 |
+
'Labels should be integer values ' + \
|
620 |
+
'of dtype torch.in64 (long)'
|
621 |
+
y = self.embedding(labels)
|
622 |
+
x = torch.cat([x, y], dim=-1)
|
623 |
+
else:
|
624 |
+
assert self.embedding is None, 'Missing input labels.'
|
625 |
+
if self.normalize_input:
|
626 |
+
x = x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
627 |
+
return self.main(x)
|
628 |
+
|
629 |
+
|
630 |
+
# Base class for the synthesising and discriminating models. This is used as parent
|
631 |
+
# class to reduce duplicate code and documentation for shared arguments.
|
632 |
+
class _BaseAdverserialModel(_BaseParameterizedModel):
|
633 |
+
"""
|
634 |
+
data_channels (int): Number of channels of the data.
|
635 |
+
Default value is 3.
|
636 |
+
base_shape (list, tuple): This is the shape of the feature
|
637 |
+
activations when it is most compact and still has the
|
638 |
+
same number of dims as the data. This is one of the
|
639 |
+
arguments that controls what shape the data will be.
|
640 |
+
the value of each size in the shape is going to double
|
641 |
+
in size for number of `channels` - 1.
|
642 |
+
Example:
|
643 |
+
`data_channels=3`
|
644 |
+
`base_shape=(4, 2)`
|
645 |
+
and 9 `channels` in total will give us a shape of
|
646 |
+
(3, 4 * 2^(9 - 1), 2 * 2^(9 - 1)) which is the
|
647 |
+
same as (3, 1024, 512).
|
648 |
+
Default value is (4, 4).
|
649 |
+
channels (int, list, optional): The channels of each block
|
650 |
+
of layers. If int, this many channel values will be
|
651 |
+
created with sensible default values optimal for image
|
652 |
+
synthesis. If list, the number of blocks in this model
|
653 |
+
will be the same as the number of channels in the list.
|
654 |
+
Default value is the int value 9 which will create the
|
655 |
+
following channels: [32, 32, 64, 128, 256, 512, 512, 512, 512].
|
656 |
+
These are the channel values used in the stylegan2 paper for
|
657 |
+
their FFHQ-trained face generation network.
|
658 |
+
If channels is given as a list it should be in the order:
|
659 |
+
Generator: last layer -> first layer
|
660 |
+
Discriminator: first layer -> last layer
|
661 |
+
resnet (bool): Use resnet connections.
|
662 |
+
Defaults:
|
663 |
+
Generator: False
|
664 |
+
Discriminator: True
|
665 |
+
skip (bool): Use skip connections for data.
|
666 |
+
Defaults:
|
667 |
+
Generator: True
|
668 |
+
Discriminator: False
|
669 |
+
fused_resample (bool): Fuse any up- or downsampling that
|
670 |
+
is paired with a convolutional layer into a strided
|
671 |
+
convolution (transposed if upsampling was used).
|
672 |
+
Default value is True.
|
673 |
+
conv_resample_mode (str): The resample mode of up- or
|
674 |
+
downsampling layers. If `fused_resample=True` only
|
675 |
+
'FIR' and 'none' can be used. Else, 'FIR' or anything
|
676 |
+
that can be passed to torch.nn.functional.interpolate
|
677 |
+
is a valid mode (and 'max' but only for downsampling
|
678 |
+
operations). Default value is 'FIR'.
|
679 |
+
conv_filter (int, list): The filter to use if
|
680 |
+
`conv_resample_mode='FIR'`. If int, a low
|
681 |
+
pass filter of this size will be used. If list,
|
682 |
+
the filter is explicitly specified. If the filter
|
683 |
+
is of a single dimension it will be expanded to
|
684 |
+
the number of dimensions of the data. Default
|
685 |
+
value is a low pass filter of [1, 3, 3, 1].
|
686 |
+
skip_resample_mode (str): If `skip=True`, this
|
687 |
+
mode is used for the resamplings of skip
|
688 |
+
connections of different sizes. Same possible
|
689 |
+
values as `conv_filter` (except 'none', which
|
690 |
+
can not be used). Default value is 'FIR'.
|
691 |
+
skip_filter (int, list): Same description as
|
692 |
+
`conv_filter` but for skip connections.
|
693 |
+
Only used if `skip_resample_mode='FIR'` and
|
694 |
+
`skip=True`. Default value is a low pass
|
695 |
+
filter of [1, 3, 3, 1].
|
696 |
+
kernel_size (int): The size of the convolutional kernels.
|
697 |
+
Default value is 3.
|
698 |
+
conv_pad_mode (str): The padding mode for convolutional
|
699 |
+
layers. Has to be one of 'constant', 'reflect',
|
700 |
+
'replicate' or 'circular'. Default value is
|
701 |
+
'constant'.
|
702 |
+
conv_pad_constant (float): The value to use for conv
|
703 |
+
padding if `conv_pad_mode='constant'`. Default
|
704 |
+
value is 0.
|
705 |
+
filter_pad_mode (str): The padding mode for FIR
|
706 |
+
filters. Same possible values as `conv_pad_mode`.
|
707 |
+
Default value is 'constant'.
|
708 |
+
filter_pad_constant (float): The value to use for FIR
|
709 |
+
padding if `filter_pad_mode='constant'`. Default
|
710 |
+
value is 0.
|
711 |
+
pad_once (bool): If FIR filter is used in conjunction with a
|
712 |
+
conv layer, do all the padding for both convolution and
|
713 |
+
FIR in the FIR layer instead of once per layer.
|
714 |
+
Default value is True.
|
715 |
+
conv_block_size (int): The number of conv layers in
|
716 |
+
each conv block. Default value is 2."""
|
717 |
+
__doc__ += _BaseParameterizedModel.__doc__
|
718 |
+
|
719 |
+
def __init__(self, **kwargs):
|
720 |
+
super(_BaseAdverserialModel, self).__init__()
|
721 |
+
self._update_default_kwargs(
|
722 |
+
data_channels=3,
|
723 |
+
base_shape=(4, 4),
|
724 |
+
channels=9,
|
725 |
+
resnet=False,
|
726 |
+
skip=False,
|
727 |
+
fused_resample=True,
|
728 |
+
conv_resample_mode='FIR',
|
729 |
+
conv_filter=[1, 3, 3, 1],
|
730 |
+
skip_resample_mode='FIR',
|
731 |
+
skip_filter=[1, 3, 3, 1],
|
732 |
+
kernel_size=3,
|
733 |
+
conv_pad_mode='constant',
|
734 |
+
conv_pad_constant=0,
|
735 |
+
filter_pad_mode='constant',
|
736 |
+
filter_pad_constant=0,
|
737 |
+
pad_once=True,
|
738 |
+
conv_block_size=2,
|
739 |
+
)
|
740 |
+
self._update_kwargs(**kwargs)
|
741 |
+
|
742 |
+
self.dim = len(self.base_shape)
|
743 |
+
assert 1 <= self.dim <= 3, '`base_shape` can only have 1, 2 or 3 dimensions.'
|
744 |
+
if isinstance(self.channels, int):
|
745 |
+
# Create the specified number of channel values with sensible
|
746 |
+
# sizes (these values do well for image synthesis).
|
747 |
+
num_channels = self.channels
|
748 |
+
self.channels = [min(32 * 2 ** i, 512) for i in range(min(8, num_channels))]
|
749 |
+
if len(self.channels) < num_channels:
|
750 |
+
self.channels = [32] * (num_channels - len(self.channels)) + self.channels
|
751 |
+
|
752 |
+
|
753 |
+
class GeneratorSynthesis(_BaseAdverserialModel):
|
754 |
+
"""
|
755 |
+
The synthesis model that takes latents and synthesises
|
756 |
+
some data.
|
757 |
+
Keyword Arguments:
|
758 |
+
latent_size (int): The size of the latent vectors.
|
759 |
+
This will also be the size of the disentangled
|
760 |
+
latent vectors.
|
761 |
+
Default value is 512.
|
762 |
+
demodulate (bool): Normalize feature outputs from conv
|
763 |
+
layers. Default value is True.
|
764 |
+
modulate_data_out (bool): Apply style to the data output
|
765 |
+
layers. These layers are projections of the feature
|
766 |
+
maps into the space of the data. Default value is True.
|
767 |
+
noise (bool): Add noise after each conv style layer.
|
768 |
+
Default value is True."""
|
769 |
+
__doc__ += _BaseAdverserialModel.__doc__
|
770 |
+
|
771 |
+
def __init__(self, **kwargs):
|
772 |
+
super(GeneratorSynthesis, self).__init__()
|
773 |
+
self._update_default_kwargs(
|
774 |
+
latent_size=512,
|
775 |
+
demodulate=True,
|
776 |
+
modulate_data_out=True,
|
777 |
+
noise=True,
|
778 |
+
resnet=False,
|
779 |
+
skip=True
|
780 |
+
)
|
781 |
+
self._update_kwargs(**kwargs)
|
782 |
+
|
783 |
+
# The constant input of the model has no activations
|
784 |
+
# normalization, it is just passed straight to the first
|
785 |
+
# layer of the model.
|
786 |
+
self.const = torch.nn.Parameter(
|
787 |
+
torch.empty(self.channels[-1], *self.base_shape).normal_()
|
788 |
+
)
|
789 |
+
conv_block_kwargs = dict(
|
790 |
+
latent_size=self.latent_size,
|
791 |
+
demodulate=self.demodulate,
|
792 |
+
resnet=self.resnet,
|
793 |
+
up=True,
|
794 |
+
num_layers=self.conv_block_size,
|
795 |
+
filter=self.conv_filter,
|
796 |
+
activation=self.activation,
|
797 |
+
mode=self.conv_resample_mode,
|
798 |
+
fused=self.fused_resample,
|
799 |
+
kernel_size=self.kernel_size,
|
800 |
+
pad_mode=self.conv_pad_mode,
|
801 |
+
pad_constant=self.conv_pad_constant,
|
802 |
+
filter_pad_mode=self.filter_pad_mode,
|
803 |
+
filter_pad_constant=self.filter_pad_constant,
|
804 |
+
pad_once=self.pad_once,
|
805 |
+
noise=self.noise,
|
806 |
+
lr_mul=self.lr_mul,
|
807 |
+
weight_scale=self.weight_scale,
|
808 |
+
gain=1,
|
809 |
+
dim=self.dim,
|
810 |
+
eps=self.eps
|
811 |
+
)
|
812 |
+
self.conv_blocks = nn.ModuleList()
|
813 |
+
|
814 |
+
# The first convolutional layer is slightly different
|
815 |
+
# from the following convolutional blocks but can still
|
816 |
+
# be represented as a convolutional block if we change
|
817 |
+
# some of its arguments.
|
818 |
+
self.conv_blocks.append(
|
819 |
+
modules.GeneratorConvBlock(
|
820 |
+
**{
|
821 |
+
**conv_block_kwargs,
|
822 |
+
'in_channels': self.channels[-1],
|
823 |
+
'out_channels': self.channels[-1],
|
824 |
+
'resnet': False,
|
825 |
+
'up': False,
|
826 |
+
'num_layers': 1
|
827 |
+
}
|
828 |
+
)
|
829 |
+
)
|
830 |
+
|
831 |
+
# The rest of the convolutional blocks all look the same
|
832 |
+
# except for number of input and output channels
|
833 |
+
for i in range(1, len(self.channels)):
|
834 |
+
self.conv_blocks.append(
|
835 |
+
modules.GeneratorConvBlock(
|
836 |
+
in_channels=self.channels[-i],
|
837 |
+
out_channels=self.channels[-i - 1],
|
838 |
+
**conv_block_kwargs
|
839 |
+
)
|
840 |
+
)
|
841 |
+
|
842 |
+
# If not using the skip architecture, only one
|
843 |
+
# layer will project the feature maps into
|
844 |
+
# the space of the data (from the activations of
|
845 |
+
# the last convolutional block). If using the skip
|
846 |
+
# architecture, every block will have its
|
847 |
+
# own projection layer instead.
|
848 |
+
self.to_data_layers = nn.ModuleList()
|
849 |
+
for i in range(1, len(self.channels) + 1):
|
850 |
+
to_data = None
|
851 |
+
if i == len(self.channels) or self.skip:
|
852 |
+
to_data = modules.BiasActivationWrapper(
|
853 |
+
layer=modules.ConvLayer(
|
854 |
+
**{
|
855 |
+
**conv_block_kwargs,
|
856 |
+
'in_channels': self.channels[-i],
|
857 |
+
'out_channels': self.data_channels,
|
858 |
+
'modulate': self.modulate_data_out,
|
859 |
+
'demodulate': False,
|
860 |
+
'kernel_size': 1
|
861 |
+
}
|
862 |
+
),
|
863 |
+
**{
|
864 |
+
**conv_block_kwargs,
|
865 |
+
'features': self.data_channels,
|
866 |
+
'use_bias': True,
|
867 |
+
'activation': 'linear',
|
868 |
+
'bias_init': 0
|
869 |
+
}
|
870 |
+
)
|
871 |
+
self.to_data_layers.append(to_data)
|
872 |
+
|
873 |
+
# When the skip architecture is used we need to
|
874 |
+
# upsample data outputs of previous convolutional
|
875 |
+
# blocks so that it can be added to the data output
|
876 |
+
# of the current convolutional block.
|
877 |
+
self.upsample = None
|
878 |
+
if self.skip:
|
879 |
+
self.upsample = modules.Upsample(
|
880 |
+
mode=self.skip_resample_mode,
|
881 |
+
filter=self.skip_filter,
|
882 |
+
filter_pad_mode=self.filter_pad_mode,
|
883 |
+
filter_pad_constant=self.filter_pad_constant,
|
884 |
+
gain=1,
|
885 |
+
dim=self.dim
|
886 |
+
)
|
887 |
+
|
888 |
+
# Calculate the number of latents required
|
889 |
+
# in the input.
|
890 |
+
self._num_latents = 1 + self.conv_block_size * (len(self.channels) - 1)
|
891 |
+
# Only the final data output layer uses
|
892 |
+
# its own latent input when being modulated.
|
893 |
+
# The other data output layers recycles latents
|
894 |
+
# from the next convolutional block.
|
895 |
+
if self.modulate_data_out:
|
896 |
+
self._num_latents += 1
|
897 |
+
|
898 |
+
def __len__(self):
|
899 |
+
"""
|
900 |
+
Get the number of affine (style) layers of this model.
|
901 |
+
"""
|
902 |
+
return self._num_latents
|
903 |
+
|
904 |
+
def random_noise(self):
|
905 |
+
"""
|
906 |
+
Set injected noise to be random for each new input.
|
907 |
+
"""
|
908 |
+
for module in self.modules():
|
909 |
+
if isinstance(module, modules.NoiseInjectionWrapper):
|
910 |
+
module.random_noise()
|
911 |
+
|
912 |
+
def static_noise(self, trainable=False, noise_tensors=None):
|
913 |
+
"""
|
914 |
+
Set up injected noise to be fixed (alternatively trainable).
|
915 |
+
Get the fixed noise tensors (or parameters).
|
916 |
+
Arguments:
|
917 |
+
trainable (bool): Make noise trainable and return
|
918 |
+
parameters instead of normal tensors.
|
919 |
+
noise_tensors (list, optional): List of tensors to use as static noise.
|
920 |
+
Has to be same length as number of noise injection layers.
|
921 |
+
Returns:
|
922 |
+
noise_tensors (list): List of the noise tensors (or parameters).
|
923 |
+
"""
|
924 |
+
rtn_tensors = []
|
925 |
+
|
926 |
+
if not self.noise:
|
927 |
+
return rtn_tensors
|
928 |
+
|
929 |
+
for module in self.modules():
|
930 |
+
if isinstance(module, modules.NoiseInjectionWrapper):
|
931 |
+
has_noise_shape = module.has_noise_shape()
|
932 |
+
device = module.weight.device
|
933 |
+
dtype = module.weight.dtype
|
934 |
+
break
|
935 |
+
|
936 |
+
# If noise layers dont have the shape that the noise should be
|
937 |
+
# we first need to pass some data through the network once for
|
938 |
+
# these layers to record the shape. To create noise tensors
|
939 |
+
# we need to know what size they should be.
|
940 |
+
if not has_noise_shape:
|
941 |
+
with torch.no_grad():
|
942 |
+
self(torch.zeros(
|
943 |
+
1, len(self), self.latent_size, device=device, dtype=dtype))
|
944 |
+
|
945 |
+
i = 0
|
946 |
+
for block in self.conv_blocks:
|
947 |
+
for layer in block.conv_block:
|
948 |
+
for module in layer.modules():
|
949 |
+
if isinstance(module, modules.NoiseInjectionWrapper):
|
950 |
+
noise_tensor = None
|
951 |
+
if noise_tensors is not None:
|
952 |
+
if i < len(noise_tensors):
|
953 |
+
noise_tensor = noise_tensors[i]
|
954 |
+
i += 1
|
955 |
+
else:
|
956 |
+
rtn_tensors.append(None)
|
957 |
+
continue
|
958 |
+
rtn_tensors.append(
|
959 |
+
module.static_noise(trainable=trainable, noise_tensor=noise_tensor))
|
960 |
+
|
961 |
+
if noise_tensors is not None:
|
962 |
+
assert len(rtn_tensors) == len(noise_tensors), \
|
963 |
+
'Got a list of {} '.format(len(noise_tensors)) + \
|
964 |
+
'noise tensors but there are ' + \
|
965 |
+
'{} noise layers in this model'.format(len(rtn_tensors))
|
966 |
+
|
967 |
+
return rtn_tensors
|
968 |
+
|
969 |
+
def forward(self, latents):
|
970 |
+
"""
|
971 |
+
Synthesise some data from input latents.
|
972 |
+
Arguments:
|
973 |
+
latents (torch.Tensor): Latent vectors of shape
|
974 |
+
(batch_size, num_affine_layers, latent_size)
|
975 |
+
where num_affine_layers is the value returned
|
976 |
+
by __len__() of this class.
|
977 |
+
Returns:
|
978 |
+
synthesised (torch.Tensor): Synthesised data.
|
979 |
+
"""
|
980 |
+
assert latents.dim() == 3 and latents.size(1) == len(self), \
|
981 |
+
'Input mismatch, expected latents of shape ' + \
|
982 |
+
'(batch_size, {}, latent_size) '.format(len(self)) + \
|
983 |
+
'but got {}.'.format(tuple(latents.size()))
|
984 |
+
# Declare our feature activations variable
|
985 |
+
# and give it the value of our const parameter with
|
986 |
+
# an added batch dimension.
|
987 |
+
x = self.const.unsqueeze(0)
|
988 |
+
# Declare our data (output) variable
|
989 |
+
y = None
|
990 |
+
# Start counting style layers used. This is used for specifying
|
991 |
+
# which latents should be passed to the current block in the loop.
|
992 |
+
layer_idx = 0
|
993 |
+
for block, to_data in zip(self.conv_blocks, self.to_data_layers):
|
994 |
+
# Get the latents for the style layers in this block.
|
995 |
+
block_latents = latents[:, layer_idx:layer_idx + len(block)]
|
996 |
+
|
997 |
+
x = block(input=x, latents=block_latents)
|
998 |
+
|
999 |
+
layer_idx += len(block)
|
1000 |
+
|
1001 |
+
# Upsample the data output of the previous block to fit
|
1002 |
+
# the data output size of this block so that they can
|
1003 |
+
# be added together. Only performed for 'skip' architectures.
|
1004 |
+
if self.upsample is not None and layer_idx < len(self):
|
1005 |
+
if y is not None:
|
1006 |
+
y = self.upsample(y)
|
1007 |
+
|
1008 |
+
# Combine the data output of this block with any previous
|
1009 |
+
# blocks outputs if using 'skip' architecture, else only
|
1010 |
+
# perform this operation for the very last block outputs.
|
1011 |
+
if to_data is not None:
|
1012 |
+
t = to_data(input=x, latent=latents[:, layer_idx])
|
1013 |
+
y = t if y is None else y + t
|
1014 |
+
return y
|
1015 |
+
|
1016 |
+
|
1017 |
+
class Discriminator(_BaseAdverserialModel):
|
1018 |
+
"""
|
1019 |
+
The discriminator scores data inputs.
|
1020 |
+
Keyword Arguments:
|
1021 |
+
label_size (int, optional): The number of different
|
1022 |
+
possible labels. Use for label conditioning of
|
1023 |
+
the GAN. The discriminator will calculate scores
|
1024 |
+
for each possible label and only returns the score
|
1025 |
+
from the label passed with the input data. If no
|
1026 |
+
labels are used, only one score is calculated.
|
1027 |
+
Disabled by default.
|
1028 |
+
mbstd_group_size (int): Group size for minibatch std
|
1029 |
+
before the final conv layer. A value of 0 indicates
|
1030 |
+
not to use minibatch std, and a value of -1 indicates
|
1031 |
+
that the group should be over the entire batch.
|
1032 |
+
This is used for increasing variety of the outputs of
|
1033 |
+
the generator. Default value is 4.
|
1034 |
+
NOTE: Scores for the same data may vary depending
|
1035 |
+
on batch size when using a value of -1.
|
1036 |
+
NOTE: If a value > 0 is given, every input batch
|
1037 |
+
must have a size evenly divisible by this value.
|
1038 |
+
dense_hidden (int, optional): The number of hidden features
|
1039 |
+
of the first dense layer. By default, this is the same as
|
1040 |
+
the number of channels in the final conv layer."""
|
1041 |
+
__doc__ += _BaseAdverserialModel.__doc__
|
1042 |
+
|
1043 |
+
def __init__(self, **kwargs):
|
1044 |
+
super(Discriminator, self).__init__()
|
1045 |
+
self._update_default_kwargs(
|
1046 |
+
label_size=0,
|
1047 |
+
mbstd_group_size=4,
|
1048 |
+
dense_hidden=None,
|
1049 |
+
resnet=True,
|
1050 |
+
skip=False
|
1051 |
+
)
|
1052 |
+
self._update_kwargs(**kwargs)
|
1053 |
+
|
1054 |
+
conv_block_kwargs = dict(
|
1055 |
+
resnet=self.resnet,
|
1056 |
+
down=True,
|
1057 |
+
num_layers=self.conv_block_size,
|
1058 |
+
filter=self.conv_filter,
|
1059 |
+
activation=self.activation,
|
1060 |
+
mode=self.conv_resample_mode,
|
1061 |
+
fused=self.fused_resample,
|
1062 |
+
kernel_size=self.kernel_size,
|
1063 |
+
pad_mode=self.conv_pad_mode,
|
1064 |
+
pad_constant=self.conv_pad_constant,
|
1065 |
+
filter_pad_mode=self.filter_pad_mode,
|
1066 |
+
filter_pad_constant=self.filter_pad_constant,
|
1067 |
+
pad_once=self.pad_once,
|
1068 |
+
noise=False,
|
1069 |
+
lr_mul=self.lr_mul,
|
1070 |
+
weight_scale=self.weight_scale,
|
1071 |
+
gain=1,
|
1072 |
+
dim=self.dim,
|
1073 |
+
eps=self.eps
|
1074 |
+
)
|
1075 |
+
self.conv_blocks = nn.ModuleList()
|
1076 |
+
|
1077 |
+
# All but the last of the convolutional blocks look the same
|
1078 |
+
# except for number of input and output channels
|
1079 |
+
for i in range(len(self.channels) - 1):
|
1080 |
+
self.conv_blocks.append(
|
1081 |
+
modules.DiscriminatorConvBlock(
|
1082 |
+
in_channels=self.channels[i],
|
1083 |
+
out_channels=self.channels[i + 1],
|
1084 |
+
**conv_block_kwargs
|
1085 |
+
)
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
# The final convolutional layer is slightly different
|
1089 |
+
# from the previous convolutional blocks but can still
|
1090 |
+
# be represented as a convolutional block if we change
|
1091 |
+
# some of its arguments and optionally add a minibatch
|
1092 |
+
# std layer before it.
|
1093 |
+
final_conv_block = []
|
1094 |
+
if self.mbstd_group_size:
|
1095 |
+
final_conv_block.append(
|
1096 |
+
modules.MinibatchStd(
|
1097 |
+
group_size=self.mbstd_group_size,
|
1098 |
+
eps=self.eps
|
1099 |
+
)
|
1100 |
+
)
|
1101 |
+
final_conv_block.append(
|
1102 |
+
modules.DiscriminatorConvBlock(
|
1103 |
+
**{
|
1104 |
+
**conv_block_kwargs,
|
1105 |
+
'in_channels': self.channels[-1] + (1 if self.mbstd_group_size else 0),
|
1106 |
+
'out_channels': self.channels[-1],
|
1107 |
+
'resnet': False,
|
1108 |
+
'down': False,
|
1109 |
+
'num_layers': 1
|
1110 |
+
},
|
1111 |
+
)
|
1112 |
+
)
|
1113 |
+
self.conv_blocks.append(nn.Sequential(*final_conv_block))
|
1114 |
+
|
1115 |
+
# If not using the skip architecture, only one
|
1116 |
+
# layer will project the data into feature maps.
|
1117 |
+
# This would be performed only for the input data at
|
1118 |
+
# the first block.
|
1119 |
+
# If using the skip architecture, every block will
|
1120 |
+
# have its own projection layer instead.
|
1121 |
+
self.from_data_layers = nn.ModuleList()
|
1122 |
+
for i in range(len(self.channels)):
|
1123 |
+
from_data = None
|
1124 |
+
if i == 0 or self.skip:
|
1125 |
+
from_data = modules.BiasActivationWrapper(
|
1126 |
+
layer=modules.ConvLayer(
|
1127 |
+
**{
|
1128 |
+
**conv_block_kwargs,
|
1129 |
+
'in_channels': self.data_channels,
|
1130 |
+
'out_channels': self.channels[i],
|
1131 |
+
'modulate': False,
|
1132 |
+
'demodulate': False,
|
1133 |
+
'kernel_size': 1
|
1134 |
+
}
|
1135 |
+
),
|
1136 |
+
**{
|
1137 |
+
**conv_block_kwargs,
|
1138 |
+
'features': self.channels[i],
|
1139 |
+
'use_bias': True,
|
1140 |
+
'activation': self.activation,
|
1141 |
+
'bias_init': 0
|
1142 |
+
}
|
1143 |
+
)
|
1144 |
+
self.from_data_layers.append(from_data)
|
1145 |
+
|
1146 |
+
# When the skip architecture is used we need to
|
1147 |
+
# downsample the data input so that it has the same
|
1148 |
+
# size as the feature maps of each block so that it
|
1149 |
+
# can be projected and added to these feature maps.
|
1150 |
+
self.downsample = None
|
1151 |
+
if self.skip:
|
1152 |
+
self.downsample = modules.Downsample(
|
1153 |
+
mode=self.skip_resample_mode,
|
1154 |
+
filter=self.skip_filter,
|
1155 |
+
filter_pad_mode=self.filter_pad_mode,
|
1156 |
+
filter_pad_constant=self.filter_pad_constant,
|
1157 |
+
gain=1,
|
1158 |
+
dim=self.dim
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
# The final layers are two dense layers that maps
|
1162 |
+
# the features into score logits. If labels are
|
1163 |
+
# used, we instead output one score for each possible
|
1164 |
+
# class of the labels and then return the score for the
|
1165 |
+
# labeled class.
|
1166 |
+
dense_layers = []
|
1167 |
+
in_features = self.channels[-1] * np.prod(self.base_shape)
|
1168 |
+
out_features = self.dense_hidden or self.channels[-1]
|
1169 |
+
activation = self.activation
|
1170 |
+
for _ in range(2):
|
1171 |
+
dense_layers.append(
|
1172 |
+
modules.BiasActivationWrapper(
|
1173 |
+
layer=modules.DenseLayer(
|
1174 |
+
in_features=in_features,
|
1175 |
+
out_features=out_features,
|
1176 |
+
lr_mul=self.lr_mul,
|
1177 |
+
weight_scale=self.weight_scale,
|
1178 |
+
gain=1,
|
1179 |
+
),
|
1180 |
+
features=out_features,
|
1181 |
+
activation=activation,
|
1182 |
+
use_bias=True,
|
1183 |
+
bias_init=0,
|
1184 |
+
lr_mul=self.lr_mul,
|
1185 |
+
weight_scale=self.weight_scale
|
1186 |
+
)
|
1187 |
+
)
|
1188 |
+
in_features = out_features
|
1189 |
+
out_features = max(1, self.label_size)
|
1190 |
+
activation = 'linear'
|
1191 |
+
self.dense = nn.Sequential(*dense_layers)
|
1192 |
+
|
1193 |
+
def forward(self, input, labels=None):
|
1194 |
+
"""
|
1195 |
+
Takes some data and optionally its labels and
|
1196 |
+
produces one score logit per data input.
|
1197 |
+
Arguments:
|
1198 |
+
input (torch.Tensor)
|
1199 |
+
labels (torch.Tensor, list, optional)
|
1200 |
+
Returns:
|
1201 |
+
score_logits (torch.Tensor)
|
1202 |
+
"""
|
1203 |
+
# Declare our feature activations variable.
|
1204 |
+
x = None
|
1205 |
+
# Declare our data (input) variable
|
1206 |
+
y = input
|
1207 |
+
for i, (block, from_data) in enumerate(zip(self.conv_blocks, self.from_data_layers)):
|
1208 |
+
# Combine the data input of this block with any previous
|
1209 |
+
# block output if using 'skip' architecture, else only
|
1210 |
+
# perform this operation as a way to create inputs for
|
1211 |
+
# the first block.
|
1212 |
+
if from_data is not None:
|
1213 |
+
t = from_data(y)
|
1214 |
+
x = t if x is None else x + t
|
1215 |
+
|
1216 |
+
x = block(input=x)
|
1217 |
+
|
1218 |
+
# Downsample the data input of this block to fit
|
1219 |
+
# the feature size of the output of this block so that they can
|
1220 |
+
# be added together. Only performed for 'skip' architectures.
|
1221 |
+
if self.downsample is not None and i != len(self.conv_blocks) - 1:
|
1222 |
+
y = self.downsample(y)
|
1223 |
+
# Calculate scores
|
1224 |
+
x = x.view(x.size(0), -1)
|
1225 |
+
x = self.dense(x)
|
1226 |
+
if labels is not None:
|
1227 |
+
# Use advanced indexing to fetch only the score of the
|
1228 |
+
# class labels.
|
1229 |
+
x = x[torch.arange(x.size(0)), labels].unsqueeze(-1)
|
1230 |
+
return x
|
stylegan2/modules.py
ADDED
@@ -0,0 +1,1601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def get_activation(activation):
|
8 |
+
"""
|
9 |
+
Get the module for a specific activation function and its gain if
|
10 |
+
it can be calculated.
|
11 |
+
Arguments:
|
12 |
+
activation (str, callable, nn.Module): String representing the activation.
|
13 |
+
Returns:
|
14 |
+
activation_module (torch.nn.Module): The module representing
|
15 |
+
the activation function.
|
16 |
+
gain (float): The gain value. Defaults to 1 if it can not be calculated.
|
17 |
+
"""
|
18 |
+
if isinstance(activation, nn.Module) or callable(activation):
|
19 |
+
return activation, 1.
|
20 |
+
if isinstance(activation, str):
|
21 |
+
activation = activation.lower()
|
22 |
+
if activation in [None, 'linear']:
|
23 |
+
return nn.Identity(), 1.
|
24 |
+
lrelu_strings = ('leaky', 'leakyrely', 'leaky_relu', 'leaky relu', 'lrelu')
|
25 |
+
if activation.startswith(lrelu_strings):
|
26 |
+
for l_s in lrelu_strings:
|
27 |
+
activation = activation.replace(l_s, '')
|
28 |
+
slope = ''.join(
|
29 |
+
char for char in activation if char.isdigit() or char == '.')
|
30 |
+
slope = float(slope) if slope else 0.01
|
31 |
+
return nn.LeakyReLU(slope), np.sqrt(2) # close enough to true gain
|
32 |
+
elif activation.startswith('swish'):
|
33 |
+
return Swish(affine=activation != 'swish'), np.sqrt(2)
|
34 |
+
elif activation in ['relu']:
|
35 |
+
return nn.ReLU(), np.sqrt(2)
|
36 |
+
elif activation in ['elu']:
|
37 |
+
return nn.ELU(), 1.
|
38 |
+
elif activation in ['prelu']:
|
39 |
+
return nn.PReLU(), np.sqrt(2)
|
40 |
+
elif activation in ['rrelu', 'randomrelu']:
|
41 |
+
return nn.RReLU(), np.sqrt(2)
|
42 |
+
elif activation in ['selu']:
|
43 |
+
return nn.SELU(), 1.
|
44 |
+
elif activation in ['softplus']:
|
45 |
+
return nn.Softplus(), 1
|
46 |
+
elif activation in ['softsign']:
|
47 |
+
return nn.Softsign(), 1 # unsure about this gain
|
48 |
+
elif activation in ['sigmoid', 'logistic']:
|
49 |
+
return nn.Sigmoid(), 1.
|
50 |
+
elif activation in ['tanh']:
|
51 |
+
return nn.Tanh(), 1.
|
52 |
+
else:
|
53 |
+
raise ValueError(
|
54 |
+
'Activation "{}" not available.'.format(activation)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
class Swish(nn.Module):
|
59 |
+
"""
|
60 |
+
Performs the 'Swish' non-linear activation function.
|
61 |
+
https://arxiv.org/pdf/1710.05941.pdf
|
62 |
+
Arguments:
|
63 |
+
affine (bool): Multiply the input to sigmoid
|
64 |
+
with a learnable scale. Default value is False.
|
65 |
+
"""
|
66 |
+
def __init__(self, affine=False):
|
67 |
+
super(Swish, self).__init__()
|
68 |
+
if affine:
|
69 |
+
self.beta = nn.Parameter(torch.tensor([1.]))
|
70 |
+
self.affine = affine
|
71 |
+
|
72 |
+
def forward(self, input, *args, **kwargs):
|
73 |
+
"""
|
74 |
+
Apply the swish non-linear activation function
|
75 |
+
and return the results.
|
76 |
+
Arguments:
|
77 |
+
input (torch.Tensor)
|
78 |
+
Returns:
|
79 |
+
output (torch.Tensor)
|
80 |
+
"""
|
81 |
+
x = input
|
82 |
+
if self.affine:
|
83 |
+
x *= self.beta
|
84 |
+
return x * torch.sigmoid(x)
|
85 |
+
|
86 |
+
|
87 |
+
def _get_weight_and_coef(shape, lr_mul=1, weight_scale=True, gain=1, fill=None):
|
88 |
+
"""
|
89 |
+
Get an intialized weight and its runtime coefficients as an nn.Parameter tensor.
|
90 |
+
Arguments:
|
91 |
+
shape (tuple, list): Shape of weight tensor.
|
92 |
+
lr_mul (float): The learning rate multiplier for
|
93 |
+
this weight. Default value is 1.
|
94 |
+
weight_scale (bool): Use weight scaling for equalized
|
95 |
+
learning rate. Default value is True.
|
96 |
+
gain (float): The gain of the weight. Default value is 1.
|
97 |
+
fill (float, optional): Instead of initializing the weight
|
98 |
+
with scaled normally distributed values, fill it with
|
99 |
+
this value. Useful for bias weights.
|
100 |
+
Returns:
|
101 |
+
weight (nn.Parameter)
|
102 |
+
"""
|
103 |
+
fan_in = np.prod(shape[1:])
|
104 |
+
he_std = gain / np.sqrt(fan_in)
|
105 |
+
|
106 |
+
if weight_scale:
|
107 |
+
init_std = 1 / lr_mul
|
108 |
+
runtime_coef = he_std * lr_mul
|
109 |
+
else:
|
110 |
+
init_std = he_std / lr_mul
|
111 |
+
runtime_coef = lr_mul
|
112 |
+
|
113 |
+
weight = torch.empty(*shape)
|
114 |
+
if fill is None:
|
115 |
+
weight.normal_(0, init_std)
|
116 |
+
else:
|
117 |
+
weight.fill_(fill)
|
118 |
+
return nn.Parameter(weight), runtime_coef
|
119 |
+
|
120 |
+
|
121 |
+
def _apply_conv(input, *args, transpose=False, **kwargs):
|
122 |
+
"""
|
123 |
+
Perform a 1d, 2d or 3d convolution with specified
|
124 |
+
positional and keyword arguments. Which type of
|
125 |
+
convolution that is used depends on shape of data.
|
126 |
+
Arguments:
|
127 |
+
input (torch.Tensor): The input data for the
|
128 |
+
convolution.
|
129 |
+
*args: Positional arguments for the convolution.
|
130 |
+
Keyword Arguments:
|
131 |
+
transpose (bool): Transpose the convolution.
|
132 |
+
Default value is False
|
133 |
+
**kwargs: Keyword arguments for the convolution.
|
134 |
+
"""
|
135 |
+
dim = input.dim() - 2
|
136 |
+
conv_fn = getattr(
|
137 |
+
F, 'conv{}{}d'.format('_transpose' if transpose else '', dim))
|
138 |
+
return conv_fn(input=input, *args, **kwargs)
|
139 |
+
|
140 |
+
|
141 |
+
def _setup_mod_weight_for_t_conv(weight, in_channels, out_channels):
|
142 |
+
"""
|
143 |
+
Reshape a modulated conv weight for use with a transposed convolution.
|
144 |
+
Arguments:
|
145 |
+
weight (torch.Tensor)
|
146 |
+
in_channels (int)
|
147 |
+
out_channels (int)
|
148 |
+
Returns:
|
149 |
+
reshaped_weight (torch.Tensor)
|
150 |
+
"""
|
151 |
+
# [BO]I*k -> BOI*k
|
152 |
+
weight = weight.view(
|
153 |
+
-1,
|
154 |
+
out_channels,
|
155 |
+
in_channels,
|
156 |
+
*weight.size()[2:]
|
157 |
+
)
|
158 |
+
# BOI*k -> BIO*k
|
159 |
+
weight = weight.transpose(1, 2)
|
160 |
+
# BIO*k -> [BI]O*k
|
161 |
+
weight = weight.reshape(
|
162 |
+
-1,
|
163 |
+
out_channels,
|
164 |
+
*weight.size()[3:]
|
165 |
+
)
|
166 |
+
return weight
|
167 |
+
|
168 |
+
|
169 |
+
def _setup_filter_kernel(filter_kernel, gain=1, up_factor=1, dim=2):
|
170 |
+
"""
|
171 |
+
Set up a filter kernel and return it as a tensor.
|
172 |
+
Arguments:
|
173 |
+
filter_kernel (int, list, torch.tensor, None): The filter kernel
|
174 |
+
values to use. If this value is an int, a binomial filter of
|
175 |
+
this size is created. If a sequence with a single axis is used,
|
176 |
+
it will be expanded to the number of `dims` specified. If value
|
177 |
+
is None, a filter of values [1, 1] is used.
|
178 |
+
gain (float): Gain of the filter kernel. Default value is 1.
|
179 |
+
up_factor (int): Scale factor. Should only be given for upscaling filters.
|
180 |
+
Default value is 1.
|
181 |
+
dim (int): Number of dimensions of data. Default value is 2.
|
182 |
+
Returns:
|
183 |
+
filter_kernel_tensor (torch.Tensor)
|
184 |
+
"""
|
185 |
+
filter_kernel = filter_kernel or 2
|
186 |
+
if isinstance(filter_kernel, (int, float)):
|
187 |
+
def binomial(n, k):
|
188 |
+
if k in [1, n]:
|
189 |
+
return 1
|
190 |
+
return np.math.factorial(n) / (np.math.factorial(k) * np.math.factorial(n - k))
|
191 |
+
filter_kernel = [binomial(filter_kernel, k) for k in range(1, filter_kernel + 1)]
|
192 |
+
if not torch.is_tensor(filter_kernel):
|
193 |
+
filter_kernel = torch.tensor(filter_kernel)
|
194 |
+
filter_kernel = filter_kernel.float()
|
195 |
+
if filter_kernel.dim() == 1:
|
196 |
+
_filter_kernel = filter_kernel.unsqueeze(0)
|
197 |
+
while filter_kernel.dim() < dim:
|
198 |
+
filter_kernel = torch.matmul(
|
199 |
+
filter_kernel.unsqueeze(-1), _filter_kernel)
|
200 |
+
assert all(filter_kernel.size(0) == s for s in filter_kernel.size())
|
201 |
+
filter_kernel /= filter_kernel.sum()
|
202 |
+
filter_kernel *= gain * up_factor ** 2
|
203 |
+
return filter_kernel.float()
|
204 |
+
|
205 |
+
|
206 |
+
def _get_layer(layer_class, kwargs, wrap=False, noise=False):
|
207 |
+
"""
|
208 |
+
Create a layer and wrap it in optional
|
209 |
+
noise and/or bias/activation layers.
|
210 |
+
Arguments:
|
211 |
+
layer_class: The class of the layer to construct.
|
212 |
+
kwargs (dict): The keyword arguments to use for constructing
|
213 |
+
the layer and optionally the bias/activaiton layer.
|
214 |
+
wrap (bool): Wrap the layer in an bias/activation layer and
|
215 |
+
optionally a noise injection layer. Default value is False.
|
216 |
+
noise (bool): Inject noise before the bias/activation wrapper.
|
217 |
+
This can only be done when `wrap=True`. Default value is False.
|
218 |
+
"""
|
219 |
+
layer = layer_class(**kwargs)
|
220 |
+
if wrap:
|
221 |
+
if noise:
|
222 |
+
layer = NoiseInjectionWrapper(layer)
|
223 |
+
layer = BiasActivationWrapper(layer, **kwargs)
|
224 |
+
return layer
|
225 |
+
|
226 |
+
|
227 |
+
class BiasActivationWrapper(nn.Module):
|
228 |
+
"""
|
229 |
+
Wrap a module to add bias and non-linear activation
|
230 |
+
to any outputs of that module.
|
231 |
+
Arguments:
|
232 |
+
layer (nn.Module): The module to wrap.
|
233 |
+
features (int, optional): The number of features
|
234 |
+
of the output of the `layer`. This argument
|
235 |
+
has to be specified if `use_bias=True`.
|
236 |
+
use_bias (bool): Add bias to the output.
|
237 |
+
Default value is True.
|
238 |
+
activation (str, nn.Module, callable, optional):
|
239 |
+
non-linear activation function to use.
|
240 |
+
Unused if notspecified.
|
241 |
+
bias_init (float): Value to initialize bias
|
242 |
+
weight with. Default value is 0.
|
243 |
+
lr_mul (float): Learning rate multiplier of
|
244 |
+
the bias weight. Weights are scaled by
|
245 |
+
this value. Default value is 1.
|
246 |
+
weight_scale (float): Scale weights for
|
247 |
+
equalized learning rate.
|
248 |
+
Default value is True.
|
249 |
+
"""
|
250 |
+
def __init__(self,
|
251 |
+
layer,
|
252 |
+
features=None,
|
253 |
+
use_bias=True,
|
254 |
+
activation='linear',
|
255 |
+
bias_init=0,
|
256 |
+
lr_mul=1,
|
257 |
+
weight_scale=True,
|
258 |
+
*args,
|
259 |
+
**kwargs):
|
260 |
+
super(BiasActivationWrapper, self).__init__()
|
261 |
+
self.layer = layer
|
262 |
+
bias = None
|
263 |
+
bias_coef = None
|
264 |
+
if use_bias:
|
265 |
+
assert features, '`features` is required when using bias.'
|
266 |
+
bias, bias_coef = _get_weight_and_coef(
|
267 |
+
shape=[features],
|
268 |
+
lr_mul=lr_mul,
|
269 |
+
weight_scale=False,
|
270 |
+
fill=bias_init
|
271 |
+
)
|
272 |
+
self.register_parameter('bias', bias)
|
273 |
+
self.bias_coef = bias_coef
|
274 |
+
self.act, self.gain = get_activation(activation)
|
275 |
+
|
276 |
+
def forward(self, *args, **kwargs):
|
277 |
+
"""
|
278 |
+
Forward all possitional and keyword arguments
|
279 |
+
to the layer wrapped by this module and add
|
280 |
+
bias (if set) and run through non-linear activation
|
281 |
+
function (if set).
|
282 |
+
Arguments:
|
283 |
+
*args (positional arguments)
|
284 |
+
**kwargs (keyword arguments)
|
285 |
+
Returns:
|
286 |
+
output (torch.Tensor)
|
287 |
+
"""
|
288 |
+
x = self.layer(*args, **kwargs)
|
289 |
+
if self.bias is not None:
|
290 |
+
bias = self.bias.view(1, -1, *[1] * (x.dim() - 2))
|
291 |
+
if self.bias_coef != 1:
|
292 |
+
bias = self.bias_coef * bias
|
293 |
+
x += bias
|
294 |
+
x = self.act(x)
|
295 |
+
if self.gain != 1:
|
296 |
+
x *= self.gain
|
297 |
+
return x
|
298 |
+
|
299 |
+
def extra_repr(self):
|
300 |
+
return 'bias={}'.format(self.bias is not None)
|
301 |
+
|
302 |
+
|
303 |
+
class NoiseInjectionWrapper(nn.Module):
|
304 |
+
"""
|
305 |
+
Wrap a module to add noise scaled by a
|
306 |
+
learnable parameter to any outputs of the
|
307 |
+
wrapped module.
|
308 |
+
Noise is randomized for each output but can
|
309 |
+
be set to static noise by calling `static_noise()`
|
310 |
+
of this object. This can only be done once data
|
311 |
+
has passed through this layer at least once so that
|
312 |
+
the shape of the static noise to create is known.
|
313 |
+
Check if the shape is known by calling `has_noise_shape()`.
|
314 |
+
Arguments:
|
315 |
+
layer (nn.Module): The module to wrap.
|
316 |
+
same_over_batch (bool): Repeat the same
|
317 |
+
noise values over the entire batch
|
318 |
+
instead of creating separate noise
|
319 |
+
values for each entry in the batch.
|
320 |
+
Default value is True.
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(self, layer, same_over_batch=True):
|
324 |
+
super(NoiseInjectionWrapper, self).__init__()
|
325 |
+
self.layer = layer
|
326 |
+
self.weight = torch.nn.Parameter(torch.zeros(1))
|
327 |
+
self.register_buffer('noise_storage', None)
|
328 |
+
self.same_over_batch = same_over_batch
|
329 |
+
self.random_noise()
|
330 |
+
|
331 |
+
def has_noise_shape(self):
|
332 |
+
"""
|
333 |
+
If this module has had data passed through it
|
334 |
+
the noise shape is known and this function returns
|
335 |
+
True. Else False.
|
336 |
+
Returns:
|
337 |
+
noise_shape_known (bool)
|
338 |
+
"""
|
339 |
+
return self.noise_storage is not None
|
340 |
+
|
341 |
+
def random_noise(self):
|
342 |
+
"""
|
343 |
+
Randomize noise for each
|
344 |
+
new output.
|
345 |
+
"""
|
346 |
+
self._fixed_noise = False
|
347 |
+
if isinstance(self.noise_storage, nn.Parameter):
|
348 |
+
noise_storage = self.noise_storage
|
349 |
+
del self.noise_storage
|
350 |
+
self.register_buffer('noise_storage', noise_storage.data)
|
351 |
+
|
352 |
+
def static_noise(self, trainable=False, noise_tensor=None):
|
353 |
+
"""
|
354 |
+
Set up static noise that can optionally be a trainable
|
355 |
+
parameter. Static noise does not change between inputs
|
356 |
+
unless the user has altered its values. Returns the tensor
|
357 |
+
object that stores the static noise.
|
358 |
+
Arguments:
|
359 |
+
trainable (bool): Wrap the static noise tensor in
|
360 |
+
nn.Parameter to make it trainable. The returned
|
361 |
+
tensor will be wrapped.
|
362 |
+
noise_tensor (torch.Tensor, optional): A predefined
|
363 |
+
static noise tensor. If not passed, one will be
|
364 |
+
created.
|
365 |
+
"""
|
366 |
+
assert self.has_noise_shape(), \
|
367 |
+
'Noise shape is unknown'
|
368 |
+
if noise_tensor is None:
|
369 |
+
noise_tensor = self.noise_storage
|
370 |
+
else:
|
371 |
+
noise_tensor = noise_tensor.to(self.weight)
|
372 |
+
if trainable and not isinstance(noise_tensor, nn.Parameter):
|
373 |
+
noise_tensor = nn.Parameter(noise_tensor)
|
374 |
+
if isinstance(self.noise_storage, nn.Parameter) and not trainable:
|
375 |
+
del self.noise_storage
|
376 |
+
self.register_buffer('noise_storage', noise_tensor)
|
377 |
+
else:
|
378 |
+
self.noise_storage = noise_tensor
|
379 |
+
self._fixed_noise = True
|
380 |
+
return noise_tensor
|
381 |
+
|
382 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
383 |
+
r"""Saves module state to `destination` dictionary, containing a state
|
384 |
+
submodule in :meth:`~torch.nn.Module.state_dict`.
|
385 |
+
|
386 |
+
Overridden to ignore the noise storage buffer.
|
387 |
+
|
388 |
+
Arguments:
|
389 |
+
destination (dict): a dict where state will be stored
|
390 |
+
prefix (str): the prefix for parameters and buffers used in this
|
391 |
+
module
|
392 |
+
"""
|
393 |
+
for name, param in self._parameters.items():
|
394 |
+
if name != 'noise_storage' and param is not None:
|
395 |
+
destination[prefix + name] = param if keep_vars else param.data
|
396 |
+
for name, buf in self._buffers.items():
|
397 |
+
if name != 'noise_storage' and buf is not None:
|
398 |
+
destination[prefix + name] = buf if keep_vars else buf.data
|
399 |
+
|
400 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
401 |
+
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
402 |
+
this module, but not its descendants. This is called on every submodule
|
403 |
+
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
404 |
+
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
405 |
+
For state dicts without metadata, :attr:`local_metadata` is empty.
|
406 |
+
Overridden to ignore noise storage buffer.
|
407 |
+
"""
|
408 |
+
key = prefix + 'noise_storage'
|
409 |
+
if key in state_dict:
|
410 |
+
del state_dict[key]
|
411 |
+
return super(NoiseInjectionWrapper, self)._load_from_state_dict(
|
412 |
+
state_dict, prefix, *args, **kwargs)
|
413 |
+
|
414 |
+
def forward(self, *args, **kwargs):
|
415 |
+
"""
|
416 |
+
Forward all possitional and keyword arguments
|
417 |
+
to the layer wrapped by this module and add
|
418 |
+
noise to its outputs before returning them.
|
419 |
+
Arguments:
|
420 |
+
*args (positional arguments)
|
421 |
+
**kwargs (keyword arguments)
|
422 |
+
Returns:
|
423 |
+
output (torch.Tensor)
|
424 |
+
"""
|
425 |
+
x = self.layer(*args, **kwargs)
|
426 |
+
noise_shape = list(x.size())
|
427 |
+
noise_shape[1] = 1
|
428 |
+
if self.same_over_batch:
|
429 |
+
noise_shape[0] = 1
|
430 |
+
if self.noise_storage is None or list(self.noise_storage.size()) != noise_shape:
|
431 |
+
if not self._fixed_noise:
|
432 |
+
self.noise_storage = torch.empty(
|
433 |
+
*noise_shape,
|
434 |
+
dtype=self.weight.dtype,
|
435 |
+
device=self.weight.device
|
436 |
+
)
|
437 |
+
else:
|
438 |
+
assert list(self.noise_storage.size()[2:]) == noise_shape[2:], \
|
439 |
+
'A data size {} has been encountered, '.format(x.size()[2:]) + \
|
440 |
+
'the static noise previously set up does ' + \
|
441 |
+
'not match this size {}'.format(self.noise_storage.size()[2:])
|
442 |
+
assert self.noise_storage.size(0) == 1 or self.noise_storage.size(0) == x.size(0), \
|
443 |
+
'Static noise batch size mismatch! ' + \
|
444 |
+
'Noise batch size: {}, '.format(self.noise_storage.size(0)) + \
|
445 |
+
'input batch size: {}'.format(x.size(0))
|
446 |
+
assert self.noise_storage.size(1) == 1 or self.noise_storage.size(1) == x.size(1), \
|
447 |
+
'Static noise channel size mismatch! ' + \
|
448 |
+
'Noise channel size: {}, '.format(self.noise_storage.size(1)) + \
|
449 |
+
'input channel size: {}'.format(x.size(1))
|
450 |
+
if not self._fixed_noise:
|
451 |
+
self.noise_storage.normal_()
|
452 |
+
x += self.weight * self.noise_storage
|
453 |
+
return x
|
454 |
+
|
455 |
+
def extra_repr(self):
|
456 |
+
return 'static_noise={}'.format(self._fixed_noise)
|
457 |
+
|
458 |
+
|
459 |
+
class FilterLayer(nn.Module):
|
460 |
+
"""
|
461 |
+
Apply a filter by using convolution.
|
462 |
+
Arguments:
|
463 |
+
filter_kernel (torch.Tensor): The filter kernel to use.
|
464 |
+
Should be of shape `dims * (k,)` where `k` is the
|
465 |
+
kernel size and `dims` is the number of data dimensions
|
466 |
+
(excluding batch and channel dimension).
|
467 |
+
stride (int): The stride of the convolution.
|
468 |
+
pad0 (int): Amount to pad start of each data dimension.
|
469 |
+
Default value is 0.
|
470 |
+
pad1 (int): Amount to pad end of each data dimension.
|
471 |
+
Default value is 0.
|
472 |
+
pad_mode (str): The padding mode. Default value is 'constant'.
|
473 |
+
pad_constant (float): The constant value to pad with if
|
474 |
+
`pad_mode='constant'`. Default value is 0.
|
475 |
+
"""
|
476 |
+
def __init__(self,
|
477 |
+
filter_kernel,
|
478 |
+
stride=1,
|
479 |
+
pad0=0,
|
480 |
+
pad1=0,
|
481 |
+
pad_mode='constant',
|
482 |
+
pad_constant=0,
|
483 |
+
*args,
|
484 |
+
**kwargs):
|
485 |
+
super(FilterLayer, self).__init__()
|
486 |
+
dim = filter_kernel.dim()
|
487 |
+
filter_kernel = filter_kernel.view(1, 1, *filter_kernel.size())
|
488 |
+
self.register_buffer('filter_kernel', filter_kernel)
|
489 |
+
self.stride = stride
|
490 |
+
if pad0 == pad1 and (pad0 == 0 or pad_mode == 'constant' and pad_constant == 0):
|
491 |
+
self.fused_pad = True
|
492 |
+
self.padding = pad0
|
493 |
+
else:
|
494 |
+
self.fused_pad = False
|
495 |
+
self.padding = [pad0, pad1] * dim
|
496 |
+
self.pad_mode = pad_mode
|
497 |
+
self.pad_constant = pad_constant
|
498 |
+
|
499 |
+
def forward(self, input, **kwargs):
|
500 |
+
"""
|
501 |
+
Pad the input and run the filter over it
|
502 |
+
before returning the new values.
|
503 |
+
Arguments:
|
504 |
+
input (torch.Tensor)
|
505 |
+
Returns:
|
506 |
+
output (torch.Tensor)
|
507 |
+
"""
|
508 |
+
x = input
|
509 |
+
conv_kwargs = dict(
|
510 |
+
weight=self.filter_kernel.repeat(
|
511 |
+
input.size(1), *[1] * (self.filter_kernel.dim() - 1)),
|
512 |
+
stride=self.stride,
|
513 |
+
groups=input.size(1),
|
514 |
+
)
|
515 |
+
if self.fused_pad:
|
516 |
+
conv_kwargs.update(padding=self.padding)
|
517 |
+
else:
|
518 |
+
x = F.pad(x, self.padding, mode=self.pad_mode, value=self.pad_constant)
|
519 |
+
return _apply_conv(
|
520 |
+
input=x,
|
521 |
+
transpose=False,
|
522 |
+
**conv_kwargs
|
523 |
+
)
|
524 |
+
|
525 |
+
def extra_repr(self):
|
526 |
+
return 'filter_size={}, stride={}'.format(
|
527 |
+
tuple(self.filter_kernel.size()[2:]), self.stride)
|
528 |
+
|
529 |
+
|
530 |
+
class Upsample(nn.Module):
|
531 |
+
"""
|
532 |
+
Performs upsampling without learnable parameters that doubles
|
533 |
+
the size of data.
|
534 |
+
Arguments:
|
535 |
+
mode (str): 'FIR' or one of the valid modes
|
536 |
+
that can be passed to torch.nn.functional.interpolate().
|
537 |
+
filter (int, list, tensor): Filter to use if `mode='FIR'`.
|
538 |
+
Default value is a lowpass filter of values [1, 3, 3, 1].
|
539 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
540 |
+
See `FilterLayer` docstring for more info.
|
541 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
542 |
+
See `FilterLayer` docstring for more info.
|
543 |
+
gain (float): If `mode='FIR'`, this is used with the filter.
|
544 |
+
See `FilterLayer` docstring for more info.
|
545 |
+
dim (int): Dims of data (excluding batch and channel dimensions).
|
546 |
+
Default value is 2.
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(self,
|
550 |
+
mode='FIR',
|
551 |
+
filter=[1, 3, 3, 1],
|
552 |
+
filter_pad_mode='constant',
|
553 |
+
filter_pad_constant=0,
|
554 |
+
gain=1,
|
555 |
+
dim=2,
|
556 |
+
*args,
|
557 |
+
**kwargs):
|
558 |
+
super(Upsample, self).__init__()
|
559 |
+
assert mode != 'max', 'mode \'max\' can only be used for downsampling.'
|
560 |
+
if mode == 'FIR':
|
561 |
+
if filter is None:
|
562 |
+
filter = [1, 1]
|
563 |
+
filter_kernel = _setup_filter_kernel(
|
564 |
+
filter_kernel=filter,
|
565 |
+
gain=gain,
|
566 |
+
up_factor=2,
|
567 |
+
dim=dim
|
568 |
+
)
|
569 |
+
pad = filter_kernel.size(-1) - 1
|
570 |
+
self.filter = FilterLayer(
|
571 |
+
filter_kernel=filter_kernel,
|
572 |
+
pad0=(pad + 1) // 2 + 1,
|
573 |
+
pad1=pad // 2,
|
574 |
+
pad_mode=filter_pad_mode,
|
575 |
+
pad_constant=filter_pad_constant
|
576 |
+
)
|
577 |
+
self.register_buffer('weight', torch.ones(*[1 for _ in range(dim + 2)]))
|
578 |
+
self.mode = mode
|
579 |
+
|
580 |
+
def forward(self, input, **kwargs):
|
581 |
+
"""
|
582 |
+
Upsample inputs.
|
583 |
+
Arguments:
|
584 |
+
input (torch.Tensor)
|
585 |
+
Returns:
|
586 |
+
output (torch.Tensor)
|
587 |
+
"""
|
588 |
+
if self.mode == 'FIR':
|
589 |
+
x = _apply_conv(
|
590 |
+
input=input,
|
591 |
+
weight=self.weight.expand(input.size(1), *self.weight.size()[1:]),
|
592 |
+
groups=input.size(1),
|
593 |
+
stride=2,
|
594 |
+
transpose=True
|
595 |
+
)
|
596 |
+
x = self.filter(x)
|
597 |
+
else:
|
598 |
+
interp_kwargs = dict(scale_factor=2, mode=self.mode)
|
599 |
+
if 'linear' in self.mode or 'cubic' in self.mode:
|
600 |
+
interp_kwargs.update(align_corners=False)
|
601 |
+
x = F.interpolate(input, **interp_kwargs)
|
602 |
+
return x
|
603 |
+
|
604 |
+
def extra_repr(self):
|
605 |
+
return 'resample_mode={}'.format(self.mode)
|
606 |
+
|
607 |
+
|
608 |
+
class Downsample(nn.Module):
|
609 |
+
"""
|
610 |
+
Performs downsampling without learnable parameters that
|
611 |
+
reduces size of data by half.
|
612 |
+
Arguments:
|
613 |
+
mode (str): 'FIR', 'max' or one of the valid modes
|
614 |
+
that can be passed to torch.nn.functional.interpolate().
|
615 |
+
filter (int, list, tensor): Filter to use if `mode='FIR'`.
|
616 |
+
Default value is a lowpass filter of values [1, 3, 3, 1].
|
617 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
618 |
+
See `FilterLayer` docstring for more info.
|
619 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
620 |
+
See `FilterLayer` docstring for more info.
|
621 |
+
gain (float): If `mode='FIR'`, this is used with the filter.
|
622 |
+
See `FilterLayer` docstring for more info.
|
623 |
+
dim (int): Dims of data (excluding batch and channel dimensions).
|
624 |
+
Default value is 2.
|
625 |
+
"""
|
626 |
+
|
627 |
+
def __init__(self,
|
628 |
+
mode='FIR',
|
629 |
+
filter=[1, 3, 3, 1],
|
630 |
+
filter_pad_mode='constant',
|
631 |
+
filter_pad_constant=0,
|
632 |
+
gain=1,
|
633 |
+
dim=2,
|
634 |
+
*args,
|
635 |
+
**kwargs):
|
636 |
+
super(Downsample, self).__init__()
|
637 |
+
if mode == 'FIR':
|
638 |
+
if filter is None:
|
639 |
+
filter = [1, 1]
|
640 |
+
filter_kernel = _setup_filter_kernel(
|
641 |
+
filter_kernel=filter,
|
642 |
+
gain=gain,
|
643 |
+
up_factor=1,
|
644 |
+
dim=dim
|
645 |
+
)
|
646 |
+
pad = filter_kernel.size(-1) - 2
|
647 |
+
pad0 = pad // 2
|
648 |
+
pad1 = pad - pad0
|
649 |
+
self.filter = FilterLayer(
|
650 |
+
filter_kernel=filter_kernel,
|
651 |
+
stride=2,
|
652 |
+
pad0=pad0,
|
653 |
+
pad1=pad1,
|
654 |
+
pad_mode=filter_pad_mode,
|
655 |
+
pad_constant=filter_pad_constant
|
656 |
+
)
|
657 |
+
self.mode = mode
|
658 |
+
|
659 |
+
def forward(self, input, **kwargs):
|
660 |
+
"""
|
661 |
+
Downsample inputs to half its size.
|
662 |
+
Arguments:
|
663 |
+
input (torch.Tensor)
|
664 |
+
Returns:
|
665 |
+
output (torch.Tensor)
|
666 |
+
"""
|
667 |
+
if self.mode == 'FIR':
|
668 |
+
x = self.filter(input)
|
669 |
+
elif self.mode == 'max':
|
670 |
+
return getattr(F, 'max_pool{}d'.format(input.dim() - 2))(input)
|
671 |
+
else:
|
672 |
+
x = F.interpolate(input, scale_factor=0.5, mode=self.mode)
|
673 |
+
return x
|
674 |
+
|
675 |
+
def extra_repr(self):
|
676 |
+
return 'resample_mode={}'.format(self.mode)
|
677 |
+
|
678 |
+
|
679 |
+
class MinibatchStd(nn.Module):
|
680 |
+
"""
|
681 |
+
Adds the aveage std of each data point over a
|
682 |
+
slice of the minibatch to that slice as a new
|
683 |
+
feature map. This gives an output with one extra
|
684 |
+
channel.
|
685 |
+
Arguments:
|
686 |
+
group_size (int): Number of entries in each slice
|
687 |
+
of the batch. If <= 0, the entire batch is used.
|
688 |
+
Default value is 4.
|
689 |
+
eps (float): Epsilon value added for numerical stability.
|
690 |
+
Default value is 1e-8.
|
691 |
+
"""
|
692 |
+
def __init__(self, group_size=4, eps=1e-8, *args, **kwargs):
|
693 |
+
super(MinibatchStd, self).__init__()
|
694 |
+
if group_size is None or group_size <= 0:
|
695 |
+
# Entire batch as group size
|
696 |
+
group_size = 0
|
697 |
+
assert group_size != 1, 'Can not use 1 as minibatch std group size.'
|
698 |
+
self.group_size = group_size
|
699 |
+
self.eps = eps
|
700 |
+
|
701 |
+
def forward(self, input, **kwargs):
|
702 |
+
"""
|
703 |
+
Add a new feature map to the input containing the average
|
704 |
+
standard deviation for each slice.
|
705 |
+
Arguments:
|
706 |
+
input (torch.Tensor)
|
707 |
+
Returns:
|
708 |
+
output (torch.Tensor)
|
709 |
+
"""
|
710 |
+
group_size = self.group_size or input.size(0)
|
711 |
+
assert input.size(0) >= group_size, \
|
712 |
+
'Can not use a smaller batch size ' + \
|
713 |
+
'({}) than the specified '.format(input.size(0)) + \
|
714 |
+
'group size ({}) '.format(group_size) + \
|
715 |
+
'of this minibatch std layer.'
|
716 |
+
assert input.size(0) % group_size == 0, \
|
717 |
+
'Can not use a batch of a size ' + \
|
718 |
+
'({}) that is not '.format(input.size(0)) + \
|
719 |
+
'evenly divisible by the group size ({})'.format(group_size)
|
720 |
+
x = input
|
721 |
+
|
722 |
+
# B = batch size, C = num channels
|
723 |
+
# *s = the size dimensions (height, width for images)
|
724 |
+
|
725 |
+
# BC*s -> G[B/G]C*s
|
726 |
+
y = input.view(group_size, -1, *input.size()[1:])
|
727 |
+
# For numerical stability when training with mixed precision
|
728 |
+
y = y.float()
|
729 |
+
# G[B/G]C*s
|
730 |
+
y -= y.mean(dim=0, keepdim=True)
|
731 |
+
# [B/G]C*s
|
732 |
+
y = torch.mean(y ** 2, dim=0)
|
733 |
+
# [B/G]C*s
|
734 |
+
y = torch.sqrt(y + self.eps)
|
735 |
+
# [B/G]
|
736 |
+
y = torch.mean(y.view(y.size(0), -1), dim=-1)
|
737 |
+
# [B/G]1*1
|
738 |
+
y = y.view(-1, *[1] * (input.dim() - 1))
|
739 |
+
# Cast back to input dtype
|
740 |
+
y = y.to(x)
|
741 |
+
# B1*1
|
742 |
+
y = y.repeat(group_size, *[1] * (y.dim() - 1))
|
743 |
+
# B1*s
|
744 |
+
y = y.expand(y.size(0), 1, *x.size()[2:])
|
745 |
+
# B[C+1]*s
|
746 |
+
x = torch.cat([x, y], dim=1)
|
747 |
+
return x
|
748 |
+
|
749 |
+
def extra_repr(self):
|
750 |
+
return 'group_size={}'.format(self.group_size or '-1')
|
751 |
+
|
752 |
+
|
753 |
+
class DenseLayer(nn.Module):
|
754 |
+
"""
|
755 |
+
A fully connected layer.
|
756 |
+
NOTE: No bias is applied in this layer.
|
757 |
+
Arguments:
|
758 |
+
in_features (int): Number of input features.
|
759 |
+
out_features (int): Number of output features.
|
760 |
+
lr_mul (float): Learning rate multiplier of
|
761 |
+
the weight. Weights are scaled by
|
762 |
+
this value. Default value is 1.
|
763 |
+
weight_scale (float): Scale weights for
|
764 |
+
equalized learning rate.
|
765 |
+
Default value is True.
|
766 |
+
gain (float): The gain of this layer. Default value is 1.
|
767 |
+
"""
|
768 |
+
def __init__(self,
|
769 |
+
in_features,
|
770 |
+
out_features,
|
771 |
+
lr_mul=1,
|
772 |
+
weight_scale=True,
|
773 |
+
gain=1,
|
774 |
+
*args,
|
775 |
+
**kwargs):
|
776 |
+
super(DenseLayer, self).__init__()
|
777 |
+
weight, weight_coef = _get_weight_and_coef(
|
778 |
+
shape=[out_features, in_features],
|
779 |
+
lr_mul=lr_mul,
|
780 |
+
weight_scale=weight_scale,
|
781 |
+
gain=gain
|
782 |
+
)
|
783 |
+
self.register_parameter('weight', weight)
|
784 |
+
self.weight_coef = weight_coef
|
785 |
+
|
786 |
+
def forward(self, input, **kwargs):
|
787 |
+
"""
|
788 |
+
Perform a matrix multiplication of the weight
|
789 |
+
of this layer and the input.
|
790 |
+
Arguments:
|
791 |
+
input (torch.Tensor)
|
792 |
+
Returns:
|
793 |
+
output (torch.Tensor)
|
794 |
+
"""
|
795 |
+
weight = self.weight
|
796 |
+
if self.weight_coef != 1:
|
797 |
+
weight = self.weight_coef * weight
|
798 |
+
return input.matmul(weight.t())
|
799 |
+
|
800 |
+
def extra_repr(self):
|
801 |
+
return 'in_features={}, out_features={}'.format(
|
802 |
+
self.weight.size(1), self.weight.size(0))
|
803 |
+
|
804 |
+
|
805 |
+
class ConvLayer(nn.Module):
|
806 |
+
"""
|
807 |
+
A convolutional layer that can have its outputs
|
808 |
+
modulated (style mod). It can also normalize outputs.
|
809 |
+
These operations are done by modifying the convolutional
|
810 |
+
kernel weight and employing grouped convolutions for
|
811 |
+
efficiency.
|
812 |
+
NOTE: No bias is applied in this layer.
|
813 |
+
NOTE: Amount of padding used is the same as 'SAME'
|
814 |
+
argument in tensorflow for conv padding.
|
815 |
+
Arguments:
|
816 |
+
in_channels (int): Number of input channels.
|
817 |
+
out_channels (int): Number of output channels.
|
818 |
+
latent_size (int, optional): The size of the
|
819 |
+
latents to use for modulating this convolution.
|
820 |
+
Only required when `modulate=True`.
|
821 |
+
modulate (bool): Applies a "style" to the outputs
|
822 |
+
of the layer. The style is given by a latent
|
823 |
+
vector passed with the input to this layer.
|
824 |
+
A dense layer is added that projects the
|
825 |
+
values of the latent into scales for the
|
826 |
+
data channels.
|
827 |
+
Default value is False.
|
828 |
+
demodulate (bool): Normalize std of outputs.
|
829 |
+
Can only be set to True when `modulate=True`.
|
830 |
+
Default value is False.
|
831 |
+
kernel_size (int): The size of the kernel.
|
832 |
+
Default value is 3.
|
833 |
+
pad_mode (str): The padding mode. Default value is 'constant'.
|
834 |
+
pad_constant (float): The constant value to pad with if
|
835 |
+
`pad_mode='constant'`. Default value is 0.
|
836 |
+
lr_mul (float): Learning rate multiplier of
|
837 |
+
the weight. Weights are scaled by
|
838 |
+
this value. Default value is 1.
|
839 |
+
weight_scale (float): Scale weights for
|
840 |
+
equalized learning rate.
|
841 |
+
Default value is True.
|
842 |
+
gain (float): The gain of this layer. Default value is 1.
|
843 |
+
dim (int): Dims of data (excluding batch and channel dimensions).
|
844 |
+
Default value is 2.
|
845 |
+
eps (float): Epsilon value added for numerical stability.
|
846 |
+
Default value is 1e-8.
|
847 |
+
"""
|
848 |
+
def __init__(self,
|
849 |
+
in_channels,
|
850 |
+
out_channels,
|
851 |
+
latent_size=None,
|
852 |
+
modulate=False,
|
853 |
+
demodulate=False,
|
854 |
+
kernel_size=3,
|
855 |
+
pad_mode='constant',
|
856 |
+
pad_constant=0,
|
857 |
+
lr_mul=1,
|
858 |
+
weight_scale=True,
|
859 |
+
gain=1,
|
860 |
+
dim=2,
|
861 |
+
eps=1e-8,
|
862 |
+
*args,
|
863 |
+
**kwargs):
|
864 |
+
super(ConvLayer, self).__init__()
|
865 |
+
assert modulate or not demodulate, '`demodulate=True` can ' + \
|
866 |
+
'only be used when `modulate=True`'
|
867 |
+
if modulate:
|
868 |
+
assert latent_size is not None, 'When using `modulate=True`, ' + \
|
869 |
+
'`latent_size` has to be specified.'
|
870 |
+
kernel_shape = [out_channels, in_channels] + dim * [kernel_size]
|
871 |
+
weight, weight_coef = _get_weight_and_coef(
|
872 |
+
shape=kernel_shape,
|
873 |
+
lr_mul=lr_mul,
|
874 |
+
weight_scale=weight_scale,
|
875 |
+
gain=gain
|
876 |
+
)
|
877 |
+
self.register_parameter('weight', weight)
|
878 |
+
self.weight_coef = weight_coef
|
879 |
+
if modulate:
|
880 |
+
self.dense = BiasActivationWrapper(
|
881 |
+
layer=DenseLayer(
|
882 |
+
in_features=latent_size,
|
883 |
+
out_features=in_channels,
|
884 |
+
lr_mul=lr_mul,
|
885 |
+
weight_scale=weight_scale,
|
886 |
+
gain=1
|
887 |
+
),
|
888 |
+
features=in_channels,
|
889 |
+
use_bias=True,
|
890 |
+
activation='linear',
|
891 |
+
bias_init=1,
|
892 |
+
lr_mul=lr_mul,
|
893 |
+
weight_scale=weight_scale,
|
894 |
+
)
|
895 |
+
self.dense_reshape = [-1, 1, in_channels] + dim * [1]
|
896 |
+
self.dmod_reshape = [-1, out_channels, 1] + dim * [1]
|
897 |
+
pad = (kernel_size - 1)
|
898 |
+
pad0 = pad - pad // 2
|
899 |
+
pad1 = pad - pad0
|
900 |
+
if pad0 == pad1 and (pad0 == 0 or pad_mode == 'constant' and pad_constant == 0):
|
901 |
+
self.fused_pad = True
|
902 |
+
self.padding = pad0
|
903 |
+
else:
|
904 |
+
self.fused_pad = False
|
905 |
+
self.padding = [pad0, pad1] * dim
|
906 |
+
self.pad_mode = pad_mode
|
907 |
+
self.pad_constant = pad_constant
|
908 |
+
self.in_channels = in_channels
|
909 |
+
self.out_channels = out_channels
|
910 |
+
self.latent_size = latent_size
|
911 |
+
self.modulate = modulate
|
912 |
+
self.demodulate = demodulate
|
913 |
+
self.kernel_size = kernel_size
|
914 |
+
self.lr_mul = lr_mul
|
915 |
+
self.weight_scale = weight_scale
|
916 |
+
self.gain = gain
|
917 |
+
self.dim = dim
|
918 |
+
self.eps = eps
|
919 |
+
|
920 |
+
def forward_mod(self, input, latent, weight, **kwargs):
|
921 |
+
"""
|
922 |
+
Run the forward operation with modulation.
|
923 |
+
Automatically called from `forward()` if modulation
|
924 |
+
is enabled.
|
925 |
+
"""
|
926 |
+
assert latent is not None, 'A latent vector is ' + \
|
927 |
+
'required for the forwad pass of a modulated conv layer.'
|
928 |
+
|
929 |
+
# B = batch size, C = num channels
|
930 |
+
# *s = the size dimensions, example: (height, width) for images
|
931 |
+
# *k = sizes of the convolutional kernel excluding in and out channel dimensions.
|
932 |
+
# *1 = multiple dimensions of size 1, with number of dimensions depending on data format.
|
933 |
+
# O = num output channels, I = num input channels
|
934 |
+
|
935 |
+
# BI
|
936 |
+
style_mod = self.dense(input=latent)
|
937 |
+
# B1I*1
|
938 |
+
style_mod = style_mod.view(*self.dense_reshape)
|
939 |
+
# 1OI*k
|
940 |
+
weight = weight.unsqueeze(0)
|
941 |
+
# (1OI*k)x(B1I*1) -> BOI*k
|
942 |
+
weight = weight * style_mod
|
943 |
+
if self.demodulate:
|
944 |
+
# BO
|
945 |
+
dmod = torch.rsqrt(
|
946 |
+
torch.sum(
|
947 |
+
weight.view(
|
948 |
+
weight.size(0),
|
949 |
+
weight.size(1),
|
950 |
+
-1
|
951 |
+
) ** 2,
|
952 |
+
dim=-1
|
953 |
+
) + self.eps
|
954 |
+
)
|
955 |
+
# BO1*1
|
956 |
+
dmod = dmod.view(*self.dmod_reshape)
|
957 |
+
# (BOI*k)x(BO1*1) -> BOI*k
|
958 |
+
weight = weight * dmod
|
959 |
+
# BI*s -> 1[BI]*s
|
960 |
+
x = input.view(1, -1, *input.size()[2:])
|
961 |
+
# BOI*k -> [BO]I*k
|
962 |
+
weight = weight.view(-1, *weight.size()[2:])
|
963 |
+
# 1[BO]*s
|
964 |
+
x = self._process(input=x, weight=weight, groups=input.size(0))
|
965 |
+
# 1[BO]*s -> BO*s
|
966 |
+
x = x.view(-1, self.out_channels, *x.size()[2:])
|
967 |
+
return x
|
968 |
+
|
969 |
+
def forward(self, input, latent=None, **kwargs):
|
970 |
+
"""
|
971 |
+
Convolve the input.
|
972 |
+
Arguments:
|
973 |
+
input (torch.Tensor)
|
974 |
+
latents (torch.Tensor, optional)
|
975 |
+
Returns:
|
976 |
+
output (torch.Tensor)
|
977 |
+
"""
|
978 |
+
weight = self.weight
|
979 |
+
if self.weight_coef != 1:
|
980 |
+
weight = self.weight_coef * weight
|
981 |
+
if self.modulate:
|
982 |
+
return self.forward_mod(input=input, latent=latent, weight=weight)
|
983 |
+
return self._process(input=input, weight=weight)
|
984 |
+
|
985 |
+
def _process(self, input, weight, **kwargs):
|
986 |
+
"""
|
987 |
+
Pad input and convolve it returning the result.
|
988 |
+
"""
|
989 |
+
x = input
|
990 |
+
if self.fused_pad:
|
991 |
+
kwargs.update(padding=self.padding)
|
992 |
+
else:
|
993 |
+
x = F.pad(x, self.padding, mode=self.pad_mode, value=self.pad_constant)
|
994 |
+
return _apply_conv(input=x, weight=weight, transpose=False, **kwargs)
|
995 |
+
|
996 |
+
def extra_repr(self):
|
997 |
+
string = 'in_channels={}, out_channels={}'.format(
|
998 |
+
self.weight.size(1), self.weight.size(0))
|
999 |
+
string += ', modulate={}, demodulate={}'.format(
|
1000 |
+
self.modulate, self.demodulate)
|
1001 |
+
return string
|
1002 |
+
|
1003 |
+
|
1004 |
+
class ConvUpLayer(ConvLayer):
|
1005 |
+
"""
|
1006 |
+
A convolutional upsampling layer that doubles the size of inputs.
|
1007 |
+
Extends the functionality of the `ConvLayer` class.
|
1008 |
+
Arguments:
|
1009 |
+
Same arguments as the `ConvLayer` class.
|
1010 |
+
Class Specific Keyword Arguments:
|
1011 |
+
fused (bool): Fuse the upsampling operation with the
|
1012 |
+
convolution, turning this layer into a strided transposed
|
1013 |
+
convolution. Default value is True.
|
1014 |
+
mode (str): Resample mode, can only be 'FIR' or 'none' if the operation
|
1015 |
+
is fused, otherwise it can also be one of the valid modes
|
1016 |
+
that can be passed to torch.nn.functional.interpolate().
|
1017 |
+
filter (int, list, tensor): Filter to use if `mode='FIR'`.
|
1018 |
+
Default value is a lowpass filter of values [1, 3, 3, 1].
|
1019 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
1020 |
+
See `FilterLayer` docstring for more info.
|
1021 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
1022 |
+
See `FilterLayer` docstring for more info.
|
1023 |
+
pad_once (bool): If FIR filter is used, do all the padding for
|
1024 |
+
both convolution and FIR in the FIR layer instead of once per layer.
|
1025 |
+
Default value is True.
|
1026 |
+
"""
|
1027 |
+
|
1028 |
+
def __init__(self,
|
1029 |
+
*args,
|
1030 |
+
fused=True,
|
1031 |
+
mode='FIR',
|
1032 |
+
filter=[1, 3, 3, 1],
|
1033 |
+
filter_pad_mode='constant',
|
1034 |
+
filter_pad_constant=0,
|
1035 |
+
pad_once=True,
|
1036 |
+
**kwargs):
|
1037 |
+
super(ConvUpLayer, self).__init__(*args, **kwargs)
|
1038 |
+
if fused:
|
1039 |
+
assert mode in ['FIR', 'none'], \
|
1040 |
+
'Fused conv upsample can only use ' + \
|
1041 |
+
'\'FIR\' or \'none\' for resampling ' + \
|
1042 |
+
'(`mode` argument).'
|
1043 |
+
self.padding = np.ceil(self.kernel_size / 2 - 1)
|
1044 |
+
self.output_padding = 2 * (self.padding + 1) - self.kernel_size
|
1045 |
+
if not self.modulate:
|
1046 |
+
# pre-prepare weights only once instead of every forward pass
|
1047 |
+
self.weight = nn.Parameter(self.weight.data.transpose(0, 1).contiguous())
|
1048 |
+
self.filter = None
|
1049 |
+
if mode == 'FIR':
|
1050 |
+
filter_kernel = _setup_filter_kernel(
|
1051 |
+
filter_kernel=filter,
|
1052 |
+
gain=self.gain,
|
1053 |
+
up_factor=2,
|
1054 |
+
dim=self.dim
|
1055 |
+
)
|
1056 |
+
if pad_once:
|
1057 |
+
self.padding = 0
|
1058 |
+
self.output_padding = 0
|
1059 |
+
pad = (filter_kernel.size(-1) - 2) - (self.kernel_size - 1)
|
1060 |
+
pad0 = (pad + 1) // 2 + 1,
|
1061 |
+
pad1 = pad // 2 + 1,
|
1062 |
+
else:
|
1063 |
+
pad = (filter_kernel.size(-1) - 1)
|
1064 |
+
pad0 = pad // 2
|
1065 |
+
pad1 = pad - pad0
|
1066 |
+
self.filter = FilterLayer(
|
1067 |
+
filter_kernel=filter_kernel,
|
1068 |
+
pad0=pad0,
|
1069 |
+
pad1=pad1,
|
1070 |
+
pad_mode=filter_pad_mode,
|
1071 |
+
pad_constant=filter_pad_constant
|
1072 |
+
)
|
1073 |
+
else:
|
1074 |
+
assert mode != 'none', '\'none\' can not be used as ' + \
|
1075 |
+
'sampling `mode` when `fused=False` as upsampling ' + \
|
1076 |
+
'has to be performed separately from the conv layer.'
|
1077 |
+
self.upsample = Upsample(
|
1078 |
+
mode=mode,
|
1079 |
+
filter=filter,
|
1080 |
+
filter_pad_mode=filter_pad_mode,
|
1081 |
+
filter_pad_constant=filter_pad_constant,
|
1082 |
+
channels=self.in_channels,
|
1083 |
+
gain=self.gain,
|
1084 |
+
dim=self.dim
|
1085 |
+
)
|
1086 |
+
self.fused = fused
|
1087 |
+
self.mode = mode
|
1088 |
+
|
1089 |
+
def _process(self, input, weight, **kwargs):
|
1090 |
+
"""
|
1091 |
+
Apply resampling (if enabled) and convolution.
|
1092 |
+
"""
|
1093 |
+
x = input
|
1094 |
+
if self.fused:
|
1095 |
+
if self.modulate:
|
1096 |
+
weight = _setup_mod_weight_for_t_conv(
|
1097 |
+
weight=weight,
|
1098 |
+
in_channels=self.in_channels,
|
1099 |
+
out_channels=self.out_channels
|
1100 |
+
)
|
1101 |
+
pad_out = False
|
1102 |
+
if self.pad_mode == 'constant' and self.pad_constant == 0:
|
1103 |
+
if self.filter is not None or not self.pad_once:
|
1104 |
+
kwargs.update(
|
1105 |
+
padding=self.padding,
|
1106 |
+
output_padding=self.output_padding,
|
1107 |
+
)
|
1108 |
+
elif self.filter is None:
|
1109 |
+
if self.padding:
|
1110 |
+
x = F.pad(
|
1111 |
+
x,
|
1112 |
+
[self.padding] * 2 * self.dim,
|
1113 |
+
mode=self.pad_mode,
|
1114 |
+
value=self.pad_constant
|
1115 |
+
)
|
1116 |
+
pad_out = self.output_padding != 0
|
1117 |
+
kwargs.update(stride=2)
|
1118 |
+
x = _apply_conv(
|
1119 |
+
input=x,
|
1120 |
+
weight=weight,
|
1121 |
+
transpose=True,
|
1122 |
+
**kwargs
|
1123 |
+
)
|
1124 |
+
if pad_out:
|
1125 |
+
x = F.pad(
|
1126 |
+
x,
|
1127 |
+
[self.output_padding, 0] * self.dim,
|
1128 |
+
mode=self.pad_mode,
|
1129 |
+
value=self.pad_constant
|
1130 |
+
)
|
1131 |
+
if self.filter is not None:
|
1132 |
+
x = self.filter(x)
|
1133 |
+
else:
|
1134 |
+
x = super(ConvUpLayer, self)._process(
|
1135 |
+
input=self.upsample(input=x),
|
1136 |
+
weight=weight,
|
1137 |
+
**kwargs
|
1138 |
+
)
|
1139 |
+
return x
|
1140 |
+
|
1141 |
+
def extra_repr(self):
|
1142 |
+
string = super(ConvUpLayer, self).extra_repr()
|
1143 |
+
string += ', fused={}, resample_mode={}'.format(
|
1144 |
+
self.fused, self.mode)
|
1145 |
+
return string
|
1146 |
+
|
1147 |
+
|
1148 |
+
class ConvDownLayer(ConvLayer):
|
1149 |
+
"""
|
1150 |
+
A convolutional downsampling layer that halves the size of inputs.
|
1151 |
+
Extends the functionality of the `ConvLayer` class.
|
1152 |
+
Arguments:
|
1153 |
+
Same arguments as the `ConvLayer` class.
|
1154 |
+
Class Specific Keyword Arguments:
|
1155 |
+
fused (bool): Fuse the downsampling operation with the
|
1156 |
+
convolution, turning this layer into a strided convolution.
|
1157 |
+
Default value is True.
|
1158 |
+
mode (str): Resample mode, can only be 'FIR' or 'none' if the operation
|
1159 |
+
is fused, otherwise it can also be 'max' or one of the valid modes
|
1160 |
+
that can be passed to torch.nn.functional.interpolate().
|
1161 |
+
filter (int, list, tensor): Filter to use if `mode='FIR'`.
|
1162 |
+
Default value is a lowpass filter of values [1, 3, 3, 1].
|
1163 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
1164 |
+
See `FilterLayer` docstring for more info.
|
1165 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
1166 |
+
See `FilterLayer` docstring for more info.
|
1167 |
+
pad_once (bool): If FIR filter is used, do all the padding for
|
1168 |
+
both convolution and FIR in the FIR layer instead of once per layer.
|
1169 |
+
Default value is True.
|
1170 |
+
"""
|
1171 |
+
|
1172 |
+
def __init__(self,
|
1173 |
+
*args,
|
1174 |
+
fused=True,
|
1175 |
+
mode='FIR',
|
1176 |
+
filter=[1, 3, 3, 1],
|
1177 |
+
filter_pad_mode='constant',
|
1178 |
+
filter_pad_constant=0,
|
1179 |
+
pad_once=True,
|
1180 |
+
**kwargs):
|
1181 |
+
super(ConvDownLayer, self).__init__(*args, **kwargs)
|
1182 |
+
if fused:
|
1183 |
+
assert mode in ['FIR', 'none'], \
|
1184 |
+
'Fused conv downsample can only use ' + \
|
1185 |
+
'\'FIR\' or \'none\' for resampling ' + \
|
1186 |
+
'(`mode` argument).'
|
1187 |
+
pad = self.kernel_size - 2
|
1188 |
+
pad0 = pad // 2
|
1189 |
+
pad1 = pad - pad0
|
1190 |
+
if pad0 == pad1 and (pad0 == 0 or self.pad_mode == 'constant' and self.pad_constant == 0):
|
1191 |
+
self.fused_pad = True
|
1192 |
+
self.padding = pad0
|
1193 |
+
else:
|
1194 |
+
self.fused_pad = False
|
1195 |
+
self.padding = [pad0, pad1] * self.dim
|
1196 |
+
self.filter = None
|
1197 |
+
if mode == 'FIR':
|
1198 |
+
filter_kernel = _setup_filter_kernel(
|
1199 |
+
filter_kernel=filter,
|
1200 |
+
gain=self.gain,
|
1201 |
+
up_factor=1,
|
1202 |
+
dim=self.dim
|
1203 |
+
)
|
1204 |
+
if pad_once:
|
1205 |
+
self.fused_pad = True
|
1206 |
+
self.padding = 0
|
1207 |
+
pad = (filter_kernel.size(-1) - 2) + (self.kernel_size - 1)
|
1208 |
+
pad0 = (pad + 1) // 2,
|
1209 |
+
pad1 = pad // 2,
|
1210 |
+
else:
|
1211 |
+
pad = (filter_kernel.size(-1) - 1)
|
1212 |
+
pad0 = pad // 2
|
1213 |
+
pad1 = pad - pad0
|
1214 |
+
self.filter = FilterLayer(
|
1215 |
+
filter_kernel=filter_kernel,
|
1216 |
+
pad0=pad0,
|
1217 |
+
pad1=pad1,
|
1218 |
+
pad_mode=filter_pad_mode,
|
1219 |
+
pad_constant=filter_pad_constant
|
1220 |
+
)
|
1221 |
+
self.pad_once = pad_once
|
1222 |
+
else:
|
1223 |
+
assert mode != 'none', '\'none\' can not be used as ' + \
|
1224 |
+
'sampling `mode` when `fused=False` as downsampling ' + \
|
1225 |
+
'has to be performed separately from the conv layer.'
|
1226 |
+
self.downsample = Downsample(
|
1227 |
+
mode=mode,
|
1228 |
+
filter=filter,
|
1229 |
+
pad_mode=filter_pad_mode,
|
1230 |
+
pad_constant=filter_pad_constant,
|
1231 |
+
channels=self.in_channels,
|
1232 |
+
gain=self.gain,
|
1233 |
+
dim=self.dim
|
1234 |
+
)
|
1235 |
+
self.fused = fused
|
1236 |
+
self.mode = mode
|
1237 |
+
|
1238 |
+
def _process(self, input, weight, **kwargs):
|
1239 |
+
"""
|
1240 |
+
Apply resampling (if enabled) and convolution.
|
1241 |
+
"""
|
1242 |
+
x = input
|
1243 |
+
if self.fused:
|
1244 |
+
kwargs.update(stride=2)
|
1245 |
+
if self.filter is not None:
|
1246 |
+
x = self.filter(input=x)
|
1247 |
+
else:
|
1248 |
+
x = self.downsample(input=x)
|
1249 |
+
x = super(ConvDownLayer, self)._process(
|
1250 |
+
input=x,
|
1251 |
+
weight=weight,
|
1252 |
+
**kwargs
|
1253 |
+
)
|
1254 |
+
return x
|
1255 |
+
|
1256 |
+
def extra_repr(self):
|
1257 |
+
string = super(ConvDownLayer, self).extra_repr()
|
1258 |
+
string += ', fused={}, resample_mode={}'.format(
|
1259 |
+
self.fused, self.mode)
|
1260 |
+
return string
|
1261 |
+
|
1262 |
+
|
1263 |
+
class GeneratorConvBlock(nn.Module):
|
1264 |
+
"""
|
1265 |
+
A convblock for the synthesiser model.
|
1266 |
+
Arguments:
|
1267 |
+
in_channels (int): Number of input channels.
|
1268 |
+
out_channels (int): Number of output channels.
|
1269 |
+
latent_size (int): The size of the latent vectors.
|
1270 |
+
demodulate (bool): Normalize feature outputs from conv
|
1271 |
+
layers. Default value is True.
|
1272 |
+
resnet (bool): Use residual connections. Default value is
|
1273 |
+
False.
|
1274 |
+
up (bool): Upsample the data to twice its size. This is
|
1275 |
+
performed in the first layer of the block. Default
|
1276 |
+
value is False.
|
1277 |
+
num_layers (int): Number of convolutional layers of this
|
1278 |
+
block. Default value is 2.
|
1279 |
+
filter (int, list): The filter to use if
|
1280 |
+
`up=True` and `mode='FIR'`. If int, a low
|
1281 |
+
pass filter of this size will be used. If list,
|
1282 |
+
the filter is explicitly specified. If the filter
|
1283 |
+
is of a single dimension it will be expanded to
|
1284 |
+
the number of dimensions of the data. Default
|
1285 |
+
value is a low pass filter of [1, 3, 3, 1].
|
1286 |
+
activation (str, callable, nn.Module): The non-linear
|
1287 |
+
activation function to use.
|
1288 |
+
Default value is leaky relu with a slope of 0.2.
|
1289 |
+
mode (str): The resample mode of upsampling layers.
|
1290 |
+
Only used when `up=True`. If fused=True` only 'FIR'
|
1291 |
+
and 'none' can be used. Else, anything that can
|
1292 |
+
be passed to torch.nn.functional.interpolate is
|
1293 |
+
a valid mode. Default value is 'FIR'.
|
1294 |
+
fused (bool): If `up=True`, fuse the upsample operation
|
1295 |
+
and the first convolutional layer into a transposed
|
1296 |
+
convolutional layer.
|
1297 |
+
kernel_size (int): Size of the convolutional kernel.
|
1298 |
+
Default value is 3.
|
1299 |
+
pad_mode (str): The padding mode for convolutional
|
1300 |
+
layers. Has to be one of 'constant', 'reflect',
|
1301 |
+
'replicate' or 'circular'. Default value is
|
1302 |
+
'constant'.
|
1303 |
+
pad_constant (float): The value to use for conv
|
1304 |
+
padding if `conv_pad_mode='constant'`. Default
|
1305 |
+
value is 0.
|
1306 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
1307 |
+
Otherwise works the same as `pad_mode`.
|
1308 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
1309 |
+
Otherwise works the same as `pad_constant`
|
1310 |
+
pad_once (bool): If FIR filter is used, do all the padding for
|
1311 |
+
both convolution and FIR in the FIR layer instead of once per layer.
|
1312 |
+
Default value is True.
|
1313 |
+
use_bias (bool): Add bias to layer outputs. Default value is True.
|
1314 |
+
noise (bool): Add noise to the output of each layer. Default value
|
1315 |
+
is True.
|
1316 |
+
lr_mul (float): The learning rate multiplier for this
|
1317 |
+
block. When loading weights of previously trained
|
1318 |
+
networks, this value has to be the same as when
|
1319 |
+
the network was trained for the outputs to not
|
1320 |
+
change (as this is used to scale the weights).
|
1321 |
+
Default value is 1.
|
1322 |
+
weight_scale (bool): Use weight scaling for
|
1323 |
+
equalized learning rate. Default value
|
1324 |
+
is True.
|
1325 |
+
eps (float): Epsilon value added for numerical stability.
|
1326 |
+
Default value is 1e-8.
|
1327 |
+
"""
|
1328 |
+
def __init__(self,
|
1329 |
+
in_channels,
|
1330 |
+
out_channels,
|
1331 |
+
latent_size,
|
1332 |
+
demodulate=True,
|
1333 |
+
resnet=False,
|
1334 |
+
up=False,
|
1335 |
+
num_layers=2,
|
1336 |
+
filter=[1, 3, 3, 1],
|
1337 |
+
activation='leaky:0.2',
|
1338 |
+
mode='FIR',
|
1339 |
+
fused=True,
|
1340 |
+
kernel_size=3,
|
1341 |
+
pad_mode='constant',
|
1342 |
+
pad_constant=0,
|
1343 |
+
filter_pad_mode='constant',
|
1344 |
+
filter_pad_constant=0,
|
1345 |
+
pad_once=True,
|
1346 |
+
use_bias=True,
|
1347 |
+
noise=True,
|
1348 |
+
lr_mul=1,
|
1349 |
+
weight_scale=True,
|
1350 |
+
gain=1,
|
1351 |
+
dim=2,
|
1352 |
+
eps=1e-8,
|
1353 |
+
*args,
|
1354 |
+
**kwargs):
|
1355 |
+
super(GeneratorConvBlock, self).__init__()
|
1356 |
+
layer_kwargs = locals()
|
1357 |
+
layer_kwargs.pop('self')
|
1358 |
+
layer_kwargs.pop('__class__')
|
1359 |
+
layer_kwargs.update(
|
1360 |
+
features=out_channels,
|
1361 |
+
modulate=True,
|
1362 |
+
)
|
1363 |
+
|
1364 |
+
assert num_layers > 0
|
1365 |
+
assert 1 <= dim <= 3, '`dim` can only be 1, 2 or 3.'
|
1366 |
+
if up:
|
1367 |
+
available_sampling = ['FIR']
|
1368 |
+
if fused:
|
1369 |
+
available_sampling.append('none')
|
1370 |
+
else:
|
1371 |
+
available_sampling.append('nearest')
|
1372 |
+
if dim == 1:
|
1373 |
+
available_sampling.append('linear')
|
1374 |
+
elif dim == 2:
|
1375 |
+
available_sampling.append('bilinear')
|
1376 |
+
available_sampling.append('bicubic')
|
1377 |
+
else:
|
1378 |
+
available_sampling.append('trilinear')
|
1379 |
+
assert mode in available_sampling, \
|
1380 |
+
'`mode` {} '.format(mode) + \
|
1381 |
+
'is not one of the available sample ' + \
|
1382 |
+
'modes {}.'.format(available_sampling)
|
1383 |
+
|
1384 |
+
self.conv_block = nn.ModuleList()
|
1385 |
+
|
1386 |
+
while len(self.conv_block) < num_layers:
|
1387 |
+
use_up = up and not self.conv_block
|
1388 |
+
self.conv_block.append(_get_layer(
|
1389 |
+
ConvUpLayer if use_up else ConvLayer, layer_kwargs, wrap=True, noise=noise))
|
1390 |
+
layer_kwargs.update(in_channels=out_channels)
|
1391 |
+
|
1392 |
+
self.projection = None
|
1393 |
+
if resnet:
|
1394 |
+
projection_kwargs = {
|
1395 |
+
**layer_kwargs,
|
1396 |
+
'in_channels': in_channels,
|
1397 |
+
'kernel_size': 1,
|
1398 |
+
'modulate': False,
|
1399 |
+
'demodulate': False
|
1400 |
+
}
|
1401 |
+
self.projection = _get_layer(
|
1402 |
+
ConvUpLayer if up else ConvLayer, projection_kwargs, wrap=False)
|
1403 |
+
|
1404 |
+
self.res_scale = 1 / np.sqrt(2)
|
1405 |
+
|
1406 |
+
def __len__(self):
|
1407 |
+
"""
|
1408 |
+
Get the number of conv layers in this block.
|
1409 |
+
"""
|
1410 |
+
return len(self.conv_block)
|
1411 |
+
|
1412 |
+
def forward(self, input, latents=None, **kwargs):
|
1413 |
+
"""
|
1414 |
+
Run some input through this block and return the output.
|
1415 |
+
Arguments:
|
1416 |
+
input (torch.Tensor)
|
1417 |
+
latents (torch.Tensor)
|
1418 |
+
Returns:
|
1419 |
+
output (torch.Tensor)
|
1420 |
+
"""
|
1421 |
+
if latents.dim() == 2:
|
1422 |
+
latents.unsqueeze(1)
|
1423 |
+
if latents.size(1) == 1:
|
1424 |
+
latents = latents.repeat(1, len(self), 1)
|
1425 |
+
assert latents.size(1) == len(self), \
|
1426 |
+
'Number of latent inputs ' + \
|
1427 |
+
'({}) does not match '.format(latents.size(1)) + \
|
1428 |
+
'number of conv layers ' + \
|
1429 |
+
'({}) in block.'.format(len(self))
|
1430 |
+
x = input
|
1431 |
+
for i, layer in enumerate(self.conv_block):
|
1432 |
+
x = layer(input=x, latent=latents[:, i])
|
1433 |
+
if self.projection is not None:
|
1434 |
+
x += self.projection(input=input)
|
1435 |
+
x *= self.res_scale
|
1436 |
+
return x
|
1437 |
+
|
1438 |
+
|
1439 |
+
class DiscriminatorConvBlock(nn.Module):
|
1440 |
+
"""
|
1441 |
+
A convblock for the discriminator model.
|
1442 |
+
Arguments:
|
1443 |
+
in_channels (int): Number of input channels.
|
1444 |
+
out_channels (int): Number of output channels.
|
1445 |
+
demodulate (bool): Normalize feature outputs from conv
|
1446 |
+
layers. Default value is True.
|
1447 |
+
resnet (bool): Use residual connections. Default value is
|
1448 |
+
False.
|
1449 |
+
down (bool): Downsample the data to twice its size. This is
|
1450 |
+
performed in the last layer of the block. Default
|
1451 |
+
value is False.
|
1452 |
+
num_layers (int): Number of convolutional layers of this
|
1453 |
+
block. Default value is 2.
|
1454 |
+
filter (int, list): The filter to use if
|
1455 |
+
`down=True` and `mode='FIR'`. If int, a low
|
1456 |
+
pass filter of this size will be used. If list,
|
1457 |
+
the filter is explicitly specified. If the filter
|
1458 |
+
is of a single dimension it will be expanded to
|
1459 |
+
the number of dimensions of the data. Default
|
1460 |
+
value is a low pass filter of [1, 3, 3, 1].
|
1461 |
+
activation (str, callable, nn.Module): The non-linear
|
1462 |
+
activation function to use.
|
1463 |
+
Default value is leaky relu with a slope of 0.2.
|
1464 |
+
mode (str): The resample mode of downsampling layers.
|
1465 |
+
Only used when `down=True`. If fused=True` only 'FIR'
|
1466 |
+
and 'none' can be used. Else, 'max' or anything that can
|
1467 |
+
be passed to torch.nn.functional.interpolate is
|
1468 |
+
a valid mode. Default value is 'FIR'.
|
1469 |
+
fused (bool): If `down=True`, fuse the downsample operation
|
1470 |
+
and the last convolutional layer into a strided
|
1471 |
+
convolutional layer.
|
1472 |
+
kernel_size (int): Size of the convolutional kernel.
|
1473 |
+
Default value is 3.
|
1474 |
+
pad_mode (str): The padding mode for convolutional
|
1475 |
+
layers. Has to be one of 'constant', 'reflect',
|
1476 |
+
'replicate' or 'circular'. Default value is
|
1477 |
+
'constant'.
|
1478 |
+
pad_constant (float): The value to use for conv
|
1479 |
+
padding if `conv_pad_mode='constant'`. Default
|
1480 |
+
value is 0.
|
1481 |
+
filter_pad_mode (str): If `mode='FIR'`, this is used with the filter.
|
1482 |
+
Otherwise works the same as `pad_mode`.
|
1483 |
+
filter_pad_constant (float): If `mode='FIR'`, this is used with the filter.
|
1484 |
+
Otherwise works the same as `pad_constant`
|
1485 |
+
pad_once (bool): If FIR filter is used, do all the padding for
|
1486 |
+
both convolution and FIR in the FIR layer instead of once per layer.
|
1487 |
+
Default value is True.
|
1488 |
+
use_bias (bool): Add bias to layer outputs. Default value is True.
|
1489 |
+
lr_mul (float): The learning rate multiplier for this
|
1490 |
+
block. When loading weights of previously trained
|
1491 |
+
networks, this value has to be the same as when
|
1492 |
+
the network was trained for the outputs to not
|
1493 |
+
change (as this is used to scale the weights).
|
1494 |
+
Default value is 1.
|
1495 |
+
weight_scale (bool): Use weight scaling for
|
1496 |
+
equalized learning rate. Default value
|
1497 |
+
is True.
|
1498 |
+
"""
|
1499 |
+
def __init__(self,
|
1500 |
+
in_channels,
|
1501 |
+
out_channels,
|
1502 |
+
resnet=False,
|
1503 |
+
down=False,
|
1504 |
+
num_layers=2,
|
1505 |
+
filter=[1, 3, 3, 1],
|
1506 |
+
activation='leaky:0.2',
|
1507 |
+
mode='FIR',
|
1508 |
+
fused=True,
|
1509 |
+
kernel_size=3,
|
1510 |
+
pad_mode='constant',
|
1511 |
+
pad_constant=0,
|
1512 |
+
filter_pad_mode='constant',
|
1513 |
+
filter_pad_constant=0,
|
1514 |
+
pad_once=True,
|
1515 |
+
use_bias=True,
|
1516 |
+
lr_mul=1,
|
1517 |
+
weight_scale=True,
|
1518 |
+
gain=1,
|
1519 |
+
dim=2,
|
1520 |
+
*args,
|
1521 |
+
**kwargs):
|
1522 |
+
super(DiscriminatorConvBlock, self).__init__()
|
1523 |
+
layer_kwargs = locals()
|
1524 |
+
layer_kwargs.pop('self')
|
1525 |
+
layer_kwargs.pop('__class__')
|
1526 |
+
layer_kwargs.update(
|
1527 |
+
out_channels=in_channels,
|
1528 |
+
features=in_channels,
|
1529 |
+
modulate=False,
|
1530 |
+
demodulate=False
|
1531 |
+
)
|
1532 |
+
|
1533 |
+
assert num_layers > 0
|
1534 |
+
assert 1 <= dim <= 3, '`dim` can only be 1, 2 or 3.'
|
1535 |
+
if down:
|
1536 |
+
available_sampling = ['FIR']
|
1537 |
+
if fused:
|
1538 |
+
available_sampling.append('none')
|
1539 |
+
else:
|
1540 |
+
available_sampling.append('max')
|
1541 |
+
available_sampling.append('area')
|
1542 |
+
available_sampling.append('nearest')
|
1543 |
+
if dim == 1:
|
1544 |
+
available_sampling.append('linear')
|
1545 |
+
elif dim == 2:
|
1546 |
+
available_sampling.append('bilinear')
|
1547 |
+
available_sampling.append('bicubic')
|
1548 |
+
else:
|
1549 |
+
available_sampling.append('trilinear')
|
1550 |
+
assert mode in available_sampling, \
|
1551 |
+
'`mode` {} '.format(mode) + \
|
1552 |
+
'is not one of the available sample ' + \
|
1553 |
+
'modes {}'.format(available_sampling)
|
1554 |
+
|
1555 |
+
self.conv_block = nn.ModuleList()
|
1556 |
+
|
1557 |
+
while len(self.conv_block) < num_layers:
|
1558 |
+
if len(self.conv_block) == num_layers - 1:
|
1559 |
+
layer_kwargs.update(
|
1560 |
+
out_channels=out_channels,
|
1561 |
+
features=out_channels
|
1562 |
+
)
|
1563 |
+
use_down = down and len(self.conv_block) == num_layers - 1
|
1564 |
+
self.conv_block.append(_get_layer(
|
1565 |
+
ConvDownLayer if use_down else ConvLayer, layer_kwargs, wrap=True, noise=False))
|
1566 |
+
|
1567 |
+
self.projection = None
|
1568 |
+
if resnet:
|
1569 |
+
projection_kwargs = {
|
1570 |
+
**layer_kwargs,
|
1571 |
+
'in_channels': in_channels,
|
1572 |
+
'kernel_size': 1,
|
1573 |
+
'modulate': False,
|
1574 |
+
'demodulate': False
|
1575 |
+
}
|
1576 |
+
self.projection = _get_layer(
|
1577 |
+
ConvDownLayer if down else ConvLayer, projection_kwargs, wrap=False)
|
1578 |
+
|
1579 |
+
self.res_scale = 1 / np.sqrt(2)
|
1580 |
+
|
1581 |
+
def __len__(self):
|
1582 |
+
"""
|
1583 |
+
Get the number of conv layers in this block.
|
1584 |
+
"""
|
1585 |
+
return len(self.conv_block)
|
1586 |
+
|
1587 |
+
def forward(self, input, **kwargs):
|
1588 |
+
"""
|
1589 |
+
Run some input through this block and return the output.
|
1590 |
+
Arguments:
|
1591 |
+
input (torch.Tensor)
|
1592 |
+
Returns:
|
1593 |
+
output (torch.Tensor)
|
1594 |
+
"""
|
1595 |
+
x = input
|
1596 |
+
for layer in self.conv_block:
|
1597 |
+
x = layer(input=x)
|
1598 |
+
if self.projection is not None:
|
1599 |
+
x += self.projection(input=input)
|
1600 |
+
x *= self.res_scale
|
1601 |
+
return x
|
stylegan2/project.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from . import models, utils
|
8 |
+
from .external_models import lpips
|
9 |
+
|
10 |
+
|
11 |
+
class Projector(nn.Module):
|
12 |
+
"""
|
13 |
+
Projects data to latent space and noise tensors.
|
14 |
+
Arguments:
|
15 |
+
G (Generator)
|
16 |
+
dlatent_avg_samples (int): Number of dlatent samples
|
17 |
+
to collect to find the mean and std.
|
18 |
+
Default value is 10 000.
|
19 |
+
dlatent_avg_label (int, torch.Tensor, optional): The label to
|
20 |
+
use when gathering dlatent statistics.
|
21 |
+
dlatent_device (int, str, torch.device, optional): Device to use
|
22 |
+
for gathering statistics of dlatents. By default uses
|
23 |
+
the same device as parameters of `G` reside on.
|
24 |
+
dlatent_batch_size (int): The batch size to sample
|
25 |
+
dlatents with. Default value is 1024.
|
26 |
+
lpips_model (nn.Module): A model that returns feature the distance
|
27 |
+
between two inputs. Default value is the LPIPS VGG16 model.
|
28 |
+
lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
|
29 |
+
the data so that its smallest side is the same size as this
|
30 |
+
argument. Only has a default value of 256 if `lpips_model` is unspecified.
|
31 |
+
verbose (bool): Write progress of dlatent statistics gathering to stdout.
|
32 |
+
Default value is True.
|
33 |
+
"""
|
34 |
+
def __init__(self,
|
35 |
+
G,
|
36 |
+
dlatent_avg_samples=10000,
|
37 |
+
dlatent_avg_label=None,
|
38 |
+
dlatent_device=None,
|
39 |
+
dlatent_batch_size=1024,
|
40 |
+
lpips_model=None,
|
41 |
+
lpips_size=None,
|
42 |
+
verbose=True):
|
43 |
+
super(Projector, self).__init__()
|
44 |
+
assert isinstance(G, models.Generator)
|
45 |
+
G.eval().requires_grad_(False)
|
46 |
+
|
47 |
+
self.G_synthesis = G.G_synthesis
|
48 |
+
|
49 |
+
G_mapping = G.G_mapping
|
50 |
+
|
51 |
+
dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples)
|
52 |
+
|
53 |
+
if dlatent_device is None:
|
54 |
+
dlatent_device = next(G_mapping.parameters()).device()
|
55 |
+
else:
|
56 |
+
dlatent_device = torch.device(dlatent_device)
|
57 |
+
|
58 |
+
G_mapping.to(dlatent_device)
|
59 |
+
|
60 |
+
latents = torch.empty(
|
61 |
+
dlatent_avg_samples, G_mapping.latent_size).normal_()
|
62 |
+
dlatents = []
|
63 |
+
|
64 |
+
labels = None
|
65 |
+
if dlatent_avg_label is not None:
|
66 |
+
labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size)
|
67 |
+
|
68 |
+
if verbose:
|
69 |
+
progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size))
|
70 |
+
progress.write('Gathering dlatents...', step=False)
|
71 |
+
|
72 |
+
for i in range(0, dlatent_avg_samples, dlatent_batch_size):
|
73 |
+
batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device)
|
74 |
+
batch_labels = None
|
75 |
+
if labels is not None:
|
76 |
+
batch_labels = labels[:len(batch_latents)]
|
77 |
+
with torch.no_grad():
|
78 |
+
dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu())
|
79 |
+
if verbose:
|
80 |
+
progress.step()
|
81 |
+
|
82 |
+
if verbose:
|
83 |
+
progress.write('Done!', step=False)
|
84 |
+
progress.close()
|
85 |
+
|
86 |
+
dlatents = torch.cat(dlatents, dim=0)
|
87 |
+
|
88 |
+
self.register_buffer(
|
89 |
+
'_dlatent_avg',
|
90 |
+
dlatents.mean(dim=0).view(1, 1, -1)
|
91 |
+
)
|
92 |
+
self.register_buffer(
|
93 |
+
'_dlatent_std',
|
94 |
+
torch.sqrt(
|
95 |
+
torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8
|
96 |
+
).view(1, 1, 1)
|
97 |
+
)
|
98 |
+
|
99 |
+
if lpips_model is None:
|
100 |
+
warnings.warn(
|
101 |
+
'Using default LPIPS distance metric based on VGG 16. ' + \
|
102 |
+
'This metric will only work on image data where values are in ' + \
|
103 |
+
'the range [-1, 1], please specify an lpips module if you want ' + \
|
104 |
+
'to use other kinds of data formats.'
|
105 |
+
)
|
106 |
+
lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
|
107 |
+
lpips_size = 256
|
108 |
+
self.lpips_model = lpips_model.eval().requires_grad_(False)
|
109 |
+
self.lpips_size = lpips_size
|
110 |
+
|
111 |
+
self.to(dlatent_device)
|
112 |
+
|
113 |
+
def _scale_for_lpips(self, data):
|
114 |
+
if not self.lpips_size:
|
115 |
+
return data
|
116 |
+
scale_factor = self.lpips_size / min(data.size()[2:])
|
117 |
+
if scale_factor == 1:
|
118 |
+
return data
|
119 |
+
mode = 'nearest'
|
120 |
+
if scale_factor < 1:
|
121 |
+
mode = 'area'
|
122 |
+
return F.interpolate(data, scale_factor=scale_factor, mode=mode)
|
123 |
+
|
124 |
+
def _check_job(self):
|
125 |
+
assert self._job is not None, 'Call `start()` first to set up target.'
|
126 |
+
# device of dlatent param will not change with the rest of the models
|
127 |
+
# and buffers of this class as it was never registered as a buffer or
|
128 |
+
# parameter. Same goes for optimizer. Make sure it is on the correct device.
|
129 |
+
if self._job.dlatent_param.device != self._dlatent_avg.device:
|
130 |
+
self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg)
|
131 |
+
self._job.opt.load_state_dict(
|
132 |
+
utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0])
|
133 |
+
|
134 |
+
def generate(self):
|
135 |
+
"""
|
136 |
+
Generate an output with the current dlatent and noise values.
|
137 |
+
Returns:
|
138 |
+
output (torch.Tensor)
|
139 |
+
"""
|
140 |
+
self._check_job()
|
141 |
+
with torch.no_grad():
|
142 |
+
return self.G_synthesis(self._job.dlatent_param)
|
143 |
+
|
144 |
+
def get_dlatent(self):
|
145 |
+
"""
|
146 |
+
Get a copy of the current dlatent values.
|
147 |
+
Returns:
|
148 |
+
dlatents (torch.Tensor)
|
149 |
+
"""
|
150 |
+
self._check_job()
|
151 |
+
return self._job.dlatent_param.data.clone()
|
152 |
+
|
153 |
+
def get_noise(self):
|
154 |
+
"""
|
155 |
+
Get a copy of the current noise values.
|
156 |
+
Returns:
|
157 |
+
noise_tensors (list)
|
158 |
+
"""
|
159 |
+
self._check_job()
|
160 |
+
return [noise.data.clone() for noise in self._job.noise_params]
|
161 |
+
|
162 |
+
def start(self,
|
163 |
+
target,
|
164 |
+
num_steps=1000,
|
165 |
+
initial_learning_rate=0.1,
|
166 |
+
initial_noise_factor=0.05,
|
167 |
+
lr_rampdown_length=0.25,
|
168 |
+
lr_rampup_length=0.05,
|
169 |
+
noise_ramp_length=0.75,
|
170 |
+
regularize_noise_weight=1e5,
|
171 |
+
verbose=True,
|
172 |
+
verbose_prefix=''):
|
173 |
+
"""
|
174 |
+
Set up a target and its projection parameters.
|
175 |
+
Arguments:
|
176 |
+
target (torch.Tensor): The data target. This should
|
177 |
+
already be preprocessed (scaled to correct value range).
|
178 |
+
num_steps (int): Number of optimization steps. Default
|
179 |
+
value is 1000.
|
180 |
+
initial_learning_rate (float): Default value is 0.1.
|
181 |
+
initial_noise_factor (float): Default value is 0.05.
|
182 |
+
lr_rampdown_length (float): Default value is 0.25.
|
183 |
+
lr_rampup_length (float): Default value is 0.05.
|
184 |
+
noise_ramp_length (float): Default value is 0.75.
|
185 |
+
regularize_noise_weight (float): Default value is 1e5.
|
186 |
+
verbose (bool): Write progress to stdout every time
|
187 |
+
`step()` is called.
|
188 |
+
verbose_prefix (str, optional): This is written before
|
189 |
+
any other output to stdout.
|
190 |
+
"""
|
191 |
+
if target.dim() == self.G_synthesis.dim + 1:
|
192 |
+
target = target.unsqueeze(0)
|
193 |
+
assert target.dim() == self.G_synthesis.dim + 2, \
|
194 |
+
'Number of dimensions of target data is incorrect.'
|
195 |
+
|
196 |
+
target = target.to(self._dlatent_avg)
|
197 |
+
target_scaled = self._scale_for_lpips(target)
|
198 |
+
|
199 |
+
dlatent_param = nn.Parameter(
|
200 |
+
self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1))
|
201 |
+
noise_params = self.G_synthesis.static_noise(trainable=True)
|
202 |
+
params = [dlatent_param] + noise_params
|
203 |
+
|
204 |
+
opt = torch.optim.Adam(params)
|
205 |
+
|
206 |
+
noise_tensor = torch.empty_like(dlatent_param)
|
207 |
+
|
208 |
+
if verbose:
|
209 |
+
progress = utils.ProgressWriter(num_steps)
|
210 |
+
value_tracker = utils.ValueTracker()
|
211 |
+
|
212 |
+
self._job = utils.AttributeDict(**locals())
|
213 |
+
self._job.current_step = 0
|
214 |
+
|
215 |
+
def step(self, steps=1):
|
216 |
+
"""
|
217 |
+
Take a projection step.
|
218 |
+
Arguments:
|
219 |
+
steps (int): Number of steps to take. If this
|
220 |
+
exceeds the remaining steps of the projection
|
221 |
+
that amount of steps is taken instead. Default
|
222 |
+
value is 1.
|
223 |
+
"""
|
224 |
+
self._check_job()
|
225 |
+
|
226 |
+
remaining_steps = self._job.num_steps - self._job.current_step
|
227 |
+
if not remaining_steps > 0:
|
228 |
+
warnings.warn(
|
229 |
+
'Trying to take a projection step after the ' + \
|
230 |
+
'final projection iteration has been completed.'
|
231 |
+
)
|
232 |
+
if steps < 0:
|
233 |
+
steps = remaining_steps
|
234 |
+
steps = min(remaining_steps, steps)
|
235 |
+
|
236 |
+
if not steps > 0:
|
237 |
+
return
|
238 |
+
|
239 |
+
for _ in range(steps):
|
240 |
+
|
241 |
+
if self._job.current_step >= self._job.num_steps:
|
242 |
+
break
|
243 |
+
|
244 |
+
# Hyperparameters.
|
245 |
+
t = self._job.current_step / self._job.num_steps
|
246 |
+
noise_strength = self._dlatent_std * self._job.initial_noise_factor \
|
247 |
+
* max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2
|
248 |
+
lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length)
|
249 |
+
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
|
250 |
+
lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length)
|
251 |
+
learning_rate = self._job.initial_learning_rate * lr_ramp
|
252 |
+
|
253 |
+
for param_group in self._job.opt.param_groups:
|
254 |
+
param_group['lr'] = learning_rate
|
255 |
+
|
256 |
+
dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_()
|
257 |
+
|
258 |
+
output = self.G_synthesis(dlatents)
|
259 |
+
assert output.size() == self._job.target.size(), \
|
260 |
+
'target size {} does not fit output size {} of generator'.format(
|
261 |
+
target.size(), output.size())
|
262 |
+
|
263 |
+
output_scaled = self._scale_for_lpips(output)
|
264 |
+
|
265 |
+
# Main loss: LPIPS distance of output and target
|
266 |
+
lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled))
|
267 |
+
|
268 |
+
# Calculate noise regularization loss
|
269 |
+
reg_loss = 0
|
270 |
+
for p in self._job.noise_params:
|
271 |
+
size = min(p.size()[2:])
|
272 |
+
dim = p.dim() - 2
|
273 |
+
while True:
|
274 |
+
reg_loss += torch.mean(
|
275 |
+
(p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2)
|
276 |
+
if size <= 8:
|
277 |
+
break
|
278 |
+
p = F.interpolate(p, scale_factor=0.5, mode='area')
|
279 |
+
size = size // 2
|
280 |
+
|
281 |
+
# Combine loss, backward and update params
|
282 |
+
loss = lpips_distance + self._job.regularize_noise_weight * reg_loss
|
283 |
+
self._job.opt.zero_grad()
|
284 |
+
loss.backward()
|
285 |
+
self._job.opt.step()
|
286 |
+
|
287 |
+
# Normalize noise values
|
288 |
+
for p in self._job.noise_params:
|
289 |
+
with torch.no_grad():
|
290 |
+
p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True)
|
291 |
+
p_rstd = torch.rsqrt(
|
292 |
+
torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8)
|
293 |
+
p.data = (p.data - p_mean) * p_rstd
|
294 |
+
|
295 |
+
self._job.current_step += 1
|
296 |
+
|
297 |
+
if self._job.verbose:
|
298 |
+
self._job.value_tracker.add('loss', float(loss))
|
299 |
+
self._job.value_tracker.add('lpips_distance', float(lpips_distance))
|
300 |
+
self._job.value_tracker.add('noise_reg', float(reg_loss))
|
301 |
+
self._job.value_tracker.add('lr', learning_rate, beta=0)
|
302 |
+
self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker))
|
303 |
+
if self._job.current_step >= self._job.num_steps:
|
304 |
+
self._job.progress.close()
|
stylegan2/train.py
ADDED
@@ -0,0 +1,1013 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.tensorboard
|
10 |
+
from torch import nn
|
11 |
+
import torchvision
|
12 |
+
try:
|
13 |
+
import apex
|
14 |
+
from apex import amp
|
15 |
+
except ImportError:
|
16 |
+
pass
|
17 |
+
|
18 |
+
from . import models, utils, loss_fns
|
19 |
+
|
20 |
+
|
21 |
+
class Trainer:
|
22 |
+
"""
|
23 |
+
Class that handles training and logging for stylegan2.
|
24 |
+
For distributed training, the arguments `rank`, `world_size`,
|
25 |
+
`master_addr`, `master_port` can all be given as environmnet variables
|
26 |
+
(only difference is that the keys should be capital cased).
|
27 |
+
Environment variables if available will override any python
|
28 |
+
value for the same argument.
|
29 |
+
Arguments:
|
30 |
+
G (Generator): The generator model.
|
31 |
+
D (Discriminator): The discriminator model.
|
32 |
+
latent_size (int): The size of the latent inputs.
|
33 |
+
dataset (indexable object): The dataset. Has to implement
|
34 |
+
'__getitem__' and '__len__'. If `label_size` > 0, this
|
35 |
+
dataset object has to return both a data entry and its
|
36 |
+
label when calling '__getitem__'.
|
37 |
+
device (str, int, list, torch.device): The device to run training on.
|
38 |
+
Can be a list of integers for parallel training in the same
|
39 |
+
process. Parallel training can also be achieved by spawning
|
40 |
+
seperate processes and using the `rank` argument for each
|
41 |
+
process. In that case, only one device should be specified
|
42 |
+
per process.
|
43 |
+
Gs (Generator, optional): A generator copy with the current
|
44 |
+
moving average of the training generator. If not specified,
|
45 |
+
a copy of the generator is made for the moving average of
|
46 |
+
weights.
|
47 |
+
Gs_beta (float): The beta value for the moving average weights.
|
48 |
+
Default value is 1 / (2 ^(32 / 10000)).
|
49 |
+
Gs_device (str, int, torch.device, optional): The device to store
|
50 |
+
the moving average weights on. If using a different device
|
51 |
+
than what is specified for the `device` argument, updating
|
52 |
+
the moving average weights will take longer as the data
|
53 |
+
will have to be transfered over different devices. If
|
54 |
+
this argument is not specified, the same device is used
|
55 |
+
as specified in the `device` argument.
|
56 |
+
batch_size (int): The total batch size to average gradients
|
57 |
+
over. This should be the combined batch size of all used
|
58 |
+
devices (it is later divided by world size for distributed
|
59 |
+
training).
|
60 |
+
Example: We want to average gradients over 32 data
|
61 |
+
entries. To do this we just set `batch_size=32`.
|
62 |
+
Even if we train on 8 GPUs we still use the same
|
63 |
+
batch size (each GPU will take 4 data entries per
|
64 |
+
batch).
|
65 |
+
Default value is 32.
|
66 |
+
device_batch_size (int): The number of data entries that can
|
67 |
+
fit on the specified device at a time.
|
68 |
+
Example: We want to average gradients over 32 data
|
69 |
+
entries. To do this we just set `batch_size=32`.
|
70 |
+
However, our device can only handle a batch of
|
71 |
+
4 at a time before running out of memory. We
|
72 |
+
therefor set `device_batch_size=4`. With a
|
73 |
+
single device (no distributed training), each
|
74 |
+
batch is split into 32 / 4 parts and gradients
|
75 |
+
are averaged over all these parts.
|
76 |
+
Default value is 4.
|
77 |
+
label_size (int, optional): Number of possible class labels.
|
78 |
+
This is required for conditioning the GAN with labels.
|
79 |
+
If not specified it is assumed that no labels are used.
|
80 |
+
data_workers (int): The number of spawned processes that
|
81 |
+
handle data loading. Default value is 4.
|
82 |
+
G_loss (str, callable): The loss function to use
|
83 |
+
for the generator. If string, it can be one of the
|
84 |
+
following: 'logistic', 'logistic_ns' or 'wgan'.
|
85 |
+
If not a string, the callable has to follow
|
86 |
+
the format of functions found in `stylegan2.loss`.
|
87 |
+
Default value is 'logistic_ns' (non-saturating logistic).
|
88 |
+
D_loss (str, callable): The loss function to use
|
89 |
+
for the discriminator. If string, it can be one of the
|
90 |
+
following: 'logistic' or 'wgan'.
|
91 |
+
If not a string, same restriction follows as for `G_loss`.
|
92 |
+
Default value is 'logistic'.
|
93 |
+
G_reg (str, callable, None): The regularizer function to use
|
94 |
+
for the generator. If string, it can only be 'pathreg'
|
95 |
+
(pathlength regularization). A weight for the regularizer
|
96 |
+
can be passed after the string name like the following:
|
97 |
+
G_reg='pathreg:5'
|
98 |
+
This will assign a weight of 5 to the regularization loss.
|
99 |
+
If set to None, no geenerator regularization is performed.
|
100 |
+
Default value is 'pathreg:2'.
|
101 |
+
G_reg_interval (int): The interval at which to regularize the
|
102 |
+
generator. If set to 0, the regularization and loss gradients
|
103 |
+
are combined in a single optimization step every iteration.
|
104 |
+
If set to 1, the gradients for the regularization and loss
|
105 |
+
are used separately for two optimization steps. Any value
|
106 |
+
higher than 1 indicates that regularization should only
|
107 |
+
be performed at this interval (lazy regularization).
|
108 |
+
Default value is 4.
|
109 |
+
G_opt_class (str, class): The optimizer class for the generator.
|
110 |
+
Default value is 'Adam'.
|
111 |
+
G_opt_kwargs (dict): Keyword arguments for the generator optimizer
|
112 |
+
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
|
113 |
+
G_reg_batch_size (int): Same as `batch_size` but only for
|
114 |
+
the regularization loss of the generator. Default value
|
115 |
+
is 16.
|
116 |
+
G_reg_device_batch_size (int): Same as `device_batch_size`
|
117 |
+
but only for the regularization loss of the generator.
|
118 |
+
Default value is 2.
|
119 |
+
D_reg (str, callable, None): The regularizer function to use
|
120 |
+
for the discriminator. If string, the following values
|
121 |
+
can be used: 'r1', 'r2', 'gp'. See doc for `G_reg` for
|
122 |
+
rest of info on regularizer format.
|
123 |
+
Default value is 'r1:10'.
|
124 |
+
D_reg_interval (int): Same as `D_reg_interval` but for the
|
125 |
+
discriminator. Default value is 16.
|
126 |
+
D_opt_class (str, class): The optimizer class for the discriminator.
|
127 |
+
Default value is 'Adam'.
|
128 |
+
D_opt_kwargs (dict): Keyword arguments for the discriminator optimizer
|
129 |
+
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
|
130 |
+
style_mix_prob (float): The probability of passing 2 latents instead of 1
|
131 |
+
to the generator during training. Default value is 0.9.
|
132 |
+
G_iter (int): Number of generator iterations for every full training
|
133 |
+
iteration. Default value is 1.
|
134 |
+
D_iter (int): Number of discriminator iterations for every full training
|
135 |
+
iteration. Default value is 1.
|
136 |
+
pl_avg (float, torch.Tensor): The average pathlength starting value for
|
137 |
+
pathlength regularization of the generator. Default value is 0.
|
138 |
+
tensorboard_log_dir (str, optional): A path to a directory to log training values
|
139 |
+
in for tensorboard. Only used without distributed training or when
|
140 |
+
distributed training is enabled and the rank of this trainer is 0.
|
141 |
+
checkpoint_dir (str, optional): A path to a directory to save training
|
142 |
+
checkpoints to. If not specified, not checkpoints are automatically
|
143 |
+
saved during training.
|
144 |
+
checkpoint_interval (int): The interval at which to save training checkpoints.
|
145 |
+
Default value is 10000.
|
146 |
+
seen (int): The number of previously trained iterations. Used for logging.
|
147 |
+
Default value is 0.
|
148 |
+
half (bool): Use mixed precision training. Default value is False.
|
149 |
+
rank (int, optional): If set, use distributed training. Expects that
|
150 |
+
this object has been constructed with the same arguments except
|
151 |
+
for `rank` in different processes.
|
152 |
+
world_size (int, optional): If using distributed training, this specifies
|
153 |
+
the number of nodes in the training.
|
154 |
+
master_addr (str): The master address for distributed training.
|
155 |
+
Default value is '127.0.0.1'.
|
156 |
+
master_port (str): The master port for distributed training.
|
157 |
+
Default value is '23456'.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(self,
|
161 |
+
G,
|
162 |
+
D,
|
163 |
+
latent_size,
|
164 |
+
dataset,
|
165 |
+
device,
|
166 |
+
Gs=None,
|
167 |
+
Gs_beta=0.5 ** (32 / 10000),
|
168 |
+
Gs_device=None,
|
169 |
+
batch_size=32,
|
170 |
+
device_batch_size=4,
|
171 |
+
label_size=0,
|
172 |
+
data_workers=4,
|
173 |
+
G_loss='logistic_ns',
|
174 |
+
D_loss='logistic',
|
175 |
+
G_reg='pathreg:2',
|
176 |
+
G_reg_interval=4,
|
177 |
+
G_opt_class='Adam',
|
178 |
+
G_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
|
179 |
+
G_reg_batch_size=None,
|
180 |
+
G_reg_device_batch_size=None,
|
181 |
+
D_reg='r1:10',
|
182 |
+
D_reg_interval=16,
|
183 |
+
D_opt_class='Adam',
|
184 |
+
D_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
|
185 |
+
style_mix_prob=0.9,
|
186 |
+
G_iter=1,
|
187 |
+
D_iter=1,
|
188 |
+
pl_avg=0.,
|
189 |
+
tensorboard_log_dir=None,
|
190 |
+
checkpoint_dir=None,
|
191 |
+
checkpoint_interval=10000,
|
192 |
+
seen=0,
|
193 |
+
half=False,
|
194 |
+
rank=None,
|
195 |
+
world_size=None,
|
196 |
+
master_addr='127.0.0.1',
|
197 |
+
master_port='23456'):
|
198 |
+
assert not isinstance(G, nn.parallel.DistributedDataParallel) and \
|
199 |
+
not isinstance(D, nn.parallel.DistributedDataParallel), \
|
200 |
+
'Encountered a model wrapped in `DistributedDataParallel`. ' + \
|
201 |
+
'Distributed parallelism is handled by this class and can ' + \
|
202 |
+
'not be initialized before.'
|
203 |
+
# We store the training settings in a dict that can be saved as a json file.
|
204 |
+
kwargs = locals()
|
205 |
+
# First we remove the arguments that can not be turned into json.
|
206 |
+
kwargs.pop('self')
|
207 |
+
kwargs.pop('G')
|
208 |
+
kwargs.pop('D')
|
209 |
+
kwargs.pop('Gs')
|
210 |
+
kwargs.pop('dataset')
|
211 |
+
# Some arguments may have to be turned into strings to be compatible with json.
|
212 |
+
kwargs.update(pl_avg=float(pl_avg))
|
213 |
+
if isinstance(device, torch.device):
|
214 |
+
kwargs.update(device=str(device))
|
215 |
+
if isinstance(Gs_device, torch.device):
|
216 |
+
kwargs.update(device=str(Gs_device))
|
217 |
+
self.kwargs = kwargs
|
218 |
+
|
219 |
+
if device or device == 0:
|
220 |
+
if isinstance(device, (tuple, list)):
|
221 |
+
self.device = torch.device(device[0])
|
222 |
+
else:
|
223 |
+
self.device = torch.device(device)
|
224 |
+
else:
|
225 |
+
self.device = torch.device('cpu')
|
226 |
+
if self.device.index is not None:
|
227 |
+
torch.cuda.set_device(self.device.index)
|
228 |
+
else:
|
229 |
+
assert not half, 'Mixed precision training only available ' + \
|
230 |
+
'for CUDA devices.'
|
231 |
+
# Set up the models
|
232 |
+
self.G = G.train().to(self.device)
|
233 |
+
self.D = D.train().to(self.device)
|
234 |
+
if isinstance(device, (tuple, list)) and len(device) > 1:
|
235 |
+
assert all(isinstance(dev, int) for dev in device), \
|
236 |
+
'Multiple devices have to be specified as a list ' + \
|
237 |
+
'or tuple of integers corresponding to device indices.'
|
238 |
+
# TODO: Look into bug with torch.autograd.grad and nn.DataParallel
|
239 |
+
# In the meanwhile just prohibit its use together.
|
240 |
+
assert G_reg is None and D_reg is None, 'Regularization ' + \
|
241 |
+
'currently not supported for multi-gpu training in single process. ' + \
|
242 |
+
'Please use distributed training with one device per process instead.'
|
243 |
+
|
244 |
+
device_batch_size *= len(device)
|
245 |
+
def to_data_parallel(model):
|
246 |
+
if not isinstance(model, nn.DataParallel):
|
247 |
+
return nn.DataParallel(model, device_ids=device)
|
248 |
+
return model
|
249 |
+
self.G = to_data_parallel(self.G)
|
250 |
+
self.D = to_data_parallel(self.D)
|
251 |
+
|
252 |
+
# Default generator reg batch size is the global batch size
|
253 |
+
# unless it has been specified otherwise.
|
254 |
+
G_reg_batch_size = G_reg_batch_size or batch_size
|
255 |
+
G_reg_device_batch_size = G_reg_device_batch_size or device_batch_size
|
256 |
+
|
257 |
+
# Set up distributed training
|
258 |
+
rank = os.environ.get('RANK', rank)
|
259 |
+
if rank is not None:
|
260 |
+
rank = int(rank)
|
261 |
+
addr = os.environ.get('MASTER_ADDR', master_addr)
|
262 |
+
port = os.environ.get('MASTER_PORT', master_port)
|
263 |
+
world_size = os.environ.get('WORLD_SIZE', world_size)
|
264 |
+
assert world_size is not None, 'Distributed training ' + \
|
265 |
+
'requires specifying world size.'
|
266 |
+
world_size = int(world_size)
|
267 |
+
assert self.device.index is not None, \
|
268 |
+
'Distributed training is only supported for CUDA.'
|
269 |
+
assert batch_size % world_size == 0, 'Batch size has to be ' + \
|
270 |
+
'evenly divisible by world size.'
|
271 |
+
assert G_reg_batch_size % world_size == 0, 'G reg batch size has to be ' + \
|
272 |
+
'evenly divisible by world size.'
|
273 |
+
batch_size = batch_size // world_size
|
274 |
+
G_reg_batch_size = G_reg_batch_size // world_size
|
275 |
+
init_method = 'tcp://{}:{}'.format(addr, port)
|
276 |
+
torch.distributed.init_process_group(
|
277 |
+
backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
|
278 |
+
else:
|
279 |
+
world_size = 1
|
280 |
+
self.rank = rank
|
281 |
+
self.world_size = world_size
|
282 |
+
|
283 |
+
# Set up variable to keep track of moving average of path lengths
|
284 |
+
self.pl_avg = torch.tensor(
|
285 |
+
pl_avg, dtype=torch.float16 if half else torch.float32, device=self.device)
|
286 |
+
|
287 |
+
# Broadcast parameters from rank 0 if running distributed
|
288 |
+
self._sync_distributed(G=self.G, D=self.D, broadcast_weights=True)
|
289 |
+
|
290 |
+
# Set up moving average of generator
|
291 |
+
# Only for non-distributed training or
|
292 |
+
# if rank is 0
|
293 |
+
if not self.rank:
|
294 |
+
# Values for `rank`: None -> not distributed, 0 -> distributed and 'main' node
|
295 |
+
self.Gs = Gs
|
296 |
+
if not isinstance(Gs, utils.MovingAverageModule):
|
297 |
+
self.Gs = utils.MovingAverageModule(
|
298 |
+
from_module=self.G,
|
299 |
+
to_module=Gs,
|
300 |
+
param_beta=Gs_beta,
|
301 |
+
device=self.device if Gs_device is None else Gs_device
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
self.Gs = None
|
305 |
+
|
306 |
+
# Set up loss and regularization functions
|
307 |
+
self.G_loss = get_loss_fn('G', G_loss)
|
308 |
+
self.D_loss = get_loss_fn('D', D_loss)
|
309 |
+
self.G_reg = get_reg_fn('G', G_reg, pl_avg=self.pl_avg)
|
310 |
+
self.D_reg = get_reg_fn('D', D_reg)
|
311 |
+
self.G_reg_interval = G_reg_interval
|
312 |
+
self.D_reg_interval = D_reg_interval
|
313 |
+
self.G_iter = G_iter
|
314 |
+
self.D_iter = D_iter
|
315 |
+
|
316 |
+
# Set up optimizers (adjust hyperparameters if lazy regularization is active)
|
317 |
+
self.G_opt = build_opt(self.G, G_opt_class, G_opt_kwargs, self.G_reg, self.G_reg_interval)
|
318 |
+
self.D_opt = build_opt(self.D, D_opt_class, D_opt_kwargs, self.D_reg, self.D_reg_interval)
|
319 |
+
|
320 |
+
# Set up mixed precision training
|
321 |
+
if half:
|
322 |
+
assert 'apex' in sys.modules, 'Can not run mixed precision ' + \
|
323 |
+
'training (`half=True`) without the apex module.'
|
324 |
+
(self.G, self.D), (self.G_opt, self.D_opt) = amp.initialize(
|
325 |
+
[self.G, self.D], [self.G_opt, self.D_opt], opt_level='O1')
|
326 |
+
self.half = half
|
327 |
+
|
328 |
+
# Data
|
329 |
+
sampler = None
|
330 |
+
if self.rank is not None:
|
331 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
|
332 |
+
self.dataloader = torch.utils.data.DataLoader(
|
333 |
+
dataset,
|
334 |
+
batch_size=device_batch_size,
|
335 |
+
num_workers=data_workers,
|
336 |
+
shuffle=sampler is None,
|
337 |
+
pin_memory=self.device.index is not None,
|
338 |
+
drop_last=True,
|
339 |
+
sampler=sampler
|
340 |
+
)
|
341 |
+
self.dataloader_iter = None
|
342 |
+
self.prior_generator = utils.PriorGenerator(
|
343 |
+
latent_size=latent_size,
|
344 |
+
label_size=label_size,
|
345 |
+
batch_size=device_batch_size,
|
346 |
+
device=self.device
|
347 |
+
)
|
348 |
+
assert batch_size % device_batch_size == 0, \
|
349 |
+
'Batch size has to be evenly divisible by the product of ' + \
|
350 |
+
'device batch size and world size.'
|
351 |
+
self.subdivisions = batch_size // device_batch_size
|
352 |
+
assert G_reg_batch_size % G_reg_device_batch_size == 0, \
|
353 |
+
'G reg batch size has to be evenly divisible by the product of ' + \
|
354 |
+
'G reg device batch size and world size.'
|
355 |
+
self.G_reg_subdivisions = G_reg_batch_size // G_reg_device_batch_size
|
356 |
+
self.G_reg_device_batch_size = G_reg_device_batch_size
|
357 |
+
|
358 |
+
self.tb_writer = None
|
359 |
+
if tensorboard_log_dir and not self.rank:
|
360 |
+
self.tb_writer = torch.utils.tensorboard.SummaryWriter(tensorboard_log_dir)
|
361 |
+
|
362 |
+
self.label_size = label_size
|
363 |
+
self.style_mix_prob = style_mix_prob
|
364 |
+
self.checkpoint_dir = checkpoint_dir
|
365 |
+
self.checkpoint_interval = checkpoint_interval
|
366 |
+
self.seen = seen
|
367 |
+
self.metrics = {}
|
368 |
+
self.callbacks = []
|
369 |
+
|
370 |
+
def _get_batch(self):
|
371 |
+
"""
|
372 |
+
Fetch a batch and its labels. If no labels are
|
373 |
+
available the returned labels will be `None`.
|
374 |
+
Returns:
|
375 |
+
data
|
376 |
+
labels
|
377 |
+
"""
|
378 |
+
if self.dataloader_iter is None:
|
379 |
+
self.dataloader_iter = iter(self.dataloader)
|
380 |
+
try:
|
381 |
+
batch = next(self.dataloader_iter)
|
382 |
+
except StopIteration:
|
383 |
+
self.dataloader_iter = None
|
384 |
+
return self._get_batch()
|
385 |
+
if isinstance(batch, (tuple, list)):
|
386 |
+
if len(batch) > 1:
|
387 |
+
data, label = batch[:2]
|
388 |
+
else:
|
389 |
+
data, label = batch[0], None
|
390 |
+
else:
|
391 |
+
data, label = batch, None
|
392 |
+
if not self.label_size:
|
393 |
+
label = None
|
394 |
+
if torch.is_tensor(data):
|
395 |
+
data = data.to(self.device)
|
396 |
+
if torch.is_tensor(label):
|
397 |
+
label = label.to(self.device)
|
398 |
+
return data, label
|
399 |
+
|
400 |
+
def _sync_distributed(self, G=None, D=None, broadcast_weights=False):
|
401 |
+
"""
|
402 |
+
Sync the gradients (and alternatively the weights) of
|
403 |
+
the specified networks over the distributed training
|
404 |
+
nodes. Varying buffers are broadcasted from rank 0.
|
405 |
+
If no distributed training is not enabled, no action
|
406 |
+
is taken and this is a no-op function.
|
407 |
+
Arguments:
|
408 |
+
G (Generator, optional)
|
409 |
+
D (Discriminator, optional)
|
410 |
+
broadcast_weights (bool): Broadcast the weights from
|
411 |
+
node of rank 0 to all other ranks. Default
|
412 |
+
value is False.
|
413 |
+
"""
|
414 |
+
if self.rank is None:
|
415 |
+
return
|
416 |
+
for net in [G, D]:
|
417 |
+
if net is None:
|
418 |
+
continue
|
419 |
+
for p in net.parameters():
|
420 |
+
if p.grad is not None:
|
421 |
+
torch.distributed.all_reduce(p.grad, async_op=True)
|
422 |
+
if broadcast_weights:
|
423 |
+
torch.distributed.broadcast(p.data, src=0, async_op=True)
|
424 |
+
if G is not None:
|
425 |
+
if G.dlatent_avg is not None:
|
426 |
+
torch.distributed.broadcast(G.dlatent_avg, src=0, async_op=True)
|
427 |
+
if self.pl_avg is not None:
|
428 |
+
torch.distributed.broadcast(self.pl_avg, src=0, async_op=True)
|
429 |
+
if G is not None or D is not None:
|
430 |
+
torch.distributed.barrier(async_op=False)
|
431 |
+
|
432 |
+
def _backward(self, loss, opt, mul=1, subdivisions=None):
|
433 |
+
"""
|
434 |
+
Reduce loss by world size and subdivisions before
|
435 |
+
calling backward for the loss. Loss scaling is
|
436 |
+
performed when mixed precision training is
|
437 |
+
enabled.
|
438 |
+
Arguments:
|
439 |
+
loss (torch.Tensor)
|
440 |
+
opt (torch.optim.Optimizer)
|
441 |
+
mul (float): Loss weight. Default value is 1.
|
442 |
+
subdivisions (int, optional): The number of
|
443 |
+
subdivisions to divide by. If this is
|
444 |
+
not specified, the subdvisions from
|
445 |
+
the specified batch and device size
|
446 |
+
at construction is used.
|
447 |
+
Returns:
|
448 |
+
loss (torch.Tensor): The loss scaled by mul
|
449 |
+
and subdivisions but not by world size.
|
450 |
+
"""
|
451 |
+
if loss is None:
|
452 |
+
return 0
|
453 |
+
mul /= subdivisions or self.subdivisions
|
454 |
+
mul /= self.world_size or 1
|
455 |
+
if mul != 1:
|
456 |
+
loss *= mul
|
457 |
+
if self.half:
|
458 |
+
with amp.scale_loss(loss, opt) as scaled_loss:
|
459 |
+
scaled_loss.backward()
|
460 |
+
else:
|
461 |
+
loss.backward()
|
462 |
+
#get the scalar only
|
463 |
+
return loss.item() * (self.world_size or 1)
|
464 |
+
|
465 |
+
def train(self, iterations, callbacks=None, verbose=True):
|
466 |
+
"""
|
467 |
+
Train the models for a specific number of iterations.
|
468 |
+
Arguments:
|
469 |
+
iterations (int): Number of iterations to train for.
|
470 |
+
callbacks (callable, list, optional): One
|
471 |
+
or more callbacks to call at the end of each
|
472 |
+
iteration. The function is given the total
|
473 |
+
number of batches that have been processed since
|
474 |
+
this trainer object was initialized (not reset
|
475 |
+
when loading a saved checkpoint).
|
476 |
+
Default value is None (unused).
|
477 |
+
verbose (bool): Write progress to stdout.
|
478 |
+
Default value is True.
|
479 |
+
"""
|
480 |
+
evaluated_metrics = {}
|
481 |
+
if self.rank:
|
482 |
+
verbose=False
|
483 |
+
if verbose:
|
484 |
+
progress = utils.ProgressWriter(iterations)
|
485 |
+
value_tracker = utils.ValueTracker()
|
486 |
+
for _ in range(iterations):
|
487 |
+
# Figure out if G and/or D be
|
488 |
+
# regularized this iteration
|
489 |
+
G_reg = self.G_reg is not None
|
490 |
+
if self.G_reg_interval and G_reg:
|
491 |
+
G_reg = self.seen % self.G_reg_interval == 0
|
492 |
+
D_reg = self.D_reg is not None
|
493 |
+
if self.D_reg_interval and D_reg:
|
494 |
+
D_reg = self.seen % self.D_reg_interval == 0
|
495 |
+
|
496 |
+
# -----| Train G |----- #
|
497 |
+
|
498 |
+
# Disable gradients for D while training G
|
499 |
+
self.D.requires_grad_(False)
|
500 |
+
|
501 |
+
for _ in range(self.G_iter):
|
502 |
+
self.G_opt.zero_grad()
|
503 |
+
|
504 |
+
G_loss = 0
|
505 |
+
for i in range(self.subdivisions):
|
506 |
+
latents, latent_labels = self.prior_generator(
|
507 |
+
multi_latent_prob=self.style_mix_prob)
|
508 |
+
loss, _ = self.G_loss(
|
509 |
+
G=self.G,
|
510 |
+
D=self.D,
|
511 |
+
latents=latents,
|
512 |
+
latent_labels=latent_labels
|
513 |
+
)
|
514 |
+
G_loss += self._backward(loss, self.G_opt)
|
515 |
+
|
516 |
+
if G_reg:
|
517 |
+
if self.G_reg_interval:
|
518 |
+
# For lazy regularization, even if the interval
|
519 |
+
# is set to 1, the optimization step is taken
|
520 |
+
# before the gradients of the regularization is gathered.
|
521 |
+
self._sync_distributed(G=self.G)
|
522 |
+
self.G_opt.step()
|
523 |
+
self.G_opt.zero_grad()
|
524 |
+
G_reg_loss = 0
|
525 |
+
# Pathreg is expensive to compute which
|
526 |
+
# is why G regularization has its own settings
|
527 |
+
# for subdivisions and batch size.
|
528 |
+
for i in range(self.G_reg_subdivisions):
|
529 |
+
latents, latent_labels = self.prior_generator(
|
530 |
+
batch_size=self.G_reg_device_batch_size,
|
531 |
+
multi_latent_prob=self.style_mix_prob
|
532 |
+
)
|
533 |
+
_, reg_loss = self.G_reg(
|
534 |
+
G=self.G,
|
535 |
+
latents=latents,
|
536 |
+
latent_labels=latent_labels
|
537 |
+
)
|
538 |
+
G_reg_loss += self._backward(
|
539 |
+
reg_loss,
|
540 |
+
self.G_opt, mul=self.G_reg_interval or 1,
|
541 |
+
subdivisions=self.G_reg_subdivisions
|
542 |
+
)
|
543 |
+
self._sync_distributed(G=self.G)
|
544 |
+
self.G_opt.step()
|
545 |
+
# Update moving average of weights after
|
546 |
+
# each G training subiteration
|
547 |
+
if self.Gs is not None:
|
548 |
+
self.Gs.update()
|
549 |
+
|
550 |
+
# Re-enable gradients for D
|
551 |
+
self.D.requires_grad_(True)
|
552 |
+
|
553 |
+
# -----| Train D |----- #
|
554 |
+
|
555 |
+
# Disable gradients for G while training D
|
556 |
+
self.G.requires_grad_(False)
|
557 |
+
|
558 |
+
for _ in range(self.D_iter):
|
559 |
+
self.D_opt.zero_grad()
|
560 |
+
|
561 |
+
D_loss = 0
|
562 |
+
for i in range(self.subdivisions):
|
563 |
+
latents, latent_labels = self.prior_generator(
|
564 |
+
multi_latent_prob=self.style_mix_prob)
|
565 |
+
reals, real_labels = self._get_batch()
|
566 |
+
loss, _ = self.D_loss(
|
567 |
+
G=self.G,
|
568 |
+
D=self.D,
|
569 |
+
latents=latents,
|
570 |
+
latent_labels=latent_labels,
|
571 |
+
reals=reals,
|
572 |
+
real_labels=real_labels
|
573 |
+
)
|
574 |
+
D_loss += self._backward(loss, self.D_opt)
|
575 |
+
|
576 |
+
if D_reg:
|
577 |
+
if self.D_reg_interval:
|
578 |
+
# For lazy regularization, even if the interval
|
579 |
+
# is set to 1, the optimization step is taken
|
580 |
+
# before the gradients of the regularization is gathered.
|
581 |
+
self._sync_distributed(D=self.D)
|
582 |
+
self.D_opt.step()
|
583 |
+
self.D_opt.zero_grad()
|
584 |
+
D_reg_loss = 0
|
585 |
+
for i in range(self.subdivisions):
|
586 |
+
latents, latent_labels = self.prior_generator(
|
587 |
+
multi_latent_prob=self.style_mix_prob)
|
588 |
+
reals, real_labels = self._get_batch()
|
589 |
+
_, reg_loss = self.D_reg(
|
590 |
+
G=self.G,
|
591 |
+
D=self.D,
|
592 |
+
latents=latents,
|
593 |
+
latent_labels=latent_labels,
|
594 |
+
reals=reals,
|
595 |
+
real_labels=real_labels
|
596 |
+
)
|
597 |
+
D_reg_loss += self._backward(
|
598 |
+
reg_loss, self.D_opt, mul=self.D_reg_interval or 1)
|
599 |
+
self._sync_distributed(D=self.D)
|
600 |
+
self.D_opt.step()
|
601 |
+
|
602 |
+
# Re-enable grads for G
|
603 |
+
self.G.requires_grad_(True)
|
604 |
+
|
605 |
+
if self.tb_writer is not None or verbose:
|
606 |
+
# In case verbose is true and tensorboard logging enabled
|
607 |
+
# we calculate grad norm here to only do it once as well
|
608 |
+
# as making sure we do it before any metrics that may
|
609 |
+
# possibly zero the grads.
|
610 |
+
G_grad_norm = utils.get_grad_norm_from_optimizer(self.G_opt)
|
611 |
+
D_grad_norm = utils.get_grad_norm_from_optimizer(self.D_opt)
|
612 |
+
|
613 |
+
for name, metric in self.metrics.items():
|
614 |
+
if not metric['interval'] or self.seen % metric['interval'] == 0:
|
615 |
+
evaluated_metrics[name] = metric['eval_fn']()
|
616 |
+
|
617 |
+
# Printing and logging
|
618 |
+
|
619 |
+
# Tensorboard logging
|
620 |
+
if self.tb_writer is not None:
|
621 |
+
self.tb_writer.add_scalar('Loss/G_loss', G_loss, self.seen)
|
622 |
+
if G_reg:
|
623 |
+
self.tb_writer.add_scalar('Loss/G_reg', G_reg_loss, self.seen)
|
624 |
+
self.tb_writer.add_scalar('Grad_norm/G_reg', G_grad_norm, self.seen)
|
625 |
+
self.tb_writer.add_scalar('Params/pl_avg', self.pl_avg, self.seen)
|
626 |
+
else:
|
627 |
+
self.tb_writer.add_scalar('Grad_norm/G_loss', G_grad_norm, self.seen)
|
628 |
+
self.tb_writer.add_scalar('Loss/D_loss', D_loss, self.seen)
|
629 |
+
if D_reg:
|
630 |
+
self.tb_writer.add_scalar('Loss/D_reg', D_reg_loss, self.seen)
|
631 |
+
self.tb_writer.add_scalar('Grad_norm/D_reg', D_grad_norm, self.seen)
|
632 |
+
else:
|
633 |
+
self.tb_writer.add_scalar('Grad_norm/D_loss', D_grad_norm, self.seen)
|
634 |
+
for name, value in evaluated_metrics.items():
|
635 |
+
self.tb_writer.add_scalar('Metrics/{}'.format(name), value, self.seen)
|
636 |
+
|
637 |
+
# Printing
|
638 |
+
if verbose:
|
639 |
+
value_tracker.add('seen', self.seen + 1, beta=0)
|
640 |
+
value_tracker.add('G_lr', self.G_opt.param_groups[0]['lr'], beta=0)
|
641 |
+
value_tracker.add('G_loss', G_loss)
|
642 |
+
if G_reg:
|
643 |
+
value_tracker.add('G_reg', G_reg_loss)
|
644 |
+
value_tracker.add('G_reg_grad_norm', G_grad_norm)
|
645 |
+
value_tracker.add('pl_avg', self.pl_avg, beta=0)
|
646 |
+
else:
|
647 |
+
value_tracker.add('G_loss_grad_norm', G_grad_norm)
|
648 |
+
value_tracker.add('D_lr', self.D_opt.param_groups[0]['lr'], beta=0)
|
649 |
+
value_tracker.add('D_loss', D_loss)
|
650 |
+
if D_reg:
|
651 |
+
value_tracker.add('D_reg', D_reg_loss)
|
652 |
+
value_tracker.add('D_reg_grad_norm', D_grad_norm)
|
653 |
+
else:
|
654 |
+
value_tracker.add('D_loss_grad_norm', D_grad_norm)
|
655 |
+
for name, value in evaluated_metrics.items():
|
656 |
+
value_tracker.add(name, value, beta=0)
|
657 |
+
progress.write(str(value_tracker))
|
658 |
+
|
659 |
+
# Callback
|
660 |
+
for callback in utils.to_list(callbacks) + self.callbacks:
|
661 |
+
callback(self.seen)
|
662 |
+
|
663 |
+
self.seen += 1
|
664 |
+
|
665 |
+
# clear cache
|
666 |
+
torch.cuda.empty_cache()
|
667 |
+
# Handle checkpointing
|
668 |
+
if not self.rank and self.checkpoint_dir and self.checkpoint_interval:
|
669 |
+
if self.seen % self.checkpoint_interval == 0:
|
670 |
+
checkpoint_path = os.path.join(
|
671 |
+
self.checkpoint_dir,
|
672 |
+
'{}_{}'.format(self.seen, time.strftime('%Y-%m-%d_%H-%M-%S'))
|
673 |
+
)
|
674 |
+
self.save_checkpoint(checkpoint_path)
|
675 |
+
|
676 |
+
if verbose:
|
677 |
+
progress.close()
|
678 |
+
|
679 |
+
def register_metric(self, name, eval_fn, interval):
|
680 |
+
"""
|
681 |
+
Add a metric. This will be evaluated every `interval`
|
682 |
+
training iteration. Used by tensorboard and progress
|
683 |
+
updates written to stdout while training.
|
684 |
+
Arguments:
|
685 |
+
name (str): A name for the metric. If a metric with
|
686 |
+
this name already exists it will be overwritten.
|
687 |
+
eval_fn (callable): A function that evaluates the metric
|
688 |
+
and returns a python number.
|
689 |
+
interval (int): The interval to evaluate at.
|
690 |
+
"""
|
691 |
+
self.metrics[name] = {'eval_fn': eval_fn, 'interval': interval}
|
692 |
+
|
693 |
+
def remove_metric(self, name):
|
694 |
+
"""
|
695 |
+
Remove a metric that was previously registered.
|
696 |
+
Arguments:
|
697 |
+
name (str): Name of the metric.
|
698 |
+
"""
|
699 |
+
if name in self.metrics:
|
700 |
+
del self.metrics[name]
|
701 |
+
else:
|
702 |
+
warnings.warn(
|
703 |
+
'Attempting to remove metric {} '.format(name) + \
|
704 |
+
'which does not exist.'
|
705 |
+
)
|
706 |
+
|
707 |
+
def generate_images(self,
|
708 |
+
num_images,
|
709 |
+
seed=None,
|
710 |
+
truncation_psi=None,
|
711 |
+
truncation_cutoff=None,
|
712 |
+
label=None,
|
713 |
+
pixel_min=-1,
|
714 |
+
pixel_max=1):
|
715 |
+
"""
|
716 |
+
Generate some images with the generator and transform them into PIL
|
717 |
+
images and return them as a list.
|
718 |
+
Arguments:
|
719 |
+
num_images (int): Number of images to generate.
|
720 |
+
seed (int, optional): The seed for the random generation
|
721 |
+
of input latent values.
|
722 |
+
truncation_psi (float): See stylegan2.model.Generator.set_truncation()
|
723 |
+
Default value is None.
|
724 |
+
truncation_cutoff (int): See stylegan2.model.Generator.set_truncation()
|
725 |
+
label (int, list, optional): Label to condition all generated images with
|
726 |
+
or multiple labels, one for each generated image.
|
727 |
+
pixel_min (float): The min value in the pixel range of the generator.
|
728 |
+
Default value is -1.
|
729 |
+
pixel_min (float): The max value in the pixel range of the generator.
|
730 |
+
Default value is 1.
|
731 |
+
Returns:
|
732 |
+
images (list): List of PIL images.
|
733 |
+
"""
|
734 |
+
if seed is None:
|
735 |
+
seed = int(10000 * time.time())
|
736 |
+
latents, latent_labels = self.prior_generator(num_images, seed=seed)
|
737 |
+
if label:
|
738 |
+
assert latent_labels is not None, 'Can not specify label when no labels ' + \
|
739 |
+
'are used by this model.'
|
740 |
+
label = utils.to_list(label)
|
741 |
+
assert all(isinstance(l, int) for l in label), '`label` can only consist of ' + \
|
742 |
+
'one or more python integers.'
|
743 |
+
assert len(label) == 1 or len(label) == num_images, '`label` can either ' + \
|
744 |
+
'specify one label to use for all images or a list of labels of the ' + \
|
745 |
+
'same length as number of images. Received {} labels '.format(len(label)) + \
|
746 |
+
'but {} images are to be generated.'.format(num_images)
|
747 |
+
if len(label) == 1:
|
748 |
+
latent_labels.fill_(label[0])
|
749 |
+
else:
|
750 |
+
latent_labels = torch.tensor(label).to(latent_labels)
|
751 |
+
self.Gs.set_truncation(
|
752 |
+
truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
753 |
+
with torch.no_grad():
|
754 |
+
generated = self.Gs(latents=latents, labels=latent_labels)
|
755 |
+
assert generated.dim() - 2 == 2, 'Can only generate images when using a ' + \
|
756 |
+
'network built for 2-dimensional data.'
|
757 |
+
assert generated.dim() == 4, 'Only generators that produce 2d data ' + \
|
758 |
+
'can be used to generate images.'
|
759 |
+
return utils.tensor_to_PIL(generated, pixel_min=pixel_min, pixel_max=pixel_max)
|
760 |
+
|
761 |
+
def log_images_tensorboard(self, images, name, resize=256):
|
762 |
+
"""
|
763 |
+
Log a list of images to tensorboard by first turning
|
764 |
+
them into a grid. Can not be performed if rank > 0
|
765 |
+
or tensorboard_log_dir was not given at construction.
|
766 |
+
Arguments:
|
767 |
+
images (list): List of PIL images.
|
768 |
+
name (str): The name to log images for.
|
769 |
+
resize (int, tuple): The height and width to use for
|
770 |
+
each image in the grid. Default value is 256.
|
771 |
+
"""
|
772 |
+
assert self.tb_writer is not None, \
|
773 |
+
'No tensorboard log dir was specified ' + \
|
774 |
+
'when constructing this object.'
|
775 |
+
image = utils.stack_images_PIL(images, individual_img_size=resize)
|
776 |
+
image = torchvision.transforms.ToTensor()(image)
|
777 |
+
self.tb_writer.add_image(name, image, self.seen)
|
778 |
+
|
779 |
+
def add_tensorboard_image_logging(self,
|
780 |
+
name,
|
781 |
+
interval,
|
782 |
+
num_images,
|
783 |
+
resize=256,
|
784 |
+
seed=None,
|
785 |
+
truncation_psi=None,
|
786 |
+
truncation_cutoff=None,
|
787 |
+
label=None,
|
788 |
+
pixel_min=-1,
|
789 |
+
pixel_max=1):
|
790 |
+
"""
|
791 |
+
Set up tensorboard logging of generated images to be performed
|
792 |
+
at a certain training interval. If distributed training is set up
|
793 |
+
and this object does not have the rank 0, no logging will be performed
|
794 |
+
by this object.
|
795 |
+
All arguments except the ones mentioned below have their description
|
796 |
+
in the docstring of `generate_images()` and `log_images_tensorboard()`.
|
797 |
+
Arguments:
|
798 |
+
interval (int): The interval at which to log generated images.
|
799 |
+
"""
|
800 |
+
if self.rank:
|
801 |
+
return
|
802 |
+
def callback(seen):
|
803 |
+
if seen % interval == 0:
|
804 |
+
images = self.generate_images(
|
805 |
+
num_images=num_images,
|
806 |
+
seed=seed,
|
807 |
+
truncation_psi=truncation_psi,
|
808 |
+
truncation_cutoff=truncation_cutoff,
|
809 |
+
label=label,
|
810 |
+
pixel_min=pixel_min,
|
811 |
+
pixel_max=pixel_max
|
812 |
+
)
|
813 |
+
self.log_images_tensorboard(
|
814 |
+
images=images,
|
815 |
+
name=name,
|
816 |
+
resize=resize
|
817 |
+
)
|
818 |
+
self.callbacks.append(callback)
|
819 |
+
|
820 |
+
def save_checkpoint(self, dir_path):
|
821 |
+
"""
|
822 |
+
Save the current state of this trainer as a checkpoint.
|
823 |
+
NOTE: The dataset can not be serialized and saved so this
|
824 |
+
has to be reconstructed and given when loading this checkpoint.
|
825 |
+
Arguments:
|
826 |
+
dir_path (str): The checkpoint path.
|
827 |
+
"""
|
828 |
+
if not os.path.exists(dir_path):
|
829 |
+
os.makedirs(dir_path)
|
830 |
+
else:
|
831 |
+
assert os.path.isdir(dir_path), '`dir_path` points to a file.'
|
832 |
+
kwargs = self.kwargs.copy()
|
833 |
+
# Update arguments that may have changed since construction
|
834 |
+
kwargs.update(
|
835 |
+
seen=self.seen,
|
836 |
+
pl_avg=float(self.pl_avg)
|
837 |
+
)
|
838 |
+
with open(os.path.join(dir_path, 'kwargs.json'), 'w') as fp:
|
839 |
+
json.dump(kwargs, fp)
|
840 |
+
torch.save(self.G_opt.state_dict(), os.path.join(dir_path, 'G_opt.pth'))
|
841 |
+
torch.save(self.D_opt.state_dict(), os.path.join(dir_path, 'D_opt.pth'))
|
842 |
+
models.save(self.G, os.path.join(dir_path, 'G.pth'))
|
843 |
+
models.save(self.D, os.path.join(dir_path, 'D.pth'))
|
844 |
+
if self.Gs is not None:
|
845 |
+
models.save(self.Gs, os.path.join(dir_path, 'Gs.pth'))
|
846 |
+
|
847 |
+
@classmethod
|
848 |
+
def load_checkpoint(cls, checkpoint_path, dataset, **kwargs):
|
849 |
+
"""
|
850 |
+
Load a checkpoint into a new Trainer object and return that
|
851 |
+
object. If the path specified points at a folder containing
|
852 |
+
multiple checkpoints, the latest one will be used.
|
853 |
+
The dataset can not be serialized and saved so it is required
|
854 |
+
to be explicitly given when loading a checkpoint.
|
855 |
+
Arguments:
|
856 |
+
checkpoint_path (str): Path to a checkpoint or to a folder
|
857 |
+
containing one or more checkpoints.
|
858 |
+
dataset (indexable): The dataset to use.
|
859 |
+
**kwargs (keyword arguments): Any other arguments to override
|
860 |
+
the ones saved in the checkpoint. Useful for when training
|
861 |
+
is continued on a different device or when distributed training
|
862 |
+
is changed.
|
863 |
+
"""
|
864 |
+
checkpoint_path = _find_checkpoint(checkpoint_path)
|
865 |
+
_is_checkpoint(checkpoint_path, enforce=True)
|
866 |
+
with open(os.path.join(checkpoint_path, 'kwargs.json'), 'r') as fp:
|
867 |
+
loaded_kwargs = json.load(fp)
|
868 |
+
loaded_kwargs.update(**kwargs)
|
869 |
+
device = torch.device('cpu')
|
870 |
+
if isinstance(loaded_kwargs['device'], (list, tuple)):
|
871 |
+
device = torch.device(loaded_kwargs['device'][0])
|
872 |
+
for name in ['G', 'D']:
|
873 |
+
fpath = os.path.join(checkpoint_path, name + '.pth')
|
874 |
+
loaded_kwargs[name] = models.load(fpath, map_location=device)
|
875 |
+
if os.path.exists(os.path.join(checkpoint_path, 'Gs.pth')):
|
876 |
+
loaded_kwargs['Gs'] = models.load(
|
877 |
+
os.path.join(checkpoint_path, 'Gs.pth'),
|
878 |
+
map_location=device if loaded_kwargs['Gs_device'] is None \
|
879 |
+
else torch.device(loaded_kwargs['Gs_device'])
|
880 |
+
)
|
881 |
+
obj = cls(dataset=dataset, **loaded_kwargs)
|
882 |
+
for name in ['G_opt', 'D_opt']:
|
883 |
+
fpath = os.path.join(checkpoint_path, name + '.pth')
|
884 |
+
state_dict = torch.load(fpath, map_location=device)
|
885 |
+
getattr(obj, name).load_state_dict(state_dict)
|
886 |
+
return obj
|
887 |
+
|
888 |
+
|
889 |
+
#----------------------------------------------------------------------------
|
890 |
+
# Checkpoint helper functions
|
891 |
+
|
892 |
+
|
893 |
+
def _is_checkpoint(dir_path, enforce=False):
|
894 |
+
if not dir_path:
|
895 |
+
if enforce:
|
896 |
+
raise ValueError('Not a checkpoint.')
|
897 |
+
return False
|
898 |
+
if not os.path.exists(dir_path):
|
899 |
+
if enforce:
|
900 |
+
raise FileNotFoundError('{} could not be found.'.format(dir_path))
|
901 |
+
return False
|
902 |
+
if not os.path.isdir(dir_path):
|
903 |
+
if enforce:
|
904 |
+
raise NotADirectoryError('{} is not a directory.'.format(dir_path))
|
905 |
+
return False
|
906 |
+
fnames = os.listdir(dir_path)
|
907 |
+
for fname in ['G.pth', 'D.pth', 'G_opt.pth', 'D_opt.pth', 'kwargs.json']:
|
908 |
+
if fname not in fnames:
|
909 |
+
if enforce:
|
910 |
+
raise FileNotFoundError(
|
911 |
+
'Could not find {} in {}.'.format(fname, dir_path))
|
912 |
+
return False
|
913 |
+
return True
|
914 |
+
|
915 |
+
|
916 |
+
def _find_checkpoint(dir_path):
|
917 |
+
if not dir_path:
|
918 |
+
return None
|
919 |
+
if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
|
920 |
+
return None
|
921 |
+
if _is_checkpoint(dir_path):
|
922 |
+
return dir_path
|
923 |
+
checkpoint_names = []
|
924 |
+
for name in os.listdir(dir_path):
|
925 |
+
if _is_checkpoint(os.path.join(dir_path, name)):
|
926 |
+
checkpoint_names.append(name)
|
927 |
+
if not checkpoint_names:
|
928 |
+
return None
|
929 |
+
def get_iteration(name):
|
930 |
+
return int(name.split('_')[0])
|
931 |
+
def get_timestamp(name):
|
932 |
+
return '_'.join(name.split('_')[1:])
|
933 |
+
# Python sort is stable, meaning that this sort operation
|
934 |
+
# will guarantee that the order of values after the first
|
935 |
+
# sort will stay for a set of values that have the same
|
936 |
+
# key value.
|
937 |
+
checkpoint_names = sorted(
|
938 |
+
sorted(checkpoint_names, key=get_iteration), key=get_timestamp)
|
939 |
+
return os.path.join(dir_path, checkpoint_names[-1])
|
940 |
+
|
941 |
+
|
942 |
+
#----------------------------------------------------------------------------
|
943 |
+
# Reg and loss function fetchers
|
944 |
+
|
945 |
+
|
946 |
+
def build_opt(net, opt_class, opt_kwargs, reg, reg_interval):
|
947 |
+
opt_kwargs['lr'] = opt_kwargs.get('lr', 1e-3)
|
948 |
+
if reg not in [None, False] and reg_interval:
|
949 |
+
mb_ratio = reg_interval / (reg_interval + 1.)
|
950 |
+
opt_kwargs['lr'] *= mb_ratio
|
951 |
+
if 'momentum' in opt_kwargs:
|
952 |
+
opt_kwargs['momentum'] = opt_kwargs['momentum'] ** mb_ratio
|
953 |
+
if 'betas' in opt_kwargs:
|
954 |
+
betas = opt_kwargs['betas']
|
955 |
+
opt_kwargs['betas'] = (betas[0] ** mb_ratio, betas[1] ** mb_ratio)
|
956 |
+
if isinstance(opt_class, str):
|
957 |
+
opt_class = getattr(torch.optim, opt_class.title())
|
958 |
+
return opt_class(net.parameters(), **opt_kwargs)
|
959 |
+
|
960 |
+
|
961 |
+
#----------------------------------------------------------------------------
|
962 |
+
# Reg and loss function fetchers
|
963 |
+
|
964 |
+
|
965 |
+
_LOSS_FNS = {
|
966 |
+
'G': {
|
967 |
+
'logistic': loss_fns.G_logistic,
|
968 |
+
'logistic_ns': loss_fns.G_logistic_ns,
|
969 |
+
'wgan': loss_fns.G_wgan
|
970 |
+
},
|
971 |
+
'D': {
|
972 |
+
'logistic': loss_fns.D_logistic,
|
973 |
+
'wgan': loss_fns.D_wgan
|
974 |
+
}
|
975 |
+
}
|
976 |
+
def get_loss_fn(net, loss):
|
977 |
+
if callable(loss):
|
978 |
+
return loss
|
979 |
+
net = net.upper()
|
980 |
+
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
|
981 |
+
loss = loss.lower()
|
982 |
+
for name in _LOSS_FNS[net].keys():
|
983 |
+
if loss == name:
|
984 |
+
return _LOSS_FNS[net][name]
|
985 |
+
raise ValueError('Unknow {} loss {}'.format(net, loss))
|
986 |
+
|
987 |
+
|
988 |
+
_REG_FNS = {
|
989 |
+
'G': {
|
990 |
+
'pathreg': loss_fns.G_pathreg
|
991 |
+
},
|
992 |
+
'D': {
|
993 |
+
'r1': loss_fns.D_r1,
|
994 |
+
'r2': loss_fns.D_r2,
|
995 |
+
'gp': loss_fns.D_gp,
|
996 |
+
}
|
997 |
+
}
|
998 |
+
def get_reg_fn(net, reg, **kwargs):
|
999 |
+
if reg is None:
|
1000 |
+
return None
|
1001 |
+
if callable(reg):
|
1002 |
+
functools.partial(reg, **kwargs)
|
1003 |
+
net = net.upper()
|
1004 |
+
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
|
1005 |
+
reg = reg.lower()
|
1006 |
+
gamma = None
|
1007 |
+
for name in _REG_FNS[net].keys():
|
1008 |
+
if reg.startswith(name):
|
1009 |
+
gamma_chars = [c for c in reg.replace(name, '') if c.isdigit() or c == '.']
|
1010 |
+
if gamma_chars:
|
1011 |
+
kwargs.update(gamma=float(''.join(gamma_chars)))
|
1012 |
+
return functools.partial(_REG_FNS[net][name], **kwargs)
|
1013 |
+
raise ValueError('Unknow regularizer {}'.format(reg))
|
stylegan2/utils.py
ADDED
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numbers
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import collections
|
6 |
+
import argparse
|
7 |
+
import yaml
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import torchvision
|
14 |
+
try:
|
15 |
+
import tqdm
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
try:
|
19 |
+
from IPython.display import display as notebook_display
|
20 |
+
from IPython.display import clear_output as notebook_clear
|
21 |
+
except ImportError:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Miscellaneous utils
|
27 |
+
|
28 |
+
|
29 |
+
class AttributeDict(dict):
|
30 |
+
"""
|
31 |
+
Dict where values can be accessed using attribute syntax.
|
32 |
+
Same as "EasyDict" in the NVIDIA stylegan git repository.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __getattr__(self, name):
|
36 |
+
try:
|
37 |
+
return self[name]
|
38 |
+
except KeyError:
|
39 |
+
raise AttributeError(name)
|
40 |
+
|
41 |
+
def __setattr__(self, name, value):
|
42 |
+
self[name] = value
|
43 |
+
|
44 |
+
def __delattr__(self, name):
|
45 |
+
del self[name]
|
46 |
+
|
47 |
+
def __getstate__(self):
|
48 |
+
return dict(**self)
|
49 |
+
|
50 |
+
def __setstate__(self, state):
|
51 |
+
self.update(**state)
|
52 |
+
|
53 |
+
def __repr__(self):
|
54 |
+
return '{}({})'.format(
|
55 |
+
self.__class__.__name__,
|
56 |
+
', '.join('{}={}'.format(key, value) for key, value in self.items())
|
57 |
+
)
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def convert_dict_recursive(cls, obj):
|
61 |
+
if isinstance(obj, dict):
|
62 |
+
for key in list(obj.keys()):
|
63 |
+
obj[key] = cls.convert_dict_recursive(obj[key])
|
64 |
+
if not isinstance(obj, cls):
|
65 |
+
return cls(**obj)
|
66 |
+
return obj
|
67 |
+
|
68 |
+
|
69 |
+
class Timer:
|
70 |
+
|
71 |
+
def __init__(self):
|
72 |
+
self.reset()
|
73 |
+
|
74 |
+
def __enter__(self):
|
75 |
+
self._t0 = time.time()
|
76 |
+
|
77 |
+
def __exit__(self, *args):
|
78 |
+
self._t += time.time() - self._t0
|
79 |
+
|
80 |
+
def value(self):
|
81 |
+
return self._t
|
82 |
+
|
83 |
+
def reset(self):
|
84 |
+
self._t = 0
|
85 |
+
|
86 |
+
def __str__(self):
|
87 |
+
"""
|
88 |
+
Get a string representation of the recorded time.
|
89 |
+
Returns:
|
90 |
+
time_as_string (str)
|
91 |
+
"""
|
92 |
+
value = self.value()
|
93 |
+
if not value or value >= 100:
|
94 |
+
return '{} s'.format(int(value))
|
95 |
+
elif value >= 1:
|
96 |
+
return '{:.3g} s'.format(value)
|
97 |
+
elif value >= 1e-3:
|
98 |
+
return '{:.3g} ms'.format(value * 1e+3)
|
99 |
+
elif value >= 1e-6:
|
100 |
+
return '{:.3g} us'.format(value * 1e+6)
|
101 |
+
elif value >= 1e-9:
|
102 |
+
return '{:.3g} ns'.format(value * 1e+9)
|
103 |
+
else:
|
104 |
+
return '{:.2E} s'.format(value)
|
105 |
+
|
106 |
+
|
107 |
+
def to_list(values):
|
108 |
+
if values is None:
|
109 |
+
return []
|
110 |
+
if isinstance(values, tuple):
|
111 |
+
return list(values)
|
112 |
+
if not isinstance(values, list):
|
113 |
+
return [values]
|
114 |
+
return values
|
115 |
+
|
116 |
+
|
117 |
+
def lerp(a, b, beta):
|
118 |
+
if isinstance(beta, numbers.Number):
|
119 |
+
if beta == 1:
|
120 |
+
return b
|
121 |
+
elif beta == 0:
|
122 |
+
return a
|
123 |
+
if torch.is_tensor(a) and a.dtype == torch.float32:
|
124 |
+
# torch lerp only available for fp32
|
125 |
+
return torch.lerp(a, b, beta)
|
126 |
+
# More numerically stable than a + beta * (b - a)
|
127 |
+
return (1 - beta) * a + beta * b
|
128 |
+
|
129 |
+
|
130 |
+
def _normalize(v):
|
131 |
+
return v * torch.rsqrt(torch.sum(v ** 2, dim=-1, keepdim=True))
|
132 |
+
|
133 |
+
|
134 |
+
def slerp(a, b, beta):
|
135 |
+
assert a.size() == b.size(), 'Size mismatch between ' + \
|
136 |
+
'slerp arguments, received {} and {}'.format(a.size(), b.size())
|
137 |
+
if not torch.is_tensor(beta):
|
138 |
+
beta = torch.tensor(beta).to(a)
|
139 |
+
a = _normalize(a)
|
140 |
+
b = _normalize(b)
|
141 |
+
d = torch.sum(a * b, axis=-1, keepdim=True)
|
142 |
+
p = beta * torch.acos(beta)
|
143 |
+
c = _normalize(b - d * a)
|
144 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
145 |
+
return _normalize(d)
|
146 |
+
|
147 |
+
|
148 |
+
#----------------------------------------------------------------------------
|
149 |
+
# Command line utils
|
150 |
+
|
151 |
+
|
152 |
+
def _parse_configs(configs):
|
153 |
+
kwargs = {}
|
154 |
+
for config in configs:
|
155 |
+
with open(config, 'r') as fp:
|
156 |
+
kwargs.update(yaml.safe_load(fp))
|
157 |
+
return kwargs
|
158 |
+
|
159 |
+
|
160 |
+
class ConfigArgumentParser(argparse.ArgumentParser):
|
161 |
+
|
162 |
+
_CONFIG_ARG_KEY = '_configs'
|
163 |
+
|
164 |
+
def __init__(self, *args, **kwargs):
|
165 |
+
super(ConfigArgumentParser, self).__init__(*args, **kwargs)
|
166 |
+
self.add_argument(
|
167 |
+
self._CONFIG_ARG_KEY,
|
168 |
+
nargs='*',
|
169 |
+
help='Any yaml-style config file whos values will override the defaults of this argument parser.',
|
170 |
+
type=str
|
171 |
+
)
|
172 |
+
|
173 |
+
def parse_args(self, args=None):
|
174 |
+
config_args = _parse_configs(
|
175 |
+
getattr(
|
176 |
+
super(ConfigArgumentParser, self).parse_args(args),
|
177 |
+
self._CONFIG_ARG_KEY
|
178 |
+
)
|
179 |
+
)
|
180 |
+
self.set_defaults(**config_args)
|
181 |
+
return super(ConfigArgumentParser, self).parse_args(args)
|
182 |
+
|
183 |
+
|
184 |
+
def bool_type(v):
|
185 |
+
if isinstance(v, bool):
|
186 |
+
return v
|
187 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
188 |
+
return True
|
189 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
190 |
+
return False
|
191 |
+
else:
|
192 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
193 |
+
|
194 |
+
|
195 |
+
def range_type(s):
|
196 |
+
"""
|
197 |
+
Accept either a comma separated list of numbers
|
198 |
+
'a,b,c' or a range 'a-c' and return as a list of ints.
|
199 |
+
"""
|
200 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
201 |
+
m = range_re.match(s)
|
202 |
+
if m:
|
203 |
+
return range(int(m.group(1)), int(m.group(2))+1)
|
204 |
+
vals = s.split(',')
|
205 |
+
return [int(x) for x in vals]
|
206 |
+
|
207 |
+
|
208 |
+
#----------------------------------------------------------------------------
|
209 |
+
# Dataset and generation of latents
|
210 |
+
|
211 |
+
|
212 |
+
class ResizeTransform:
|
213 |
+
|
214 |
+
def __init__(self, height, width, resize=True, mode='bicubic'):
|
215 |
+
if resize:
|
216 |
+
assert height and width, 'Height and width have to be given ' + \
|
217 |
+
'when resizing data.'
|
218 |
+
self.height = height
|
219 |
+
self.width = width
|
220 |
+
self.resize = resize
|
221 |
+
self.mode = mode
|
222 |
+
|
223 |
+
def __call__(self, tensor):
|
224 |
+
if self.height and self.width:
|
225 |
+
if tensor.size(1) != self.height or tensor.size(2) != self.width:
|
226 |
+
if self.resize:
|
227 |
+
kwargs = {}
|
228 |
+
if 'cubic' in self.mode or 'linear' in self.mode:
|
229 |
+
kwargs.update(align_corners=False)
|
230 |
+
tensor = F.interpolate(
|
231 |
+
tensor.unsqueeze(0),
|
232 |
+
size=(self.height, self.width),
|
233 |
+
mode=self.mode,
|
234 |
+
**kwargs
|
235 |
+
).squeeze(0)
|
236 |
+
else:
|
237 |
+
raise ValueError(
|
238 |
+
'Data shape incorrect, expected ({},{}) '.format(self.width, self.height) + \
|
239 |
+
'but got ({},{}) (width, height)'.format(tensor.size(2), tensor.size(1))
|
240 |
+
)
|
241 |
+
return tensor
|
242 |
+
|
243 |
+
|
244 |
+
def _PIL_RGB_loader(path):
|
245 |
+
return Image.open(path).convert('RGB')
|
246 |
+
|
247 |
+
|
248 |
+
def _PIL_grayscale_loader(path):
|
249 |
+
return Image.open(path).convert('L')
|
250 |
+
|
251 |
+
|
252 |
+
class ImageFolder(torchvision.datasets.ImageFolder):
|
253 |
+
|
254 |
+
def __init__(self,
|
255 |
+
*args,
|
256 |
+
mirror=False,
|
257 |
+
pixel_min=-1,
|
258 |
+
pixel_max=1,
|
259 |
+
height=None,
|
260 |
+
width=None,
|
261 |
+
resize=False,
|
262 |
+
resize_mode='bicubic',
|
263 |
+
grayscale=False,
|
264 |
+
**kwargs):
|
265 |
+
super(ImageFolder, self).__init__(
|
266 |
+
*args,
|
267 |
+
loader=_PIL_grayscale_loader if grayscale else _PIL_RGB_loader,
|
268 |
+
**kwargs
|
269 |
+
)
|
270 |
+
transforms = []
|
271 |
+
if mirror:
|
272 |
+
transforms.append(torchvision.transforms.RandomHorizontalFlip())
|
273 |
+
transforms.append(torchvision.transforms.ToTensor())
|
274 |
+
transforms.append(
|
275 |
+
torchvision.transforms.Normalize(
|
276 |
+
mean=[-(pixel_min / (pixel_max - pixel_min))],
|
277 |
+
std=[1. / (pixel_max - pixel_min)]
|
278 |
+
)
|
279 |
+
)
|
280 |
+
transforms.append(ResizeTransform(
|
281 |
+
height=height, width=width, resize=resize, mode=resize_mode))
|
282 |
+
self.transform = torchvision.transforms.Compose(transforms)
|
283 |
+
|
284 |
+
def _find_classes(self, *args, **kwargs):
|
285 |
+
classes, class_to_idx = super(ImageFolder, self)._find_classes(*args, **kwargs)
|
286 |
+
if not classes:
|
287 |
+
classes = ['']
|
288 |
+
class_to_idx = {'': 0}
|
289 |
+
return classes, class_to_idx
|
290 |
+
|
291 |
+
|
292 |
+
class PriorGenerator:
|
293 |
+
|
294 |
+
def __init__(self, latent_size, label_size, batch_size, device):
|
295 |
+
self.latent_size = latent_size
|
296 |
+
self.label_size = label_size
|
297 |
+
self.batch_size = batch_size
|
298 |
+
self.device = device
|
299 |
+
|
300 |
+
def __iter__(self):
|
301 |
+
return self
|
302 |
+
|
303 |
+
def __next__(self):
|
304 |
+
return self()
|
305 |
+
|
306 |
+
def __call__(self, batch_size=None, multi_latent_prob=0, seed=None):
|
307 |
+
if batch_size is None:
|
308 |
+
batch_size = self.batch_size
|
309 |
+
shape = [batch_size, self.latent_size]
|
310 |
+
if multi_latent_prob:
|
311 |
+
if seed is not None:
|
312 |
+
np.random.seed(seed)
|
313 |
+
if np.random.uniform() < multi_latent_prob:
|
314 |
+
shape = [batch_size, 2, self.latent_size]
|
315 |
+
if seed is not None:
|
316 |
+
torch.manual_seed(seed)
|
317 |
+
latents = torch.empty(*shape, device=self.device).normal_()
|
318 |
+
labels = None
|
319 |
+
if self.label_size:
|
320 |
+
label_shape = [batch_size]
|
321 |
+
labels = torch.randint(0, self.label_size, label_shape, device=self.device)
|
322 |
+
return latents, labels
|
323 |
+
|
324 |
+
|
325 |
+
#----------------------------------------------------------------------------
|
326 |
+
# Training utils
|
327 |
+
|
328 |
+
|
329 |
+
class MovingAverageModule:
|
330 |
+
|
331 |
+
def __init__(self,
|
332 |
+
from_module,
|
333 |
+
to_module=None,
|
334 |
+
param_beta=0.995,
|
335 |
+
buffer_beta=0,
|
336 |
+
device=None):
|
337 |
+
from_module = unwrap_module(from_module)
|
338 |
+
to_module = unwrap_module(to_module)
|
339 |
+
if device is None:
|
340 |
+
module = from_module
|
341 |
+
if to_module is not None:
|
342 |
+
module = to_module
|
343 |
+
device = next(module.parameters()).device
|
344 |
+
else:
|
345 |
+
device = torch.device(device)
|
346 |
+
self.from_module = from_module
|
347 |
+
if to_module is None:
|
348 |
+
self.module = from_module.clone().to(device)
|
349 |
+
else:
|
350 |
+
assert type(to_module) == type(from_module), \
|
351 |
+
'Mismatch between type of source and target module.'
|
352 |
+
assert set(self._get_named_parameters(to_module).keys()) \
|
353 |
+
== set(self._get_named_parameters(from_module).keys()), \
|
354 |
+
'Mismatch between parameters of source and target module.'
|
355 |
+
assert set(self._get_named_buffers(to_module).keys()) \
|
356 |
+
== set(self._get_named_buffers(from_module).keys()), \
|
357 |
+
'Mismatch between buffers of source and target module.'
|
358 |
+
self.module = to_module.to(device)
|
359 |
+
self.module.eval().requires_grad_(False)
|
360 |
+
self.param_beta = param_beta
|
361 |
+
self.buffer_beta = buffer_beta
|
362 |
+
self.device = device
|
363 |
+
|
364 |
+
def __getattr__(self, name):
|
365 |
+
try:
|
366 |
+
return super(object, self).__getattr__(name)
|
367 |
+
except AttributeError:
|
368 |
+
return getattr(self.module, name)
|
369 |
+
|
370 |
+
def update(self):
|
371 |
+
self._update_data(
|
372 |
+
from_data=self._get_named_parameters(self.from_module),
|
373 |
+
to_data=self._get_named_parameters(self.module),
|
374 |
+
beta=self.param_beta
|
375 |
+
)
|
376 |
+
self._update_data(
|
377 |
+
from_data=self._get_named_buffers(self.from_module),
|
378 |
+
to_data=self._get_named_buffers(self.module),
|
379 |
+
beta=self.buffer_beta
|
380 |
+
)
|
381 |
+
|
382 |
+
@staticmethod
|
383 |
+
def _update_data(from_data, to_data, beta):
|
384 |
+
for name in from_data.keys():
|
385 |
+
if name not in to_data:
|
386 |
+
continue
|
387 |
+
fr, to = from_data[name], to_data[name]
|
388 |
+
with torch.no_grad():
|
389 |
+
if beta == 0:
|
390 |
+
to.data.copy_(fr.data.to(to.data))
|
391 |
+
elif beta < 1:
|
392 |
+
to.data.copy_(lerp(fr.data.to(to.data), to.data, beta))
|
393 |
+
|
394 |
+
@staticmethod
|
395 |
+
def _get_named_parameters(module):
|
396 |
+
return {name: value for name, value in module.named_parameters()}
|
397 |
+
|
398 |
+
@staticmethod
|
399 |
+
def _get_named_buffers(module):
|
400 |
+
return {name: value for name, value in module.named_buffers()}
|
401 |
+
|
402 |
+
def __call__(self, *args, **kwargs):
|
403 |
+
return self.forward(*args, **kwargs)
|
404 |
+
|
405 |
+
def forward(self, *args, **kwargs):
|
406 |
+
self.module.eval()
|
407 |
+
args, args_in_device = move_to_device(args, self.device)
|
408 |
+
kwargs, kwargs_in_device = move_to_device(kwargs, self.device)
|
409 |
+
in_device = None
|
410 |
+
if args_in_device is not None:
|
411 |
+
in_device = args_in_device
|
412 |
+
if kwargs_in_device is not None:
|
413 |
+
in_device = kwargs_in_device
|
414 |
+
out = self.module(*args, **kwargs)
|
415 |
+
if in_device is not None:
|
416 |
+
out, _ = move_to_device(out, in_device)
|
417 |
+
return out
|
418 |
+
|
419 |
+
|
420 |
+
def move_to_device(value, device):
|
421 |
+
if torch.is_tensor(value):
|
422 |
+
value.to(device), value.device
|
423 |
+
orig_device = None
|
424 |
+
if isinstance(value, (tuple, list)):
|
425 |
+
values = []
|
426 |
+
for val in value:
|
427 |
+
_val, orig_device = move_to_device(val, device)
|
428 |
+
values.append(_val)
|
429 |
+
return type(value)(values), orig_device
|
430 |
+
if isinstance(value, dict):
|
431 |
+
if isinstance(value, collections.OrderedDict):
|
432 |
+
values = collections.OrderedDict()
|
433 |
+
else:
|
434 |
+
values = {}
|
435 |
+
for key, val in value.items():
|
436 |
+
_val, orig_device = move_to_device(val, device)
|
437 |
+
values[key] = val
|
438 |
+
return values, orig_device
|
439 |
+
return value, orig_device
|
440 |
+
|
441 |
+
|
442 |
+
_WRAPPER_CLASSES = (MovingAverageModule, nn.DataParallel, nn.parallel.DistributedDataParallel)
|
443 |
+
def unwrap_module(module):
|
444 |
+
if isinstance(module, _WRAPPER_CLASSES):
|
445 |
+
return module.module
|
446 |
+
return module
|
447 |
+
|
448 |
+
|
449 |
+
def get_grad_norm_from_optimizer(optimizer, norm_type=2):
|
450 |
+
"""
|
451 |
+
Get the gradient norm for some parameters contained in an optimizer.
|
452 |
+
Arguments:
|
453 |
+
optimizer (torch.optim.Optimizer)
|
454 |
+
norm_type (int): Type of norm. Default value is 2.
|
455 |
+
Returns:
|
456 |
+
norm (float)
|
457 |
+
"""
|
458 |
+
total_norm = 0
|
459 |
+
if optimizer is not None:
|
460 |
+
for param_group in optimizer.param_groups:
|
461 |
+
for p in param_group['params']:
|
462 |
+
if p.grad is not None:
|
463 |
+
with torch.no_grad():
|
464 |
+
param_norm = p.grad.data.norm(norm_type)
|
465 |
+
total_norm += param_norm ** norm_type
|
466 |
+
total_norm = total_norm ** (1. / norm_type)
|
467 |
+
return total_norm.item()
|
468 |
+
|
469 |
+
|
470 |
+
#----------------------------------------------------------------------------
|
471 |
+
# printing and logging utils
|
472 |
+
|
473 |
+
|
474 |
+
class ValueTracker:
|
475 |
+
|
476 |
+
def __init__(self, beta=0.95):
|
477 |
+
self.beta = beta
|
478 |
+
self.values = {}
|
479 |
+
|
480 |
+
def add(self, name, value, beta=None):
|
481 |
+
if torch.is_tensor(value):
|
482 |
+
value = value.item()
|
483 |
+
if beta is None:
|
484 |
+
beta = self.beta
|
485 |
+
if name not in self.values:
|
486 |
+
self.values[name] = value
|
487 |
+
else:
|
488 |
+
self.values[name] = lerp(value, self.values[name], beta)
|
489 |
+
|
490 |
+
def __getitem__(self, key):
|
491 |
+
return self.values[key]
|
492 |
+
|
493 |
+
def __str__(self):
|
494 |
+
string = ''
|
495 |
+
for i, name in enumerate(sorted(self.values.keys())):
|
496 |
+
if i and i % 3 == 0:
|
497 |
+
string += '\n'
|
498 |
+
elif string:
|
499 |
+
string += ', '
|
500 |
+
format_string = '{}: {}'
|
501 |
+
if isinstance(self.values[name], float):
|
502 |
+
format_string = '{}: {:.4g}'
|
503 |
+
string += format_string.format(name, self.values[name])
|
504 |
+
return string
|
505 |
+
|
506 |
+
|
507 |
+
def is_notebook():
|
508 |
+
"""
|
509 |
+
Check if code is running from jupyter notebook.
|
510 |
+
Returns:
|
511 |
+
notebook (bool): True if running from jupyter notebook,
|
512 |
+
else False.
|
513 |
+
"""
|
514 |
+
try:
|
515 |
+
__IPYTHON__
|
516 |
+
return True
|
517 |
+
except NameError:
|
518 |
+
return False
|
519 |
+
|
520 |
+
|
521 |
+
def _progress_bar(count, total):
|
522 |
+
"""
|
523 |
+
Get a simple one-line string representing a progress bar.
|
524 |
+
Arguments:
|
525 |
+
count (int): Current count. Starts at 0.
|
526 |
+
total (int): Total count.
|
527 |
+
Returns:
|
528 |
+
pbar_string (str): The string progress bar.
|
529 |
+
"""
|
530 |
+
bar_len = 60
|
531 |
+
filled_len = int(round(bar_len * (count + 1) / float(total)))
|
532 |
+
bar = '=' * filled_len + '-' * (bar_len - filled_len)
|
533 |
+
return '[{}] {}/{}'.format(bar, count + 1, total)
|
534 |
+
|
535 |
+
|
536 |
+
class ProgressWriter:
|
537 |
+
"""
|
538 |
+
Handles writing output and displaying a progress bar. Automatically
|
539 |
+
adjust for notebooks. Supports outputting text
|
540 |
+
that is compatible with the progressbar (in notebooks the text is
|
541 |
+
refreshed instead of printed).
|
542 |
+
Arguments:
|
543 |
+
length (int, optional): Total length of the progressbar.
|
544 |
+
Default value is None.
|
545 |
+
progress_bar (bool, optional): Display a progressbar.
|
546 |
+
Default value is True.
|
547 |
+
clear (bool, optional): If running from a notebook, clear
|
548 |
+
the current cell's output. Default value is False.
|
549 |
+
"""
|
550 |
+
def __init__(self, length=None, progress_bar=True, clear=False):
|
551 |
+
if is_notebook() and clear:
|
552 |
+
notebook_clear()
|
553 |
+
|
554 |
+
if length is not None:
|
555 |
+
length = int(length)
|
556 |
+
self.length = length
|
557 |
+
self.count = 0
|
558 |
+
|
559 |
+
self._simple_pbar = False
|
560 |
+
if progress_bar and 'tqdm' not in sys.modules:
|
561 |
+
self._simple_pbar = True
|
562 |
+
|
563 |
+
progress_bar = progress_bar and 'tqdm' in sys.modules
|
564 |
+
|
565 |
+
self._progress_bar = None
|
566 |
+
if progress_bar:
|
567 |
+
pbar = tqdm.tqdm
|
568 |
+
if is_notebook():
|
569 |
+
pbar = tqdm.tqdm_notebook
|
570 |
+
if length is not None:
|
571 |
+
self._progress_bar = pbar(total=length, file=sys.stdout)
|
572 |
+
else:
|
573 |
+
self._progress_bar = pbar(file=sys.stdout)
|
574 |
+
|
575 |
+
if is_notebook():
|
576 |
+
self._writer = notebook_display(
|
577 |
+
_StrRepr(''),
|
578 |
+
display_id=time.asctime()
|
579 |
+
)
|
580 |
+
else:
|
581 |
+
if progress_bar:
|
582 |
+
self._writer = self._progress_bar
|
583 |
+
else:
|
584 |
+
self._writer = sys.stdout
|
585 |
+
|
586 |
+
def write(self, *lines, step=True):
|
587 |
+
"""
|
588 |
+
Output values to stdout (or a display object if called from notebook).
|
589 |
+
Arguments:
|
590 |
+
*lines: The lines to write (positional arguments).
|
591 |
+
step (bool): Update the progressbar if present.
|
592 |
+
Default value is True.
|
593 |
+
"""
|
594 |
+
string = '\n'.join(str(line) for line in lines if line and line.strip())
|
595 |
+
if self._simple_pbar:
|
596 |
+
string = _progress_bar(self.count, self.length) + '\n' + string
|
597 |
+
if is_notebook():
|
598 |
+
self._writer.update(_StrRepr(string))
|
599 |
+
else:
|
600 |
+
self._writer.write('\n\n' + string)
|
601 |
+
if hasattr(self._writer, 'flush'):
|
602 |
+
self._writer.flush()
|
603 |
+
if step:
|
604 |
+
self.step()
|
605 |
+
|
606 |
+
def step(self):
|
607 |
+
"""
|
608 |
+
Update the progressbar if present.
|
609 |
+
"""
|
610 |
+
self.count += 1
|
611 |
+
if self._progress_bar is not None:
|
612 |
+
self._progress_bar.update()
|
613 |
+
|
614 |
+
def __iter__(self):
|
615 |
+
return self
|
616 |
+
|
617 |
+
def __next__(self):
|
618 |
+
return next(self.rnge)
|
619 |
+
|
620 |
+
def close(self):
|
621 |
+
if hasattr(self._writer, 'close'):
|
622 |
+
can_close = True
|
623 |
+
try:
|
624 |
+
can_close = self._writer != sys.stdout and self._writer != sys.stderr
|
625 |
+
except AttributeError:
|
626 |
+
pass
|
627 |
+
if can_close:
|
628 |
+
self._writer.close()
|
629 |
+
if hasattr(self._progress_bar, 'close'):
|
630 |
+
self._progress_bar.close()
|
631 |
+
|
632 |
+
def __del__(self):
|
633 |
+
self.close()
|
634 |
+
|
635 |
+
|
636 |
+
class _StrRepr:
|
637 |
+
"""
|
638 |
+
A wrapper for strings that returns the string
|
639 |
+
on repr() calls. Used by notebooks.
|
640 |
+
"""
|
641 |
+
def __init__(self, string):
|
642 |
+
self.string = string
|
643 |
+
|
644 |
+
def __repr__(self):
|
645 |
+
return self.string
|
646 |
+
|
647 |
+
|
648 |
+
#----------------------------------------------------------------------------
|
649 |
+
# image utils
|
650 |
+
|
651 |
+
|
652 |
+
def tensor_to_PIL(image_tensor, pixel_min=-1, pixel_max=1):
|
653 |
+
image_tensor = image_tensor.cpu()
|
654 |
+
if pixel_min != 0 or pixel_max != 1:
|
655 |
+
image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)
|
656 |
+
image_tensor.clamp_(min=0, max=1)
|
657 |
+
to_pil = torchvision.transforms.functional.to_pil_image
|
658 |
+
if image_tensor.dim() == 4:
|
659 |
+
return [to_pil(img) for img in image_tensor]
|
660 |
+
return to_pil(image_tensor)
|
661 |
+
|
662 |
+
|
663 |
+
def PIL_to_tensor(image, pixel_min=-1, pixel_max=1):
|
664 |
+
to_tensor = torchvision.transforms.functional.to_tensor
|
665 |
+
if isinstance(image, (list, tuple)):
|
666 |
+
image_tensor = torch.stack([to_tensor(img) for img in image])
|
667 |
+
else:
|
668 |
+
image_tensor = to_tensor(image)
|
669 |
+
if pixel_min != 0 or pixel_max != 1:
|
670 |
+
image_tensor = image_tensor * (pixel_max - pixel_min) + pixel_min
|
671 |
+
return image_tensor
|
672 |
+
|
673 |
+
|
674 |
+
def stack_images_PIL(imgs, shape=None, individual_img_size=None):
|
675 |
+
"""
|
676 |
+
Concatenate multiple images into a grid within a single image.
|
677 |
+
Arguments:
|
678 |
+
imgs (Sequence of PIL.Image): Input images.
|
679 |
+
shape (list, tuple, int, optional): Shape of the grid. Should consist
|
680 |
+
of two values, (width, height). If an integer value is passed it
|
681 |
+
is used for both width and height. If no value is passed the shape
|
682 |
+
is infered from the number of images. Default value is None.
|
683 |
+
individual_img_size (list, tuple, int, optional): The size of the
|
684 |
+
images being concatenated. Default value is None.
|
685 |
+
Returns:
|
686 |
+
canvas (PIL.Image): Image containing input images in a grid.
|
687 |
+
"""
|
688 |
+
assert len(imgs) > 0, 'No images received.'
|
689 |
+
if shape is None:
|
690 |
+
size = int(np.ceil(np.sqrt(len(imgs))))
|
691 |
+
shape = [int(np.ceil(len(imgs) / size)), size]
|
692 |
+
else:
|
693 |
+
if isinstance(shape, numbers.Number):
|
694 |
+
shape = 2 * [shape]
|
695 |
+
assert len(shape) == 2, 'Shape should specify (width, height).'
|
696 |
+
|
697 |
+
if individual_img_size is None:
|
698 |
+
for i in range(len(imgs) - 1):
|
699 |
+
assert imgs[i].size == imgs[i + 1].size, \
|
700 |
+
'Images are of different sizes, please specify a ' + \
|
701 |
+
'size (width, height). Found sizes:\n' + \
|
702 |
+
', '.join(str(img.size) for img in imgs)
|
703 |
+
individual_img_size = imgs[0].size
|
704 |
+
else:
|
705 |
+
if not isinstance(individual_img_size, (tuple, list)):
|
706 |
+
individual_img_size = 2 * (individual_img_size,)
|
707 |
+
individual_img_size = tuple(individual_img_size)
|
708 |
+
for i in range(len(imgs)):
|
709 |
+
if imgs[i].size != individual_img_size:
|
710 |
+
imgs[i] = imgs[i].resize(individual_img_size)
|
711 |
+
|
712 |
+
width, height = individual_img_size
|
713 |
+
width, height = int(width), int(height)
|
714 |
+
canvas = Image.new(
|
715 |
+
'RGB',
|
716 |
+
(shape[0] * width, shape[1] * height),
|
717 |
+
(0, 0, 0, 0)
|
718 |
+
)
|
719 |
+
imgs = imgs.copy()
|
720 |
+
for h_i in range(shape[1]):
|
721 |
+
for w_i in range(shape[0]):
|
722 |
+
if len(imgs) > 0:
|
723 |
+
img = imgs.pop(0).convert('RGB')
|
724 |
+
offset = (w_i * width, h_i * height)
|
725 |
+
canvas.paste(img, offset)
|
726 |
+
return canvas
|