|
""" Lion Optimizer |
|
Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675 |
|
Original Impl: https://github.com/google/automl/tree/master/lion |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
import torch |
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
class Lion(Optimizer): |
|
r"""Implements Lion algorithm.""" |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr=1e-4, |
|
betas=(0.9, 0.99), |
|
weight_decay=0.0, |
|
maximize=False, |
|
foreach=None, |
|
): |
|
"""Initialize the hyperparameters. |
|
|
|
Args: |
|
params (iterable): iterable of parameters to optimize or dicts defining |
|
parameter groups |
|
lr (float, optional): learning rate (default: 1e-4) |
|
betas (Tuple[float, float], optional): coefficients used for computing |
|
running averages of gradient and its square (default: (0.9, 0.99)) |
|
weight_decay (float, optional): weight decay coefficient (default: 0) |
|
""" |
|
|
|
if not 0.0 <= lr: |
|
raise ValueError('Invalid learning rate: {}'.format(lr)) |
|
if not 0.0 <= betas[0] < 1.0: |
|
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) |
|
if not 0.0 <= betas[1] < 1.0: |
|
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) |
|
defaults = dict( |
|
lr=lr, |
|
betas=betas, |
|
weight_decay=weight_decay, |
|
foreach=foreach, |
|
maximize=maximize, |
|
) |
|
super().__init__(params, defaults) |
|
|
|
def __setstate__(self, state): |
|
super().__setstate__(state) |
|
for group in self.param_groups: |
|
group.setdefault('maximize', False) |
|
group.setdefault('foreach', None) |
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
|
|
Returns: |
|
the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
params_with_grad = [] |
|
grads = [] |
|
exp_avgs = [] |
|
beta1, beta2 = group['betas'] |
|
|
|
for p in group['params']: |
|
if p.grad is None: |
|
continue |
|
params_with_grad.append(p) |
|
if p.grad.is_sparse: |
|
raise RuntimeError('Lion does not support sparse gradients') |
|
grads.append(p.grad) |
|
|
|
state = self.state[p] |
|
|
|
|
|
if len(state) == 0: |
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) |
|
|
|
exp_avgs.append(state['exp_avg']) |
|
|
|
lion( |
|
params_with_grad, |
|
grads, |
|
exp_avgs, |
|
beta1=beta1, |
|
beta2=beta2, |
|
lr=group['lr'], |
|
weight_decay=group['weight_decay'], |
|
maximize=group['maximize'], |
|
foreach=group['foreach'], |
|
) |
|
|
|
return loss |
|
|
|
|
|
def lion( |
|
params: List[torch.Tensor], |
|
grads: List[torch.Tensor], |
|
exp_avgs: List[torch.Tensor], |
|
|
|
|
|
maximize: bool = False, |
|
foreach: bool = None, |
|
*, |
|
beta1: float, |
|
beta2: float, |
|
lr: float, |
|
weight_decay: float, |
|
): |
|
r"""Functional API that performs Lion algorithm computation. |
|
""" |
|
if foreach is None: |
|
|
|
foreach = False |
|
|
|
if foreach and torch.jit.is_scripting(): |
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
|
|
|
if foreach and not torch.jit.is_scripting(): |
|
func = _multi_tensor_lion |
|
else: |
|
func = _single_tensor_lion |
|
|
|
func( |
|
params, |
|
grads, |
|
exp_avgs, |
|
beta1=beta1, |
|
beta2=beta2, |
|
lr=lr, |
|
weight_decay=weight_decay, |
|
maximize=maximize, |
|
) |
|
|
|
|
|
def _single_tensor_lion( |
|
params: List[torch.Tensor], |
|
grads: List[torch.Tensor], |
|
exp_avgs: List[torch.Tensor], |
|
*, |
|
beta1: float, |
|
beta2: float, |
|
lr: float, |
|
weight_decay: float, |
|
maximize: bool, |
|
): |
|
for i, param in enumerate(params): |
|
grad = grads[i] if not maximize else -grads[i] |
|
exp_avg = exp_avgs[i] |
|
|
|
if torch.is_complex(param): |
|
grad = torch.view_as_real(grad) |
|
exp_avg = torch.view_as_real(exp_avg) |
|
param = torch.view_as_real(param) |
|
|
|
|
|
param.mul_(1 - lr * weight_decay) |
|
|
|
|
|
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) |
|
param.add_(torch.sign(update), alpha=-lr) |
|
|
|
|
|
exp_avg.lerp_(grad, 1 - beta2) |
|
|
|
|
|
def _multi_tensor_lion( |
|
params: List[torch.Tensor], |
|
grads: List[torch.Tensor], |
|
exp_avgs: List[torch.Tensor], |
|
*, |
|
beta1: float, |
|
beta2: float, |
|
lr: float, |
|
weight_decay: float, |
|
maximize: bool, |
|
): |
|
if len(params) == 0: |
|
return |
|
|
|
if maximize: |
|
grads = torch._foreach_neg(tuple(grads)) |
|
|
|
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads] |
|
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs] |
|
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params] |
|
|
|
|
|
torch._foreach_mul_(params, 1 - lr * weight_decay) |
|
|
|
|
|
updates = torch._foreach_mul(exp_avgs, beta1) |
|
torch._foreach_add_(updates, grads, alpha=1 - beta1) |
|
|
|
updates = [u.sign() for u in updates] |
|
torch._foreach_add_(params, updates, alpha=-lr) |
|
|
|
|
|
torch._foreach_mul_(exp_avgs, beta2) |
|
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2) |
|
|