Our3D / lib /masked_adam.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
3.46 kB
import os
import torch
from torch.utils.cpp_extension import load
parent_dir = os.path.dirname(os.path.abspath(__file__))
sources=['cuda/adam_upd.cpp', 'cuda/adam_upd_kernel.cu']
adam_upd_cuda = load(
name='adam_upd_cuda',
sources=[os.path.join(parent_dir, path) for path in sources],
verbose=True)
''' Extend Adam optimizer
1. support per-voxel learning rate
2. masked update (ignore zero grad) which speeduping training
'''
class MaskedAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps)
self.per_lr = None
self.f_per_lr = None
super(MaskedAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(MaskedAdam, self).__setstate__(state)
def set_pervoxel_lr(self, count):
assert self.param_groups[0]['params'][0].shape == count.shape
self.per_lr = count.float() / count.max()
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
skip_zero_grad = group['skip_zero_grad']
for param in group['params']:
if param.grad is not None:
state = self.state[param]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format)
state['step'] += 1
if self.per_lr is not None and param.shape == self.per_lr.shape:
adam_upd_cuda.adam_upd_with_perlr(
param, param.grad, state['exp_avg'], state['exp_avg_sq'], self.per_lr,
state['step'], beta1, beta2, lr, eps)
elif self.f_per_lr is not None and param.shape == self.f_per_lr.shape:
adam_upd_cuda.adam_upd_with_perlr(
param, param.grad, state['exp_avg'], state['exp_avg_sq'], self.f_per_lr,
state['step'], beta1, beta2, lr, eps)
elif skip_zero_grad:
adam_upd_cuda.masked_adam_upd(
param, param.grad, state['exp_avg'], state['exp_avg_sq'],
state['step'], beta1, beta2, lr, eps)
else:
adam_upd_cuda.adam_upd(
param, param.grad, state['exp_avg'], state['exp_avg_sq'],
state['step'], beta1, beta2, lr, eps)