Spaces:
Build error
Build error
File size: 17,575 Bytes
8f87579 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Network architectures used in the ProGAN paper."""
import numpy as np
import tensorflow as tf
# NOTE: Do not import any application-specific modules here!
# Specify all network parameters as kwargs.
#----------------------------------------------------------------------------
def lerp(a, b, t): return a + (b - a) * t
def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
#----------------------------------------------------------------------------
# Get/create weight tensor for a convolutional or fully-connected layer.
def get_weight(shape, gain=np.sqrt(2), use_wscale=False):
fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
std = gain / np.sqrt(fan_in) # He init
if use_wscale:
wscale = tf.constant(np.float32(std), name='wscale')
w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale
else:
w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))
return w
#----------------------------------------------------------------------------
# Fully-connected layer.
def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False):
if len(x.shape) > 2:
x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])
w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)
w = tf.cast(w, x.dtype)
return tf.matmul(x, w)
#----------------------------------------------------------------------------
# Convolutional layer.
def conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):
assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)
w = tf.cast(w, x.dtype)
return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW')
#----------------------------------------------------------------------------
# Apply bias to the given activation tensor.
def apply_bias(x):
b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros())
b = tf.cast(b, x.dtype)
if len(x.shape) == 2:
return x + b
return x + tf.reshape(b, [1, -1, 1, 1])
#----------------------------------------------------------------------------
# Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16.
def leaky_relu(x, alpha=0.2):
with tf.name_scope('LeakyRelu'):
alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')
return tf.maximum(x * alpha, x)
#----------------------------------------------------------------------------
# Nearest-neighbor upscaling layer.
def upscale2d(x, factor=2):
assert isinstance(factor, int) and factor >= 1
if factor == 1: return x
with tf.variable_scope('Upscale2D'):
s = x.shape
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
x = tf.tile(x, [1, 1, 1, factor, 1, factor])
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
return x
#----------------------------------------------------------------------------
# Fused upscale2d + conv2d.
# Faster and uses less memory than performing the operations separately.
def upscale2d_conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):
assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)
w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]
w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])
w = tf.cast(w, x.dtype)
os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2]
return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')
#----------------------------------------------------------------------------
# Box filter downscaling layer.
def downscale2d(x, factor=2):
assert isinstance(factor, int) and factor >= 1
if factor == 1: return x
with tf.variable_scope('Downscale2D'):
ksize = [1, 1, factor, factor]
return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True
#----------------------------------------------------------------------------
# Fused conv2d + downscale2d.
# Faster and uses less memory than performing the operations separately.
def conv2d_downscale2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):
assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)
w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')
w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25
w = tf.cast(w, x.dtype)
return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW')
#----------------------------------------------------------------------------
# Pixelwise feature vector normalization.
def pixel_norm(x, epsilon=1e-8):
with tf.variable_scope('PixelNorm'):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)
#----------------------------------------------------------------------------
# Minibatch standard deviation.
def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
with tf.variable_scope('MinibatchStddev'):
group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.
y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32.
y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group.
y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group.
y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group.
y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels.
y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups
y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type.
y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels.
return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap.
#----------------------------------------------------------------------------
# Networks used in the ProgressiveGAN paper.
def G_paper(
latents_in, # First input: Latent vectors [minibatch, latent_size].
labels_in, # Second input: Labels [minibatch, label_size].
num_channels = 1, # Number of output color channels. Overridden based on dataset.
resolution = 32, # Output resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 8192, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_max = 512, # Maximum number of feature maps in any layer.
latent_size = None, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max).
normalize_latents = True, # Normalize latent vectors before feeding them to the network?
use_wscale = True, # Enable equalized learning rate?
use_pixelnorm = True, # Enable pixelwise feature vector normalization?
pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization.
use_leakyrelu = True, # True = leaky ReLU, False = ReLU.
dtype = 'float32', # Data type to use for activations and outputs.
fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers.
structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically.
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x
if latent_size is None: latent_size = nf(0)
if structure is None: structure = 'linear' if is_template_graph else 'recursive'
act = leaky_relu if use_leakyrelu else tf.nn.relu
latents_in.set_shape([None, latent_size])
labels_in.set_shape([None, label_size])
combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
images_out = None
# Building blocks.
def block(x, res): # res = 2..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
if res == 2: # 4x4
if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon)
with tf.variable_scope('Dense'):
x = dense(x, fmaps=nf(res-1)*16, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation
x = tf.reshape(x, [-1, nf(res-1), 4, 4])
x = PN(act(apply_bias(x)))
with tf.variable_scope('Conv'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
else: # 8x8 and up
if fused_scale:
with tf.variable_scope('Conv0_up'):
x = PN(act(apply_bias(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
else:
x = upscale2d(x)
with tf.variable_scope('Conv0'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
with tf.variable_scope('Conv1'):
x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))
return x
def torgb(x, res): # res = 2..resolution_log2
lod = resolution_log2 - res
with tf.variable_scope('ToRGB_lod%d' % lod):
return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale))
# Linear structure: simple but inefficient.
if structure == 'linear':
x = block(combo_in, 2)
images_out = torgb(x, 2)
for res in range(3, resolution_log2 + 1):
lod = resolution_log2 - res
x = block(x, res)
img = torgb(x, res)
images_out = upscale2d(images_out)
with tf.variable_scope('Grow_lod%d' % lod):
images_out = lerp_clip(img, images_out, lod_in - lod)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def grow(x, res, lod):
y = block(x, res)
img = lambda: upscale2d(torgb(y, res), 2**lod)
if res > 2: img = cset(img, (lod_in > lod), lambda: upscale2d(lerp(torgb(y, res), upscale2d(torgb(x, res - 1)), lod_in - lod), 2**lod))
if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))
return img()
images_out = grow(combo_in, 2, resolution_log2 - 2)
assert images_out.dtype == tf.as_dtype(dtype)
images_out = tf.identity(images_out, name='images_out')
return images_out
def D_paper(
images_in, # First input: Images [minibatch, channel, height, width].
labels_in, # Second input: Labels [minibatch, label_size].
num_channels = 1, # Number of input color channels. Overridden based on dataset.
resolution = 32, # Input resolution. Overridden based on dataset.
label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
fmap_base = 8192, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_max = 512, # Maximum number of feature maps in any layer.
use_wscale = True, # Enable equalized learning rate?
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable.
dtype = 'float32', # Data type to use for activations and outputs.
fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers.
structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically
is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.
**_kwargs): # Ignore unrecognized keyword args.
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
if structure is None: structure = 'linear' if is_template_graph else 'recursive'
act = leaky_relu
images_in.set_shape([None, num_channels, resolution, resolution])
labels_in.set_shape([None, label_size])
images_in = tf.cast(images_in, dtype)
labels_in = tf.cast(labels_in, dtype)
lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
scores_out = None
# Building blocks.
def fromrgb(x, res): # res = 2..resolution_log2
with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale)))
def block(x, res): # res = 2..resolution_log2
with tf.variable_scope('%dx%d' % (2**res, 2**res)):
if res >= 3: # 8x8 and up
with tf.variable_scope('Conv0'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))
if fused_scale:
with tf.variable_scope('Conv1_down'):
x = act(apply_bias(conv2d_downscale2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))
else:
with tf.variable_scope('Conv1'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))
x = downscale2d(x)
else: # 4x4
if mbstd_group_size > 1:
x = minibatch_stddev_layer(x, mbstd_group_size)
with tf.variable_scope('Conv'):
x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))
with tf.variable_scope('Dense0'):
x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale)))
with tf.variable_scope('Dense1'):
x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale))
return x
# Linear structure: simple but inefficient.
if structure == 'linear':
img = images_in
x = fromrgb(img, resolution_log2)
for res in range(resolution_log2, 2, -1):
lod = resolution_log2 - res
x = block(x, res)
img = downscale2d(img)
y = fromrgb(img, res - 1)
with tf.variable_scope('Grow_lod%d' % lod):
x = lerp_clip(x, y, lod_in - lod)
scores_out = block(x, 2)
# Recursive structure: complex but efficient.
if structure == 'recursive':
def grow(res, lod):
x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)
if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
x = block(x(), res); y = lambda: x
if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))
return y()
scores_out = grow(2, resolution_log2 - 2)
assert scores_out.dtype == tf.as_dtype(dtype)
scores_out = tf.identity(scores_out, name='scores_out')
return scores_out
#----------------------------------------------------------------------------
|