|
import tempfile |
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
import pytest |
|
import torch |
|
|
|
from finetune.args import LoraArgs |
|
from finetune.checkpointing import Checkpointer |
|
from finetune.loss import compute_loss_with_mask |
|
from finetune.mixed_precision import ( |
|
downcast_mixed_precision, |
|
prepare_mixed_precision, |
|
upcast_mixed_precision, |
|
) |
|
from finetune.utils import TrainState |
|
from finetune.wrapped_model import load_model |
|
from model.transformer import ( |
|
LoRALinear, |
|
) |
|
from tests.test_utils import ( |
|
MODEL_PATH, |
|
get_dataloader, |
|
is_float_equal, |
|
setup_mp_test_dist, |
|
) |
|
|
|
from .test_utils import spawn_for_all_world_sizes |
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "enable_lora", "dtype"), |
|
[ |
|
(1, False, torch.float32), |
|
(1, True, torch.float32), |
|
(2, False, torch.float32), |
|
(2, True, torch.float32), |
|
(1, False, torch.bfloat16), |
|
(1, True, torch.bfloat16), |
|
(2, False, torch.bfloat16), |
|
(2, True, torch.bfloat16), |
|
], |
|
) |
|
def test_weights_loading(world_size, enable_lora, dtype): |
|
spawn_for_all_world_sizes( |
|
_check_weights_loading, |
|
world_sizes=[world_size], |
|
args=[enable_lora, dtype], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_weights_loading( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
enable_lora: bool, |
|
dtype: torch.dtype, |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=enable_lora), |
|
checkpoint=True, |
|
param_dtype=dtype, |
|
) |
|
|
|
|
|
def register_merge_lora_hook(m: torch.nn.Module): |
|
def merge_lora( |
|
m: torch.nn.Module, destination: Dict[str, torch.Tensor], prefix: str, *args |
|
): |
|
weight = m.merge_weight() |
|
destination[prefix + "weight"] = weight |
|
|
|
if isinstance(m, LoRALinear): |
|
m._merge_lora_handle = m._register_state_dict_hook(merge_lora) |
|
|
|
model.apply(register_merge_lora_hook) |
|
|
|
if world_size > 1: |
|
with model.summon_full_params(model, writeback=True): |
|
states = { |
|
k: v |
|
for k, v in model.state_dict().items() |
|
if "lora" not in k and "frozen" not in k |
|
} |
|
else: |
|
states = { |
|
k: v |
|
for k, v in model.state_dict().items() |
|
if "lora" not in k and "frozen" not in k |
|
} |
|
|
|
EXP_PARAM_SUM = 308.9932 if dtype == torch.float32 else 308.0 |
|
params = sum([v.sum() for v in states.values()]).item() |
|
|
|
|
|
assert is_float_equal(params, EXP_PARAM_SUM), params |
|
|
|
if enable_lora: |
|
lora_B_params = [ |
|
v.float().abs().sum() for k, v in model.named_parameters() if "lora_B" in k |
|
] |
|
|
|
assert len(lora_B_params) > 0 |
|
assert sum(lora_B_params) == 0, "Lora_B should always be zero init" |
|
|
|
lora_A_params = [ |
|
v.float().abs().sum() for k, v in model.named_parameters() if "lora_A" in k |
|
] |
|
|
|
assert len(lora_A_params) > 0 |
|
assert sum(lora_A_params) > 0, "Lora_A should init to non-zero values" |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)] |
|
) |
|
def test_fsdp_logits_and_loss(world_size, enable_lora): |
|
spawn_for_all_world_sizes( |
|
_check_fsdp_logits_and_loss, |
|
world_sizes=[world_size], |
|
args=[enable_lora], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_fsdp_logits_and_loss( |
|
rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
seq_len = 100 |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=enable_lora), |
|
checkpoint=True, |
|
param_dtype=torch.bfloat16, |
|
) |
|
|
|
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) |
|
|
|
batch = next(data_loader) |
|
|
|
x = torch.from_numpy(batch.x).cuda(non_blocking=True) |
|
y = torch.from_numpy(batch.y).cuda(non_blocking=True) |
|
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) |
|
|
|
|
|
output = model( |
|
input_ids=x, |
|
seqlens=batch.sizes, |
|
) |
|
|
|
|
|
|
|
assert output.shape == (seq_len, model.args.vocab_size) |
|
output_sum = output.abs().float().sum().item() |
|
|
|
EXP_OUTPUT_WORLD_1 = 162617.625 |
|
|
|
assert is_float_equal(output_sum, EXP_OUTPUT_WORLD_1, precision=1e1), output_sum |
|
|
|
|
|
|
|
mb_loss = compute_loss_with_mask(output, y, y_mask) |
|
|
|
EXPECTED_LOSS = 10.408413887023926 |
|
|
|
assert is_float_equal(mb_loss.item(), EXPECTED_LOSS), mb_loss.item() |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "dtype"), |
|
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], |
|
) |
|
def test_fsdp_grads_non_lora(world_size, dtype): |
|
spawn_for_all_world_sizes( |
|
_check_fsdp_grads_non_lora, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
args=[dtype], |
|
) |
|
|
|
|
|
def _check_fsdp_grads_non_lora( |
|
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
seq_len = 2048 |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=False), |
|
checkpoint=True, |
|
param_dtype=dtype, |
|
) |
|
|
|
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) |
|
|
|
batch = next(data_loader) |
|
|
|
x = torch.from_numpy(batch.x).cuda(non_blocking=True) |
|
y = torch.from_numpy(batch.y).cuda(non_blocking=True) |
|
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) |
|
|
|
|
|
output = model( |
|
input_ids=x, |
|
seqlens=batch.sizes, |
|
) |
|
|
|
mb_loss = compute_loss_with_mask(output, y, y_mask) |
|
mb_loss.backward() |
|
|
|
num_grad_params = sum([p.grad.numel() for p in model.parameters()]) |
|
|
|
assert (4301120 // world_size) == num_grad_params, num_grad_params |
|
|
|
torch.distributed.barrier() |
|
|
|
sharded_flat_grads = sum( |
|
[p.grad.float().abs().sum().item() for p in model.parameters()] |
|
) |
|
|
|
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") |
|
|
|
EXP_GRAD_WORLD_2_RANK_0 = 95.45827150344849 |
|
EXP_GRAD_WORLD_2_RANK_1 = 86.09188461303711 |
|
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 |
|
|
|
if world_size == 1: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 |
|
), sharded_flat_grads |
|
elif world_size == 2 and rank == 0: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 |
|
), sharded_flat_grads |
|
elif world_size == 2 and rank == 1: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 |
|
), sharded_flat_grads |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "dtype"), |
|
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], |
|
) |
|
def test_fsdp_grads_lora(world_size, dtype): |
|
spawn_for_all_world_sizes( |
|
_check_fsdp_grads_lora, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
args=[dtype], |
|
) |
|
|
|
|
|
def _check_fsdp_grads_lora( |
|
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
seq_len = 2048 |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=True), |
|
checkpoint=True, |
|
param_dtype=dtype, |
|
) |
|
|
|
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) |
|
|
|
batch = next(data_loader) |
|
|
|
x = torch.from_numpy(batch.x).cuda(non_blocking=True) |
|
y = torch.from_numpy(batch.y).cuda(non_blocking=True) |
|
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) |
|
|
|
|
|
output = model( |
|
input_ids=x, |
|
seqlens=batch.sizes, |
|
) |
|
|
|
mb_loss = compute_loss_with_mask(output, y, y_mask) |
|
mb_loss.backward() |
|
|
|
num_grad_params = sum( |
|
[p.grad.numel() for p in model.parameters() if p.grad is not None] |
|
) |
|
|
|
assert (40960 // world_size) == num_grad_params, num_grad_params |
|
|
|
torch.distributed.barrier() |
|
|
|
sharded_flat_grads = sum( |
|
[ |
|
p.grad.float().abs().sum().item() |
|
for p in model.parameters() |
|
if p.grad is not None |
|
] |
|
) |
|
|
|
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") |
|
|
|
EXP_GRAD_WORLD_2_RANK_0 = 3.0742580661177635 |
|
EXP_GRAD_WORLD_2_RANK_1 = 3.074301045779139 |
|
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 |
|
|
|
if world_size == 1: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 |
|
), sharded_flat_grads |
|
elif world_size == 2 and rank == 0: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 |
|
), sharded_flat_grads |
|
elif world_size == 2 and rank == 1: |
|
assert is_float_equal( |
|
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 |
|
), sharded_flat_grads |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "dtype"), |
|
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], |
|
) |
|
def test_grad_update_lora(world_size, dtype): |
|
spawn_for_all_world_sizes( |
|
_check_grad_update_lora, |
|
world_sizes=[world_size], |
|
args=[dtype], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_grad_update_lora( |
|
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
seq_len = 1000 |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=True), |
|
checkpoint=True, |
|
param_dtype=dtype, |
|
) |
|
optimizer = torch.optim.AdamW(model.parameters()) |
|
|
|
data_loader = get_dataloader(seq_len=seq_len) |
|
|
|
batch = next(data_loader) |
|
|
|
x = torch.from_numpy(batch.x).cuda(non_blocking=True) |
|
y = torch.from_numpy(batch.y).cuda(non_blocking=True) |
|
y_mask = ( |
|
torch.from_numpy(batch.y_mask).cuda(non_blocking=True) |
|
if batch.y_mask is not None |
|
else None |
|
) |
|
|
|
|
|
output = model( |
|
input_ids=x, |
|
seqlens=batch.sizes, |
|
) |
|
|
|
mb_loss = compute_loss_with_mask(output, y, y_mask) |
|
mb_loss.backward() |
|
|
|
lora_weight_sum = 0 |
|
non_lora_weight_sum = 0 |
|
for name, param in model.named_parameters(): |
|
if "lora" in name or "norm" in name: |
|
assert param.grad is not None, name |
|
lora_weight_sum += param.data.float().abs().sum() |
|
else: |
|
assert param.grad is None, name |
|
non_lora_weight_sum += param.data.float().abs().sum() |
|
|
|
|
|
optimizer.step() |
|
|
|
new_lora_weight_sum = 0 |
|
new_non_lora_weight_sum = 0 |
|
for name, param in model.named_parameters(): |
|
if "lora" in name or "norm" in name: |
|
assert param.grad is not None, name |
|
new_lora_weight_sum += param.data.float().abs().sum() |
|
else: |
|
assert param.grad is None, name |
|
new_non_lora_weight_sum += param.data.float().abs().sum() |
|
|
|
|
|
assert not is_float_equal( |
|
new_lora_weight_sum, lora_weight_sum, 1e-4 |
|
), f"New: {new_lora_weight_sum}, Old: {lora_weight_sum}" |
|
assert is_float_equal( |
|
new_non_lora_weight_sum, non_lora_weight_sum, 1e-4 |
|
), f"New: {new_non_lora_weight_sum}, Old: {non_lora_weight_sum}" |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("enable_lora", "param_dtype"), |
|
[ |
|
(False, torch.float32), |
|
(True, torch.float32), |
|
(False, torch.bfloat16), |
|
(True, torch.bfloat16), |
|
], |
|
) |
|
def test_grads_fsdp_mp(enable_lora, param_dtype): |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
for world_size in [1, 2]: |
|
spawn_for_all_world_sizes( |
|
_check_grads_fsdp_mp, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
args=[tmpdirname, enable_lora, param_dtype], |
|
) |
|
|
|
w1_sd = torch.load(Path(tmpdirname) / Path("params_w1.pt"), map_location="cpu") |
|
w2_sd = torch.load(Path(tmpdirname) / Path("params_w2.pt"), map_location="cpu") |
|
|
|
for k in w1_sd.keys(): |
|
assert w1_sd[k].shape == w2_sd[k].shape, k |
|
atol = 10 if param_dtype == torch.float32 else 100 |
|
assert (w1_sd[k] - w2_sd[k]).sum().abs().item() < atol |
|
|
|
|
|
def _check_grads_fsdp_mp( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
tmpdirname: str, |
|
enable_lora: bool, |
|
param_dtype: torch.dtype, |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
seq_len = 4096 |
|
|
|
optim_dtype = torch.float32 |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=enable_lora), |
|
checkpoint=True, |
|
param_dtype=param_dtype, |
|
) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1) |
|
|
|
|
|
steps = 4 |
|
state = TrainState(max_steps=steps) |
|
|
|
|
|
run_dir = Path(tmpdirname) |
|
|
|
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) |
|
|
|
|
|
dataloaders = [ |
|
get_dataloader(seq_len=seq_len, rank=rank + i, world_size=2) |
|
for i in range(2 - world_size + 1) |
|
] |
|
|
|
prepare_mixed_precision( |
|
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype |
|
) |
|
|
|
for _ in range(steps): |
|
state.start_step() |
|
optimizer.zero_grad() |
|
|
|
for data_loader in dataloaders: |
|
torch.manual_seed(0) |
|
batch = next(data_loader) |
|
|
|
x = torch.from_numpy(batch.x).cuda() |
|
y = torch.from_numpy(batch.y).cuda() |
|
y_mask = ( |
|
torch.from_numpy(batch.y_mask).cuda(non_blocking=True) |
|
if batch.y_mask is not None |
|
else None |
|
) |
|
|
|
|
|
output = model( |
|
input_ids=x, |
|
seqlens=batch.sizes, |
|
) |
|
|
|
mb_loss = compute_loss_with_mask(output, y, y_mask) |
|
mb_loss.backward() |
|
|
|
assert model.params[0].dtype == param_dtype |
|
|
|
print(f"rank: {rank}, world_size: {world_size}, x: {x.abs().sum()}") |
|
print(f"rank: {rank}, world_size: {world_size}, y: {y.abs().sum()}") |
|
print(f"rank: {rank}, world_size: {world_size}, x shape: {x.shape}") |
|
|
|
if y_mask is not None: |
|
print( |
|
f"rank: {rank}, world_size: {world_size}, y_mask: {y_mask.abs().sum()}" |
|
) |
|
print(f"rank: {rank}, world_size: {world_size}, loss: {mb_loss}") |
|
|
|
for p in model.parameters(): |
|
if p.requires_grad: |
|
assert p.grad is not None |
|
p.grad.div_(len(dataloaders)) |
|
|
|
max_norm = 1.0 |
|
model.clip_grad_norm_(max_norm=max_norm) |
|
|
|
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) |
|
|
|
optimizer.step() |
|
|
|
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) |
|
|
|
save_dict = checkpointer.retrieve_save_states( |
|
save_only_lora=enable_lora, save_dtype=torch.float32 |
|
) |
|
|
|
path = "params_w1.pt" if world_size == 1 else "params_w2.pt" |
|
torch.save(save_dict, Path(tmpdirname) / Path(path)) |
|
|