File size: 16,558 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 |
import tempfile
from pathlib import Path
from typing import Dict
import pytest
import torch
from finetune.args import LoraArgs
from finetune.checkpointing import Checkpointer
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.utils import TrainState
from finetune.wrapped_model import load_model
from model.transformer import (
LoRALinear,
)
from tests.test_utils import (
MODEL_PATH,
get_dataloader,
is_float_equal,
setup_mp_test_dist,
)
from .test_utils import spawn_for_all_world_sizes
torch.backends.cudnn.deterministic = True # use deterministic algorithms
torch.backends.cudnn.benchmark = False # disable cuDNN benchmark
@pytest.mark.parametrize(
("world_size", "enable_lora", "dtype"),
[
(1, False, torch.float32),
(1, True, torch.float32),
(2, False, torch.float32),
(2, True, torch.float32),
(1, False, torch.bfloat16),
(1, True, torch.bfloat16),
(2, False, torch.bfloat16),
(2, True, torch.bfloat16),
],
)
def test_weights_loading(world_size, enable_lora, dtype):
spawn_for_all_world_sizes(
_check_weights_loading,
world_sizes=[world_size],
args=[enable_lora, dtype],
deterministic=True,
)
def _check_weights_loading(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
enable_lora: bool,
dtype: torch.dtype,
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=dtype,
)
# add hook so that LoRA weights are automatically merged:
def register_merge_lora_hook(m: torch.nn.Module):
def merge_lora(
m: torch.nn.Module, destination: Dict[str, torch.Tensor], prefix: str, *args
):
weight = m.merge_weight()
destination[prefix + "weight"] = weight
if isinstance(m, LoRALinear):
m._merge_lora_handle = m._register_state_dict_hook(merge_lora)
model.apply(register_merge_lora_hook)
if world_size > 1:
with model.summon_full_params(model, writeback=True):
states = {
k: v
for k, v in model.state_dict().items()
if "lora" not in k and "frozen" not in k
}
else:
states = {
k: v
for k, v in model.state_dict().items()
if "lora" not in k and "frozen" not in k
}
EXP_PARAM_SUM = 308.9932 if dtype == torch.float32 else 308.0
params = sum([v.sum() for v in states.values()]).item()
# LoRA is equal to no LoRA as LoRA weights should be init to 0
assert is_float_equal(params, EXP_PARAM_SUM), params
if enable_lora:
lora_B_params = [
v.float().abs().sum() for k, v in model.named_parameters() if "lora_B" in k
]
assert len(lora_B_params) > 0
assert sum(lora_B_params) == 0, "Lora_B should always be zero init"
lora_A_params = [
v.float().abs().sum() for k, v in model.named_parameters() if "lora_A" in k
]
assert len(lora_A_params) > 0
assert sum(lora_A_params) > 0, "Lora_A should init to non-zero values"
@pytest.mark.parametrize(
("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)]
)
def test_fsdp_logits_and_loss(world_size, enable_lora):
spawn_for_all_world_sizes(
_check_fsdp_logits_and_loss,
world_sizes=[world_size],
args=[enable_lora],
deterministic=True,
)
def _check_fsdp_logits_and_loss(
rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 100
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=torch.bfloat16,
)
# By seting equal rank and world_size we can assure that both ranks see the same data and hence the average
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
# check logits
# logits should be the same for LoRA and non-LoRA
assert output.shape == (seq_len, model.args.vocab_size)
output_sum = output.abs().float().sum().item()
EXP_OUTPUT_WORLD_1 = 162617.625
assert is_float_equal(output_sum, EXP_OUTPUT_WORLD_1, precision=1e1), output_sum
# check loss is the same for all
# loss should be the same for LoRA and non-LoRA
mb_loss = compute_loss_with_mask(output, y, y_mask)
EXPECTED_LOSS = 10.408413887023926
assert is_float_equal(mb_loss.item(), EXPECTED_LOSS), mb_loss.item()
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_fsdp_grads_non_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_fsdp_grads_non_lora,
world_sizes=[world_size],
deterministic=True,
args=[dtype],
)
def _check_fsdp_grads_non_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 2048
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=False),
checkpoint=True,
param_dtype=dtype,
)
# same world_size to check for equality
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
num_grad_params = sum([p.grad.numel() for p in model.parameters()])
assert (4301120 // world_size) == num_grad_params, num_grad_params
torch.distributed.barrier()
sharded_flat_grads = sum(
[p.grad.float().abs().sum().item() for p in model.parameters()]
)
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}")
EXP_GRAD_WORLD_2_RANK_0 = 95.45827150344849
EXP_GRAD_WORLD_2_RANK_1 = 86.09188461303711
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1
if world_size == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 0:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1
), sharded_flat_grads
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_fsdp_grads_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_fsdp_grads_lora,
world_sizes=[world_size],
deterministic=True,
args=[dtype],
)
def _check_fsdp_grads_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 2048
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=True),
checkpoint=True,
param_dtype=dtype,
)
# same world_size to check for equality
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
num_grad_params = sum(
[p.grad.numel() for p in model.parameters() if p.grad is not None]
)
assert (40960 // world_size) == num_grad_params, num_grad_params
torch.distributed.barrier()
sharded_flat_grads = sum(
[
p.grad.float().abs().sum().item()
for p in model.parameters()
if p.grad is not None
]
)
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}")
EXP_GRAD_WORLD_2_RANK_0 = 3.0742580661177635
EXP_GRAD_WORLD_2_RANK_1 = 3.074301045779139
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1
if world_size == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 0:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1
), sharded_flat_grads
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_grad_update_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_grad_update_lora,
world_sizes=[world_size],
args=[dtype],
deterministic=True,
)
def _check_grad_update_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 1000
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=True),
checkpoint=True,
param_dtype=dtype,
)
optimizer = torch.optim.AdamW(model.parameters())
data_loader = get_dataloader(seq_len=seq_len)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
lora_weight_sum = 0
non_lora_weight_sum = 0
for name, param in model.named_parameters():
if "lora" in name or "norm" in name:
assert param.grad is not None, name
lora_weight_sum += param.data.float().abs().sum()
else:
assert param.grad is None, name
non_lora_weight_sum += param.data.float().abs().sum()
# update weights
optimizer.step()
new_lora_weight_sum = 0
new_non_lora_weight_sum = 0
for name, param in model.named_parameters():
if "lora" in name or "norm" in name:
assert param.grad is not None, name
new_lora_weight_sum += param.data.float().abs().sum()
else:
assert param.grad is None, name
new_non_lora_weight_sum += param.data.float().abs().sum()
# make sure that LoRA weights changed, but non-LoRA weights stayed the same
assert not is_float_equal(
new_lora_weight_sum, lora_weight_sum, 1e-4
), f"New: {new_lora_weight_sum}, Old: {lora_weight_sum}"
assert is_float_equal(
new_non_lora_weight_sum, non_lora_weight_sum, 1e-4
), f"New: {new_non_lora_weight_sum}, Old: {non_lora_weight_sum}"
@pytest.mark.parametrize(
("enable_lora", "param_dtype"),
[
(False, torch.float32),
(True, torch.float32),
(False, torch.bfloat16),
(True, torch.bfloat16),
],
)
def test_grads_fsdp_mp(enable_lora, param_dtype):
with tempfile.TemporaryDirectory() as tmpdirname:
for world_size in [1, 2]:
spawn_for_all_world_sizes(
_check_grads_fsdp_mp,
world_sizes=[world_size],
deterministic=True,
args=[tmpdirname, enable_lora, param_dtype],
)
w1_sd = torch.load(Path(tmpdirname) / Path("params_w1.pt"), map_location="cpu")
w2_sd = torch.load(Path(tmpdirname) / Path("params_w2.pt"), map_location="cpu")
for k in w1_sd.keys():
assert w1_sd[k].shape == w2_sd[k].shape, k
atol = 10 if param_dtype == torch.float32 else 100
assert (w1_sd[k] - w2_sd[k]).sum().abs().item() < atol
def _check_grads_fsdp_mp(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
tmpdirname: str,
enable_lora: bool,
param_dtype: torch.dtype,
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 4096
optim_dtype = torch.float32
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=param_dtype,
)
# high learning rate to show differences
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)
# mock a train state that has done three steps
steps = 4
state = TrainState(max_steps=steps)
# mock run_dir as we won't save anything in this test
run_dir = Path(tmpdirname)
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None)
# make sure the same data is seen
dataloaders = [
get_dataloader(seq_len=seq_len, rank=rank + i, world_size=2)
for i in range(2 - world_size + 1)
]
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
for _ in range(steps):
state.start_step()
optimizer.zero_grad()
for data_loader in dataloaders:
torch.manual_seed(0)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda()
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
assert model.params[0].dtype == param_dtype
print(f"rank: {rank}, world_size: {world_size}, x: {x.abs().sum()}")
print(f"rank: {rank}, world_size: {world_size}, y: {y.abs().sum()}")
print(f"rank: {rank}, world_size: {world_size}, x shape: {x.shape}")
if y_mask is not None:
print(
f"rank: {rank}, world_size: {world_size}, y_mask: {y_mask.abs().sum()}"
)
print(f"rank: {rank}, world_size: {world_size}, loss: {mb_loss}")
for p in model.parameters():
if p.requires_grad:
assert p.grad is not None
p.grad.div_(len(dataloaders))
max_norm = 1.0
model.clip_grad_norm_(max_norm=max_norm)
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
optimizer.step()
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
save_dict = checkpointer.retrieve_save_states(
save_only_lora=enable_lora, save_dtype=torch.float32
)
path = "params_w1.pt" if world_size == 1 else "params_w2.pt"
torch.save(save_dict, Path(tmpdirname) / Path(path))
|