HaMeR / mmpose /core /utils /regularizations.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame
2.71 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod, abstractproperty
import torch
class PytorchModuleHook(metaclass=ABCMeta):
"""Base class for PyTorch module hook registers.
An instance of a subclass of PytorchModuleHook can be used to
register hook to a pytorch module using the `register` method like:
hook_register.register(module)
Subclasses should add/overwrite the following methods:
- __init__
- hook
- hook_type
"""
@abstractmethod
def hook(self, *args, **kwargs):
"""Hook function."""
@abstractproperty
def hook_type(self) -> str:
"""Hook type Subclasses should overwrite this function to return a
string value in.
{`forward`, `forward_pre`, `backward`}
"""
def register(self, module):
"""Register the hook function to the module.
Args:
module (pytorch module): the module to register the hook.
Returns:
handle (torch.utils.hooks.RemovableHandle): a handle to remove
the hook by calling handle.remove()
"""
assert isinstance(module, torch.nn.Module)
if self.hook_type == 'forward':
h = module.register_forward_hook(self.hook)
elif self.hook_type == 'forward_pre':
h = module.register_forward_pre_hook(self.hook)
elif self.hook_type == 'backward':
h = module.register_backward_hook(self.hook)
else:
raise ValueError(f'Invalid hook type {self.hook}')
return h
class WeightNormClipHook(PytorchModuleHook):
"""Apply weight norm clip regularization.
The module's parameter will be clip to a given maximum norm before each
forward pass.
Args:
max_norm (float): The maximum norm of the parameter.
module_param_names (str|list): The parameter name (or name list) to
apply weight norm clip.
"""
def __init__(self, max_norm=1.0, module_param_names='weight'):
self.module_param_names = module_param_names if isinstance(
module_param_names, list) else [module_param_names]
self.max_norm = max_norm
@property
def hook_type(self):
return 'forward_pre'
def hook(self, module, _input):
for name in self.module_param_names:
assert name in module._parameters, f'{name} is not a parameter' \
f' of the module {type(module)}'
param = module._parameters[name]
with torch.no_grad():
m = param.norm().item()
if m > self.max_norm:
param.mul_(self.max_norm / (m + 1e-6))