ForkedHulk2 / core /optimizers /adafactor.py
tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
raw
history blame
8.29 kB
import math
import torch
from torch.optim.optimizer import Optimizer
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from torch import Tensor
Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]
Eps2 = Tuple[float, float]
ParamGroup = Dict[str, Any]
class Adafactor_dev(Optimizer):
"""Implements Adafactor algorithm.
It has been proposed in: `Adafactor: Adaptive Learning Rates with
Sublinear Memory Cost`__.
Arguments:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: external learning rate (default: None)
eps2: regularization constans for square gradient
and parameter scale respectively (default: (1e-30, 1e-3))
clip_threshold: threshold of root mean square of
final gradient update (default: 1.0)
decay_rate: coefficient used to compute running averages of square
gradient (default: -0.8)
beta1: coefficient used for computing running averages of gradient
(default: None)
weight_decay: weight decay (L2 penalty) (default: 0)
scale_parameter: if true, learning rate is scaled by root mean square
of parameter (default: True)
relative_step: if true, time-dependent learning rate is computed
instead of external learning rate (default: True)
warmup_init: time-dependent learning rate computation depends on
whether warm-up initialization is being used (default: False)
Example:
>>> import torch_optimizer as optim
>>> optimizer = optim.Adafactor(model.parameters())
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ https://arxiv.org/abs/1804.04235
Note:
Reference code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py # noqa
"""
def __init__(
self,
params: Params,
lr: OptFloat = None,
eps2: Eps2 = (1e-30, 1e-3),
clip_threshold: float = 1.0,
decay_rate: float = -0.8,
beta1: OptFloat = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
relative_step: bool = True,
warmup_init: bool = False,
clip_beta2: Any = False,
):
if lr is not None and lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if weight_decay < 0.0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)
defaults = dict(
lr=lr,
eps2=eps2,
clip_threshold=clip_threshold,
decay_rate=decay_rate,
beta1=beta1,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step=relative_step,
warmup_init=warmup_init,
clip_beta2=clip_beta2,
)
super(Adafactor_dev, self).__init__(params, defaults)
def _get_lr(self, param_group: ParamGroup, param_state: State) -> float:
rel_step_sz = param_group['lr']
if param_group['relative_step']:
min_step = (
1e-6 * param_state['step']
if param_group['warmup_init']
else 1e-2
)
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step']))
param_scale = 1.0
if param_group['scale_parameter']:
param_scale = max(param_group['eps2'][1], param_state['RMS'])
return param_scale * rel_step_sz
def _get_options(
self, param_group: ParamGroup, param_shape: Tuple[int, ...]
) -> Tuple[bool, bool]:
factored = len(param_shape) >= 2
use_first_moment = param_group['beta1'] is not None
return factored, use_first_moment
def _rms(self, tensor: torch.Tensor) -> float:
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(
self,
exp_avg_sq_row: torch.Tensor,
exp_avg_sq_col: torch.Tensor,
output: torch.Tensor,
) -> None:
r_factor = (
(exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True))
.rsqrt_()
.unsqueeze(-1)
)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)
def step(self, closure: OptLossClosure = None) -> OptFloat:
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adafactor does not support sparse gradients.'
)
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(
group, grad_shape
)
# State Initialization
if len(state) == 0:
state['step'] = 0
if use_first_moment:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)
if factored:
state['exp_avg_sq_row'] = torch.zeros(
grad_shape[:-1]
).type_as(grad)
state['exp_avg_sq_col'] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:]
).type_as(grad)
else:
state['exp_avg_sq'] = torch.zeros_like(
grad, memory_format=torch.preserve_format
)
state['RMS'] = 0
state['step'] += 1
state['RMS'] = self._rms(p.data)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
if group['clip_beta2'] != False:
beta2t = min(beta2t, group['clip_beta2'])
update = (grad ** 2) + group['eps2'][0]
if factored:
exp_avg_sq_row = state['exp_avg_sq_row']
exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(
update.mean(dim=-1), alpha=1.0 - beta2t
)
exp_avg_sq_col.mul_(beta2t).add_(
update.mean(dim=-2), alpha=1.0 - beta2t
)
# Approximation of exponential moving average of square
# of gradient
self._approx_sq_grad(
exp_avg_sq_row, exp_avg_sq_col, update
)
update.mul_(grad)
else:
exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
update.div_(
max(1.0, self._rms(update) / group['clip_threshold'])
)
update.mul_(lr)
if use_first_moment:
exp_avg = state['exp_avg']
exp_avg.mul_(group['beta1']).add_(
update, alpha=1 - group['beta1']
)
update = exp_avg
if group['weight_decay'] != 0:
p.data.add_(p.data, alpha=-group['weight_decay'] * lr)
p.data.add_(-update)
return loss