Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,233 Bytes
e52682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import norm_except_dim
from typing import Any, TypeVar
import warnings
from torch.nn.modules import Module
import torch
class WeightNorm:
name: str
dim: int
def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim
# TODO Make return type more specific
def compute_weight(self, module: Module) -> Any:
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
input_dtype = v.dtype
v = v.to(torch.float32)
reduce_dims = list(range(v.dim()))
reduce_dims.pop(self.dim)
variance = v.pow(2).sum(reduce_dims, keepdim=True)
v = v * torch.rsqrt(variance + 1e-6)
return g * v.to(input_dtype)
@staticmethod
def apply(module, name: str, dim: int) -> 'WeightNorm':
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
for hook in module._forward_pre_hooks.values():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}")
if dim is None:
dim = -1
fn = WeightNorm(name, dim)
weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying weight normalization')
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
module.register_parameter(name + '_v', Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module: Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_g']
del module._parameters[self.name + '_v']
setattr(module, self.name, Parameter(weight.data))
def __call__(self, module: Module, inputs: Any) -> None:
setattr(module, self.name, self.compute_weight(module))
T_module = TypeVar('T_module', bound=Module)
def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
r"""Apply weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
.. warning::
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
which uses the modern parametrization API. The new ``weight_norm`` is compatible
with ``state_dict`` generated from old ``weight_norm``.
Migration guide:
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
respectively. If this is bothering you, please comment on
https://github.com/pytorch/pytorch/issues/102999
* To remove the weight normalization reparametrization, use
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
* The weight is no longer recomputed once at module forward; instead, it will
be recomputed on every access. To restore the old behavior, use
:func:`torch.nn.utils.parametrize.cached` before invoking the module
in question.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
WeightNorm.apply(module, name, dim)
return module
|