Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod, abstractproperty | |
import torch | |
class PytorchModuleHook(metaclass=ABCMeta): | |
"""Base class for PyTorch module hook registers. | |
An instance of a subclass of PytorchModuleHook can be used to | |
register hook to a pytorch module using the `register` method like: | |
hook_register.register(module) | |
Subclasses should add/overwrite the following methods: | |
- __init__ | |
- hook | |
- hook_type | |
""" | |
def hook(self, *args, **kwargs): | |
"""Hook function.""" | |
def hook_type(self) -> str: | |
"""Hook type Subclasses should overwrite this function to return a | |
string value in. | |
{`forward`, `forward_pre`, `backward`} | |
""" | |
def register(self, module): | |
"""Register the hook function to the module. | |
Args: | |
module (pytorch module): the module to register the hook. | |
Returns: | |
handle (torch.utils.hooks.RemovableHandle): a handle to remove | |
the hook by calling handle.remove() | |
""" | |
assert isinstance(module, torch.nn.Module) | |
if self.hook_type == 'forward': | |
h = module.register_forward_hook(self.hook) | |
elif self.hook_type == 'forward_pre': | |
h = module.register_forward_pre_hook(self.hook) | |
elif self.hook_type == 'backward': | |
h = module.register_backward_hook(self.hook) | |
else: | |
raise ValueError(f'Invalid hook type {self.hook}') | |
return h | |
class WeightNormClipHook(PytorchModuleHook): | |
"""Apply weight norm clip regularization. | |
The module's parameter will be clip to a given maximum norm before each | |
forward pass. | |
Args: | |
max_norm (float): The maximum norm of the parameter. | |
module_param_names (str|list): The parameter name (or name list) to | |
apply weight norm clip. | |
""" | |
def __init__(self, max_norm=1.0, module_param_names='weight'): | |
self.module_param_names = module_param_names if isinstance( | |
module_param_names, list) else [module_param_names] | |
self.max_norm = max_norm | |
def hook_type(self): | |
return 'forward_pre' | |
def hook(self, module, _input): | |
for name in self.module_param_names: | |
assert name in module._parameters, f'{name} is not a parameter' \ | |
f' of the module {type(module)}' | |
param = module._parameters[name] | |
with torch.no_grad(): | |
m = param.norm().item() | |
if m > self.max_norm: | |
param.mul_(self.max_norm / (m + 1e-6)) | |