|
import copy |
|
import json |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pytest |
|
from mistral_common.protocol.instruct.messages import FinetuningAssistantMessage |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
|
|
from finetune.data.args import DataArgs, InstructArgs |
|
from finetune.data.data_loader import build_data_loader |
|
from finetune.data.dataset import ( |
|
DataFile, |
|
SampleType, |
|
get_dataset_iterator, |
|
get_rng, |
|
lazy_load_and_yield, |
|
maybe_chunk_lines, |
|
parse_data_sources, |
|
preload_and_yield, |
|
) |
|
from finetune.data.tokenize import build_instruct_sample, encode |
|
|
|
from .test_utils import spawn_for_all_world_sizes |
|
|
|
|
|
EXPECTED_X = [ |
|
[ |
|
|
|
[ |
|
|
|
[2051851, 1961139, 2000184, 2081307, 2341123, 1225437, 1739008, 724695, 570810, 632094] |
|
], |
|
[ |
|
|
|
[2020745, 1938377, 2244286, 2042079, 1824023], |
|
[2103241, 2032118, 1868430, 1093072, 770996], |
|
] |
|
], |
|
[ |
|
|
|
[ |
|
|
|
[1379941, 1438894, 965536, 1019713, 889921, 999322, 1647173, 941080, 1281597, 1584884] |
|
], |
|
[ |
|
|
|
[1379941, 1438894, 889899, 1005451, 876854], |
|
[1034325, 999322, 982295, 941080, 725946], |
|
] |
|
] |
|
] |
|
EXPECTED_Y = [ |
|
[ |
|
|
|
[ |
|
|
|
[2081367, 1961098, 1970714, 2110856, 2334822, 1251057, 1745267, 699854, 571600, 660015] |
|
], |
|
[ |
|
|
|
[2021840, 1966833, 2223275, 2063077, 1824011], |
|
[2132793, 2002569, 1870876, 1122569, 757126], |
|
] |
|
], |
|
[ |
|
|
|
[ |
|
|
|
[1409448, 1430886, 937609, 1019339, 889921, 970976, 1660330, 942631, 1308399, 1583658] |
|
], |
|
[ |
|
|
|
[1409448, 1430886, 895531, 990091, 863522], |
|
[1041462, 970976, 991091, 942631, 737311] |
|
] |
|
] |
|
] |
|
EXPECTED_MASKS = [ |
|
[ |
|
|
|
[ |
|
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
|
], |
|
[ |
|
|
|
[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
] |
|
], |
|
[ |
|
|
|
[ |
|
|
|
[47, 0, 34, 0, 0, 82, 0, 0, 0, 0] |
|
], |
|
[ |
|
|
|
[47, 0, 4, 0, 0], |
|
[19, 82, 0, 0, 23], |
|
] |
|
] |
|
] |
|
|
|
EXPECTED_X_FUNC = [ |
|
[ |
|
|
|
[1005531, 1551735, 1261711, 1531024, 1280259, 1069883, 858107, 1021583, 1203265, 1242999], |
|
], |
|
[ |
|
|
|
[985281, 1217766, 1442139, 1533790, 1253607], |
|
[1005531, 1551735, 1261711, 1531024, 1280259], |
|
] |
|
] |
|
|
|
EXPECTED_Y_FUNC = [ |
|
[ |
|
|
|
[977126, 1580120, 1233326, 1559463, 1280241, 1042456, 879031, 994127, 1196263, 1270581], |
|
], |
|
[ |
|
|
|
[957934, 1218899, 1441783, 1533011, 1224541], |
|
[977126, 1580120, 1233326, 1559463, 1280241], |
|
] |
|
] |
|
|
|
EXPECTED_MASKS_FUNC = [ |
|
[ |
|
|
|
[91, 0, 0, 0, 0, 77, 0, 0, 53, 0], |
|
], |
|
[ |
|
|
|
[16, 47, 0, 86, 98], |
|
[91, 0, 0, 0, 0], |
|
] |
|
] |
|
|
|
|
|
|
|
class MockTokenizer: |
|
def encode(self, content: str, *args, **kwargs) -> str: |
|
return content |
|
|
|
|
|
class MockInstructTokenizerBaseBase: |
|
def __init__(self): |
|
self.tokenizer = MockTokenizer() |
|
|
|
def encode_user_message(self, message, *args, **kwargs): |
|
return message.content |
|
|
|
def encode_assistant_message(self, message, *args, **kwargs): |
|
return message.content |
|
|
|
def start(self): |
|
return [] |
|
|
|
|
|
def stringify(samples): |
|
lines = [] |
|
for sample in samples: |
|
string_list = sample.tokens |
|
lines.append("".join(string_list)) |
|
|
|
return lines |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("world_size", "model_parallel", "is_instruct"), |
|
[ |
|
(1, 1, False), |
|
(2, 1, False), |
|
(2, 2, False), |
|
(1, 1, True), |
|
(2, 1, True), |
|
(2, 2, True), |
|
], |
|
) |
|
def test_data_loader_dist(world_size, model_parallel, is_instruct): |
|
spawn_for_all_world_sizes( |
|
_check_data_loader_dist, |
|
world_sizes=[world_size], |
|
args=[model_parallel, is_instruct], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_data_loader_dist( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
model_parallel: int, |
|
is_instruct: bool, |
|
): |
|
dp_world_size = world_size // model_parallel |
|
dp_rank = rank // model_parallel |
|
|
|
seed = 0 |
|
seq_len = 100 |
|
batch_size = 1 |
|
|
|
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) |
|
|
|
if is_instruct: |
|
|
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", |
|
instruct=instruct, |
|
) |
|
else: |
|
data_args = DataArgs( |
|
data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0", |
|
instruct_data="tests/fixtures/sample_instruct.jsonl:.01", |
|
instruct=instruct, |
|
) |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
data_loader = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=seed, |
|
rank=dp_rank, |
|
world_size=dp_world_size, |
|
is_eval=False, |
|
) |
|
|
|
x_sums = [] |
|
y_sums = [] |
|
masks = [] |
|
|
|
num_samples = 10 // dp_world_size |
|
|
|
for _ in range(num_samples): |
|
batch = next(data_loader) |
|
x_sums.append(batch.x.sum()) |
|
y_sums.append(batch.y.sum()) |
|
mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0 |
|
masks.append(mask_sum) |
|
|
|
expected_x_sums = EXPECTED_X[is_instruct][dp_world_size - 1][dp_rank] |
|
expected_y_sums = EXPECTED_Y[is_instruct][dp_world_size - 1][dp_rank] |
|
expected_masks = EXPECTED_MASKS[is_instruct][dp_world_size - 1][dp_rank] |
|
|
|
print(f"rank: {rank}, world_size: {world_size}, x: {x_sums}") |
|
print(f"rank: {rank}, world_size: {world_size}, y: {y_sums}") |
|
print(f"rank: {rank}, world_size: {world_size}, x shape: {masks}") |
|
|
|
assert x_sums == expected_x_sums, x_sums |
|
assert y_sums == expected_y_sums, y_sums |
|
assert masks == expected_masks, masks |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2]) |
|
def test_data_loader_dist_fn_call(world_size): |
|
spawn_for_all_world_sizes( |
|
_check_data_loader_dist_fn_call, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_data_loader_dist_fn_call( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
): |
|
dp_world_size = world_size |
|
dp_rank = rank |
|
|
|
seed = 0 |
|
seq_len = 100 |
|
batch_size = 1 |
|
|
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="tests/fixtures/sample_instruct_fn_call_short.jsonl:.3", |
|
instruct=InstructArgs(shuffle=True, dynamic_chunk_fn_call=True), |
|
) |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
data_loader = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=seed, |
|
rank=dp_rank, |
|
world_size=dp_world_size, |
|
is_eval=False, |
|
) |
|
|
|
x_sums = [] |
|
y_sums = [] |
|
masks = [] |
|
|
|
num_samples = 10 // dp_world_size |
|
|
|
for _ in range(num_samples): |
|
batch = next(data_loader) |
|
x_sums.append(batch.x.sum()) |
|
y_sums.append(batch.y.sum()) |
|
mask_sum = batch.y_mask.sum() if batch.y_mask is not None else 0 |
|
masks.append(mask_sum) |
|
|
|
expected_x_sums = EXPECTED_X_FUNC[dp_world_size - 1][dp_rank] |
|
expected_y_sums = EXPECTED_Y_FUNC[dp_world_size - 1][dp_rank] |
|
expected_masks = EXPECTED_MASKS_FUNC[dp_world_size - 1][dp_rank] |
|
|
|
assert x_sums == expected_x_sums, x_sums |
|
assert y_sums == expected_y_sums, y_sums |
|
assert masks == expected_masks, masks |
|
|
|
|
|
def test_data_loader_equal_fsdp(): |
|
spawn_for_all_world_sizes( |
|
_check_data_loader_equal_fsdp, |
|
world_sizes=[2], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_data_loader_equal_fsdp( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
): |
|
model_parallel = 2 |
|
world_size // model_parallel |
|
rank // model_parallel |
|
|
|
seed = 0 |
|
seq_len = 100 |
|
batch_size = 1 |
|
|
|
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) |
|
|
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", |
|
instruct=instruct, |
|
) |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
data_loader_0 = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=seed, |
|
rank=0, |
|
world_size=world_size, |
|
is_eval=False, |
|
) |
|
|
|
data_loader_1 = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=seed, |
|
rank=1, |
|
world_size=world_size, |
|
is_eval=False, |
|
) |
|
|
|
x_sums = [] |
|
y_sums = [] |
|
|
|
num_samples = 10 // 2 |
|
|
|
for _ in range(num_samples): |
|
batch = next(data_loader_0) |
|
x_sums.append(batch.x.sum()) |
|
y_sums.append(batch.y.sum()) |
|
|
|
batch = next(data_loader_1) |
|
x_sums.append(batch.x.sum()) |
|
y_sums.append(batch.y.sum()) |
|
|
|
|
|
expected_x_sums = [ |
|
y for x in zip(EXPECTED_X[1][1][0], EXPECTED_X[1][1][1]) for y in x |
|
] |
|
expected_y_sums = [ |
|
y for x in zip(EXPECTED_Y[1][1][0], EXPECTED_Y[1][1][1]) for y in x |
|
] |
|
|
|
assert x_sums == expected_x_sums, x_sums |
|
assert y_sums == expected_y_sums, y_sums |
|
|
|
|
|
def test_dynamic_fn_call_chunk(): |
|
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_short.jsonl") |
|
|
|
non_chunked_samples = [] |
|
with jsonl_file.open() as file_handle: |
|
for line in file_handle: |
|
non_chunked_samples.append(build_instruct_sample(json.loads(line))) |
|
|
|
num_expected_chunks = 0 |
|
for sample in non_chunked_samples: |
|
if sample.only_last: |
|
num_expected_chunks += ( |
|
sum(isinstance(m, FinetuningAssistantMessage) for m in sample.messages) |
|
- 1 |
|
) |
|
|
|
chunked_samples = [] |
|
with jsonl_file.open() as file_handle: |
|
lines = file_handle.readlines() |
|
extra_lines = maybe_chunk_lines(lines) |
|
|
|
for line in extra_lines: |
|
chunked_samples.append(build_instruct_sample(json.loads(line))) |
|
|
|
assert num_expected_chunks == len(chunked_samples) |
|
|
|
|
|
def test_dynamic_fn_call_chunk_integration(): |
|
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_multi.jsonl") |
|
|
|
multi_samples = [] |
|
with jsonl_file.open() as file_handle: |
|
for line in file_handle: |
|
multi_samples.append(build_instruct_sample(json.loads(line))) |
|
|
|
jsonl_file = Path("tests/fixtures/sample_instruct_fn_call_single.jsonl") |
|
|
|
chunked_samples = [] |
|
with jsonl_file.open() as file_handle: |
|
for line in file_handle: |
|
chunked_samples.append(build_instruct_sample(json.loads(line))) |
|
|
|
with jsonl_file.open() as file_handle: |
|
lines = file_handle.readlines() |
|
extra_lines = maybe_chunk_lines(lines) |
|
|
|
for line in extra_lines: |
|
chunked_samples.append(build_instruct_sample(json.loads(line))) |
|
|
|
assert list(reversed(multi_samples)) == chunked_samples |
|
|
|
|
|
def test_fn_call(): |
|
batch_size = 1 |
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="", |
|
eval_instruct_data="tests/fixtures/sample_instruct_fn_call.jsonl", |
|
) |
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
seq_len = 10000 |
|
|
|
data_loader = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=None, |
|
rank=0, |
|
world_size=1, |
|
is_eval=True, |
|
) |
|
|
|
all_loss_strings = [] |
|
for batch in data_loader: |
|
y_mask = ( |
|
np.asarray(batch.y_mask, int) |
|
if batch.y_mask is not None |
|
else np.ones_like(batch.x) |
|
) |
|
|
|
start_index = end_index = 0 |
|
for size in batch.sizes: |
|
end_index += size |
|
tokens = batch.y[start_index:end_index] |
|
mask = y_mask[start_index:end_index] |
|
|
|
tokens_for_loss = [int(y) for i, y in enumerate(tokens) if mask[i] == 1] |
|
start_index += size |
|
|
|
decoded = instruct_tokenizer.tokenizer.decode(tokens_for_loss) |
|
if len(decoded) > 0: |
|
all_loss_strings.append(decoded) |
|
|
|
|
|
expected_loss_strings = [] |
|
with open(data_args.eval_instruct_data, "r") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
last_message = data["interactions"][-1] |
|
if "content" in last_message: |
|
expected_loss_strings.append(last_message["content"]) |
|
elif "tool_calls" in last_message: |
|
tool_calls = last_message["tool_calls"] |
|
arguments = tool_calls[0]["function"]["arguments"] |
|
string = [ |
|
{ |
|
"name": call["function"]["name"], |
|
"arguments": json.loads(arguments), |
|
} |
|
for call in tool_calls |
|
] |
|
expected_loss_strings.append(json.dumps(string)) |
|
|
|
assert expected_loss_strings == all_loss_strings |
|
|
|
|
|
def test_data_weighting(): |
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="", |
|
eval_instruct_data="tests/fixtures/sample_instruct.jsonl", |
|
) |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
jsonl_file = Path(data_args.eval_instruct_data) |
|
|
|
with jsonl_file.open() as file_handle: |
|
data = json.loads(next(file_handle)) |
|
|
|
token_sample = encode(data, instruct_tokenizer, SampleType.INSTRUCT) |
|
|
|
data_weight_0 = copy.deepcopy(data) |
|
data_weight_0["interactions"][-1]["weight"] = 0 |
|
|
|
token_sample_weight_0 = encode( |
|
data_weight_0, instruct_tokenizer, SampleType.INSTRUCT |
|
) |
|
|
|
data_weight_1 = copy.deepcopy(data) |
|
data_weight_1["interactions"][-1]["weight"] = 1 |
|
|
|
token_sample_weight_1 = encode( |
|
data_weight_1, instruct_tokenizer, SampleType.INSTRUCT |
|
) |
|
|
|
assert ( |
|
token_sample.tokens |
|
== token_sample_weight_0.tokens |
|
== token_sample_weight_1.tokens |
|
) |
|
assert token_sample.masks == token_sample_weight_1.masks |
|
assert token_sample.masks != token_sample_weight_0.masks |
|
assert not any(token_sample_weight_0.masks) |
|
|
|
|
|
def test_eval_dataloader(): |
|
batch_size = 1 |
|
|
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="", |
|
eval_instruct_data="tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl", |
|
) |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
|
|
for world_size in [1, 2, 8]: |
|
for seq_len in [10, 100, 1000, 10000]: |
|
x_sums = [] |
|
y_sums = [] |
|
y_masks = [] |
|
|
|
data_loaders = [] |
|
for rank in range(world_size): |
|
data_loaders.append( |
|
build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=None, |
|
rank=rank, |
|
world_size=world_size, |
|
is_eval=True, |
|
) |
|
) |
|
|
|
for data_loader in data_loaders: |
|
for batch in data_loader: |
|
mask = ( |
|
np.asarray(batch.y_mask, int) |
|
if batch.y_mask is not None |
|
else np.ones_like(batch.x) |
|
) |
|
x_sums.append((batch.x * mask).sum()) |
|
y_sums.append((batch.y * mask).sum()) |
|
y_masks.append(mask.sum()) |
|
|
|
assert len(batch.x) == len(mask) == len(batch.y) == seq_len |
|
|
|
assert sum(x_sums) == 71404835 |
|
assert sum(y_sums) == 71404795 |
|
assert sum(y_masks) == 5538 |
|
|
|
|
|
def test_shuffle_data(): |
|
instruct_tokenizer = MockInstructTokenizerBaseBase() |
|
|
|
data_args = DataArgs(data="", instruct_data="", eval_instruct_data="") |
|
|
|
data_file = Path("tests/fixtures/sample_instruct_long_1.jsonl") |
|
|
|
dataset_iterator = get_dataset_iterator( |
|
source=DataFile(path=data_file, sample_type=SampleType.INSTRUCT), |
|
instruct_args=data_args.instruct, |
|
instruct_tokenizer=instruct_tokenizer, |
|
rank=0, |
|
world_size=1, |
|
is_finite=False, |
|
seed=0, |
|
shuffle_at_epoch=True, |
|
) |
|
|
|
with data_file.open() as f: |
|
lines = f.readlines() |
|
lines = [ |
|
encode( |
|
json.loads(line), |
|
instruct_tokenizer=instruct_tokenizer, |
|
as_type=SampleType.INSTRUCT, |
|
) |
|
for line in lines |
|
] |
|
prev_lines = stringify(lines) |
|
|
|
num_lines = len(prev_lines) |
|
|
|
samples = [] |
|
|
|
for i in range(4 * num_lines): |
|
samples.append(next(dataset_iterator)) |
|
|
|
if (i + 1) % num_lines == 0: |
|
|
|
|
|
lines = stringify(samples) |
|
assert lines != prev_lines, "No shuffling - make sure dataset is shuffled!" |
|
assert sorted(lines) == sorted( |
|
prev_lines |
|
), "datasets need to match at every epoch" |
|
|
|
prev_lines = lines |
|
samples = [] |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2]) |
|
def test_shuffle_data_same_as_no_shuffle(world_size): |
|
spawn_for_all_world_sizes( |
|
_check_shuffle_data_same_as_no_shuffle, |
|
world_sizes=[world_size], |
|
deterministic=True, |
|
) |
|
|
|
|
|
def _check_shuffle_data_same_as_no_shuffle( |
|
rank: int, |
|
world_size: int, |
|
filename: str, |
|
filename_rpc: str, |
|
): |
|
instruct_tokenizer = MockInstructTokenizerBaseBase() |
|
|
|
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) |
|
|
|
data_args = DataArgs( |
|
data="tests/fixtures/sample_pretrain_1.jsonl:1.0,tests/fixtures/sample_pretrain_2.jsonl:1.0", |
|
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", |
|
instruct=instruct, |
|
) |
|
|
|
sources, _ = parse_data_sources( |
|
pretrain_data=data_args.data, instruct_data=data_args.instruct_data |
|
) |
|
|
|
seed = 0 |
|
rng = get_rng(seed, rank) |
|
|
|
for source in sources: |
|
jsonl_files = source.jsonl_files |
|
chunk_dataset = source.sample_type == SampleType.INSTRUCT |
|
|
|
for jsonl_file in jsonl_files: |
|
samples = [[], []] |
|
for shuffle in [True, False]: |
|
print(jsonl_file) |
|
if shuffle: |
|
iterator = preload_and_yield( |
|
jsonl_file, |
|
chunk_dataset=chunk_dataset, |
|
rank=rank, |
|
world_size=world_size, |
|
rng=rng, |
|
instruct_tokenizer=instruct_tokenizer, |
|
sample_type=source.sample_type, |
|
) |
|
else: |
|
iterator = lazy_load_and_yield( |
|
jsonl_file, |
|
rank=rank, |
|
world_size=world_size, |
|
instruct_tokenizer=instruct_tokenizer, |
|
sample_type=source.sample_type, |
|
) |
|
|
|
for tokens in iterator: |
|
samples[shuffle].append(tokens) |
|
|
|
strings_0 = sorted(stringify(samples[0])) |
|
strings_1 = sorted(stringify(samples[1])) |
|
|
|
assert strings_0 == strings_1 |
|
|