|
import os |
|
import tempfile |
|
from datetime import timedelta |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
|
|
from finetune.data.args import DataArgs, InstructArgs |
|
from finetune.data.data_loader import build_data_loader |
|
from finetune.distributed import get_rank, get_world_size |
|
from finetune.utils import set_random_seed |
|
|
|
|
|
def is_float_equal(a, b, precision=5e-3): |
|
return abs(a - b) < precision |
|
|
|
|
|
MODEL_PATH = os.getenv("DUMMY_MODEL") |
|
assert MODEL_PATH != "", "Provide a path to a dummy model" |
|
DATA_PATH = "tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1" |
|
EVAL_DATA_PATH = "tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl" |
|
|
|
|
|
|
|
_MODEL_PARALLEL_GROUP = None |
|
|
|
_DATA_PARALLEL_GROUP = None |
|
|
|
_PIPELINE_PARALLEL_GROUP = None |
|
|
|
_PIPELINE_PARALLEL_RANKS = None |
|
|
|
|
|
def rmf(filename: str) -> None: |
|
"""Remove a file like rm -f.""" |
|
try: |
|
os.remove(filename) |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
def test_runner( |
|
rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any] |
|
) -> None: |
|
|
|
if deterministic: |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
torch.manual_seed(1357) |
|
|
|
test_func(rank, *args, **kwargs) |
|
|
|
|
|
def spawn_for_all_world_sizes( |
|
test_func: Callable, world_sizes: List[int] = [], args: Any = [], deterministic: bool = False |
|
) -> None: |
|
for world_size in world_sizes: |
|
_, filename = tempfile.mkstemp() |
|
_, filename_rpc = tempfile.mkstemp() |
|
|
|
try: |
|
torch.multiprocessing.spawn( |
|
test_runner, |
|
args=(test_func, deterministic, world_size, filename, filename_rpc, *args), |
|
nprocs=world_size, |
|
join=True, |
|
) |
|
finally: |
|
rmf(filename) |
|
rmf(filename_rpc) |
|
|
|
def initialize_model_parallel( |
|
model_parallel_size_: int, |
|
pipeline_length: int = 1, |
|
*, |
|
model_parallel_backend: Optional[str] = None, |
|
pipeline_backend: Optional[str] = None, |
|
ddp_backend: Optional[str] = None |
|
) -> None: |
|
""" |
|
Initialize model data parallel groups. |
|
|
|
Arguments: |
|
model_parallel_size: number of GPUs used to parallelize model. |
|
|
|
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we |
|
use 2 GPUs to parallelize the model. The present function will |
|
create 4 model parallel groups and 2 data parallel groups as: |
|
4 model parallel groups: |
|
[g0, g1], [g2, g3], [g4, g5], [g6, g7] |
|
2 data parallel groups: |
|
[g0, g2, g4, g6], [g1, g3, g5, g7] |
|
Note that for efficiency, the caller should make sure adjacent ranks |
|
are on the same DGX box. For example if we are using 2 DGX-1 boxes |
|
with a total of 16 GPUs, rank 0 to 7 belong to the first box and |
|
ranks 8 to 15 belong to the second box. |
|
""" |
|
|
|
assert torch.distributed.is_initialized() |
|
world_size = torch.distributed.get_world_size() |
|
model_parallel_size = int(min(model_parallel_size_, world_size)) |
|
rank = torch.distributed.get_rank() |
|
|
|
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) |
|
|
|
if torch.distributed.get_rank() == 0: |
|
print("> initializing model parallel with size {}".format(model_parallel_size_)) |
|
print("> initializing ddp with size {}".format(data_parallel_size)) |
|
print("> initializing pipeline with size {}".format(pipeline_length)) |
|
|
|
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) |
|
|
|
found = torch.where(groups == rank) |
|
assert all(len(x) == 1 for x in found) |
|
found = [x[0] for x in found] |
|
|
|
|
|
global _DATA_PARALLEL_GROUP |
|
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" |
|
for j in range(pipeline_length): |
|
for k in range(model_parallel_size): |
|
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) |
|
if j == found[1] and k == found[2]: |
|
_DATA_PARALLEL_GROUP = group |
|
|
|
|
|
global _MODEL_PARALLEL_GROUP |
|
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" |
|
for i in range(data_parallel_size): |
|
for j in range(pipeline_length): |
|
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) |
|
if i == found[0] and j == found[1]: |
|
_MODEL_PARALLEL_GROUP = group |
|
|
|
global _PIPELINE_PARALLEL_GROUP |
|
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" |
|
global _PIPELINE_PARALLEL_RANKS |
|
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized" |
|
for i in range(data_parallel_size): |
|
for k in range(model_parallel_size): |
|
ranks = groups[i, :, k].tolist() |
|
group = torch.distributed.new_group(ranks, backend=pipeline_backend) |
|
if i == found[0] and k == found[2]: |
|
_PIPELINE_PARALLEL_GROUP = group |
|
_PIPELINE_PARALLEL_RANKS = ranks |
|
|
|
|
|
def setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0): |
|
dist_init_for_testing(rank, world_size, filename) |
|
torch.cuda.set_device(rank) |
|
|
|
|
|
backend = "nccl" |
|
initialize_model_parallel( |
|
model_parallel, |
|
model_parallel_backend=backend, |
|
pipeline_backend=backend, |
|
ddp_backend=backend, |
|
) |
|
|
|
set_random_seed(seed) |
|
|
|
if torch.cuda.is_available(): |
|
torch.set_default_tensor_type(torch.cuda.FloatTensor) |
|
|
|
|
|
def dist_init_for_testing( |
|
rank: int, world_size: int, filename: str, filename_rpc: str = "", timeout: int = 30 |
|
): |
|
""" |
|
Same than fairscale testing.dist_init but without rpc |
|
|
|
filename_rpc is here to keep same signature than fairscale init |
|
""" |
|
|
|
os.environ["WORLD_SIZE"] = str(world_size) |
|
os.environ["RANK"] = str(rank) |
|
url = "file://" + filename |
|
|
|
backend = "nccl" if torch.cuda.is_available() else "gloo" |
|
|
|
if backend == "nccl" and torch.cuda.device_count() < world_size: |
|
raise RuntimeError( |
|
f"Requested world size {world_size} cannot be reached on this machine, not enough GPUs {torch.cuda.device_count()}" |
|
) |
|
|
|
dist.init_process_group( |
|
backend=backend, |
|
rank=rank, |
|
world_size=world_size, |
|
init_method=url, |
|
timeout=timedelta(seconds=timeout), |
|
) |
|
|
|
|
|
def get_dataloader( |
|
seed: int = 0, |
|
seq_len: int = 10000, |
|
rank: Optional[int] = None, |
|
world_size: Optional[int] = None, |
|
): |
|
batch_size = 1 |
|
rank = rank if rank is not None else get_rank() |
|
world_size = world_size if world_size is not None else get_world_size() |
|
|
|
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer |
|
|
|
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) |
|
|
|
data_args = DataArgs( |
|
data="", |
|
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", |
|
instruct=instruct, |
|
) |
|
data_loader = build_data_loader( |
|
instruct_tokenizer, |
|
data_args, |
|
batch_size, |
|
seq_len, |
|
seed=seed, |
|
rank=rank, |
|
world_size=world_size, |
|
is_eval=False, |
|
) |
|
return data_loader |
|
|