File size: 4,615 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from pathlib import Path
import pytest
import torch
from finetune.args import LoraArgs
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.wrapped_model import load_model
from tests.test_utils import MODEL_PATH, get_dataloader, setup_mp_test_dist
from .test_utils import spawn_for_all_world_sizes
@pytest.mark.parametrize(
("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)]
)
def test_mixed_precision(world_size, enable_lora):
spawn_for_all_world_sizes(
_check_mixed_precision,
world_sizes=[world_size],
args=[enable_lora],
deterministic=True,
)
def _check_mixed_precision(
rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 100
folder = Path(MODEL_PATH)
# mixed precision
param_dtype = torch.bfloat16
optim_dtype = torch.float32
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=param_dtype,
)
optimizer = torch.optim.AdamW(model.parameters())
# initialize mixed precision training for TP
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
data_loader = get_dataloader(seq_len=seq_len)
# ensure every parameter that requires a grad has a _mp_param of optim_dtype precision
for param in model.parameters():
assert param.dtype == param_dtype
if param.requires_grad:
assert param._mp_param.dtype == optim_dtype
assert (
param._mp_param.tolist() == param.data.to(optim_dtype).tolist()
), "mp param has to match param in optim dtype precision"
else:
assert not hasattr(param, "_mp_param")
# test three train steps
for _ in range(3):
optimizer.zero_grad()
# micro-batching
for _ in range(2):
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
# ensure all params are upcasted correctly and mp param equals param
param_sum = 0
for param in model.parameters():
if param.requires_grad:
assert param.dtype == optim_dtype, param.dtype
assert (
param._mp_param.tolist() == param.data.tolist()
), "mp param and param should point to the same data"
assert param.grad.dtype == optim_dtype
assert param._temp.dtype == param_dtype
param_sum += param.data.float().abs().sum()
else:
assert param.dtype == param_dtype
optimizer.step()
# ensure that after optimizer step params are still in optim dtype precision
new_param_sum = 0
for param in model.parameters():
if param.requires_grad:
assert param.dtype == optim_dtype
assert param._mp_param.dtype == optim_dtype
assert param.grad.dtype == optim_dtype
new_param_sum += param.data.float().abs().sum()
else:
assert param.dtype == param_dtype
assert new_param_sum != param_sum, "Make sure parameters are updated"
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
# ensure that before new forward pass params are downcasted to param dtype
for param in model.parameters():
assert param.dtype == param_dtype
if param.requires_grad:
assert param._mp_param.dtype == optim_dtype
assert param.grad.dtype == param_dtype
assert (
param._mp_param.to(param_dtype).tolist() == param.data.tolist()
), "mp param has to match param in optim dtype precision"
|