|
""" PyTorch MADGRAD optimizer |
|
|
|
MADGRAD: https://arxiv.org/abs/2101.11075 |
|
|
|
Code from: https://github.com/facebookresearch/madgrad |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import TYPE_CHECKING, Any, Callable, Optional |
|
|
|
import torch |
|
import torch.optim |
|
|
|
if TYPE_CHECKING: |
|
from torch.optim.optimizer import _params_t |
|
else: |
|
_params_t = Any |
|
|
|
|
|
class MADGRAD(torch.optim.Optimizer): |
|
""" |
|
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic |
|
Optimization. |
|
|
|
.. _MADGRAD: https://arxiv.org/abs/2101.11075 |
|
|
|
MADGRAD is a general purpose optimizer that can be used in place of SGD or |
|
Adam may converge faster and generalize better. Currently GPU-only. |
|
Typically, the same learning rate schedule that is used for SGD or Adam may |
|
be used. The overall learning rate is not comparable to either method and |
|
should be determined by a hyper-parameter sweep. |
|
|
|
MADGRAD requires less weight decay than other methods, often as little as |
|
zero. Momentum values used for SGD or Adam's beta1 should work here also. |
|
|
|
On sparse problems both weight_decay and momentum should be set to 0. |
|
|
|
Arguments: |
|
params (iterable): |
|
Iterable of parameters to optimize or dicts defining parameter groups. |
|
lr (float): |
|
Learning rate (default: 1e-2). |
|
momentum (float): |
|
Momentum value in the range [0,1) (default: 0.9). |
|
weight_decay (float): |
|
Weight decay, i.e. a L2 penalty (default: 0). |
|
eps (float): |
|
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
params: _params_t, |
|
lr: float = 1e-2, |
|
momentum: float = 0.9, |
|
weight_decay: float = 0, |
|
eps: float = 1e-6, |
|
decoupled_decay: bool = False, |
|
): |
|
if momentum < 0 or momentum >= 1: |
|
raise ValueError(f"Momentum {momentum} must be in the range [0,1]") |
|
if lr <= 0: |
|
raise ValueError(f"Learning rate {lr} must be positive") |
|
if weight_decay < 0: |
|
raise ValueError(f"Weight decay {weight_decay} must be non-negative") |
|
if eps < 0: |
|
raise ValueError(f"Eps must be non-negative") |
|
|
|
defaults = dict( |
|
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) |
|
super().__init__(params, defaults) |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self) -> bool: |
|
return False |
|
|
|
@property |
|
def supports_flat_params(self) -> bool: |
|
return True |
|
|
|
@torch.no_grad() |
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: |
|
"""Performs a single optimization step. |
|
|
|
Arguments: |
|
closure (callable, optional): A closure that reevaluates the model and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
eps = group['eps'] |
|
lr = group['lr'] + eps |
|
weight_decay = group['weight_decay'] |
|
momentum = group['momentum'] |
|
ck = 1 - momentum |
|
|
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
grad = p.grad |
|
if momentum != 0.0 and grad.is_sparse: |
|
raise RuntimeError("momentum != 0 is not compatible with sparse gradients") |
|
|
|
state = self.state[p] |
|
if len(state) == 0: |
|
state['step'] = 0 |
|
state['grad_sum_sq'] = torch.zeros_like(p) |
|
state['s'] = torch.zeros_like(p) |
|
if momentum != 0: |
|
state['x0'] = torch.clone(p).detach() |
|
|
|
state['step'] += 1 |
|
grad_sum_sq = state['grad_sum_sq'] |
|
s = state['s'] |
|
lamb = lr * math.sqrt(state['step']) |
|
|
|
|
|
if weight_decay != 0: |
|
if group['decoupled_decay']: |
|
p.mul_(1.0 - group['lr'] * weight_decay) |
|
else: |
|
if grad.is_sparse: |
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients") |
|
grad.add_(p, alpha=weight_decay) |
|
|
|
if grad.is_sparse: |
|
grad = grad.coalesce() |
|
grad_val = grad._values() |
|
|
|
p_masked = p.sparse_mask(grad) |
|
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) |
|
s_masked = s.sparse_mask(grad) |
|
|
|
|
|
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) |
|
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) |
|
|
|
|
|
grad_sq = grad * grad |
|
grad_sum_sq.add_(grad_sq, alpha=lamb) |
|
grad_sum_sq_masked.add_(grad_sq, alpha=lamb) |
|
|
|
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) |
|
|
|
s.add_(grad, alpha=lamb) |
|
s_masked._values().add_(grad_val, alpha=lamb) |
|
|
|
|
|
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) |
|
|
|
p_masked._values().add_(p_kp1_masked_vals, alpha=-1) |
|
p.add_(p_masked, alpha=-1) |
|
else: |
|
if momentum == 0: |
|
|
|
rms = grad_sum_sq.pow(1 / 3).add_(eps) |
|
x0 = p.addcdiv(s, rms, value=1) |
|
else: |
|
x0 = state['x0'] |
|
|
|
|
|
grad_sum_sq.addcmul_(grad, grad, value=lamb) |
|
rms = grad_sum_sq.pow(1 / 3).add_(eps) |
|
|
|
|
|
s.add_(grad, alpha=lamb) |
|
|
|
|
|
if momentum == 0: |
|
p.copy_(x0.addcdiv(s, rms, value=-1)) |
|
else: |
|
z = x0.addcdiv(s, rms, value=-1) |
|
|
|
|
|
p.mul_(1 - ck).add_(z, alpha=ck) |
|
|
|
return loss |
|
|