|
from pathlib import Path |
|
|
|
import pytest |
|
import torch |
|
|
|
from finetune.args import LoraArgs |
|
from finetune.checkpointing import Checkpointer |
|
from finetune.utils import TrainState |
|
from finetune.wrapped_model import load_model |
|
from tests.test_utils import MODEL_PATH, is_float_equal, setup_mp_test_dist |
|
from utils.merge_lora import merge_checkpoints |
|
|
|
from .test_utils import spawn_for_all_world_sizes |
|
|
|
|
|
EXPECTED_NON_LORA_KEYS = sorted(['layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'norm.weight', 'output.weight', 'tok_embeddings.weight']) |
|
EXPECTED_LORA_KEYS = sorted(['layers.0.attention.wq.lora_A.weight', 'layers.0.attention.wq.lora_B.weight', 'layers.0.attention.wk.lora_A.weight', 'layers.0.attention.wk.lora_B.weight', 'layers.0.attention.wv.lora_A.weight', 'layers.0.attention.wv.lora_B.weight', 'layers.0.attention.wo.lora_A.weight', 'layers.0.attention.wo.lora_B.weight', 'layers.0.feed_forward.w1.lora_A.weight', 'layers.0.feed_forward.w1.lora_B.weight', 'layers.0.feed_forward.w2.lora_A.weight', 'layers.0.feed_forward.w2.lora_B.weight', 'layers.0.feed_forward.w3.lora_A.weight', 'layers.0.feed_forward.w3.lora_B.weight', 'layers.1.attention.wq.lora_A.weight', 'layers.1.attention.wq.lora_B.weight', 'layers.1.attention.wk.lora_A.weight', 'layers.1.attention.wk.lora_B.weight', 'layers.1.attention.wv.lora_A.weight', 'layers.1.attention.wv.lora_B.weight', 'layers.1.attention.wo.lora_A.weight', 'layers.1.attention.wo.lora_B.weight', 'layers.1.feed_forward.w1.lora_A.weight', 'layers.1.feed_forward.w1.lora_B.weight', 'layers.1.feed_forward.w2.lora_A.weight', 'layers.1.feed_forward.w2.lora_B.weight', 'layers.1.feed_forward.w3.lora_A.weight', 'layers.1.feed_forward.w3.lora_B.weight']) |
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "save_only_lora", "enable_lora"), |
|
[ |
|
(1, False, False), |
|
(2, False, False), |
|
(1, False, True), |
|
(2, False, True), |
|
(1, True, True), |
|
(2, True, True), |
|
], |
|
) |
|
def test_states_retrieval(world_size, enable_lora, save_only_lora): |
|
spawn_for_all_world_sizes( |
|
_check_states_retrieval, |
|
world_sizes=[world_size], |
|
args=[enable_lora, save_only_lora], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_states_retrieval( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
enable_lora: bool, |
|
save_only_lora: bool, |
|
): |
|
model_parallel = 1 |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
|
|
folder = Path(MODEL_PATH) |
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=enable_lora), |
|
checkpoint=True, |
|
param_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
step = 3 |
|
state = TrainState(max_steps=10, step=step) |
|
|
|
|
|
run_dir = Path("dir") |
|
use_sf = True |
|
|
|
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) |
|
prefix = "lora" if enable_lora else "consolidated" |
|
|
|
assert checkpointer.dst_dir == Path( |
|
f"dir/checkpoints/checkpoint_00000{step}/consolidated" |
|
), checkpointer.dst_dir |
|
assert checkpointer.consolidated_path( |
|
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora |
|
) == Path( |
|
f"dir/checkpoints/checkpoint_00000{step}/consolidated/{prefix}.safetensors" |
|
), checkpointer.consolidated_path( |
|
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora |
|
) |
|
|
|
|
|
state.start_step() |
|
|
|
assert checkpointer.dst_dir == Path( |
|
f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated" |
|
), checkpointer.dst_dir |
|
assert checkpointer.consolidated_path( |
|
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora |
|
) == Path( |
|
f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated/{prefix}.safetensors" |
|
), checkpointer.consolidated_path( |
|
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora |
|
) |
|
|
|
assert all("lora" in k for k in EXPECTED_LORA_KEYS), EXPECTED_LORA_KEYS |
|
|
|
for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: |
|
|
|
save_dict = checkpointer.retrieve_save_states( |
|
save_only_lora=save_only_lora, save_dtype=save_dtype |
|
) |
|
|
|
for k, v in save_dict.items(): |
|
assert v.dtype == save_dtype, f"{k}: v.dtype" |
|
|
|
if save_only_lora: |
|
assert sorted(save_dict.keys()) == EXPECTED_LORA_KEYS, save_dict.keys() |
|
else: |
|
assert sorted(save_dict.keys()) == EXPECTED_NON_LORA_KEYS, save_dict.keys() |
|
|
|
EXPECTED_NON_LORA_VALUES = 34909.7500 |
|
|
|
EXPECTED_LORA_VALUES = 984.4179840087891 |
|
|
|
values_sum = sum(v.abs().float().sum().item() for v in save_dict.values()) |
|
|
|
if save_only_lora: |
|
assert is_float_equal( |
|
values_sum, EXPECTED_LORA_VALUES, 5e-1 |
|
), f"{values_sum} for {save_dtype}" |
|
else: |
|
assert is_float_equal( |
|
values_sum, EXPECTED_NON_LORA_VALUES, 1e-1 |
|
), f"{values_sum} for {save_dtype}" |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2]) |
|
def test_lora_merge_equal(world_size): |
|
spawn_for_all_world_sizes( |
|
_check_lora_merge_equal, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_lora_merge_equal( |
|
rank: int, world_size: int, filename: str, filename_rpc: str |
|
): |
|
model_parallel = 1 |
|
enable_lora = True |
|
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) |
|
|
|
world_size // model_parallel |
|
|
|
folder = Path(MODEL_PATH) |
|
|
|
step = 3 |
|
state = TrainState(max_steps=10, step=step) |
|
run_dir = Path("dir") |
|
|
|
non_lora_model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=False), |
|
checkpoint=True, |
|
param_dtype=torch.bfloat16, |
|
) |
|
|
|
non_lora_checkpointer = Checkpointer( |
|
non_lora_model, state, run_dir=run_dir, num_ckpt_keep=None |
|
) |
|
orig_model = non_lora_checkpointer.retrieve_save_states( |
|
save_only_lora=False, save_dtype=torch.float32 |
|
) |
|
|
|
scaling = 2.0 |
|
|
|
model = load_model( |
|
folder=folder, |
|
lora=LoraArgs(enable=enable_lora, scaling=scaling), |
|
checkpoint=True, |
|
param_dtype=torch.bfloat16, |
|
) |
|
|
|
state_dict = model.state_dict() |
|
state_dict = {k: v + 0.01 if "lora" in k else v for k, v in state_dict.items()} |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) |
|
|
|
for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: |
|
model_dict = { |
|
k: torch.empty_like(v).copy_(v).to(save_dtype) |
|
for k, v in orig_model.items() |
|
} |
|
merged_save_dict = checkpointer.retrieve_save_states( |
|
save_only_lora=False, save_dtype=save_dtype |
|
) |
|
|
|
lora_save_dict = checkpointer.retrieve_save_states( |
|
save_only_lora=True, save_dtype=save_dtype |
|
) |
|
|
|
merge_checkpoints( |
|
model_dict, lora_save_dict, scaling=scaling, save_dtype=save_dtype |
|
) |
|
|
|
for k in model_dict.keys(): |
|
torch.allclose( |
|
model_dict[k].cpu(), merged_save_dict[k].cpu(), atol=1e-3, rtol=1e-3 |
|
) |
|
|
|
for k in model_dict.keys(): |
|
|
|
if "attention" in k or "feed_forward" in k: |
|
not torch.allclose( |
|
orig_model[k].to(save_dtype).cpu(), |
|
merged_save_dict[k].cpu(), |
|
atol=1e-3, |
|
rtol=1e-3, |
|
) |
|
|