Spaces:
Runtime error
Runtime error
File size: 18,782 Bytes
a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae bb0f5a9 a0bcaae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
# Copyright (c) SenseTime Research. All rights reserved.
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html
"""Helper wrapper for a Tensorflow optimizer."""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import List, Union
from . import autosummary
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
try:
# TensorFlow 1.13
from tensorflow.python.ops import nccl_ops
except:
# Older TensorFlow versions
import tensorflow.contrib.nccl as nccl_ops
class Optimizer:
"""A Wrapper for tf.train.Optimizer.
Automatically takes care of:
- Gradient averaging for multi-GPU training.
- Gradient accumulation for arbitrarily large minibatches.
- Dynamic loss scaling and typecasts for FP16 training.
- Ignoring corrupted gradients that contain NaNs/Infs.
- Reporting statistics.
- Well-chosen default settings.
"""
def __init__(self,
# Name string that will appear in TensorFlow graph.
name: str = "Train",
# Underlying optimizer class.
tf_optimizer: str = "tf.train.AdamOptimizer",
# Learning rate. Can vary over time.
learning_rate: TfExpressionEx = 0.001,
# Treat N consecutive minibatches as one by accumulating gradients.
minibatch_multiplier: TfExpressionEx = None,
# Share internal state with a previously created optimizer?
share: "Optimizer" = None,
# Enable dynamic loss scaling for robust mixed-precision training?
use_loss_scaling: bool = False,
# Log2 of initial loss scaling factor.
loss_scaling_init: float = 64.0,
# Log2 of per-minibatch loss scaling increment when there is no overflow.
loss_scaling_inc: float = 0.0005,
# Log2 of per-minibatch loss scaling decrement when there is an overflow.
loss_scaling_dec: float = 1.0,
# Report fine-grained memory usage statistics in TensorBoard?
report_mem_usage: bool = False,
**kwargs):
# Public fields.
self.name = name
self.learning_rate = learning_rate
self.minibatch_multiplier = minibatch_multiplier
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
self.optimizer_kwargs = dict(kwargs)
self.use_loss_scaling = use_loss_scaling
self.loss_scaling_init = loss_scaling_init
self.loss_scaling_inc = loss_scaling_inc
self.loss_scaling_dec = loss_scaling_dec
# Private fields.
self._updates_applied = False
self._devices = OrderedDict() # device_name => EasyDict()
self._shared_optimizers = OrderedDict() # device_name => optimizer_class
self._gradient_shapes = None # [shape, ...]
self._report_mem_usage = report_mem_usage
# Validate arguments.
assert callable(self.optimizer_class)
# Share internal state if requested.
if share is not None:
assert isinstance(share, Optimizer)
assert self.optimizer_class is share.optimizer_class
assert self.learning_rate is share.learning_rate
assert self.optimizer_kwargs == share.optimizer_kwargs
self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
def _get_device(self, device_name: str):
"""Get internal state for the given TensorFlow device."""
tfutil.assert_tf_initialized()
if device_name in self._devices:
return self._devices[device_name]
# Initialize fields.
device = util.EasyDict()
device.name = device_name
device.optimizer = None # Underlying optimizer: optimizer_class
device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
# Raw gradients: var => [grad, ...]
device.grad_raw = OrderedDict()
device.grad_clean = OrderedDict() # Clean gradients: var => grad
# Accumulation sums: var => tf.Variable
device.grad_acc_vars = OrderedDict()
device.grad_acc_count = None # Accumulation counter: tf.Variable
device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
# Setup TensorFlow objects.
with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
if device_name not in self._shared_optimizers:
optimizer_name = self.scope.replace(
"/", "_") + "_opt%d" % len(self._shared_optimizers)
self._shared_optimizers[device_name] = self.optimizer_class(
name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
device.optimizer = self._shared_optimizers[device_name]
if self.use_loss_scaling:
device.loss_scaling_var = tf.Variable(np.float32(
self.loss_scaling_init), trainable=False, name="loss_scaling_var")
# Register device.
self._devices[device_name] = device
return device
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
"""Register the gradients of the given loss function with respect to the given variables.
Intended to be called once per GPU."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
device = self._get_device(loss.device)
# Validate trainables.
if isinstance(trainable_vars, dict):
# allow passing in Network.trainables as vars
trainable_vars = list(trainable_vars.values())
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
assert all(tfutil.is_tf_expression(expr)
for expr in trainable_vars + [loss])
assert all(var.device == device.name for var in trainable_vars)
# Validate shapes.
if self._gradient_shapes is None:
self._gradient_shapes = [var.shape.as_list()
for var in trainable_vars]
assert len(trainable_vars) == len(self._gradient_shapes)
assert all(var.shape.as_list() == var_shape for var,
var_shape in zip(trainable_vars, self._gradient_shapes))
# Report memory usage if requested.
deps = []
if self._report_mem_usage:
self._report_mem_usage = False
try:
with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
deps.append(autosummary.autosummary(
self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
except tf.errors.NotFoundError:
pass
# Compute gradients.
with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
grad_list = device.optimizer.compute_gradients(
loss=loss, var_list=trainable_vars, gate_gradients=gate)
# Register gradients.
for grad, var in grad_list:
if var not in device.grad_raw:
device.grad_raw[var] = []
device.grad_raw[var].append(grad)
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
"""Construct training op to update the registered variables based on their gradients."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
self._updates_applied = True
all_ops = []
# Check for no-op.
if allow_no_op and len(self._devices) == 0:
with tfutil.absolute_name_scope(self.scope):
return tf.no_op(name='TrainingOp')
# Clean up gradients.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
for var, grad in device.grad_raw.items():
# Filter out disconnected gradients and convert to float32.
grad = [g for g in grad if g is not None]
grad = [tf.cast(g, tf.float32) for g in grad]
# Sum within the device.
if len(grad) == 0:
grad = tf.zeros(var.shape) # No gradients => zero.
elif len(grad) == 1:
# Single gradient => use as is.
grad = grad[0]
else:
# Multiple gradients => sum.
grad = tf.add_n(grad)
# Scale as needed.
scale = 1.0 / \
len(device.grad_raw[var]) / len(self._devices)
scale = tf.constant(scale, dtype=tf.float32, name="scale")
if self.minibatch_multiplier is not None:
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
scale = self.undo_loss_scaling(scale)
device.grad_clean[var] = grad * scale
# Sum gradients across devices.
if len(self._devices) > 1:
with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
# NCCL does not support zero-sized tensors.
if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()):
all_grads = [device.grad_clean[var] for device, var in zip(
self._devices.values(), all_vars)]
all_grads = nccl_ops.all_sum(all_grads)
for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
device.grad_clean[var] = grad
# Apply updates separately on each device.
for device_idx, device in enumerate(self._devices.values()):
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
# pylint: disable=cell-var-from-loop
# Accumulate gradients over time.
if self.minibatch_multiplier is None:
acc_ok = tf.constant(True, name='acc_ok')
device.grad_acc = OrderedDict(device.grad_clean)
else:
# Create variables.
with tf.control_dependencies(None):
for var in device.grad_clean.keys():
device.grad_acc_vars[var] = tf.Variable(
tf.zeros(var.shape), trainable=False, name="grad_acc_var")
device.grad_acc_count = tf.Variable(
tf.zeros([]), trainable=False, name="grad_acc_count")
# Track counter.
count_cur = device.grad_acc_count + 1.0
def count_inc_op(): return tf.assign(device.grad_acc_count, count_cur)
def count_reset_op(): return tf.assign(device.grad_acc_count, tf.zeros([]))
acc_ok = (count_cur >= tf.cast(
self.minibatch_multiplier, tf.float32))
all_ops.append(
tf.cond(acc_ok, count_reset_op, count_inc_op))
# Track gradients.
for var, grad in device.grad_clean.items():
acc_var = device.grad_acc_vars[var]
acc_cur = acc_var + grad
device.grad_acc[var] = acc_cur
with tf.control_dependencies([acc_cur]):
def acc_inc_op(): return tf.assign(acc_var, acc_cur)
def acc_reset_op(): return tf.assign(acc_var, tf.zeros(var.shape))
all_ops.append(
tf.cond(acc_ok, acc_reset_op, acc_inc_op))
# No overflow => apply gradients.
all_ok = tf.reduce_all(tf.stack(
[acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
def apply_op(): return device.optimizer.apply_gradients(
[(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
# Adjust loss scaling.
if self.use_loss_scaling:
def ls_inc_op(): return tf.assign_add(
device.loss_scaling_var, self.loss_scaling_inc)
def ls_dec_op(): return tf.assign_sub(
device.loss_scaling_var, self.loss_scaling_dec)
def ls_update_op(): return tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
# Last device => report statistics.
if device_idx == len(self._devices) - 1:
all_ops.append(autosummary.autosummary(
self.id + "/learning_rate", self.learning_rate))
all_ops.append(autosummary.autosummary(
self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
if self.use_loss_scaling:
all_ops.append(autosummary.autosummary(
self.id + "/loss_scaling_log2", device.loss_scaling_var))
# Initialize variables.
self.reset_optimizer_state()
if self.use_loss_scaling:
tfutil.init_uninitialized_vars(
[device.loss_scaling_var for device in self._devices.values()])
if self.minibatch_multiplier is not None:
tfutil.run([var.initializer for device in self._devices.values() for var in list(
device.grad_acc_vars.values()) + [device.grad_acc_count]])
# Group everything into a single op.
with tfutil.absolute_name_scope(self.scope):
return tf.group(*all_ops, name="TrainingOp")
def reset_optimizer_state(self) -> None:
"""Reset internal state of the underlying optimizer."""
tfutil.assert_tf_initialized()
tfutil.run([var.initializer for device in self._devices.values()
for var in device.optimizer.variables()])
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
return self._get_device(device).loss_scaling_var
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Apply dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Undo the effect of dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
class SimpleAdam:
"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
self.name = name
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.all_state_vars = []
def variables(self):
return self.all_state_vars
def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
assert gate_gradients == tf.train.Optimizer.GATE_NONE
return list(zip(tf.gradients(loss, var_list), var_list))
def apply_gradients(self, grads_and_vars):
with tf.name_scope(self.name):
state_vars = []
update_ops = []
# Adjust learning rate to deal with startup bias.
with tf.control_dependencies(None):
b1pow_var = tf.Variable(
dtype=tf.float32, initial_value=1, trainable=False)
b2pow_var = tf.Variable(
dtype=tf.float32, initial_value=1, trainable=False)
state_vars += [b1pow_var, b2pow_var]
b1pow_new = b1pow_var * self.beta1
b2pow_new = b2pow_var * self.beta2
update_ops += [tf.assign(b1pow_var, b1pow_new),
tf.assign(b2pow_var, b2pow_new)]
lr_new = self.learning_rate * \
tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
# Construct ops to update each variable.
for grad, var in grads_and_vars:
with tf.control_dependencies(None):
m_var = tf.Variable(
dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
v_var = tf.Variable(
dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
state_vars += [m_var, v_var]
m_new = self.beta1 * m_var + (1 - self.beta1) * grad
v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
update_ops += [tf.assign(m_var, m_new), tf.assign(v_var,
v_new), tf.assign_sub(var, var_delta)]
# Group everything together.
self.all_state_vars += state_vars
return tf.group(*update_ops)
|