# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch from mmpose.core import WeightNormClipHook def test_weight_norm_clip(): torch.manual_seed(0) module = torch.nn.Linear(2, 2, bias=False) module.weight.data.fill_(2) WeightNormClipHook(max_norm=1.0).register(module) x = torch.rand(1, 2).requires_grad_() _ = module(x) weight_norm = module.weight.norm().item() np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)