CHEMISTral7Bv0.3 / tests /test_utils.py
Clemspace's picture
Initial model upload
cb9e677
import os
import tempfile
from datetime import timedelta
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from finetune.data.args import DataArgs, InstructArgs
from finetune.data.data_loader import build_data_loader
from finetune.distributed import get_rank, get_world_size
from finetune.utils import set_random_seed
def is_float_equal(a, b, precision=5e-3):
return abs(a - b) < precision
MODEL_PATH = os.getenv("DUMMY_MODEL")
assert MODEL_PATH != "", "Provide a path to a dummy model"
DATA_PATH = "tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1"
EVAL_DATA_PATH = "tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl"
# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass
def test_runner(
rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any]
) -> None:
# At this point we're in a new process, torch options need to be set again
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(1357)
test_func(rank, *args, **kwargs)
def spawn_for_all_world_sizes(
test_func: Callable, world_sizes: List[int] = [], args: Any = [], deterministic: bool = False
) -> None:
for world_size in world_sizes:
_, filename = tempfile.mkstemp()
_, filename_rpc = tempfile.mkstemp()
try:
torch.multiprocessing.spawn(
test_runner,
args=(test_func, deterministic, world_size, filename, filename_rpc, *args),
nprocs=world_size,
join=True,
)
finally:
rmf(filename)
rmf(filename_rpc)
def initialize_model_parallel(
model_parallel_size_: int,
pipeline_length: int = 1,
*,
model_parallel_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None
) -> None:
"""
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel groups as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size_, world_size))
rank = torch.distributed.get_rank()
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
print("> initializing ddp with size {}".format(data_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length))
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found)
found = [x[0] for x in found]
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]:
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1]:
_MODEL_PARALLEL_GROUP = group
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0):
dist_init_for_testing(rank, world_size, filename)
torch.cuda.set_device(rank)
# Init NCCL
backend = "nccl"
initialize_model_parallel(
model_parallel,
model_parallel_backend=backend,
pipeline_backend=backend,
ddp_backend=backend,
)
set_random_seed(seed)
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor) # type: ignore
def dist_init_for_testing(
rank: int, world_size: int, filename: str, filename_rpc: str = "", timeout: int = 30
):
"""
Same than fairscale testing.dist_init but without rpc
filename_rpc is here to keep same signature than fairscale init
"""
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
url = "file://" + filename
backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size:
raise RuntimeError(
f"Requested world size {world_size} cannot be reached on this machine, not enough GPUs {torch.cuda.device_count()}"
)
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=url,
timeout=timedelta(seconds=timeout),
)
def get_dataloader(
seed: int = 0,
seq_len: int = 10000,
rank: Optional[int] = None,
world_size: Optional[int] = None,
):
batch_size = 1
rank = rank if rank is not None else get_rank()
world_size = world_size if world_size is not None else get_world_size()
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False)
data_args = DataArgs(
data="",
instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1",
instruct=instruct,
)
data_loader = build_data_loader(
instruct_tokenizer,
data_args,
batch_size,
seq_len,
seed=seed,
rank=rank,
world_size=world_size,
is_eval=False,
)
return data_loader