Spaces:
Sleeping
Sleeping
import torch | |
from .torch_utils import * | |
class PolyOptimizer(torch.optim.SGD): | |
def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, nesterov=False): | |
super().__init__(params, lr, weight_decay, nesterov=nesterov) | |
self.global_step = 0 | |
self.max_step = max_step | |
self.momentum = momentum | |
self.__initial_lr = [group['lr'] for group in self.param_groups] | |
def step(self, closure=None): | |
if self.global_step < self.max_step: | |
lr_mult = (1 - self.global_step / self.max_step) ** self.momentum | |
for i in range(len(self.param_groups)): | |
self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult | |
super().step(closure) | |
self.global_step += 1 | |