Spaces:
Sleeping
Sleeping
File size: 793 Bytes
c20a1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from .torch_utils import *
class PolyOptimizer(torch.optim.SGD):
def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, nesterov=False):
super().__init__(params, lr, weight_decay, nesterov=nesterov)
self.global_step = 0
self.max_step = max_step
self.momentum = momentum
self.__initial_lr = [group['lr'] for group in self.param_groups]
def step(self, closure=None):
if self.global_step < self.max_step:
lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
for i in range(len(self.param_groups)):
self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
super().step(closure)
self.global_step += 1
|