CHEMISTral7Bv0.3 / tests /test_checkpointer.py
Clemspace's picture
Initial model upload
cb9e677
from pathlib import Path
import pytest
import torch
from finetune.args import LoraArgs
from finetune.checkpointing import Checkpointer
from finetune.utils import TrainState
from finetune.wrapped_model import load_model
from tests.test_utils import MODEL_PATH, is_float_equal, setup_mp_test_dist
from utils.merge_lora import merge_checkpoints
from .test_utils import spawn_for_all_world_sizes
# fmt: off
EXPECTED_NON_LORA_KEYS = sorted(['layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'norm.weight', 'output.weight', 'tok_embeddings.weight'])
EXPECTED_LORA_KEYS = sorted(['layers.0.attention.wq.lora_A.weight', 'layers.0.attention.wq.lora_B.weight', 'layers.0.attention.wk.lora_A.weight', 'layers.0.attention.wk.lora_B.weight', 'layers.0.attention.wv.lora_A.weight', 'layers.0.attention.wv.lora_B.weight', 'layers.0.attention.wo.lora_A.weight', 'layers.0.attention.wo.lora_B.weight', 'layers.0.feed_forward.w1.lora_A.weight', 'layers.0.feed_forward.w1.lora_B.weight', 'layers.0.feed_forward.w2.lora_A.weight', 'layers.0.feed_forward.w2.lora_B.weight', 'layers.0.feed_forward.w3.lora_A.weight', 'layers.0.feed_forward.w3.lora_B.weight', 'layers.1.attention.wq.lora_A.weight', 'layers.1.attention.wq.lora_B.weight', 'layers.1.attention.wk.lora_A.weight', 'layers.1.attention.wk.lora_B.weight', 'layers.1.attention.wv.lora_A.weight', 'layers.1.attention.wv.lora_B.weight', 'layers.1.attention.wo.lora_A.weight', 'layers.1.attention.wo.lora_B.weight', 'layers.1.feed_forward.w1.lora_A.weight', 'layers.1.feed_forward.w1.lora_B.weight', 'layers.1.feed_forward.w2.lora_A.weight', 'layers.1.feed_forward.w2.lora_B.weight', 'layers.1.feed_forward.w3.lora_A.weight', 'layers.1.feed_forward.w3.lora_B.weight'])
# fmt: on
@pytest.mark.parametrize(
("world_size", "save_only_lora", "enable_lora"),
[
(1, False, False),
(2, False, False),
(1, False, True),
(2, False, True),
(1, True, True),
(2, True, True), # this is the most important test! - FSDP only LORA
],
)
def test_states_retrieval(world_size, enable_lora, save_only_lora):
spawn_for_all_world_sizes(
_check_states_retrieval,
world_sizes=[world_size],
args=[enable_lora, save_only_lora],
deterministic=True,
)
def _check_states_retrieval(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
enable_lora: bool,
save_only_lora: bool,
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=torch.bfloat16,
)
# mock a train state that has done three steps
step = 3
state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here
# mock run_dir as we won't save anything in this test
run_dir = Path("dir")
use_sf = True
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None)
prefix = "lora" if enable_lora else "consolidated"
assert checkpointer.dst_dir == Path(
f"dir/checkpoints/checkpoint_00000{step}/consolidated"
), checkpointer.dst_dir
assert checkpointer.consolidated_path(
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora
) == Path(
f"dir/checkpoints/checkpoint_00000{step}/consolidated/{prefix}.safetensors"
), checkpointer.consolidated_path(
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora
)
# increase step by one
state.start_step()
assert checkpointer.dst_dir == Path(
f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated"
), checkpointer.dst_dir
assert checkpointer.consolidated_path(
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora
) == Path(
f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated/{prefix}.safetensors"
), checkpointer.consolidated_path(
checkpointer.dst_dir, use_sf, save_only_lora=enable_lora
)
assert all("lora" in k for k in EXPECTED_LORA_KEYS), EXPECTED_LORA_KEYS
for save_dtype in [torch.float16, torch.bfloat16, torch.float32]:
save_dict = checkpointer.retrieve_save_states(
save_only_lora=save_only_lora, save_dtype=save_dtype
)
for k, v in save_dict.items():
assert v.dtype == save_dtype, f"{k}: v.dtype"
if save_only_lora:
assert sorted(save_dict.keys()) == EXPECTED_LORA_KEYS, save_dict.keys()
else:
assert sorted(save_dict.keys()) == EXPECTED_NON_LORA_KEYS, save_dict.keys()
EXPECTED_NON_LORA_VALUES = 34909.7500
EXPECTED_LORA_VALUES = 984.4179840087891
values_sum = sum(v.abs().float().sum().item() for v in save_dict.values())
if save_only_lora:
assert is_float_equal(
values_sum, EXPECTED_LORA_VALUES, 5e-1
), f"{values_sum} for {save_dtype}"
else:
assert is_float_equal(
values_sum, EXPECTED_NON_LORA_VALUES, 1e-1
), f"{values_sum} for {save_dtype}"
@pytest.mark.parametrize("world_size", [1, 2])
def test_lora_merge_equal(world_size):
spawn_for_all_world_sizes(
_check_lora_merge_equal,
world_sizes=[world_size],
deterministic=True,
)
def _check_lora_merge_equal(
rank: int, world_size: int, filename: str, filename_rpc: str
):
model_parallel = 1
enable_lora = True
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
world_size // model_parallel
folder = Path(MODEL_PATH)
step = 3
state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here
run_dir = Path("dir")
non_lora_model = load_model(
folder=folder,
lora=LoraArgs(enable=False),
checkpoint=True,
param_dtype=torch.bfloat16,
)
non_lora_checkpointer = Checkpointer(
non_lora_model, state, run_dir=run_dir, num_ckpt_keep=None
)
orig_model = non_lora_checkpointer.retrieve_save_states(
save_only_lora=False, save_dtype=torch.float32
)
scaling = 2.0
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora, scaling=scaling),
checkpoint=True,
param_dtype=torch.bfloat16,
)
state_dict = model.state_dict()
state_dict = {k: v + 0.01 if "lora" in k else v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
# mock a train state that has done three steps
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None)
for save_dtype in [torch.float16, torch.bfloat16, torch.float32]:
model_dict = {
k: torch.empty_like(v).copy_(v).to(save_dtype)
for k, v in orig_model.items()
}
merged_save_dict = checkpointer.retrieve_save_states(
save_only_lora=False, save_dtype=save_dtype
)
lora_save_dict = checkpointer.retrieve_save_states(
save_only_lora=True, save_dtype=save_dtype
)
merge_checkpoints(
model_dict, lora_save_dict, scaling=scaling, save_dtype=save_dtype
)
for k in model_dict.keys():
torch.allclose(
model_dict[k].cpu(), merged_save_dict[k].cpu(), atol=1e-3, rtol=1e-3
)
for k in model_dict.keys():
# make sure that merged model differs from orig model
if "attention" in k or "feed_forward" in k:
not torch.allclose(
orig_model[k].to(save_dtype).cpu(),
merged_save_dict[k].cpu(),
atol=1e-3,
rtol=1e-3,
)