WSSS_ResNet50 / tools /ai /optim_utils.py
kittendev's picture
Upload 176 files
c20a1af verified
raw
history blame
793 Bytes
import torch
from .torch_utils import *
class PolyOptimizer(torch.optim.SGD):
def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, nesterov=False):
super().__init__(params, lr, weight_decay, nesterov=nesterov)
self.global_step = 0
self.max_step = max_step
self.momentum = momentum
self.__initial_lr = [group['lr'] for group in self.param_groups]
def step(self, closure=None):
if self.global_step < self.max_step:
lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
for i in range(len(self.param_groups)):
self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
super().step(closure)
self.global_step += 1