File size: 4,615 Bytes
cb9e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from pathlib import Path

import pytest
import torch

from finetune.args import LoraArgs
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
    downcast_mixed_precision,
    prepare_mixed_precision,
    upcast_mixed_precision,
)
from finetune.wrapped_model import load_model
from tests.test_utils import MODEL_PATH, get_dataloader, setup_mp_test_dist

from .test_utils import spawn_for_all_world_sizes


@pytest.mark.parametrize(
    ("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)]
)
def test_mixed_precision(world_size, enable_lora):
    spawn_for_all_world_sizes(
        _check_mixed_precision,
        world_sizes=[world_size],
        args=[enable_lora],
        deterministic=True,
    )


def _check_mixed_precision(
    rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool
):
    model_parallel = 1
    setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
    seq_len = 100

    folder = Path(MODEL_PATH)
    # mixed precision
    param_dtype = torch.bfloat16
    optim_dtype = torch.float32

    model = load_model(
        folder=folder,
        lora=LoraArgs(enable=enable_lora),
        checkpoint=True,
        param_dtype=param_dtype,
    )

    optimizer = torch.optim.AdamW(model.parameters())

    # initialize mixed precision training for TP
    prepare_mixed_precision(
        model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
    )

    data_loader = get_dataloader(seq_len=seq_len)

    # ensure every parameter that requires a grad has a _mp_param of optim_dtype precision
    for param in model.parameters():
        assert param.dtype == param_dtype
        if param.requires_grad:
            assert param._mp_param.dtype == optim_dtype
            assert (
                param._mp_param.tolist() == param.data.to(optim_dtype).tolist()
            ), "mp param has to match param in optim dtype precision"
        else:
            assert not hasattr(param, "_mp_param")

    # test three train steps
    for _ in range(3):

        optimizer.zero_grad()

        # micro-batching
        for _ in range(2):
            batch = next(data_loader)

            x = torch.from_numpy(batch.x).cuda(non_blocking=True)
            y = torch.from_numpy(batch.y).cuda(non_blocking=True)
            y_mask = (
                torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
                if batch.y_mask is not None
                else None
            )

            output = model(
                input_ids=x,
                seqlens=batch.sizes,
            )

            mb_loss = compute_loss_with_mask(output, y, y_mask)
            mb_loss.backward()

        upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)

        # ensure all params are upcasted correctly and mp param equals param
        param_sum = 0
        for param in model.parameters():
            if param.requires_grad:
                assert param.dtype == optim_dtype, param.dtype
                assert (
                    param._mp_param.tolist() == param.data.tolist()
                ), "mp param and param should point to the same data"
                assert param.grad.dtype == optim_dtype
                assert param._temp.dtype == param_dtype
                param_sum += param.data.float().abs().sum()
            else:
                assert param.dtype == param_dtype

        optimizer.step()

        # ensure that after optimizer step params are still in optim dtype precision
        new_param_sum = 0
        for param in model.parameters():
            if param.requires_grad:
                assert param.dtype == optim_dtype
                assert param._mp_param.dtype == optim_dtype
                assert param.grad.dtype == optim_dtype
                new_param_sum += param.data.float().abs().sum()
            else:
                assert param.dtype == param_dtype

        assert new_param_sum != param_sum, "Make sure parameters are updated"

        downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)

        # ensure that before new forward pass params are downcasted to param dtype
        for param in model.parameters():
            assert param.dtype == param_dtype
            if param.requires_grad:
                assert param._mp_param.dtype == optim_dtype
                assert param.grad.dtype == param_dtype
                assert (
                    param._mp_param.to(param_dtype).tolist() == param.data.tolist()
                ), "mp param has to match param in optim dtype precision"