File size: 3,296 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pytest
import torch

from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init


@pytest.mark.parametrize(
    ["torch_module", "layer_args"],
    [
        (torch.nn.Linear, {"in_features": 10, "out_features": 20}),
        (torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
        (torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
        (torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
        (torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
    ],
)
def test_skip_torch_weight_init_linear(torch_module, layer_args):
    """Test the interactions between `skip_torch_weight_init()` and various torch modules."""
    seed = 123

    # Initialize a torch layer *before* applying `skip_torch_weight_init()`.
    reset_params_fn_before = torch_module.reset_parameters
    torch.manual_seed(seed)
    layer_before = torch_module(**layer_args)

    # Initialize a torch layer while `skip_torch_weight_init()` is applied.
    with skip_torch_weight_init():
        reset_params_fn_during = torch_module.reset_parameters
        torch.manual_seed(123)
        layer_during = torch_module(**layer_args)

    # Initialize a torch layer *after* applying `skip_torch_weight_init()`.
    reset_params_fn_after = torch_module.reset_parameters
    torch.manual_seed(123)
    layer_after = torch_module(**layer_args)

    # Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
    assert reset_params_fn_during == _no_op
    assert not torch.allclose(layer_before.weight, layer_during.weight)
    if hasattr(layer_before, "bias"):
        assert not torch.allclose(layer_before.bias, layer_during.bias)

    # Check that the original behavior is restored after `skip_torch_weight_init()` ends.
    assert reset_params_fn_before is reset_params_fn_after
    assert torch.allclose(layer_before.weight, layer_after.weight)
    if hasattr(layer_before, "bias"):
        assert torch.allclose(layer_before.bias, layer_after.bias)


def test_skip_torch_weight_init_restores_base_class_behavior():
    """Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
    test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
    its child classes (like `Conv1d`).
    """
    with skip_torch_weight_init():
        # There is no need to do anything while the context manager is applied, we're just testing that the original
        # behavior is restored correctly.
        pass

    # Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
    # expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
    called_monkey_patched_fn = False

    def monkey_patched_fn(*args, **kwargs):
        nonlocal called_monkey_patched_fn
        called_monkey_patched_fn = True

    saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
    torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
    _ = torch.nn.Conv1d(10, 20, 3)
    torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn

    assert called_monkey_patched_fn