Spaces:
Build error
Build error
File size: 2,706 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod, abstractproperty
import torch
class PytorchModuleHook(metaclass=ABCMeta):
"""Base class for PyTorch module hook registers.
An instance of a subclass of PytorchModuleHook can be used to
register hook to a pytorch module using the `register` method like:
hook_register.register(module)
Subclasses should add/overwrite the following methods:
- __init__
- hook
- hook_type
"""
@abstractmethod
def hook(self, *args, **kwargs):
"""Hook function."""
@abstractproperty
def hook_type(self) -> str:
"""Hook type Subclasses should overwrite this function to return a
string value in.
{`forward`, `forward_pre`, `backward`}
"""
def register(self, module):
"""Register the hook function to the module.
Args:
module (pytorch module): the module to register the hook.
Returns:
handle (torch.utils.hooks.RemovableHandle): a handle to remove
the hook by calling handle.remove()
"""
assert isinstance(module, torch.nn.Module)
if self.hook_type == 'forward':
h = module.register_forward_hook(self.hook)
elif self.hook_type == 'forward_pre':
h = module.register_forward_pre_hook(self.hook)
elif self.hook_type == 'backward':
h = module.register_backward_hook(self.hook)
else:
raise ValueError(f'Invalid hook type {self.hook}')
return h
class WeightNormClipHook(PytorchModuleHook):
"""Apply weight norm clip regularization.
The module's parameter will be clip to a given maximum norm before each
forward pass.
Args:
max_norm (float): The maximum norm of the parameter.
module_param_names (str|list): The parameter name (or name list) to
apply weight norm clip.
"""
def __init__(self, max_norm=1.0, module_param_names='weight'):
self.module_param_names = module_param_names if isinstance(
module_param_names, list) else [module_param_names]
self.max_norm = max_norm
@property
def hook_type(self):
return 'forward_pre'
def hook(self, module, _input):
for name in self.module_param_names:
assert name in module._parameters, f'{name} is not a parameter' \
f' of the module {type(module)}'
param = module._parameters[name]
with torch.no_grad():
m = param.norm().item()
if m > self.max_norm:
param.mul_(self.max_norm / (m + 1e-6))
|