Spaces:
Build error
Build error
File size: 10,416 Bytes
8f87579 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Loss functions."""
import tensorflow as tf
import dnnlib.tflib as tflib
from dnnlib.tflib.autosummary import autosummary
#----------------------------------------------------------------------------
# Convenience func that casts all of its arguments to tf.float32.
def fp32(*values):
if len(values) == 1 and isinstance(values[0], tuple):
values = values[0]
values = tuple(tf.cast(v, tf.float32) for v in values)
return values if len(values) >= 2 else values[0]
#----------------------------------------------------------------------------
# WGAN & WGAN-GP loss functions.
def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
loss = -fake_scores_out
return loss
def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument
wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}.
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = fake_scores_out - real_scores_out
with tf.name_scope('EpsilonPenalty'):
epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
loss += epsilon_penalty * wgan_epsilon
return loss
def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument
wgan_lambda = 10.0, # Weight for the gradient penalty term.
wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}.
wgan_target = 1.0): # Target value for gradient magnitudes.
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = fake_scores_out - real_scores_out
with tf.name_scope('GradientPenalty'):
mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))
mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
gradient_penalty = tf.square(mixed_norms - wgan_target)
loss += gradient_penalty * (wgan_lambda / (wgan_target**2))
with tf.name_scope('EpsilonPenalty'):
epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
loss += epsilon_penalty * wgan_epsilon
return loss
#----------------------------------------------------------------------------
# Hinge loss functions. (Use G_wgan with these)
def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)
return loss
def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument
wgan_lambda = 10.0, # Weight for the gradient penalty term.
wgan_target = 1.0): # Target value for gradient magnitudes.
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)
with tf.name_scope('GradientPenalty'):
mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))
mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))
mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))
mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
gradient_penalty = tf.square(mixed_norms - wgan_target)
loss += gradient_penalty * (wgan_lambda / (wgan_target**2))
return loss
#----------------------------------------------------------------------------
# Loss functions advocated by the paper
# "Which Training Methods for GANs do actually Converge?"
def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out))
return loss
def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
labels = training_set.get_random_labels_tf(minibatch_size)
fake_images_out = G.get_output_for(latents, labels, is_training=True)
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out))
return loss
def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out))
loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type
return loss
def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument
latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
fake_images_out = G.get_output_for(latents, labels, is_training=True)
real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
real_scores_out = autosummary('Loss/scores/real', real_scores_out)
fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out))
loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type
if r1_gamma != 0.0:
with tf.name_scope('R1Penalty'):
real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out))
real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0]))
r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
r1_penalty = autosummary('Loss/r1_penalty', r1_penalty)
loss += r1_penalty * (r1_gamma * 0.5)
if r2_gamma != 0.0:
with tf.name_scope('R2Penalty'):
fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out))
fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0]))
r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])
r2_penalty = autosummary('Loss/r2_penalty', r2_penalty)
loss += r2_penalty * (r2_gamma * 0.5)
return loss
#----------------------------------------------------------------------------
|